escrito-docker/local/models/Reader.java

193 lines
6.4 KiB
Java

package de.unidue.ltl.escrito.examples.local.models;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.apache.uima.UimaContext;
import org.apache.uima.collection.CollectionException;
import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.component.JCasCollectionReader_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.pipeline.JCasIterable;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.StringArray;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.util.Progress;
import org.apache.uima.util.ProgressImpl;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.api.type.TextClassificationTarget;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils;
import de.unidue.ltl.escrito.core.types.LearnerAnswerWithReferenceAnswer;
import de.unidue.ltl.escrito.io.generic.GenericDatasetItem;
import de.unidue.ltl.escrito.io.generic.GenericDatasetReader;
import de.unidue.ltl.escrito.io.util.Utils;
public class Reader extends JCasCollectionReader_ImplBase{
public static final String PARAM_INPUT_FILE = "InputFile";
@ConfigurationParameter(name = PARAM_INPUT_FILE, mandatory = true)
protected String inputFileString;
protected URL inputFileURL;
public static final String PARAM_TARGET_ANSWER_PREFIX = "TargetAnswerPrefix";
@ConfigurationParameter(name = PARAM_TARGET_ANSWER_PREFIX, mandatory = false, defaultValue = "TA")
private String targetAnswerPrefix;
public static final String PARAM_CORPUSNAME = "corpusName";
@ConfigurationParameter(name = PARAM_CORPUSNAME, mandatory = true)
protected String corpusName;
protected int currentIndex;
protected Queue<GenericDatasetItem> items;
private Map<String, String> questions;
private Map<String, String> targetAnswers;
private Set<String> promptAnswerIds;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
items = new LinkedList<GenericDatasetItem>();
questions = new HashMap<String, String>();
targetAnswers = new HashMap<String, String>();
promptAnswerIds = new HashSet<String>();
try {
inputFileURL = ResourceUtils.resolveLocation(inputFileString, this, aContext);
BufferedReader reader = new BufferedReader(
new InputStreamReader(
inputFileURL.openStream(),
"UTF-16"
)
);
String nextLine;
int lineCounter = 1;
while ((nextLine = reader.readLine()) != null) {
// System.out.println("line: "+nextLine);
String[] nextItem = nextLine.split("\t");
String promptId = null;
String answerId = null;
String text = null;
String score = "-1";
// System.out.println(nextItem.length);
if (nextItem.length>=4) {
GenericDatasetItem newItem = null ;
promptId = nextItem[0];
answerId = nextItem[1];
text = nextItem[2];
score = nextItem[3];
text = Utils.cleanString(text);
int counter = 1;
for (int i = 4; i< nextItem.length; i++){
targetAnswers.put(promptId+"_"+counter, Utils.cleanString(nextItem[i]));
counter++;
}
newItem = new GenericDatasetItem(promptId, answerId, text, score, promptId);
items.add(newItem);
}
else {
System.out.println("Could not read lineNumber: " + lineCounter + ", " + nextItem +" item length is: " +nextItem.length);
}
lineCounter++;
}
}
catch (Exception e) {
e.printStackTrace();
throw new ResourceInitializationException(e);
}
currentIndex = 0;
if (!targetAnswers.isEmpty()){
Utils.preprocessConnectedTexts(targetAnswers, corpusName, targetAnswerPrefix, "de");
}
System.out.println("read "+items.size()+" items.");
}
@Override
public boolean hasNext()
throws IOException
{
return !items.isEmpty();
}
@Override
public void getNext(JCas jcas)
throws IOException, CollectionException
{
GenericDatasetItem item = items.poll();
getLogger().debug(item);
String itemId = String.valueOf(item.getPromptId()+"_"+item.getAnswerId());
try
{
jcas.setDocumentLanguage("de");
jcas.setDocumentText(item.getText());
DocumentMetaData dmd = DocumentMetaData.create(jcas);
dmd.setDocumentId(itemId);
dmd.setDocumentTitle(item.getText());
dmd.setDocumentUri(inputFileURL.toURI().toString());
dmd.setCollectionId(itemId);
}
catch (URISyntaxException e) {
throw new CollectionException(e);
}
LearnerAnswerWithReferenceAnswer learnerAnswer = new LearnerAnswerWithReferenceAnswer(jcas, 0, jcas.getDocumentText().length());
learnerAnswer.setPromptId(item.getPromptId());
StringArray ids = new StringArray(jcas, targetAnswers.size());
// We only have one exactly target answer per learner, so we use the same id as for the prompt
int counter = 0;
for (String taId : targetAnswers.keySet()) {
ids.set(counter, String.valueOf(taId));
counter++;
}
learnerAnswer.setReferenceAnswerIds(ids);
learnerAnswer.addToIndexes();
TextClassificationTarget unit = new TextClassificationTarget(jcas, 0, jcas.getDocumentText().length());
// will add the token content as a suffix to the ID of this unit
unit.setSuffix(itemId);
unit.addToIndexes();
TextClassificationOutcome outcome = new TextClassificationOutcome(jcas, 0, jcas.getDocumentText().length());
outcome.setOutcome(item.getGrade());
outcome.addToIndexes();
currentIndex++;
}
@Override
public Progress[] getProgress()
{
return new Progress[] { new ProgressImpl(currentIndex, currentIndex, Progress.ENTITIES) };
}
public static void main(String[] args) throws ResourceInitializationException {
CollectionReaderDescription reader = CollectionReaderFactory.createReaderDescription(
Reader.class,
Reader.PARAM_INPUT_FILE, "/Users/andrea/dkpro/datasets/KatharinaFleig/Inter_g_Lernerantworten_Original_N=15_w4Refs.tsv",
Reader.PARAM_TARGET_ANSWER_PREFIX, "TA",
Reader.PARAM_CORPUSNAME, "Mitose"
);
int i=0;
for (JCas jcas : new JCasIterable(reader)) {
System.out.println(jcas.getDocumentText());
i++;
}
}
}