first commit

This commit is contained in:
Felix S
2023-10-07 15:11:50 +02:00
commit cdffb21cd7
716 changed files with 1183 additions and 0 deletions
+192
View File
@@ -0,0 +1,192 @@
package de.unidue.ltl.escrito.examples.local.models;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.apache.uima.UimaContext;
import org.apache.uima.collection.CollectionException;
import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.component.JCasCollectionReader_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.pipeline.JCasIterable;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.StringArray;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.Progress;
import org.apache.uima.util.ProgressImpl;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.api.type.TextClassificationTarget;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils;
import de.unidue.ltl.escrito.core.types.LearnerAnswerWithReferenceAnswer;
import de.unidue.ltl.escrito.io.generic.GenericDatasetItem;
import de.unidue.ltl.escrito.io.generic.GenericDatasetReader;
import de.unidue.ltl.escrito.io.util.Utils;
public class Reader extends JCasCollectionReader_ImplBase{
public static final String PARAM_INPUT_FILE = "InputFile";
@ConfigurationParameter(name = PARAM_INPUT_FILE, mandatory = true)
protected String inputFileString;
protected URL inputFileURL;
public static final String PARAM_TARGET_ANSWER_PREFIX = "TargetAnswerPrefix";
@ConfigurationParameter(name = PARAM_TARGET_ANSWER_PREFIX, mandatory = false, defaultValue = "TA")
private String targetAnswerPrefix;
public static final String PARAM_CORPUSNAME = "corpusName";
@ConfigurationParameter(name = PARAM_CORPUSNAME, mandatory = true)
protected String corpusName;
protected int currentIndex;
protected Queue<GenericDatasetItem> items;
private Map<String, String> questions;
private Map<String, String> targetAnswers;
private Set<String> promptAnswerIds;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
items = new LinkedList<GenericDatasetItem>();
questions = new HashMap<String, String>();
targetAnswers = new HashMap<String, String>();
promptAnswerIds = new HashSet<String>();
try {
inputFileURL = ResourceUtils.resolveLocation(inputFileString, this, aContext);
BufferedReader reader = new BufferedReader(
new InputStreamReader(
inputFileURL.openStream(),
"UTF-16"
)
);
String nextLine;
int lineCounter = 1;
while ((nextLine = reader.readLine()) != null) {
// System.out.println("line: "+nextLine);
String[] nextItem = nextLine.split("\t");
String promptId = null;
String answerId = null;
String text = null;
String score = "-1";
// System.out.println(nextItem.length);
if (nextItem.length>=4) {
GenericDatasetItem newItem = null ;
promptId = nextItem[0];
answerId = nextItem[1];
text = nextItem[2];
score = nextItem[3];
text = Utils.cleanString(text);
int counter = 1;
for (int i = 4; i< nextItem.length; i++){
targetAnswers.put(promptId+"_"+counter, Utils.cleanString(nextItem[i]));
counter++;
}
newItem = new GenericDatasetItem(promptId, answerId, text, score, promptId);
items.add(newItem);
}
else {
System.out.println("Could not read lineNumber: " + lineCounter + ", " + nextItem +" item length is: " +nextItem.length);
}
lineCounter++;
}
}
catch (Exception e) {
e.printStackTrace();
throw new ResourceInitializationException(e);
}
currentIndex = 0;
if (!targetAnswers.isEmpty()){
Utils.preprocessConnectedTexts(targetAnswers, corpusName, targetAnswerPrefix, "de");
}
System.out.println("read "+items.size()+" items.");
}
@Override
public boolean hasNext()
throws IOException
{
return !items.isEmpty();
}
@Override
public void getNext(JCas jcas)
throws IOException, CollectionException
{
GenericDatasetItem item = items.poll();
getLogger().debug(item);
String itemId = String.valueOf(item.getPromptId()+"_"+item.getAnswerId());
try
{
jcas.setDocumentLanguage("de");
jcas.setDocumentText(item.getText());
DocumentMetaData dmd = DocumentMetaData.create(jcas);
dmd.setDocumentId(itemId);
dmd.setDocumentTitle(item.getText());
dmd.setDocumentUri(inputFileURL.toURI().toString());
dmd.setCollectionId(itemId);
}
catch (URISyntaxException e) {
throw new CollectionException(e);
}
LearnerAnswerWithReferenceAnswer learnerAnswer = new LearnerAnswerWithReferenceAnswer(jcas, 0, jcas.getDocumentText().length());
learnerAnswer.setPromptId(item.getPromptId());
StringArray ids = new StringArray(jcas, targetAnswers.size());
// We only have one exactly target answer per learner, so we use the same id as for the prompt
int counter = 0;
for (String taId : targetAnswers.keySet()) {
ids.set(counter, String.valueOf(taId));
counter++;
}
learnerAnswer.setReferenceAnswerIds(ids);
learnerAnswer.addToIndexes();
TextClassificationTarget unit = new TextClassificationTarget(jcas, 0, jcas.getDocumentText().length());
// will add the token content as a suffix to the ID of this unit
unit.setSuffix(itemId);
unit.addToIndexes();
TextClassificationOutcome outcome = new TextClassificationOutcome(jcas, 0, jcas.getDocumentText().length());
outcome.setOutcome(item.getGrade());
outcome.addToIndexes();
currentIndex++;
}
@Override
public Progress[] getProgress()
{
return new Progress[] { new ProgressImpl(currentIndex, currentIndex, Progress.ENTITIES) };
}
public static void main(String[] args) throws ResourceInitializationException {
CollectionReaderDescription reader = CollectionReaderFactory.createReaderDescription(
Reader.class,
Reader.PARAM_INPUT_FILE, "/Users/andrea/dkpro/datasets/KatharinaFleig/Inter_g_Lernerantworten_Original_N=15_w4Refs.tsv",
Reader.PARAM_TARGET_ANSWER_PREFIX, "TA",
Reader.PARAM_CORPUSNAME, "Mitose"
);
int i=0;
for (JCas jcas : new JCasIterable(reader)) {
System.out.println(jcas.getDocumentText());
i++;
}
}
}
+25
View File
@@ -0,0 +1,25 @@
package de.unidue.ltl.escrito.examples.local.models;
public class RunModelTraining {
public static void main(String[] args) {
String pathToTrainConfigCSVFile = "/home/felix/Documents/work/hiwi/iwm-tuebingen/escrito-stuff/model_training/config_files/config.csv";
String pathToDatasets = "/home/felix/Documents/work/hiwi/iwm-tuebingen/escrito-stuff/dkpro_target/datasets/N=704_10Ref_ChatGPT_augmented";
TrainConfig[] trainConfigs = TrainConfig.createTrainConfigsFromCSV(pathToTrainConfigCSVFile);
/*
for (TrainConfig trainCfg : trainConfigs) {
System.out.println(trainCfg);
}
*/
TrainAndSaveModel modelTrainer = new TrainAndSaveModel(pathToDatasets);
for (TrainConfig trainCfg : trainConfigs) {
modelTrainer.runModelTraining(trainCfg);
}
}
}
+80
View File
@@ -0,0 +1,80 @@
package de.unidue.ltl.escrito.examples.local.models;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.JCasFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.ml.model.PreTrainedModelProviderUnitMode;
import de.unidue.ltl.escrito.core.types.LearnerAnswer;
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
public class StoredModelPredictor extends Experiments_ImplBase{
public static final String LANG_CODE = "de";
public static void main(String[] args) {
//String experimentName = "Me_n-SMO-C-1.0-NormalizedPolyKernel-E-3.0";
File modelOutputFolder = new File(TrainAndSaveModel.OUTPUT_DIR, args[0]);
//String exampleAnswer = "Die Chromosomen, welche aus zwei Chromatiden bestehen, bewegen sich zum zentralen Äquator."; // GT = 1
//for debugging:
System.out.println("Total number of arguments passed: " + args.length);
for (int i = 0; i < args.length; i++) {
System.out.println("Argument " + i + ": " + args[i]);
}
try {
documentLoadModelSingleLabel(LANG_CODE, modelOutputFolder, args[1]);
} catch (Exception e) {
System.out.println("Exception while processing answer. Please verify the following:");
System.out.println("--> a) The correct name of an existing directory in " + TrainAndSaveModel.OUTPUT_DIR
+ " is passed as first argument.");
System.out.println("--> b) A string representing the learner's answer to be classified is passed as second argument.");
System.out.println("--> c) Both passed arguments are wrapped in double quotes: \"...\".");
e.printStackTrace();
}
}
// from de.unidue.ltl.escrito.examples.io.StoredModelApplicationExample
private static void documentLoadModelSingleLabel(String languageCode, File modelOutputFile, String exampleAnswer)
throws Exception
{
//System.out.println("Path to model: " + modelOutputFile.getAbsolutePath());
AnalysisEngine preprocessing = AnalysisEngineFactory.createEngine(Experiments_ImplBase.getPreprocessing(languageCode));
AnalysisEngine tcAnno = AnalysisEngineFactory.createEngine(PreTrainedModelProviderUnitMode.class,
PreTrainedModelProviderUnitMode.PARAM_NAME_TARGET_ANNOTATION, LearnerAnswer.class,
// Achtung: It seems like you MAY NOT use the class TextClassificationTarget (as we do in the reader)
// to indicate the unit to be considered
// as far as I can see, a TextClassifcationTarget is produced by the classifier and we only want to have one in the end!
PreTrainedModelProviderUnitMode.PARAM_TC_MODEL_LOCATION, modelOutputFile.getAbsolutePath());
JCas jcas = JCasFactory.createJCas();
jcas.setDocumentText(exampleAnswer);
jcas.setDocumentLanguage(languageCode);
LearnerAnswer unit = new LearnerAnswer(jcas, 0, jcas.getDocumentText().length());
unit.addToIndexes();
// redo the preprocessing
preprocessing.process(jcas);
tcAnno.process(jcas);
// redo the processing done by the classifier
List<TextClassificationOutcome> outcomes = new ArrayList<>(
JCasUtil.select(jcas, TextClassificationOutcome.class));
//System.out.println(jcas.getDocumentText()+"\nOutcome: "+outcomes.get(0).getOutcome());
System.out.println(outcomes.get(0).getOutcome()); // only print (binary) outcome
}
}
+167
View File
@@ -0,0 +1,167 @@
package de.unidue.ltl.escrito.examples.local.models;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.lab.Lab;
import org.dkpro.lab.task.Dimension;
import org.dkpro.lab.task.ParameterSpace;
import org.dkpro.lab.task.BatchTask.ExecutionPolicy;
import org.dkpro.tc.api.features.TcFeatureFactory;
import org.dkpro.tc.api.features.TcFeatureSet;
import org.dkpro.tc.features.ngram.CharacterNGram;
import org.dkpro.tc.features.ngram.WordNGram;
import org.dkpro.tc.ml.experiment.ExperimentSaveModel;
import org.dkpro.tc.ml.weka.WekaAdapter;
import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
import weka.classifiers.functions.SMO;
public class TrainAndSaveModel extends Experiments_ImplBase{
// some constants
public static final String TARGET_ANSWER_PREFIX = "TA";
public static final String CORPUS_NAME = "Mitose";
public static final String LANG_CODE = "de";
public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/";
private String pathToDatasets;
private String outputDir;
TrainAndSaveModel(String pathToDatasets, String outputDir) {
this.pathToDatasets = pathToDatasets;
this.outputDir = outputDir;
}
TrainAndSaveModel(String pathToDatasets) {
this.pathToDatasets = pathToDatasets;
this.outputDir = OUTPUT_DIR;
}
public void runModelTraining(TrainConfig trainConfig) {
File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName);
String[] kernelName = trainConfig.kernelClass.split("\\.");
String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class
String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C),
kernelClass, "E", Double.toString(trainConfig.E));
File modelFolder = new File(this.outputDir, experimentName);
try {
// create model folder if not existent
if(!modelFolder.exists()) {
modelFolder.mkdirs();
}
// pass full kernel class (including package prefix) to getParameterSpaceSigleLabel()
ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences,
CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E));
documentWriteModel(pspace, modelFolder, experimentName);
System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath());
} catch (ResourceInitializationException e) {
System.out.println("Error while building parameter spcae.");
e.printStackTrace();
} catch (Exception e) {
System.out.println("Error while training/saving model.");
e.printStackTrace();
}
}
// from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample
private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName)
throws Exception
{
ExperimentSaveModel batch;
batch = new ExperimentSaveModel();
batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE));
batch.setParameterSpace(paramSpace);
batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN);
batch.setExperimentName(experimentName);
batch.setOutputFolder(modelFolder);
Lab.getInstance().run(batch);
}
// from de.unidue.ltl.escrito.examples.local.gridsearch
public static ParameterSpace getParameterSpaceSingleLabel(
String dataPath,
String promptName,
int numTargetAnswers,
String corpusName,
int nGrams,
String kernelClass,
String C,
String E
)
throws ResourceInitializationException {
System.out.println("Starting to build parameter space ...");
Map<String, Object> dimReaders = new HashMap<String, Object>();
CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription(
Reader.class,
Reader.PARAM_INPUT_FILE, dataPath,
Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX,
Reader.PARAM_CORPUSNAME, corpusName
);
dimReaders.put(DIM_READER_TRAIN, readerTrain);
Dimension<String> learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL);
Map<String, Object> config = new HashMap<>();
// have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class
config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C,
"-K", kernelClass + " " + "-C -1 -E " + E});
config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass());
config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures());
Dimension<Map<String, Object>> learningsArgsDims = Dimension.createBundle("config", config);
// copied form BaseExperimentCV.java
Dimension<TcFeatureSet> dimFeatureSets = Dimension.create(
DIM_FEATURE_SET,
new TcFeatureSet(
/*
TcFeatureFactory.create(
PairwiseFeatureWrapper.class,
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(),
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
),
TcFeatureFactory.create(
PairwiseFeatureWrapper.class,
PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(),
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2",
StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5",
PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
),
*/
TcFeatureFactory.create(
WordNGram.class,
WordNGram.PARAM_NGRAM_MIN_N, 1,
WordNGram.PARAM_NGRAM_MAX_N, 3,
WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams
),
TcFeatureFactory.create(
CharacterNGram.class,
CharacterNGram.PARAM_NGRAM_MIN_N, 2,
CharacterNGram.PARAM_NGRAM_MAX_N, 5,
CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams
)
)
);
// create parameter space
ParameterSpace pSpace = null;
pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders),
learningDims,
Dimension.create(DIM_FEATURE_MODE, FM_UNIT),
dimFeatureSets,
learningsArgsDims);
System.out.println("Finished building parameter space.");
return pSpace;
}
}
+77
View File
@@ -0,0 +1,77 @@
package de.unidue.ltl.escrito.examples.local.models;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class TrainConfig {
/*
* each declared field must correspond exactly to one column in
* the CSV-file where the training configurations are specified otherwise
* reflection cannot be used as intended in TrainAndSaveModel.java
*/
String datasetName;
String prompt;
int nGrams;
int nReferences;
String kernelClass;
double C;
double E;
public TrainConfig(String datasetName, String prompt, int nGrams, int nReferences,
String kernelClass, double C, double E) {
this.datasetName = datasetName;
this.prompt = prompt;
this.nGrams = nGrams;
this.nReferences = nReferences;
this.kernelClass = kernelClass;
this.C = C;
this.E = E;
}
public String toString() {
return "[dataset: " + this.datasetName + "; prompt: " + this.prompt + "; nGrams: "
+ this.nGrams + "; nReferences: " + this.nReferences + "; kernelClass: "
+ this.kernelClass + "; C: " + C + "; E: " + E +"]";
}
public static TrainConfig[] createTrainConfigsFromCSV(String pathToConfigCSVFile) {
ArrayList<TrainConfig> trainConfigs = new ArrayList<>();
try(BufferedReader br = new BufferedReader(new FileReader(pathToConfigCSVFile))) {
// read first line, which is the header
String line = br.readLine();
while ((line = br.readLine()) != null) {
if (line.startsWith("#")) {
continue; // ignore lines starting with #
}
String[] configParts = line.split(",");
TrainConfig trainConfig = new TrainConfig(
configParts[0], // datasetName
configParts[1], // prompt
Integer.parseInt(configParts[2]), // nGrams
Integer.parseInt(configParts[3]), // nReferences
configParts[4], // kernelClass
Double.parseDouble(configParts[5]), // C
Double.parseDouble(configParts[6]) // E
);
trainConfigs.add(trainConfig);
}
} catch (IOException e) {
System.out.println("Error while reading " + pathToConfigCSVFile);
e.printStackTrace();
}
// return array containing TrainConfig objects
return trainConfigs.toArray(new TrainConfig[trainConfigs.size()]);
}
}