78 lines
2.2 KiB
Java
78 lines
2.2 KiB
Java
package de.unidue.ltl.escrito.examples.local.models;
|
|
|
|
import java.io.BufferedReader;
|
|
import java.io.FileReader;
|
|
import java.io.IOException;
|
|
import java.util.ArrayList;
|
|
|
|
public class TrainConfig {
|
|
|
|
/*
|
|
* each declared field must correspond exactly to one column in
|
|
* the CSV-file where the training configurations are specified otherwise
|
|
* reflection cannot be used as intended in TrainAndSaveModel.java
|
|
*/
|
|
|
|
String datasetName;
|
|
String prompt;
|
|
int nGrams;
|
|
int nReferences;
|
|
String kernelClass;
|
|
double C;
|
|
double E;
|
|
|
|
public TrainConfig(String datasetName, String prompt, int nGrams, int nReferences,
|
|
String kernelClass, double C, double E) {
|
|
this.datasetName = datasetName;
|
|
this.prompt = prompt;
|
|
this.nGrams = nGrams;
|
|
this.nReferences = nReferences;
|
|
this.kernelClass = kernelClass;
|
|
this.C = C;
|
|
this.E = E;
|
|
}
|
|
|
|
public String toString() {
|
|
return "[dataset: " + this.datasetName + "; prompt: " + this.prompt + "; nGrams: "
|
|
+ this.nGrams + "; nReferences: " + this.nReferences + "; kernelClass: "
|
|
+ this.kernelClass + "; C: " + C + "; E: " + E +"]";
|
|
}
|
|
|
|
|
|
public static TrainConfig[] createTrainConfigsFromCSV(String pathToConfigCSVFile) {
|
|
|
|
ArrayList<TrainConfig> trainConfigs = new ArrayList<>();
|
|
|
|
try(BufferedReader br = new BufferedReader(new FileReader(pathToConfigCSVFile))) {
|
|
// read first line, which is the header
|
|
String line = br.readLine();
|
|
|
|
while ((line = br.readLine()) != null) {
|
|
if (line.startsWith("#")) {
|
|
continue; // ignore lines starting with #
|
|
}
|
|
String[] configParts = line.split(",");
|
|
TrainConfig trainConfig = new TrainConfig(
|
|
configParts[0], // datasetName
|
|
configParts[1], // prompt
|
|
Integer.parseInt(configParts[2]), // nGrams
|
|
Integer.parseInt(configParts[3]), // nReferences
|
|
configParts[4], // kernelClass
|
|
Double.parseDouble(configParts[5]), // C
|
|
Double.parseDouble(configParts[6]) // E
|
|
);
|
|
trainConfigs.add(trainConfig);
|
|
}
|
|
|
|
} catch (IOException e) {
|
|
System.out.println("Error while reading " + pathToConfigCSVFile);
|
|
e.printStackTrace();
|
|
}
|
|
// return array containing TrainConfig objects
|
|
return trainConfigs.toArray(new TrainConfig[trainConfigs.size()]);
|
|
|
|
}
|
|
|
|
|
|
}
|