Stratified K-Fold cross-validation for imbalanced datasets

In Machine Learning, a common technique used to train a robust and generalizable model, is cross-validation. It divides the dataset into multiple subsets, where typically one of them is the validation set and the remainder consist the training set.

This process is repeated multiple times, with different combinations of subsets to ensure a more comprehensive evaluation of the model’s performance.

Between the different methods of cross-validation, Stratified K-Fold is an enhancement of the traditional K-Fold, able to handle imbalanced datasets. In short, Stratified K-Fold ensures that each fold preserves the same class distribution as the original dataset.

In standard K-Fold cross-validation, random sampling is used to split the subsets. Supposedly the original dataset is imbalanced, there is a great risk that the output subsets won’t retain the class distribution and certain classes will be significantly less represented than others. This problem is averted with the stratified sampling, with the aforementioned property of preserving the class distribution in each fold, reducing the risk of biased model evaluation.

Figure 1. K-Fold vs Stratified K-Fold cross-validation. Source: scikit-learn