Basic Data Preprocessing
Data preprocessing is a huge part of any ML project, even a relatively simple one that I want to do in the future.
I feel like there are three main functions I want to be able to do in Java for data prep.
Handling missing values, converting categorical data to numerical data, and normalizing or standaradizing data.
Handling Missing Values
Remove: Simply remove rows with missing values. This is straightforward but can lead to loss of data. I probably don’t want to go this route.
Impute: Replace missing values with a specific value, like the mean, median, or mode of the column. Seems valid enough
Predict: Use algorithms or models to predict and fill missing values based on other columns. This just leads to complexity that we don’t have time for.
import java.util.ArrayList;
import java.util.Arrays;
public class DataPreprocessing {
public static void main(String[] args) {
ArrayList<ArrayList<Double>> data = new ArrayList<>(); // Making a sample 2d array
data.add(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)));
data.add(new ArrayList<>(Arrays.asList(4.0, null, 6.0)));
data.add(new ArrayList<>(Arrays.asList(7.0, 8.0, 9.0)));
data.add(new ArrayList<>(Arrays.asList(10.0, 11.0, null)));
// Remove rows with null values or where the sum is null
data.removeIf(row -> row.contains(null));
System.out.println("Data after removing rows with missing values or null sum: " + data);
}
}
DataPreprocessing.main(null);
Data after removing rows with missing values or null sum: [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]]
import java.util.ArrayList;
import java.util.Arrays;
public class DataPreprocessing {
public static void main(String[] args) {
ArrayList<ArrayList<Double>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)));
data.add(new ArrayList<>(Arrays.asList(4.0, null, 6.0)));
data.add(new ArrayList<>(Arrays.asList(7.0, 8.0, 9.0)));
data.add(new ArrayList<>(Arrays.asList(10.0, 11.0, null)));
for (ArrayList<Double> row : data) {
imputeMissingWithRowMean(row);
}
System.out.println("Data after imputing missing values with row mean: " + data);
}
public static void imputeMissingWithRowMean(ArrayList<Double> row) {
double sum = 0;
int count = 0;
for (Double value : row) {
if (value != null) {
sum += value;
count++;
}
}
double mean = sum / count;
for (int i = 0; i < row.size(); i++) {
if (row.get(i) == null) {
row.set(i, mean);
}
}
}
}
DataPreprocessing.main(null);
Data after imputing missing values with row mean: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 10.5]]
Categorical Data to Numerical
One-Hot Encoding: Convert each category value into a new column and assign a 1 or 0 (True/False) value. This is what I normally do.
Label Encoding: Assign each category a unique integer. For instance, ‘Red’ might be 1, ‘Blue’ 2, and ‘Green’ 3.
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class OneHotEncoding {
public static void main(String[] args) {
ArrayList<ArrayList<String>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList("Red")));
data.add(new ArrayList<>(Arrays.asList("Blue")));
data.add(new ArrayList<>(Arrays.asList("Green")));
data.add(new ArrayList<>(Arrays.asList("Red")));
ArrayList<ArrayList<Integer>> encodedData = oneHotEncode(data);
System.out.println("Data after One-Hot Encoding: " + encodedData);
}
public static ArrayList<ArrayList<Integer>> oneHotEncode(ArrayList<ArrayList<String>> data) {
Map<String, Integer> categories = new HashMap<>();
int index = 0;
for (ArrayList<String> row : data) {
for (String value : row) {
if (!categories.containsKey(value)) {
categories.put(value, index++);
}
}
}
ArrayList<ArrayList<Integer>> encodedData = new ArrayList<>();
for (ArrayList<String> row : data) {
ArrayList<Integer> encodedRow = new ArrayList<>(categories.size());
for (int i = 0; i < categories.size(); i++) {
encodedRow.add(0);
}
for (String value : row) {
int categoryIndex = categories.get(value);
encodedRow.set(categoryIndex, 1);
}
encodedData.add(encodedRow);
}
return encodedData;
}
}
OneHotEncoding.main(null);
Data after One-Hot Encoding: [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class LabelEncoding {
public static void main(String[] args) {
ArrayList<ArrayList<String>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList("Red")));
data.add(new ArrayList<>(Arrays.asList("Blue")));
data.add(new ArrayList<>(Arrays.asList("Green")));
data.add(new ArrayList<>(Arrays.asList("Red")));
ArrayList<ArrayList<Integer>> encodedData = labelEncode(data);
System.out.println("Data after Label Encoding: " + encodedData);
}
public static ArrayList<ArrayList<Integer>> labelEncode(ArrayList<ArrayList<String>> data) {
Map<String, Integer> categories = new HashMap<>();
int index = 1;
for (ArrayList<String> row : data) {
for (String value : row) {
if (!categories.containsKey(value)) {
categories.put(value, index++);
}
}
}
ArrayList<ArrayList<Integer>> encodedData = new ArrayList<>();
for (ArrayList<String> row : data) {
ArrayList<Integer> encodedRow = new ArrayList<>();
for (String value : row) {
encodedRow.add(categories.get(value));
}
encodedData.add(encodedRow);
}
return encodedData;
}
}
LabelEncoding.main(null);
Data after Label Encoding: [[1], [2], [3], [1]]
Normalizing or Standaradizing Data
Havn’t normalized too much in the past as I just hoped the weights of the model would adjust for discrepancies.
Normalizing: This process scales the data between 0 and 1. The formula for normalization is: normalized value = (value − min)/(max − min)
Standardization: This process scales the data based on the mean (μ) and standard deviation (σ) so that the new data has a mean of 0 and a standard deviation of 1. The formula for standardization is: standardized value = (value - μ)/σ
import java.util.ArrayList;
import java.util.Arrays;
public class RowNormalization {
public static void main(String[] args) {
ArrayList<ArrayList<Double>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList(10.0, 20.0, 30.0)));
data.add(new ArrayList<>(Arrays.asList(40.0, 50.0, 60.0)));
data.add(new ArrayList<>(Arrays.asList(70.0, 80.0, 90.0)));
ArrayList<ArrayList<Double>> normalizedData = normalizeByRow(data);
System.out.println("Row Normalized Data: " + normalizedData);
}
public static ArrayList<ArrayList<Double>> normalizeByRow(ArrayList<ArrayList<Double>> data) {
ArrayList<ArrayList<Double>> normalizedData = new ArrayList<>();
for (ArrayList<Double> row : data) {
double min = Double.MAX_VALUE;
double max = Double.MIN_VALUE;
// Find min and max values for the row
for (Double value : row) {
if (value < min) min = value;
if (value > max) max = value;
}
ArrayList<Double> normalizedRow = new ArrayList<>();
for (Double value : row) {
double normalizedValue = (value - min) / (max - min);
normalizedRow.add(normalizedValue);
}
normalizedData.add(normalizedRow);
}
return normalizedData;
}
}
RowNormalization.main(null);
Row Normalized Data: [[0.0, 0.5, 1.0], [0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]
import java.util.ArrayList;
import java.util.Arrays;
public class RowStandardization {
public static void main(String[] args) {
ArrayList<ArrayList<Double>> data = new ArrayList<>();
data.add(new ArrayList<>(Arrays.asList(10.0, 20.0, 30.0)));
data.add(new ArrayList<>(Arrays.asList(40.0, 50.0, 60.0)));
data.add(new ArrayList<>(Arrays.asList(70.0, 80.0, 90.0)));
ArrayList<ArrayList<Double>> standardizedData = standardizeByRow(data);
System.out.println("Row Standardized Data: " + standardizedData);
}
public static ArrayList<ArrayList<Double>> standardizeByRow(ArrayList<ArrayList<Double>> data) {
ArrayList<ArrayList<Double>> standardizedData = new ArrayList<>();
for (ArrayList<Double> row : data) {
double mean = 0;
double sum = 0;
double stdDev = 0;
// Calculate mean for the row
for (Double value : row) {
sum += value;
}
mean = sum / row.size();
// Calculate standard deviation for the row
sum = 0;
for (Double value : row) {
sum += Math.pow(value - mean, 2);
}
stdDev = Math.sqrt(sum / row.size());
ArrayList<Double> standardizedRow = new ArrayList<>();
for (Double value : row) {
double standardizedValue = (value - mean) / stdDev;
standardizedRow.add(standardizedValue);
}
standardizedData.add(standardizedRow);
}
return standardizedData;
}
}
RowStandardization.main(null);
Row Standardized Data: [[-1.224744871391589, 0.0, 1.224744871391589], [-1.224744871391589, 0.0, 1.224744871391589], [-1.224744871391589, 0.0, 1.224744871391589]]