escrito-docker/local/models/TrainAndSaveModel.java

168 lines
7.0 KiB
Java
Raw Permalink Normal View History

2023-10-07 15:11:50 +02:00
package de.unidue.ltl.escrito.examples.local.models;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.lab.Lab;
import org.dkpro.lab.task.Dimension;
import org.dkpro.lab.task.ParameterSpace;
import org.dkpro.lab.task.BatchTask.ExecutionPolicy;
import org.dkpro.tc.api.features.TcFeatureFactory;
import org.dkpro.tc.api.features.TcFeatureSet;
import org.dkpro.tc.features.ngram.CharacterNGram;
import org.dkpro.tc.features.ngram.WordNGram;
import org.dkpro.tc.ml.experiment.ExperimentSaveModel;
import org.dkpro.tc.ml.weka.WekaAdapter;
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
import weka.classifiers.functions.SMO;
public class TrainAndSaveModel extends Experiments_ImplBase{
// some constants
public static final String TARGET_ANSWER_PREFIX = "TA";
public static final String CORPUS_NAME = "Mitose";
public static final String LANG_CODE = "de";
public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/";
private String pathToDatasets;
private String outputDir;
TrainAndSaveModel(String pathToDatasets, String outputDir) {
this.pathToDatasets = pathToDatasets;
this.outputDir = outputDir;
}
TrainAndSaveModel(String pathToDatasets) {
this.pathToDatasets = pathToDatasets;
this.outputDir = OUTPUT_DIR;
}
public void runModelTraining(TrainConfig trainConfig) {
File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName);
String[] kernelName = trainConfig.kernelClass.split("\\.");
String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class
String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C),
kernelClass, "E", Double.toString(trainConfig.E));
File modelFolder = new File(this.outputDir, experimentName);
try {
// create model folder if not existent
if(!modelFolder.exists()) {
modelFolder.mkdirs();
}
// pass full kernel class (including package prefix) to getParameterSpaceSigleLabel()
ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences,
CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E));
documentWriteModel(pspace, modelFolder, experimentName);
System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath());
} catch (ResourceInitializationException e) {
System.out.println("Error while building parameter spcae.");
e.printStackTrace();
} catch (Exception e) {
System.out.println("Error while training/saving model.");
e.printStackTrace();
}
}
// from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample
private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName)
throws Exception
{
ExperimentSaveModel batch;
batch = new ExperimentSaveModel();
batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE));
batch.setParameterSpace(paramSpace);
batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN);
batch.setExperimentName(experimentName);
batch.setOutputFolder(modelFolder);
Lab.getInstance().run(batch);
}
// from de.unidue.ltl.escrito.examples.local.gridsearch
public static ParameterSpace getParameterSpaceSingleLabel(
String dataPath,
String promptName,
int numTargetAnswers,
String corpusName,
int nGrams,
String kernelClass,
String C,
String E
)
throws ResourceInitializationException {
System.out.println("Starting to build parameter space ...");
Map<String, Object> dimReaders = new HashMap<String, Object>();
CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription(
Reader.class,
Reader.PARAM_INPUT_FILE, dataPath,
Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX,
Reader.PARAM_CORPUSNAME, corpusName
);
dimReaders.put(DIM_READER_TRAIN, readerTrain);
Dimension<String> learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL);
Map<String, Object> config = new HashMap<>();
// have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class
config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C,
"-K", kernelClass + " " + "-C -1 -E " + E});
config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass());
config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures());
Dimension<Map<String, Object>> learningsArgsDims = Dimension.createBundle("config", config);
// copied form BaseExperimentCV.java
Dimension<TcFeatureSet> dimFeatureSets = Dimension.create(
DIM_FEATURE_SET,
new TcFeatureSet(
/*
TcFeatureFactory.create(
PairwiseFeatureWrapper.class,
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(),
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
),
TcFeatureFactory.create(
PairwiseFeatureWrapper.class,
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(),
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2",
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5",
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
),
*/
TcFeatureFactory.create(
WordNGram.class,
WordNGram.PARAM_NGRAM_MIN_N, 1,
WordNGram.PARAM_NGRAM_MAX_N, 3,
WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams
),
TcFeatureFactory.create(
CharacterNGram.class,
CharacterNGram.PARAM_NGRAM_MIN_N, 2,
CharacterNGram.PARAM_NGRAM_MAX_N, 5,
CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams
)
)
);
// create parameter space
ParameterSpace pSpace = null;
pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders),
learningDims,
Dimension.create(DIM_FEATURE_MODE, FM_UNIT),
dimFeatureSets,
learningsArgsDims);
System.out.println("Finished building parameter space.");
return pSpace;
}
}