81 lines
3.4 KiB
Java
81 lines
3.4 KiB
Java
package de.unidue.ltl.escrito.examples.local.models;
|
|
|
|
import java.io.File;
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
|
|
import org.apache.uima.analysis_engine.AnalysisEngine;
|
|
import org.apache.uima.fit.factory.AnalysisEngineFactory;
|
|
import org.apache.uima.fit.factory.JCasFactory;
|
|
import org.apache.uima.fit.util.JCasUtil;
|
|
import org.apache.uima.jcas.JCas;
|
|
import org.dkpro.tc.api.type.TextClassificationOutcome;
|
|
import org.dkpro.tc.ml.model.PreTrainedModelProviderUnitMode;
|
|
|
|
import de.unidue.ltl.escrito.core.types.LearnerAnswer;
|
|
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
|
|
|
|
public class StoredModelPredictor extends Experiments_ImplBase{
|
|
public static final String LANG_CODE = "de";
|
|
|
|
public static void main(String[] args) {
|
|
|
|
//String experimentName = "Me_n-SMO-C-1.0-NormalizedPolyKernel-E-3.0";
|
|
File modelOutputFolder = new File(TrainAndSaveModel.OUTPUT_DIR, args[0]);
|
|
//String exampleAnswer = "Die Chromosomen, welche aus zwei Chromatiden bestehen, bewegen sich zum zentralen Äquator."; // GT = 1
|
|
|
|
//for debugging:
|
|
System.out.println("Total number of arguments passed: " + args.length);
|
|
for (int i = 0; i < args.length; i++) {
|
|
System.out.println("Argument " + i + ": " + args[i]);
|
|
}
|
|
|
|
try {
|
|
documentLoadModelSingleLabel(LANG_CODE, modelOutputFolder, args[1]);
|
|
} catch (Exception e) {
|
|
System.out.println("Exception while processing answer. Please verify the following:");
|
|
System.out.println("--> a) The correct name of an existing directory in " + TrainAndSaveModel.OUTPUT_DIR
|
|
+ " is passed as first argument.");
|
|
System.out.println("--> b) A string representing the learner's answer to be classified is passed as second argument.");
|
|
System.out.println("--> c) Both passed arguments are wrapped in double quotes: \"...\".");
|
|
e.printStackTrace();
|
|
|
|
}
|
|
}
|
|
|
|
// from de.unidue.ltl.escrito.examples.io.StoredModelApplicationExample
|
|
private static void documentLoadModelSingleLabel(String languageCode, File modelOutputFile, String exampleAnswer)
|
|
throws Exception
|
|
{
|
|
|
|
//System.out.println("Path to model: " + modelOutputFile.getAbsolutePath());
|
|
AnalysisEngine preprocessing = AnalysisEngineFactory.createEngine(Experiments_ImplBase.getPreprocessing(languageCode));
|
|
AnalysisEngine tcAnno = AnalysisEngineFactory.createEngine(PreTrainedModelProviderUnitMode.class,
|
|
PreTrainedModelProviderUnitMode.PARAM_NAME_TARGET_ANNOTATION, LearnerAnswer.class,
|
|
// Achtung: It seems like you MAY NOT use the class TextClassificationTarget (as we do in the reader)
|
|
// to indicate the unit to be considered
|
|
// as far as I can see, a TextClassifcationTarget is produced by the classifier and we only want to have one in the end!
|
|
PreTrainedModelProviderUnitMode.PARAM_TC_MODEL_LOCATION, modelOutputFile.getAbsolutePath());
|
|
|
|
JCas jcas = JCasFactory.createJCas();
|
|
jcas.setDocumentText(exampleAnswer);
|
|
jcas.setDocumentLanguage(languageCode);
|
|
|
|
LearnerAnswer unit = new LearnerAnswer(jcas, 0, jcas.getDocumentText().length());
|
|
unit.addToIndexes();
|
|
|
|
|
|
// redo the preprocessing
|
|
preprocessing.process(jcas);
|
|
tcAnno.process(jcas);
|
|
|
|
// redo the processing done by the classifier
|
|
|
|
List<TextClassificationOutcome> outcomes = new ArrayList<>(
|
|
JCasUtil.select(jcas, TextClassificationOutcome.class));
|
|
//System.out.println(jcas.getDocumentText()+"\nOutcome: "+outcomes.get(0).getOutcome());
|
|
System.out.println(outcomes.get(0).getOutcome()); // only print (binary) outcome
|
|
}
|
|
|
|
}
|