/*
 * Decompiled with CFR 0.152.
 */
package org.openhab.ui.habot.nlp.internal;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.SortedMap;
import opennlp.tools.doccat.DoccatFactory;
import opennlp.tools.doccat.DoccatModel;
import opennlp.tools.doccat.DocumentCategorizerME;
import opennlp.tools.namefind.NameFinderME;
import opennlp.tools.namefind.NameSampleDataStream;
import opennlp.tools.namefind.TokenNameFinderFactory;
import opennlp.tools.namefind.TokenNameFinderModel;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WhitespaceTokenizer;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ObjectStreamUtils;
import opennlp.tools.util.Span;
import opennlp.tools.util.TrainingParameters;
import org.openhab.ui.habot.nlp.Intent;
import org.openhab.ui.habot.nlp.Skill;
import org.openhab.ui.habot.nlp.UnsupportedLanguageException;
import org.openhab.ui.habot.nlp.internal.AlphaNumericTokenizer;
import org.openhab.ui.habot.nlp.internal.IntentDocumentSampleStream;
import org.openhab.ui.habot.nlp.internal.LowerCasePlainTextByLineStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IntentTrainer {
    private final Logger logger = LoggerFactory.getLogger(IntentTrainer.class);
    private DocumentCategorizerME categorizer;
    private NameFinderME nameFinder;
    private Tokenizer tokenizer;

    public IntentTrainer(String language, Collection<Skill> skills) throws Exception {
        this(language, skills, null, null);
    }

    public IntentTrainer(String language, Collection<Skill> skills, InputStream additionalNameSamples, String tokenizerId) throws Exception {
        this.tokenizer = tokenizerId == "alphanumeric" ? AlphaNumericTokenizer.INSTANCE : WhitespaceTokenizer.INSTANCE;
        ArrayList<IntentDocumentSampleStream> categoryStreams = new ArrayList<IntentDocumentSampleStream>();
        for (Skill skill : skills) {
            String intent = skill.getIntentId();
            try {
                InputStream trainingData = skill.getTrainingData(language);
                if (trainingData == null) {
                    throw new UnsupportedLanguageException(language);
                }
                LowerCasePlainTextByLineStream lineStream = new LowerCasePlainTextByLineStream(trainingData);
                IntentDocumentSampleStream documentSampleStream = new IntentDocumentSampleStream(intent, lineStream);
                categoryStreams.add(documentSampleStream);
            }
            catch (UnsupportedLanguageException e) {
                this.logger.warn("Ignoring intent {} because no training data for language {}", (Object)skill.getIntentId(), (Object)language);
            }
        }
        if (categoryStreams.isEmpty()) {
            throw new UnsupportedLanguageException(language);
        }
        ObjectStream combinedDocumentSampleStream = ObjectStreamUtils.concatenateObjectStream(categoryStreams);
        TrainingParameters trainingParams = TrainingParameters.defaultParams();
        trainingParams.put("PrintMessages", false);
        DoccatModel doccatModel = DocumentCategorizerME.train((String)language, (ObjectStream)combinedDocumentSampleStream, (TrainingParameters)trainingParams, (DoccatFactory)new DoccatFactory());
        combinedDocumentSampleStream.close();
        ArrayList<TokenNameFinderModel> tokenNameFinderModels = new ArrayList<TokenNameFinderModel>();
        ArrayList<NameSampleDataStream> nameStreams = new ArrayList<NameSampleDataStream>();
        for (Skill skill : skills) {
            try {
                InputStream trainingData = skill.getTrainingData(language);
                if (trainingData == null) {
                    throw new UnsupportedLanguageException(language);
                }
                LowerCasePlainTextByLineStream lineStream = new LowerCasePlainTextByLineStream(trainingData);
                NameSampleDataStream nameSampleStream = new NameSampleDataStream((ObjectStream)lineStream);
                nameStreams.add(nameSampleStream);
            }
            catch (UnsupportedLanguageException e) {
                this.logger.warn("Ignoring intent {} because no training data for language {}", (Object)skill.getIntentId(), (Object)language);
            }
        }
        if (additionalNameSamples != null) {
            LowerCasePlainTextByLineStream additionalLineStream = new LowerCasePlainTextByLineStream(additionalNameSamples);
            NameSampleDataStream additionalNameSamplesStream = new NameSampleDataStream((ObjectStream)additionalLineStream);
            nameStreams.add(additionalNameSamplesStream);
        }
        ObjectStream combinedNameSampleStream = ObjectStreamUtils.concatenateObjectStream(nameStreams);
        TokenNameFinderModel tokenNameFinderModel = NameFinderME.train((String)language, null, (ObjectStream)combinedNameSampleStream, (TrainingParameters)trainingParams, (TokenNameFinderFactory)new TokenNameFinderFactory());
        combinedNameSampleStream.close();
        tokenNameFinderModels.add(tokenNameFinderModel);
        this.categorizer = new DocumentCategorizerME(doccatModel);
        this.nameFinder = new NameFinderME(tokenNameFinderModel);
    }

    public Intent interpret(String query) {
        String[] tokens = this.tokenizer.tokenize(query.toLowerCase());
        tokens[tokens.length - 1] = tokens[tokens.length - 1].replaceAll("\\s*[!?.]+$", "");
        double[] outcome = this.categorizer.categorize(tokens);
        this.logger.debug("{}", (Object)this.categorizer.getAllResults(outcome));
        Intent intent = new Intent(this.categorizer.getBestCategory(outcome));
        Span[] spans = this.nameFinder.find(tokens);
        String[] names = Span.spansToStrings((Span[])spans, (String[])tokens);
        int i = 0;
        while (i < spans.length) {
            intent.getEntities().put(spans[i].getType(), names[i]);
            ++i;
        }
        this.logger.debug("{}", (Object)intent.toString());
        return intent;
    }

    public SortedMap<Double, Set<String>> getScoreMap(String query) {
        String[] tokens = this.tokenizer.tokenize(query.toLowerCase());
        return this.categorizer.sortedScoreMap(tokens);
    }
}

