escrito-docker/local/models/TrainConfig.java

78 lines
2.2 KiB
Java
Raw Permalink Normal View History

2023-10-07 15:11:50 +02:00
package de.unidue.ltl.escrito.examples.local.models;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class TrainConfig {
/*
* each declared field must correspond exactly to one column in
* the CSV-file where the training configurations are specified otherwise
* reflection cannot be used as intended in TrainAndSaveModel.java
*/
String datasetName;
String prompt;
int nGrams;
int nReferences;
String kernelClass;
double C;
double E;
public TrainConfig(String datasetName, String prompt, int nGrams, int nReferences,
String kernelClass, double C, double E) {
this.datasetName = datasetName;
this.prompt = prompt;
this.nGrams = nGrams;
this.nReferences = nReferences;
this.kernelClass = kernelClass;
this.C = C;
this.E = E;
}
public String toString() {
return "[dataset: " + this.datasetName + "; prompt: " + this.prompt + "; nGrams: "
+ this.nGrams + "; nReferences: " + this.nReferences + "; kernelClass: "
+ this.kernelClass + "; C: " + C + "; E: " + E +"]";
}
public static TrainConfig[] createTrainConfigsFromCSV(String pathToConfigCSVFile) {
ArrayList<TrainConfig> trainConfigs = new ArrayList<>();
try(BufferedReader br = new BufferedReader(new FileReader(pathToConfigCSVFile))) {
// read first line, which is the header
String line = br.readLine();
while ((line = br.readLine()) != null) {
if (line.startsWith("#")) {
continue; // ignore lines starting with #
}
String[] configParts = line.split(",");
TrainConfig trainConfig = new TrainConfig(
configParts[0], // datasetName
configParts[1], // prompt
Integer.parseInt(configParts[2]), // nGrams
Integer.parseInt(configParts[3]), // nReferences
configParts[4], // kernelClass
Double.parseDouble(configParts[5]), // C
Double.parseDouble(configParts[6]) // E
);
trainConfigs.add(trainConfig);
}
} catch (IOException e) {
System.out.println("Error while reading " + pathToConfigCSVFile);
e.printStackTrace();
}
// return array containing TrainConfig objects
return trainConfigs.toArray(new TrainConfig[trainConfigs.size()]);
}
}