Stratified train-test-split

code
julia
data cleaning
data engineering
Author

Dr Alun ap Rhisiart

Published

July 4, 2024

A stratified train-test-split function

I previously wrote some Julia code to do a train-test-split with the train dataset also split to train/validation. However, I wanted to improve on this by making a stratified split, which took account of the different sample sizes, giving the same proportions for each label.

The data was available in a parquet file with two columns, a ‘text’ column of strings (the corpus), and a ‘label’ column of integers (the labels).

```{julia}
using StatsBase, DataFramesMeta, Random
using Parquet2: writefile, Dataset

ds = Dataset(datadir("exp_pro", "captures_training.parquet"))
captures = DataFrame(ds)
```

Here is the function. As the comment says, splitting the training set into a train+validation set is optional. Note that in my case the labels start at zero (which means false positive). If your labels start at 1 you will need to change the loop.

```{julia}
"""
    stratified_train_test_split(df::DataFrame, split_col::Symbol, train_split::Float64[, valid_split::Float64])

Carry out a stratified train-valid-test split based on the split column name. The basic
split is train/test, according to the ratio `train_split``, which should be between 0 and 1.
 If `valid_split` is supplied, the training set is further split to training and validation 
 (this does not) affect the test split.

 Returns train_df, valid_df (may be empty), test_df.
"""
function stratified_train_test_split(df, split_col, train_split, valid_split)
    shuffled_df = DataFrame(shuffle(eachrow(df)))
    train = similar(df, 0)
    valid = similar(df, 0)
    test = similar(df, 0)

    for i = 0:maximum(shuffled_df[!, split_col])
        df = @rsubset(shuffled_df, $split_col == i) # select rows for each label in turn
        train_size = Int(round(nrow(df) * train_split))
        @info train_size
        mh_train = df[1:train_size, :]
        test_df = df[train_size+1:end, :]

        if isnothing(valid_split)
            train_df = mh_train
        else
            # split the training set into training/validation. Huggingface doesn't do cross-validation
            valid_size = Int(round(nrow(mh_train) * valid_split))
            valid_df = mh_train[valid_size+1:end, :]
            train_df = mh_train[1:valid_size, :]
        end

        train = vcat(train, train_df)
        test = vcat(test, test_df)
        valid = vcat(valid, valid_df)
    end
    train = DataFrame(shuffle(eachrow(train)))
    valid = DataFrame(shuffle(eachrow(valid)))
    test = DataFrame(shuffle(eachrow(test)))
    return train, valid, test
end
```

Now we can use the function to split the data into 80% training and 20% test rows (independently for each label), and further to split off 15% of the training rows into a validation set. Finally, I write each out to a Parquet file.

```{julia}
train, valid, test = stratified_train_test_split(captures, :label, 0.8, 0.85)

writefile(datadir("exp_pro", "train.parquet"), train)
writefile(datadir("exp_pro", "test.parquet"), test)
writefile(datadir("exp_pro", "valid.parquet"), valid)
```