escrito-docker/local/models/StoredModelPredictor.java

81 lines
3.4 KiB
Java
Raw Permalink Normal View History

2023-10-07 15:11:50 +02:00
package de.unidue.ltl.escrito.examples.local.models;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.JCasFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.ml.model.PreTrainedModelProviderUnitMode;
import de.unidue.ltl.escrito.core.types.LearnerAnswer;
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
public class StoredModelPredictor extends Experiments_ImplBase{
public static final String LANG_CODE = "de";
public static void main(String[] args) {
//String experimentName = "Me_n-SMO-C-1.0-NormalizedPolyKernel-E-3.0";
File modelOutputFolder = new File(TrainAndSaveModel.OUTPUT_DIR, args[0]);
//String exampleAnswer = "Die Chromosomen, welche aus zwei Chromatiden bestehen, bewegen sich zum zentralen Äquator."; // GT = 1
//for debugging:
System.out.println("Total number of arguments passed: " + args.length);
for (int i = 0; i < args.length; i++) {
System.out.println("Argument " + i + ": " + args[i]);
}
try {
documentLoadModelSingleLabel(LANG_CODE, modelOutputFolder, args[1]);
} catch (Exception e) {
System.out.println("Exception while processing answer. Please verify the following:");
System.out.println("--> a) The correct name of an existing directory in " + TrainAndSaveModel.OUTPUT_DIR
+ " is passed as first argument.");
System.out.println("--> b) A string representing the learner's answer to be classified is passed as second argument.");
System.out.println("--> c) Both passed arguments are wrapped in double quotes: \"...\".");
e.printStackTrace();
}
}
// from de.unidue.ltl.escrito.examples.io.StoredModelApplicationExample
private static void documentLoadModelSingleLabel(String languageCode, File modelOutputFile, String exampleAnswer)
throws Exception
{
//System.out.println("Path to model: " + modelOutputFile.getAbsolutePath());
AnalysisEngine preprocessing = AnalysisEngineFactory.createEngine(Experiments_ImplBase.getPreprocessing(languageCode));
AnalysisEngine tcAnno = AnalysisEngineFactory.createEngine(PreTrainedModelProviderUnitMode.class,
PreTrainedModelProviderUnitMode.PARAM_NAME_TARGET_ANNOTATION, LearnerAnswer.class,
// Achtung: It seems like you MAY NOT use the class TextClassificationTarget (as we do in the reader)
// to indicate the unit to be considered
// as far as I can see, a TextClassifcationTarget is produced by the classifier and we only want to have one in the end!
PreTrainedModelProviderUnitMode.PARAM_TC_MODEL_LOCATION, modelOutputFile.getAbsolutePath());
JCas jcas = JCasFactory.createJCas();
jcas.setDocumentText(exampleAnswer);
jcas.setDocumentLanguage(languageCode);
LearnerAnswer unit = new LearnerAnswer(jcas, 0, jcas.getDocumentText().length());
unit.addToIndexes();
// redo the preprocessing
preprocessing.process(jcas);
tcAnno.process(jcas);
// redo the processing done by the classifier
List<TextClassificationOutcome> outcomes = new ArrayList<>(
JCasUtil.select(jcas, TextClassificationOutcome.class));
//System.out.println(jcas.getDocumentText()+"\nOutcome: "+outcomes.get(0).getOutcome());
System.out.println(outcomes.get(0).getOutcome()); // only print (binary) outcome
}
}