package de.unidue.ltl.escrito.examples.local.models; import java.io.File; import java.util.HashMap; import java.util.Map; import org.apache.uima.collection.CollectionReaderDescription; import org.apache.uima.fit.factory.CollectionReaderFactory; import org.apache.uima.resource.ResourceInitializationException; import org.dkpro.lab.Lab; import org.dkpro.lab.task.Dimension; import org.dkpro.lab.task.ParameterSpace; import org.dkpro.lab.task.BatchTask.ExecutionPolicy; import org.dkpro.tc.api.features.TcFeatureFactory; import org.dkpro.tc.api.features.TcFeatureSet; import org.dkpro.tc.features.ngram.CharacterNGram; import org.dkpro.tc.features.ngram.WordNGram; import org.dkpro.tc.ml.experiment.ExperimentSaveModel; import org.dkpro.tc.ml.weka.WekaAdapter; import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase; import weka.classifiers.functions.SMO; public class TrainAndSaveModel extends Experiments_ImplBase{ // some constants public static final String TARGET_ANSWER_PREFIX = "TA"; public static final String CORPUS_NAME = "Mitose"; public static final String LANG_CODE = "de"; public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/"; private String pathToDatasets; private String outputDir; TrainAndSaveModel(String pathToDatasets, String outputDir) { this.pathToDatasets = pathToDatasets; this.outputDir = outputDir; } TrainAndSaveModel(String pathToDatasets) { this.pathToDatasets = pathToDatasets; this.outputDir = OUTPUT_DIR; } public void runModelTraining(TrainConfig trainConfig) { File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName); String[] kernelName = trainConfig.kernelClass.split("\\."); String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C), kernelClass, "E", Double.toString(trainConfig.E)); File modelFolder = new File(this.outputDir, experimentName); try { // create model folder if not existent if(!modelFolder.exists()) { modelFolder.mkdirs(); } // pass full kernel class (including package prefix) to getParameterSpaceSigleLabel() ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences, CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E)); documentWriteModel(pspace, modelFolder, experimentName); System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath()); } catch (ResourceInitializationException e) { System.out.println("Error while building parameter spcae."); e.printStackTrace(); } catch (Exception e) { System.out.println("Error while training/saving model."); e.printStackTrace(); } } // from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName) throws Exception { ExperimentSaveModel batch; batch = new ExperimentSaveModel(); batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE)); batch.setParameterSpace(paramSpace); batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN); batch.setExperimentName(experimentName); batch.setOutputFolder(modelFolder); Lab.getInstance().run(batch); } // from de.unidue.ltl.escrito.examples.local.gridsearch public static ParameterSpace getParameterSpaceSingleLabel( String dataPath, String promptName, int numTargetAnswers, String corpusName, int nGrams, String kernelClass, String C, String E ) throws ResourceInitializationException { System.out.println("Starting to build parameter space ..."); Map dimReaders = new HashMap(); CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription( Reader.class, Reader.PARAM_INPUT_FILE, dataPath, Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX, Reader.PARAM_CORPUSNAME, corpusName ); dimReaders.put(DIM_READER_TRAIN, readerTrain); Dimension learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL); Map config = new HashMap<>(); // have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C, "-K", kernelClass + " " + "-C -1 -E " + E}); config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass()); config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures()); Dimension> learningsArgsDims = Dimension.createBundle("config", config); // copied form BaseExperimentCV.java Dimension dimFeatureSets = Dimension.create( DIM_FEATURE_SET, new TcFeatureSet( /* TcFeatureFactory.create( PairwiseFeatureWrapper.class, PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(), PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA", PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName, PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers, PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES, PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName ), TcFeatureFactory.create( PairwiseFeatureWrapper.class, PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(), StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2", StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5", PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA", PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName, PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers, PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES, PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName ), */ TcFeatureFactory.create( WordNGram.class, WordNGram.PARAM_NGRAM_MIN_N, 1, WordNGram.PARAM_NGRAM_MAX_N, 3, WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams ), TcFeatureFactory.create( CharacterNGram.class, CharacterNGram.PARAM_NGRAM_MIN_N, 2, CharacterNGram.PARAM_NGRAM_MAX_N, 5, CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams ) ) ); // create parameter space ParameterSpace pSpace = null; pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders), learningDims, Dimension.create(DIM_FEATURE_MODE, FM_UNIT), dimFeatureSets, learningsArgsDims); System.out.println("Finished building parameter space."); return pSpace; } }