first commit
This commit is contained in:
@@ -0,0 +1,192 @@
|
||||
package de.unidue.ltl.escrito.examples.local.models;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URISyntaxException;
|
||||
import java.net.URL;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.uima.UimaContext;
|
||||
import org.apache.uima.collection.CollectionException;
|
||||
import org.apache.uima.collection.CollectionReaderDescription;
|
||||
import org.apache.uima.fit.component.JCasCollectionReader_ImplBase;
|
||||
import org.apache.uima.fit.descriptor.ConfigurationParameter;
|
||||
import org.apache.uima.fit.factory.CollectionReaderFactory;
|
||||
import org.apache.uima.fit.pipeline.JCasIterable;
|
||||
import org.apache.uima.jcas.JCas;
|
||||
import org.apache.uima.jcas.cas.StringArray;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.apache.uima.util.Progress;
|
||||
import org.apache.uima.util.ProgressImpl;
|
||||
import org.dkpro.tc.api.type.TextClassificationOutcome;
|
||||
import org.dkpro.tc.api.type.TextClassificationTarget;
|
||||
|
||||
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
|
||||
import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils;
|
||||
import de.unidue.ltl.escrito.core.types.LearnerAnswerWithReferenceAnswer;
|
||||
import de.unidue.ltl.escrito.io.generic.GenericDatasetItem;
|
||||
import de.unidue.ltl.escrito.io.generic.GenericDatasetReader;
|
||||
import de.unidue.ltl.escrito.io.util.Utils;
|
||||
|
||||
public class Reader extends JCasCollectionReader_ImplBase{
|
||||
|
||||
|
||||
public static final String PARAM_INPUT_FILE = "InputFile";
|
||||
@ConfigurationParameter(name = PARAM_INPUT_FILE, mandatory = true)
|
||||
protected String inputFileString;
|
||||
protected URL inputFileURL;
|
||||
|
||||
public static final String PARAM_TARGET_ANSWER_PREFIX = "TargetAnswerPrefix";
|
||||
@ConfigurationParameter(name = PARAM_TARGET_ANSWER_PREFIX, mandatory = false, defaultValue = "TA")
|
||||
private String targetAnswerPrefix;
|
||||
|
||||
public static final String PARAM_CORPUSNAME = "corpusName";
|
||||
@ConfigurationParameter(name = PARAM_CORPUSNAME, mandatory = true)
|
||||
protected String corpusName;
|
||||
|
||||
protected int currentIndex;
|
||||
|
||||
protected Queue<GenericDatasetItem> items;
|
||||
|
||||
private Map<String, String> questions;
|
||||
private Map<String, String> targetAnswers;
|
||||
|
||||
private Set<String> promptAnswerIds;
|
||||
@Override
|
||||
public void initialize(UimaContext aContext)
|
||||
throws ResourceInitializationException
|
||||
{
|
||||
items = new LinkedList<GenericDatasetItem>();
|
||||
questions = new HashMap<String, String>();
|
||||
targetAnswers = new HashMap<String, String>();
|
||||
promptAnswerIds = new HashSet<String>();
|
||||
try {
|
||||
inputFileURL = ResourceUtils.resolveLocation(inputFileString, this, aContext);
|
||||
BufferedReader reader = new BufferedReader(
|
||||
new InputStreamReader(
|
||||
inputFileURL.openStream(),
|
||||
"UTF-16"
|
||||
)
|
||||
);
|
||||
String nextLine;
|
||||
int lineCounter = 1;
|
||||
while ((nextLine = reader.readLine()) != null) {
|
||||
// System.out.println("line: "+nextLine);
|
||||
String[] nextItem = nextLine.split("\t");
|
||||
String promptId = null;
|
||||
String answerId = null;
|
||||
String text = null;
|
||||
String score = "-1";
|
||||
// System.out.println(nextItem.length);
|
||||
|
||||
if (nextItem.length>=4) {
|
||||
GenericDatasetItem newItem = null ;
|
||||
promptId = nextItem[0];
|
||||
answerId = nextItem[1];
|
||||
text = nextItem[2];
|
||||
score = nextItem[3];
|
||||
text = Utils.cleanString(text);
|
||||
int counter = 1;
|
||||
for (int i = 4; i< nextItem.length; i++){
|
||||
targetAnswers.put(promptId+"_"+counter, Utils.cleanString(nextItem[i]));
|
||||
counter++;
|
||||
}
|
||||
newItem = new GenericDatasetItem(promptId, answerId, text, score, promptId);
|
||||
items.add(newItem);
|
||||
}
|
||||
else {
|
||||
System.out.println("Could not read lineNumber: " + lineCounter + ", " + nextItem +" item length is: " +nextItem.length);
|
||||
}
|
||||
lineCounter++;
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
throw new ResourceInitializationException(e);
|
||||
}
|
||||
currentIndex = 0;
|
||||
if (!targetAnswers.isEmpty()){
|
||||
Utils.preprocessConnectedTexts(targetAnswers, corpusName, targetAnswerPrefix, "de");
|
||||
}
|
||||
System.out.println("read "+items.size()+" items.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext()
|
||||
throws IOException
|
||||
{
|
||||
return !items.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getNext(JCas jcas)
|
||||
throws IOException, CollectionException
|
||||
{
|
||||
GenericDatasetItem item = items.poll();
|
||||
getLogger().debug(item);
|
||||
String itemId = String.valueOf(item.getPromptId()+"_"+item.getAnswerId());
|
||||
try
|
||||
{
|
||||
jcas.setDocumentLanguage("de");
|
||||
jcas.setDocumentText(item.getText());
|
||||
DocumentMetaData dmd = DocumentMetaData.create(jcas);
|
||||
dmd.setDocumentId(itemId);
|
||||
dmd.setDocumentTitle(item.getText());
|
||||
dmd.setDocumentUri(inputFileURL.toURI().toString());
|
||||
dmd.setCollectionId(itemId);
|
||||
}
|
||||
|
||||
catch (URISyntaxException e) {
|
||||
throw new CollectionException(e);
|
||||
}
|
||||
|
||||
LearnerAnswerWithReferenceAnswer learnerAnswer = new LearnerAnswerWithReferenceAnswer(jcas, 0, jcas.getDocumentText().length());
|
||||
learnerAnswer.setPromptId(item.getPromptId());
|
||||
StringArray ids = new StringArray(jcas, targetAnswers.size());
|
||||
// We only have one exactly target answer per learner, so we use the same id as for the prompt
|
||||
int counter = 0;
|
||||
for (String taId : targetAnswers.keySet()) {
|
||||
ids.set(counter, String.valueOf(taId));
|
||||
counter++;
|
||||
}
|
||||
learnerAnswer.setReferenceAnswerIds(ids);
|
||||
learnerAnswer.addToIndexes();
|
||||
|
||||
TextClassificationTarget unit = new TextClassificationTarget(jcas, 0, jcas.getDocumentText().length());
|
||||
// will add the token content as a suffix to the ID of this unit
|
||||
unit.setSuffix(itemId);
|
||||
unit.addToIndexes();
|
||||
TextClassificationOutcome outcome = new TextClassificationOutcome(jcas, 0, jcas.getDocumentText().length());
|
||||
outcome.setOutcome(item.getGrade());
|
||||
outcome.addToIndexes();
|
||||
currentIndex++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Progress[] getProgress()
|
||||
{
|
||||
return new Progress[] { new ProgressImpl(currentIndex, currentIndex, Progress.ENTITIES) };
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws ResourceInitializationException {
|
||||
CollectionReaderDescription reader = CollectionReaderFactory.createReaderDescription(
|
||||
Reader.class,
|
||||
Reader.PARAM_INPUT_FILE, "/Users/andrea/dkpro/datasets/KatharinaFleig/Inter_g_Lernerantworten_Original_N=15_w4Refs.tsv",
|
||||
Reader.PARAM_TARGET_ANSWER_PREFIX, "TA",
|
||||
Reader.PARAM_CORPUSNAME, "Mitose"
|
||||
);
|
||||
int i=0;
|
||||
for (JCas jcas : new JCasIterable(reader)) {
|
||||
System.out.println(jcas.getDocumentText());
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package de.unidue.ltl.escrito.examples.local.models;
|
||||
|
||||
|
||||
public class RunModelTraining {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
String pathToTrainConfigCSVFile = "/home/felix/Documents/work/hiwi/iwm-tuebingen/escrito-stuff/model_training/config_files/config.csv";
|
||||
String pathToDatasets = "/home/felix/Documents/work/hiwi/iwm-tuebingen/escrito-stuff/dkpro_target/datasets/N=704_10Ref_ChatGPT_augmented";
|
||||
|
||||
TrainConfig[] trainConfigs = TrainConfig.createTrainConfigsFromCSV(pathToTrainConfigCSVFile);
|
||||
|
||||
/*
|
||||
for (TrainConfig trainCfg : trainConfigs) {
|
||||
System.out.println(trainCfg);
|
||||
}
|
||||
*/
|
||||
|
||||
TrainAndSaveModel modelTrainer = new TrainAndSaveModel(pathToDatasets);
|
||||
for (TrainConfig trainCfg : trainConfigs) {
|
||||
modelTrainer.runModelTraining(trainCfg);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package de.unidue.ltl.escrito.examples.local.models;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.uima.analysis_engine.AnalysisEngine;
|
||||
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
||||
import org.apache.uima.fit.factory.JCasFactory;
|
||||
import org.apache.uima.fit.util.JCasUtil;
|
||||
import org.apache.uima.jcas.JCas;
|
||||
import org.dkpro.tc.api.type.TextClassificationOutcome;
|
||||
import org.dkpro.tc.ml.model.PreTrainedModelProviderUnitMode;
|
||||
|
||||
import de.unidue.ltl.escrito.core.types.LearnerAnswer;
|
||||
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
|
||||
|
||||
public class StoredModelPredictor extends Experiments_ImplBase{
|
||||
public static final String LANG_CODE = "de";
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
//String experimentName = "Me_n-SMO-C-1.0-NormalizedPolyKernel-E-3.0";
|
||||
File modelOutputFolder = new File(TrainAndSaveModel.OUTPUT_DIR, args[0]);
|
||||
//String exampleAnswer = "Die Chromosomen, welche aus zwei Chromatiden bestehen, bewegen sich zum zentralen Äquator."; // GT = 1
|
||||
|
||||
//for debugging:
|
||||
System.out.println("Total number of arguments passed: " + args.length);
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
System.out.println("Argument " + i + ": " + args[i]);
|
||||
}
|
||||
|
||||
try {
|
||||
documentLoadModelSingleLabel(LANG_CODE, modelOutputFolder, args[1]);
|
||||
} catch (Exception e) {
|
||||
System.out.println("Exception while processing answer. Please verify the following:");
|
||||
System.out.println("--> a) The correct name of an existing directory in " + TrainAndSaveModel.OUTPUT_DIR
|
||||
+ " is passed as first argument.");
|
||||
System.out.println("--> b) A string representing the learner's answer to be classified is passed as second argument.");
|
||||
System.out.println("--> c) Both passed arguments are wrapped in double quotes: \"...\".");
|
||||
e.printStackTrace();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// from de.unidue.ltl.escrito.examples.io.StoredModelApplicationExample
|
||||
private static void documentLoadModelSingleLabel(String languageCode, File modelOutputFile, String exampleAnswer)
|
||||
throws Exception
|
||||
{
|
||||
|
||||
//System.out.println("Path to model: " + modelOutputFile.getAbsolutePath());
|
||||
AnalysisEngine preprocessing = AnalysisEngineFactory.createEngine(Experiments_ImplBase.getPreprocessing(languageCode));
|
||||
AnalysisEngine tcAnno = AnalysisEngineFactory.createEngine(PreTrainedModelProviderUnitMode.class,
|
||||
PreTrainedModelProviderUnitMode.PARAM_NAME_TARGET_ANNOTATION, LearnerAnswer.class,
|
||||
// Achtung: It seems like you MAY NOT use the class TextClassificationTarget (as we do in the reader)
|
||||
// to indicate the unit to be considered
|
||||
// as far as I can see, a TextClassifcationTarget is produced by the classifier and we only want to have one in the end!
|
||||
PreTrainedModelProviderUnitMode.PARAM_TC_MODEL_LOCATION, modelOutputFile.getAbsolutePath());
|
||||
|
||||
JCas jcas = JCasFactory.createJCas();
|
||||
jcas.setDocumentText(exampleAnswer);
|
||||
jcas.setDocumentLanguage(languageCode);
|
||||
|
||||
LearnerAnswer unit = new LearnerAnswer(jcas, 0, jcas.getDocumentText().length());
|
||||
unit.addToIndexes();
|
||||
|
||||
|
||||
// redo the preprocessing
|
||||
preprocessing.process(jcas);
|
||||
tcAnno.process(jcas);
|
||||
|
||||
// redo the processing done by the classifier
|
||||
|
||||
List<TextClassificationOutcome> outcomes = new ArrayList<>(
|
||||
JCasUtil.select(jcas, TextClassificationOutcome.class));
|
||||
//System.out.println(jcas.getDocumentText()+"\nOutcome: "+outcomes.get(0).getOutcome());
|
||||
System.out.println(outcomes.get(0).getOutcome()); // only print (binary) outcome
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package de.unidue.ltl.escrito.examples.local.models;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.uima.collection.CollectionReaderDescription;
|
||||
import org.apache.uima.fit.factory.CollectionReaderFactory;
|
||||
import org.apache.uima.resource.ResourceInitializationException;
|
||||
import org.dkpro.lab.Lab;
|
||||
import org.dkpro.lab.task.Dimension;
|
||||
import org.dkpro.lab.task.ParameterSpace;
|
||||
import org.dkpro.lab.task.BatchTask.ExecutionPolicy;
|
||||
import org.dkpro.tc.api.features.TcFeatureFactory;
|
||||
import org.dkpro.tc.api.features.TcFeatureSet;
|
||||
import org.dkpro.tc.features.ngram.CharacterNGram;
|
||||
import org.dkpro.tc.features.ngram.WordNGram;
|
||||
import org.dkpro.tc.ml.experiment.ExperimentSaveModel;
|
||||
import org.dkpro.tc.ml.weka.WekaAdapter;
|
||||
|
||||
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
|
||||
import weka.classifiers.functions.SMO;
|
||||
|
||||
public class TrainAndSaveModel extends Experiments_ImplBase{
|
||||
// some constants
|
||||
public static final String TARGET_ANSWER_PREFIX = "TA";
|
||||
public static final String CORPUS_NAME = "Mitose";
|
||||
public static final String LANG_CODE = "de";
|
||||
public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/";
|
||||
|
||||
private String pathToDatasets;
|
||||
private String outputDir;
|
||||
|
||||
TrainAndSaveModel(String pathToDatasets, String outputDir) {
|
||||
this.pathToDatasets = pathToDatasets;
|
||||
this.outputDir = outputDir;
|
||||
}
|
||||
|
||||
TrainAndSaveModel(String pathToDatasets) {
|
||||
this.pathToDatasets = pathToDatasets;
|
||||
this.outputDir = OUTPUT_DIR;
|
||||
}
|
||||
|
||||
public void runModelTraining(TrainConfig trainConfig) {
|
||||
File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName);
|
||||
String[] kernelName = trainConfig.kernelClass.split("\\.");
|
||||
String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class
|
||||
String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C),
|
||||
kernelClass, "E", Double.toString(trainConfig.E));
|
||||
File modelFolder = new File(this.outputDir, experimentName);
|
||||
try {
|
||||
// create model folder if not existent
|
||||
if(!modelFolder.exists()) {
|
||||
modelFolder.mkdirs();
|
||||
}
|
||||
// pass full kernel class (including package prefix) to getParameterSpaceSigleLabel()
|
||||
ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences,
|
||||
CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E));
|
||||
documentWriteModel(pspace, modelFolder, experimentName);
|
||||
System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath());
|
||||
} catch (ResourceInitializationException e) {
|
||||
System.out.println("Error while building parameter spcae.");
|
||||
e.printStackTrace();
|
||||
} catch (Exception e) {
|
||||
System.out.println("Error while training/saving model.");
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample
|
||||
private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName)
|
||||
throws Exception
|
||||
{
|
||||
ExperimentSaveModel batch;
|
||||
batch = new ExperimentSaveModel();
|
||||
batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE));
|
||||
batch.setParameterSpace(paramSpace);
|
||||
batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN);
|
||||
batch.setExperimentName(experimentName);
|
||||
batch.setOutputFolder(modelFolder);
|
||||
Lab.getInstance().run(batch);
|
||||
}
|
||||
|
||||
// from de.unidue.ltl.escrito.examples.local.gridsearch
|
||||
public static ParameterSpace getParameterSpaceSingleLabel(
|
||||
String dataPath,
|
||||
String promptName,
|
||||
int numTargetAnswers,
|
||||
String corpusName,
|
||||
int nGrams,
|
||||
String kernelClass,
|
||||
String C,
|
||||
String E
|
||||
)
|
||||
throws ResourceInitializationException {
|
||||
System.out.println("Starting to build parameter space ...");
|
||||
Map<String, Object> dimReaders = new HashMap<String, Object>();
|
||||
CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription(
|
||||
Reader.class,
|
||||
Reader.PARAM_INPUT_FILE, dataPath,
|
||||
Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX,
|
||||
Reader.PARAM_CORPUSNAME, corpusName
|
||||
);
|
||||
dimReaders.put(DIM_READER_TRAIN, readerTrain);
|
||||
|
||||
Dimension<String> learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL);
|
||||
Map<String, Object> config = new HashMap<>();
|
||||
// have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class
|
||||
config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C,
|
||||
"-K", kernelClass + " " + "-C -1 -E " + E});
|
||||
config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass());
|
||||
config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures());
|
||||
Dimension<Map<String, Object>> learningsArgsDims = Dimension.createBundle("config", config);
|
||||
|
||||
// copied form BaseExperimentCV.java
|
||||
Dimension<TcFeatureSet> dimFeatureSets = Dimension.create(
|
||||
DIM_FEATURE_SET,
|
||||
new TcFeatureSet(
|
||||
/*
|
||||
TcFeatureFactory.create(
|
||||
PairwiseFeatureWrapper.class,
|
||||
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(),
|
||||
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
|
||||
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
|
||||
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
|
||||
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
|
||||
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
|
||||
),
|
||||
TcFeatureFactory.create(
|
||||
PairwiseFeatureWrapper.class,
|
||||
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(),
|
||||
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2",
|
||||
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5",
|
||||
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
|
||||
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
|
||||
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
|
||||
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
|
||||
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
|
||||
),
|
||||
*/
|
||||
TcFeatureFactory.create(
|
||||
WordNGram.class,
|
||||
WordNGram.PARAM_NGRAM_MIN_N, 1,
|
||||
WordNGram.PARAM_NGRAM_MAX_N, 3,
|
||||
WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams
|
||||
),
|
||||
TcFeatureFactory.create(
|
||||
CharacterNGram.class,
|
||||
CharacterNGram.PARAM_NGRAM_MIN_N, 2,
|
||||
CharacterNGram.PARAM_NGRAM_MAX_N, 5,
|
||||
CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams
|
||||
)
|
||||
)
|
||||
);
|
||||
// create parameter space
|
||||
ParameterSpace pSpace = null;
|
||||
pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders),
|
||||
learningDims,
|
||||
Dimension.create(DIM_FEATURE_MODE, FM_UNIT),
|
||||
dimFeatureSets,
|
||||
learningsArgsDims);
|
||||
System.out.println("Finished building parameter space.");
|
||||
return pSpace;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package de.unidue.ltl.escrito.examples.local.models;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class TrainConfig {
|
||||
|
||||
/*
|
||||
* each declared field must correspond exactly to one column in
|
||||
* the CSV-file where the training configurations are specified otherwise
|
||||
* reflection cannot be used as intended in TrainAndSaveModel.java
|
||||
*/
|
||||
|
||||
String datasetName;
|
||||
String prompt;
|
||||
int nGrams;
|
||||
int nReferences;
|
||||
String kernelClass;
|
||||
double C;
|
||||
double E;
|
||||
|
||||
public TrainConfig(String datasetName, String prompt, int nGrams, int nReferences,
|
||||
String kernelClass, double C, double E) {
|
||||
this.datasetName = datasetName;
|
||||
this.prompt = prompt;
|
||||
this.nGrams = nGrams;
|
||||
this.nReferences = nReferences;
|
||||
this.kernelClass = kernelClass;
|
||||
this.C = C;
|
||||
this.E = E;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "[dataset: " + this.datasetName + "; prompt: " + this.prompt + "; nGrams: "
|
||||
+ this.nGrams + "; nReferences: " + this.nReferences + "; kernelClass: "
|
||||
+ this.kernelClass + "; C: " + C + "; E: " + E +"]";
|
||||
}
|
||||
|
||||
|
||||
public static TrainConfig[] createTrainConfigsFromCSV(String pathToConfigCSVFile) {
|
||||
|
||||
ArrayList<TrainConfig> trainConfigs = new ArrayList<>();
|
||||
|
||||
try(BufferedReader br = new BufferedReader(new FileReader(pathToConfigCSVFile))) {
|
||||
// read first line, which is the header
|
||||
String line = br.readLine();
|
||||
|
||||
while ((line = br.readLine()) != null) {
|
||||
if (line.startsWith("#")) {
|
||||
continue; // ignore lines starting with #
|
||||
}
|
||||
String[] configParts = line.split(",");
|
||||
TrainConfig trainConfig = new TrainConfig(
|
||||
configParts[0], // datasetName
|
||||
configParts[1], // prompt
|
||||
Integer.parseInt(configParts[2]), // nGrams
|
||||
Integer.parseInt(configParts[3]), // nReferences
|
||||
configParts[4], // kernelClass
|
||||
Double.parseDouble(configParts[5]), // C
|
||||
Double.parseDouble(configParts[6]) // E
|
||||
);
|
||||
trainConfigs.add(trainConfig);
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
System.out.println("Error while reading " + pathToConfigCSVFile);
|
||||
e.printStackTrace();
|
||||
}
|
||||
// return array containing TrainConfig objects
|
||||
return trainConfigs.toArray(new TrainConfig[trainConfigs.size()]);
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user