168 lines
7.0 KiB
Java
168 lines
7.0 KiB
Java
package de.unidue.ltl.escrito.examples.local.models;
|
|
|
|
import java.io.File;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
|
|
import org.apache.uima.collection.CollectionReaderDescription;
|
|
import org.apache.uima.fit.factory.CollectionReaderFactory;
|
|
import org.apache.uima.resource.ResourceInitializationException;
|
|
import org.dkpro.lab.Lab;
|
|
import org.dkpro.lab.task.Dimension;
|
|
import org.dkpro.lab.task.ParameterSpace;
|
|
import org.dkpro.lab.task.BatchTask.ExecutionPolicy;
|
|
import org.dkpro.tc.api.features.TcFeatureFactory;
|
|
import org.dkpro.tc.api.features.TcFeatureSet;
|
|
import org.dkpro.tc.features.ngram.CharacterNGram;
|
|
import org.dkpro.tc.features.ngram.WordNGram;
|
|
import org.dkpro.tc.ml.experiment.ExperimentSaveModel;
|
|
import org.dkpro.tc.ml.weka.WekaAdapter;
|
|
|
|
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
|
|
import weka.classifiers.functions.SMO;
|
|
|
|
public class TrainAndSaveModel extends Experiments_ImplBase{
|
|
// some constants
|
|
public static final String TARGET_ANSWER_PREFIX = "TA";
|
|
public static final String CORPUS_NAME = "Mitose";
|
|
public static final String LANG_CODE = "de";
|
|
public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/";
|
|
|
|
private String pathToDatasets;
|
|
private String outputDir;
|
|
|
|
TrainAndSaveModel(String pathToDatasets, String outputDir) {
|
|
this.pathToDatasets = pathToDatasets;
|
|
this.outputDir = outputDir;
|
|
}
|
|
|
|
TrainAndSaveModel(String pathToDatasets) {
|
|
this.pathToDatasets = pathToDatasets;
|
|
this.outputDir = OUTPUT_DIR;
|
|
}
|
|
|
|
public void runModelTraining(TrainConfig trainConfig) {
|
|
File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName);
|
|
String[] kernelName = trainConfig.kernelClass.split("\\.");
|
|
String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class
|
|
String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C),
|
|
kernelClass, "E", Double.toString(trainConfig.E));
|
|
File modelFolder = new File(this.outputDir, experimentName);
|
|
try {
|
|
// create model folder if not existent
|
|
if(!modelFolder.exists()) {
|
|
modelFolder.mkdirs();
|
|
}
|
|
// pass full kernel class (including package prefix) to getParameterSpaceSigleLabel()
|
|
ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences,
|
|
CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E));
|
|
documentWriteModel(pspace, modelFolder, experimentName);
|
|
System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath());
|
|
} catch (ResourceInitializationException e) {
|
|
System.out.println("Error while building parameter spcae.");
|
|
e.printStackTrace();
|
|
} catch (Exception e) {
|
|
System.out.println("Error while training/saving model.");
|
|
e.printStackTrace();
|
|
}
|
|
|
|
}
|
|
|
|
// from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample
|
|
private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName)
|
|
throws Exception
|
|
{
|
|
ExperimentSaveModel batch;
|
|
batch = new ExperimentSaveModel();
|
|
batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE));
|
|
batch.setParameterSpace(paramSpace);
|
|
batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN);
|
|
batch.setExperimentName(experimentName);
|
|
batch.setOutputFolder(modelFolder);
|
|
Lab.getInstance().run(batch);
|
|
}
|
|
|
|
// from de.unidue.ltl.escrito.examples.local.gridsearch
|
|
public static ParameterSpace getParameterSpaceSingleLabel(
|
|
String dataPath,
|
|
String promptName,
|
|
int numTargetAnswers,
|
|
String corpusName,
|
|
int nGrams,
|
|
String kernelClass,
|
|
String C,
|
|
String E
|
|
)
|
|
throws ResourceInitializationException {
|
|
System.out.println("Starting to build parameter space ...");
|
|
Map<String, Object> dimReaders = new HashMap<String, Object>();
|
|
CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription(
|
|
Reader.class,
|
|
Reader.PARAM_INPUT_FILE, dataPath,
|
|
Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX,
|
|
Reader.PARAM_CORPUSNAME, corpusName
|
|
);
|
|
dimReaders.put(DIM_READER_TRAIN, readerTrain);
|
|
|
|
Dimension<String> learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL);
|
|
Map<String, Object> config = new HashMap<>();
|
|
// have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class
|
|
config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C,
|
|
"-K", kernelClass + " " + "-C -1 -E " + E});
|
|
config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass());
|
|
config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures());
|
|
Dimension<Map<String, Object>> learningsArgsDims = Dimension.createBundle("config", config);
|
|
|
|
// copied form BaseExperimentCV.java
|
|
Dimension<TcFeatureSet> dimFeatureSets = Dimension.create(
|
|
DIM_FEATURE_SET,
|
|
new TcFeatureSet(
|
|
/*
|
|
TcFeatureFactory.create(
|
|
PairwiseFeatureWrapper.class,
|
|
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(),
|
|
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
|
|
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
|
|
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
|
|
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
|
|
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
|
|
),
|
|
TcFeatureFactory.create(
|
|
PairwiseFeatureWrapper.class,
|
|
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(),
|
|
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2",
|
|
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5",
|
|
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
|
|
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
|
|
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
|
|
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
|
|
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
|
|
),
|
|
*/
|
|
TcFeatureFactory.create(
|
|
WordNGram.class,
|
|
WordNGram.PARAM_NGRAM_MIN_N, 1,
|
|
WordNGram.PARAM_NGRAM_MAX_N, 3,
|
|
WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams
|
|
),
|
|
TcFeatureFactory.create(
|
|
CharacterNGram.class,
|
|
CharacterNGram.PARAM_NGRAM_MIN_N, 2,
|
|
CharacterNGram.PARAM_NGRAM_MAX_N, 5,
|
|
CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams
|
|
)
|
|
)
|
|
);
|
|
// create parameter space
|
|
ParameterSpace pSpace = null;
|
|
pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders),
|
|
learningDims,
|
|
Dimension.create(DIM_FEATURE_MODE, FM_UNIT),
|
|
dimFeatureSets,
|
|
learningsArgsDims);
|
|
System.out.println("Finished building parameter space.");
|
|
return pSpace;
|
|
}
|
|
|
|
}
|