Basic Data Preprocessing

Data preprocessing is a huge part of any ML project, even a relatively simple one that I want to do in the future.

I feel like there are three main functions I want to be able to do in Java for data prep.

Handling missing values, converting categorical data to numerical data, and normalizing or standaradizing data.

Handling Missing Values

Remove: Simply remove rows with missing values. This is straightforward but can lead to loss of data. I probably don’t want to go this route.
Impute: Replace missing values with a specific value, like the mean, median, or mode of the column. Seems valid enough
Predict: Use algorithms or models to predict and fill missing values based on other columns. This just leads to complexity that we don’t have time for.

import java.util.ArrayList;
import java.util.Arrays;

public class DataPreprocessing {

    public static void main(String[] args) {
        ArrayList<ArrayList<Double>> data = new ArrayList<>(); // Making a sample 2d array
        data.add(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)));
        data.add(new ArrayList<>(Arrays.asList(4.0, null, 6.0)));
        data.add(new ArrayList<>(Arrays.asList(7.0, 8.0, 9.0)));
        data.add(new ArrayList<>(Arrays.asList(10.0, 11.0, null)));

        // Remove rows with null values or where the sum is null
        data.removeIf(row -> row.contains(null));

        System.out.println("Data after removing rows with missing values or null sum: " + data);
    }
}

DataPreprocessing.main(null);
Data after removing rows with missing values or null sum: [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]]
import java.util.ArrayList;
import java.util.Arrays;

public class DataPreprocessing {

    public static void main(String[] args) {
        ArrayList<ArrayList<Double>> data = new ArrayList<>();
        data.add(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)));
        data.add(new ArrayList<>(Arrays.asList(4.0, null, 6.0)));
        data.add(new ArrayList<>(Arrays.asList(7.0, 8.0, 9.0)));
        data.add(new ArrayList<>(Arrays.asList(10.0, 11.0, null)));

        for (ArrayList<Double> row : data) {
            imputeMissingWithRowMean(row);
        }

        System.out.println("Data after imputing missing values with row mean: " + data);
    }

    public static void imputeMissingWithRowMean(ArrayList<Double> row) {
        double sum = 0;
        int count = 0;
        for (Double value : row) {
            if (value != null) {
                sum += value;
                count++;
            }
        }
        double mean = sum / count;

        for (int i = 0; i < row.size(); i++) {
            if (row.get(i) == null) {
                row.set(i, mean);
            }
        }
    }
}

DataPreprocessing.main(null);
Data after imputing missing values with row mean: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 10.5]]

Categorical Data to Numerical

One-Hot Encoding: Convert each category value into a new column and assign a 1 or 0 (True/False) value. This is what I normally do.
Label Encoding: Assign each category a unique integer. For instance, ‘Red’ might be 1, ‘Blue’ 2, and ‘Green’ 3.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class OneHotEncoding {

    public static void main(String[] args) {
        ArrayList<ArrayList<String>> data = new ArrayList<>();
        data.add(new ArrayList<>(Arrays.asList("Red")));
        data.add(new ArrayList<>(Arrays.asList("Blue")));
        data.add(new ArrayList<>(Arrays.asList("Green")));
        data.add(new ArrayList<>(Arrays.asList("Red")));

        ArrayList<ArrayList<Integer>> encodedData = oneHotEncode(data);

        System.out.println("Data after One-Hot Encoding: " + encodedData);
    }

    public static ArrayList<ArrayList<Integer>> oneHotEncode(ArrayList<ArrayList<String>> data) {
        Map<String, Integer> categories = new HashMap<>();
        int index = 0;
        for (ArrayList<String> row : data) {
            for (String value : row) {
                if (!categories.containsKey(value)) {
                    categories.put(value, index++);
                }
            }
        }

        ArrayList<ArrayList<Integer>> encodedData = new ArrayList<>();
        for (ArrayList<String> row : data) {
            ArrayList<Integer> encodedRow = new ArrayList<>(categories.size());
            for (int i = 0; i < categories.size(); i++) {
                encodedRow.add(0);
            }
            for (String value : row) {
                int categoryIndex = categories.get(value);
                encodedRow.set(categoryIndex, 1);
            }
            encodedData.add(encodedRow);
        }

        return encodedData;
    }
}

OneHotEncoding.main(null);

