package de.unidue.ltl.escrito.examples.local.models;

import java.io.File;
import java.util.HashMap;
import java.util.Map;

import org.apache.uima.collection.CollectionReaderDescription;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.lab.Lab;
import org.dkpro.lab.task.Dimension;
import org.dkpro.lab.task.ParameterSpace;
import org.dkpro.lab.task.BatchTask.ExecutionPolicy;
import org.dkpro.tc.api.features.TcFeatureFactory;
import org.dkpro.tc.api.features.TcFeatureSet;
import org.dkpro.tc.features.ngram.CharacterNGram;
import org.dkpro.tc.features.ngram.WordNGram;
import org.dkpro.tc.ml.experiment.ExperimentSaveModel;
import org.dkpro.tc.ml.weka.WekaAdapter;

import de.unidue.ltl.escrito.examples.basics.Experiments_ImplBase;
import weka.classifiers.functions.SMO;

public class TrainAndSaveModel extends Experiments_ImplBase{
	// some constants
	public static final String TARGET_ANSWER_PREFIX = "TA";
	public static final String CORPUS_NAME = "Mitose";
	public static final String LANG_CODE = "de";
	public static final String OUTPUT_DIR = System.getenv("DKPRO_HOME") + "/models/";
	
	private String pathToDatasets;
	private String outputDir;
	
	TrainAndSaveModel(String pathToDatasets, String outputDir) {
		this.pathToDatasets = pathToDatasets;
		this.outputDir = outputDir;
	}
	
	TrainAndSaveModel(String pathToDatasets) {
		this.pathToDatasets = pathToDatasets;
		this.outputDir = OUTPUT_DIR;
	}
	
	public void runModelTraining(TrainConfig trainConfig) {
		File fullDataPath = new File(this.pathToDatasets, trainConfig.datasetName);
		String[] kernelName = trainConfig.kernelClass.split("\\.");
		String kernelClass = kernelName[kernelName.length - 1]; // strip package prefix from name of kernel class
		String experimentName = String.join("-", trainConfig.prompt, "SMO", "C", Double.toString(trainConfig.C), 
				kernelClass, "E", Double.toString(trainConfig.E));
		File modelFolder = new File(this.outputDir, experimentName);
		try {
			// create model folder if not existent
			if(!modelFolder.exists()) {
				modelFolder.mkdirs();
			}
			// pass full kernel class (including package prefix) to getParameterSpaceSigleLabel()
			ParameterSpace pspace = getParameterSpaceSingleLabel(fullDataPath.toString(), trainConfig.prompt, trainConfig.nReferences, 
					CORPUS_NAME, trainConfig.nGrams, trainConfig.kernelClass, Double.toString(trainConfig.C), Double.toString(trainConfig.E));
			documentWriteModel(pspace, modelFolder, experimentName);
			System.out.println("Trained model stroed at: " + modelFolder.getAbsolutePath());
		} catch (ResourceInitializationException e) {
			System.out.println("Error while building parameter spcae.");
			e.printStackTrace();
		} catch (Exception e) {
			System.out.println("Error while training/saving model.");
			e.printStackTrace();
		}
		
	}
	
	// from de.unidue.ltl.escrito.examples.io.WriteAndReadModelApplicationExample
	private static void documentWriteModel(ParameterSpace paramSpace, File modelFolder, String experimentName)
			throws Exception
	{
		ExperimentSaveModel batch;
		batch = new ExperimentSaveModel();
		batch.setPreprocessing(Experiments_ImplBase.getPreprocessing(LANG_CODE));
		batch.setParameterSpace(paramSpace);
		batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN);
		batch.setExperimentName(experimentName);
		batch.setOutputFolder(modelFolder);
		Lab.getInstance().run(batch);
	}
	
	// from de.unidue.ltl.escrito.examples.local.gridsearch
	public static ParameterSpace getParameterSpaceSingleLabel(
			String dataPath,
			String promptName,
			int numTargetAnswers,
			String corpusName,
			int nGrams,
			String kernelClass,
			String C,
			String E
			)
			throws ResourceInitializationException {
		System.out.println("Starting to build parameter space ...");
		Map<String, Object> dimReaders = new HashMap<String, Object>();
		CollectionReaderDescription readerTrain = CollectionReaderFactory.createReaderDescription(
				Reader.class,
				Reader.PARAM_INPUT_FILE, dataPath,
				Reader.PARAM_TARGET_ANSWER_PREFIX, TARGET_ANSWER_PREFIX,
				Reader.PARAM_CORPUSNAME, corpusName
				);
		dimReaders.put(DIM_READER_TRAIN, readerTrain);
		
		Dimension<String> learningDims = Dimension.create(DIM_LEARNING_MODE, LM_SINGLE_LABEL);
		Map<String, Object> config = new HashMap<>();
		// have a look at javadoc for SMO class to see valid parameters for SMO class and PolyKernel class
		config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new WekaAdapter(), SMO.class.getName(), "-C", C,
				"-K", kernelClass + " " + "-C -1 -E " + E});
		config.put(DIM_DATA_WRITER, new WekaAdapter().getDataWriterClass());
		config.put(DIM_FEATURE_USE_SPARSE, new WekaAdapter().useSparseFeatures());
		Dimension<Map<String, Object>> learningsArgsDims = Dimension.createBundle("config", config);
		
		// copied form BaseExperimentCV.java
		Dimension<TcFeatureSet> dimFeatureSets = Dimension.create(
				DIM_FEATURE_SET,
				new TcFeatureSet(
						/*
						TcFeatureFactory.create(
								PairwiseFeatureWrapper.class,
								PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, WordOverlapFeatureExtractor.class.getName(),
								PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
								PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
								PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
								PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
								PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
								),
						TcFeatureFactory.create(
								PairwiseFeatureWrapper.class,
								PairwiseFeatureWrapper.PARAM_PAIRWISE_FEATURE_EXTRACTOR, StringSimilarityFeatureExtractor.class.getName(),
								StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MIN, "2",
								StringSimilarityFeatureExtractor.PARAM_STRING_TILING_MAX, "5",
								PairwiseFeatureWrapper.PARAM_TARGET_ANSWER_PREFIX, "TA",
								PairwiseFeatureWrapper.PARAM_PROMPTNAME, promptName,
								PairwiseFeatureWrapper.PARAM_NUMBER_TARGETANSWERS, numTargetAnswers,
								PairwiseFeatureWrapper.PARAM_AGGREGATION_METHOD, PairwiseFeatureWrapper.AggregationMethod.INDIVIDUAL_FEATURES,
								PairwiseFeatureWrapper.PARAM_ADDITIONAL_TEXTS_LOCATION, System.getenv("DKPRO_HOME")+"/processedData/"+corpusName
								),
						*/
						TcFeatureFactory.create(
								WordNGram.class,
								WordNGram.PARAM_NGRAM_MIN_N, 1,
								WordNGram.PARAM_NGRAM_MAX_N, 3,
								WordNGram.PARAM_NGRAM_USE_TOP_K, nGrams
								),
						TcFeatureFactory.create(
								CharacterNGram.class,
								CharacterNGram.PARAM_NGRAM_MIN_N, 2,
								CharacterNGram.PARAM_NGRAM_MAX_N, 5,
								CharacterNGram.PARAM_NGRAM_USE_TOP_K, nGrams
								)
						)
				);
		// create parameter space
		ParameterSpace pSpace = null;
		pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders),
				learningDims,
				Dimension.create(DIM_FEATURE_MODE, FM_UNIT), 
				dimFeatureSets,
				learningsArgsDims);
		System.out.println("Finished building parameter space.");
		return pSpace;
	}

}