Confident Learning - Principled Data Cleaning

ML benchmark have a lot of errors in them, generally < 5%. And model rankings can change with just 6% increase in noise prevalence. A better model in benchmark may actually be overfitting to the noise.

Confident learning: field that finding systematic errors in data

  • Theoretically grounded: proves realistic sufficient conditions for exactly finding label errors
  • Model agnostic: works with any model/dataset
  • Based on insight:
    • Most of the label errors are in a few classes and are not completely random "A fox is more likely to be labeled as a dog than a cow."
    • Models are "heteroscedastic": they have different performances across different classes/data clusters.
Key Idea

When the predicted probability of an example is greater than a per-class-threshold, we confidently count that example as actually belonging to that threshold’s class. The thresholds for each class are the average predicted probability of examples in that class.

Assume that there exists a true correct label that conditions the distribution of noisy label:

$$ p\left(\tilde{y} \mid y^* ; \boldsymbol{x}\right)=p\left(\tilde{y} \mid y^*\right) $$

where,
$y^*$ - unobserved, latent, correct label
$\tilde{y}$ - observed, noisy label

The entire goal is to estimate to the joint $\hat{p}\left(\tilde{y}, y^*\right)$, for example:

$$ \begin{array}{r|c|c|c} \hat{p}\left(\tilde{y}, y^*\right) & y^*=\operatorname{dog} & y^*=\operatorname{fox} & y^*=\operatorname{cow} \\ \hline \tilde{y}=\operatorname{dog} & 0.25 & 0.1 & 0.05 \\ \hline \tilde{y}=\operatorname{fox} & 0.14 & 0.15 & 0 \\ \hline \tilde{y}=\operatorname{cow} & 0.08 & 0.03 & 0.2 \end{array} $$

Need only two things to detect erroneous labels:

  1. Out-of-sample predicted probabilities of training data i.e. from cross prediction
  2. The (noisy) labels of the training data

Steps:

First we find thresholds as a proxy for the machine's self-confidence, on average, for each task/class $j$ i.e. for each class find the the average of probabilities the model produces.

$$ t_j=\frac{1}{\left|\boldsymbol{X}_{\tilde{y}=j}\right|} \sum_{\boldsymbol{x} \in \boldsymbol{X}_{\tilde{y}=j}} \hat{p}(\tilde{y}=j ; \boldsymbol{x}, \boldsymbol{\theta}) $$

Then for each data point, check if the probability is above the class threshold for any other class than the assigned label. If so, it's probably an error.
$$ \hat{\boldsymbol{X}}_{\tilde{y}=i, y^*=j}= \left\{\boldsymbol{x} \in \boldsymbol{X}_{\tilde{y}=i}: \hat{p}(\tilde{y}=j ; \boldsymbol{x}, \boldsymbol{\theta}) \geq t_j\right\} $$

Thats it! Deceptively simple but crazy effective.

Some insights:

  • The joint is quite sparse because noise is mostly concentrated in some pairs.

References

  1. Helpful: Author's talk video and blog post.