Data after One-Hot Encoding: [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class LabelEncoding {

    public static void main(String[] args) {
        ArrayList<ArrayList<String>> data = new ArrayList<>();
        data.add(new ArrayList<>(Arrays.asList("Red")));
        data.add(new ArrayList<>(Arrays.asList("Blue")));
        data.add(new ArrayList<>(Arrays.asList("Green")));
        data.add(new ArrayList<>(Arrays.asList("Red")));

        ArrayList<ArrayList<Integer>> encodedData = labelEncode(data);

        System.out.println("Data after Label Encoding: " + encodedData);
    }

    public static ArrayList<ArrayList<Integer>> labelEncode(ArrayList<ArrayList<String>> data) {
        Map<String, Integer> categories = new HashMap<>();
        int index = 1;
        for (ArrayList<String> row : data) {
            for (String value : row) {
                if (!categories.containsKey(value)) {
                    categories.put(value, index++);
                }
            }
        }

        ArrayList<ArrayList<Integer>> encodedData = new ArrayList<>();
        for (ArrayList<String> row : data) {
            ArrayList<Integer> encodedRow = new ArrayList<>();
            for (String value : row) {
                encodedRow.add(categories.get(value));
            }
            encodedData.add(encodedRow);
        }

        return encodedData;
    }
}

LabelEncoding.main(null);

Data after Label Encoding: [[1], [2], [3], [1]]

Normalizing or Standaradizing Data

Havn’t normalized too much in the past as I just hoped the weights of the model would adjust for discrepancies.

Normalizing: This process scales the data between 0 and 1. The formula for normalization is: normalized value = (value − min)/(max − min)
Standardization: This process scales the data based on the mean (μ) and standard deviation (σ) so that the new data has a mean of 0 and a standard deviation of 1. The formula for standardization is: standardized value = (value - μ)/σ

import java.util.ArrayList;
import java.util.Arrays;

public class RowNormalization {

    public static void main(String[] args) {
        ArrayList<ArrayList<Double>> data = new ArrayList<>();
        data.add(new ArrayList<>(Arrays.asList(10.0, 20.0, 30.0)));
        data.add(new ArrayList<>(Arrays.asList(40.0, 50.0, 60.0)));
        data.add(new ArrayList<>(Arrays.asList(70.0, 80.0, 90.0)));

        ArrayList<ArrayList<Double>> normalizedData = normalizeByRow(data);

        System.out.println("Row Normalized Data: " + normalizedData);
    }

    public static ArrayList<ArrayList<Double>> normalizeByRow(ArrayList<ArrayList<Double>> data) {
        ArrayList<ArrayList<Double>> normalizedData = new ArrayList<>();

        for (ArrayList<Double> row : data) {
            double min = Double.MAX_VALUE;
            double max = Double.MIN_VALUE;

            // Find min and max values for the row
            for (Double value : row) {
                if (value < min) min = value;
                if (value > max) max = value;
            }

            ArrayList<Double> normalizedRow = new ArrayList<>();
            for (Double value : row) {
                double normalizedValue = (value - min) / (max - min);
                normalizedRow.add(normalizedValue);
            }
            normalizedData.add(normalizedRow);
        }

        return normalizedData;
    }
}

RowNormalization.main(null);
Row Normalized Data: [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]
import java.util.ArrayList;
import java.util.Arrays;

public class RowStandardization {

    public static void main(String[] args) {
        ArrayList<ArrayList<Double>> data = new ArrayList<>();
        data.add(new ArrayList<>(Arrays.asList(10.0, 20.0, 30.0)));
        data.add(new ArrayList<>(Arrays.asList(40.0, 50.0, 60.0)));
        data.add(new ArrayList<>(Arrays.asList(70.0, 80.0, 90.0)));

        ArrayList<ArrayList<Double>> standardizedData = standardizeByRow(data);

        System.out.println("Row Standardized Data: " + standardizedData);
    }

    public static ArrayList<ArrayList<Double>> standardizeByRow(ArrayList<ArrayList<Double>> data) {
        ArrayList<ArrayList<Double>> standardizedData = new ArrayList<>();

        for (ArrayList<Double> row : data) {
            double mean = 0;
            double sum = 0;
            double stdDev = 0;

            // Calculate mean for the row
            for (Double value : row) {
                sum += value;
            }
            mean = sum / row.size();

            // Calculate standard deviation for the row
            sum = 0;
            for (Double value : row) {
                sum += Math.pow(value - mean, 2);
            }
            stdDev = Math.sqrt(sum / row.size());

            ArrayList<Double> standardizedRow = new ArrayList<>();
            for (Double value : row) {
                double standardizedValue = (value - mean) / stdDev;
                standardizedRow.add(standardizedValue);
            }
            standardizedData.add(standardizedRow);
        }

        return standardizedData;
    }
}

RowStandardization.main(null);
Row Standardized Data: [[-1.224744871391589, 0.0, 1.224744871391589], [-1.224744871391589, 0.0, 1.224744871391589], [-1.224744871391589, 0.0, 1.224744871391589]]