/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.ClassificationHandlerCorpusAdapter2;
import com.aliasi.classify.Classified;
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ConditionalClassification;
import com.aliasi.classify.ConditionalClassifier;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.LogisticRegression;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LogisticRegressionClassifier<E>
implements Classifier<E, ConditionalClassification>,
ConditionalClassifier<E>,
Compilable,
Serializable {
    static final long serialVersionUID = -400005337034204553L;
    private final LogisticRegression mModel;
    private final FeatureExtractor<? super E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final SymbolTable mFeatureSymbolTable;
    private final String[] mCategorySymbols;
    public static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";
    static final Vector[] EMPTY_VECTOR_ARRAY = new Vector[0];

    LogisticRegressionClassifier(LogisticRegression model, FeatureExtractor<? super E> featureExtractor, boolean addInterceptFeature, SymbolTable featureSymbolTable, String[] categorySymbols) {
        if (model.numOutcomes() != categorySymbols.length) {
            String msg = "Number of model outcomes must match category symbols length. Found model.numOutcomes()=" + model.numOutcomes() + " categorySymbols.length=" + categorySymbols.length;
            throw new IllegalArgumentException(msg);
        }
        HashSet<String> categorySymbolSet = new HashSet<String>();
        for (int i = 0; i < categorySymbols.length; ++i) {
            if (categorySymbolSet.add(categorySymbols[i])) continue;
            String msg = "Categories must be unique. Found duplicate category categorySymbols[" + i + "]=" + categorySymbols[i];
            throw new IllegalArgumentException(msg);
        }
        this.mModel = model;
        this.mFeatureExtractor = featureExtractor;
        this.mAddInterceptFeature = addInterceptFeature;
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mCategorySymbols = categorySymbols;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public List<String> categorySymbols() {
        return Arrays.asList(this.mCategorySymbols);
    }

    public LogisticRegression model() {
        return this.mModel;
    }

    public boolean addInterceptFeature() {
        return this.mAddInterceptFeature;
    }

    public FeatureExtractor<E> featureExtractor() {
        return new FeatureExtractor<E>(){

            @Override
            public Map<String, ? extends Number> features(E in) {
                return LogisticRegressionClassifier.this.mFeatureExtractor.features(in);
            }
        };
    }

    public ConditionalClassification classifyVector(Vector v) {
        double[] conditionalProbs = this.mModel.classify(v);
        ScoredObject[] sos = new ScoredObject[conditionalProbs.length];
        for (int i = 0; i < conditionalProbs.length; ++i) {
            sos[i] = new ScoredObject<String>(this.mCategorySymbols[i], conditionalProbs[i]);
        }
        Arrays.sort(sos, ScoredObject.reverseComparator());
        String[] categories = new String[conditionalProbs.length];
        for (int i = 0; i < conditionalProbs.length; ++i) {
            categories[i] = ((String)sos[i].getObject()).toString();
            conditionalProbs[i] = sos[i].score();
        }
        return new ConditionalClassification(categories, conditionalProbs);
    }

    public ConditionalClassification classifyFeatures(Map<String, ? extends Number> featureMap) {
        Vector v = Features.toVector(featureMap, this.mFeatureSymbolTable, this.mFeatureSymbolTable.numSymbols(), this.mAddInterceptFeature);
        return this.classifyVector(v);
    }

    @Override
    public ConditionalClassification classify(E in) {
        return this.classifyFeatures(this.mFeatureExtractor.features(in));
    }

    @Override
    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    private int categoryToId(String category) {
        for (int i = 0; i < this.mCategorySymbols.length; ++i) {
            if (!this.mCategorySymbols[i].equals(category)) continue;
            return i;
        }
        return -1;
    }

    public ObjectToDoubleMap<String> featureValues(String category) {
        int categoryId = this.categoryToId(category);
        if (categoryId < 0) {
            String msg = "Unknown category=" + category;
            throw new IllegalArgumentException(msg);
        }
        ObjectToDoubleMap<String> result = new ObjectToDoubleMap<String>();
        if (categoryId == this.mCategorySymbols.length - 1) {
            return result;
        }
        int numSymbols = this.mFeatureSymbolTable.numSymbols();
        Vector[] weightVectors = this.mModel.weightVectors();
        Vector weightVector = weightVectors[categoryId];
        for (int i = 0; i < numSymbols; ++i) {
            String symbol = this.mFeatureSymbolTable.idToSymbol(i);
            result.set(symbol, weightVector.value(i));
        }
        return result;
    }

    public String toString() {
        CharArrayWriter writer = new CharArrayWriter();
        PrintWriter printWriter = new PrintWriter(writer);
        List<String> categorySymbols = this.categorySymbols();
        printWriter.println("NUMBER OF CATEGORIES=" + categorySymbols.size());
        printWriter.println("NUMBER OF FEATURES=" + this.mFeatureSymbolTable.numSymbols());
        for (int i = 0; i < categorySymbols.size() - 1; ++i) {
            String category = categorySymbols.get(i);
            printWriter.println("\n  CATEGORY=" + category);
            ObjectToDoubleMap<String> parameterVector = this.featureValues(category);
            for (String feature : parameterVector.keysOrderedByValueList()) {
                printWriter.printf("%20s %15.6f\n", feature, parameterVector.get(feature));
            }
        }
        printWriter.write(10);
        return writer.toString();
    }

    private Object writeReplace() {
        return new Externalizer(this);
    }

    @Deprecated
    public static <F> LogisticRegressionClassifier<F> train(FeatureExtractor<? super F> featureExtractor, Corpus<ClassificationHandler<F, Classification>> corpus, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) throws IOException {
        Reporter reporter = progressWriter == null ? Reporters.silent() : Reporters.writer(progressWriter).setLevel(LogLevel.DEBUG);
        return LogisticRegressionClassifier.train(featureExtractor, corpus, minFeatureCount, addInterceptFeature, prior, annealingSchedule, reporter, minImprovement, minEpochs, maxEpochs);
    }

    @Deprecated
    public static <F> LogisticRegressionClassifier<F> train(FeatureExtractor<? super F> featureExtractor, Corpus<ClassificationHandler<F, Classification>> corpus, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, AnnealingSchedule annealingSchedule, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) throws IOException {
        ClassificationHandlerCorpusAdapter2 corpus2 = new ClassificationHandlerCorpusAdapter2(corpus);
        return LogisticRegressionClassifier.train(corpus2, featureExtractor, minFeatureCount, addInterceptFeature, prior, annealingSchedule, minImprovement, minEpochs, maxEpochs, reporter);
    }

    public static <F> LogisticRegressionClassifier<F> train(Corpus<ObjectHandler<Classified<F>>> corpus, FeatureExtractor<? super F> featureExtractor, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, Reporter reporter) throws IOException {
        int priorBlockSize = -1;
        return LogisticRegressionClassifier.train(corpus, featureExtractor, minFeatureCount, addInterceptFeature, prior, priorBlockSize, null, annealingSchedule, minImprovement, 5, minEpochs, maxEpochs, null, reporter);
    }

    public static <F> LogisticRegressionClassifier<F> train(Corpus<ObjectHandler<Classified<F>>> corpus, FeatureExtractor<? super F> featureExtractor, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, int priorBlockSize, LogisticRegressionClassifier<F> hotStart, AnnealingSchedule annealingSchedule, double minImprovement, int rollingAverageSize, int minEpochs, int maxEpochs, ObjectHandler<LogisticRegressionClassifier<F>> classifierHandler, Reporter reporter) throws IOException {
        MapSymbolTable featureSymbolTable = new MapSymbolTable();
        MapSymbolTable categorySymbolTable = new MapSymbolTable();
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        if (addInterceptFeature) {
            featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);
        }
        reporter.info("Feature Extractor class=" + featureExtractor.getClass());
        reporter.info("min feature count=" + minFeatureCount);
        reporter.info("Extracting Training Data");
        reporter.debug("  Counting features");
        ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>();
        corpus.visitTrain(new FeatureCounter<F>(featureExtractor, featureCounter));
        reporter.debug("  Pruning features");
        featureCounter.prune(minFeatureCount);
        for (String feature : featureCounter.keySet()) {
            featureSymbolTable.getOrAddSymbol(feature);
        }
        reporter.debug("  Extracting vectors");
        DataExtractor<F> dataExtractor = new DataExtractor<F>(featureExtractor, featureSymbolTable, categorySymbolTable, addInterceptFeature, featureSymbolTable.numSymbols());
        corpus.visitTrain(dataExtractor);
        Vector[] inputs = dataExtractor.inputs();
        int[] categories = dataExtractor.categories();
        int numInputDimensions = inputs[0].numDimensions();
        String[] categorySymbols = new String[categorySymbolTable.numSymbols()];
        for (int i = 0; i < categorySymbols.length; ++i) {
            categorySymbols[i] = categorySymbolTable.idToSymbol(i);
        }
        LogisticRegression lrHotStart = null;
        if (hotStart != null) {
            int k;
            reporter.debug("hot starting");
            HashSet<String> hotStartCategorySet = new HashSet<String>(hotStart.categorySymbols());
            Vector[] weightVectors = new Vector[categorySymbols.length - 1];
            for (k = 0; k < weightVectors.length; ++k) {
                weightVectors[k] = new DenseVector(numInputDimensions);
            }
            for (k = 0; k < weightVectors.length - 1; ++k) {
                String category = categorySymbols[k];
                if (!hotStartCategorySet.contains(category)) continue;
                ObjectToDoubleMap<String> featureVector = hotStart.featureValues(category);
                for (int i = 0; i < numInputDimensions; ++i) {
                    String feature = featureSymbolTable.idToSymbol(i);
                    double value = featureVector.getValue(feature);
                    weightVectors[k].setValue(i, value);
                }
            }
            lrHotStart = new LogisticRegression(weightVectors);
        }
        reporter.info(hotStart != null ? "Hot start" : "Cold start");
        RegressionHandlerAdapter<F> regressionHandler = classifierHandler == null ? null : new RegressionHandlerAdapter<F>(classifierHandler, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols);
        reporter.info(regressionHandler != null ? "Regssion callback handler class=" + regressionHandler.getClass() : "Regression callback handler=" + null);
        if (priorBlockSize == -1) {
            priorBlockSize = Math.max(1, categories.length / 50);
        }
        LogisticRegression model = LogisticRegression.estimate(inputs, categories, prior, priorBlockSize, lrHotStart, annealingSchedule, minImprovement, rollingAverageSize, minEpochs, maxEpochs, regressionHandler, reporter);
        return new LogisticRegressionClassifier<F>(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class DataExtractor<F>
    implements ObjectHandler<Classified<F>> {
        final FeatureExtractor<? super F> mFeatureExtractor;
        final SymbolTable mFeatureSymbolTable;
        final SymbolTable mCategorySymbolTable;
        final boolean mAddInterceptFeature;
        final int mNumSymbols;
        final List<Vector> mInputVectorList = new ArrayList<Vector>();
        final List<Integer> mOutputCategoryList = new ArrayList<Integer>();

        DataExtractor(FeatureExtractor<? super F> featureExtractor, SymbolTable featureSymbolTable, SymbolTable categorySymbolTable, boolean addInterceptFeature, int numSymbols) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureSymbolTable = featureSymbolTable;
            this.mCategorySymbolTable = categorySymbolTable;
            this.mAddInterceptFeature = addInterceptFeature;
            this.mNumSymbols = numSymbols;
        }

        @Override
        public void handle(Classified<F> classified) {
            F input = classified.getObject();
            Classification output = classified.getClassification();
            String outputCategoryName = output.bestCategory();
            Integer outputCategoryId = this.mCategorySymbolTable.getOrAddSymbol(outputCategoryName);
            Map<String, Number> featureMap = this.mFeatureExtractor.features(input);
            Vector vector = Features.toVector(featureMap, this.mFeatureSymbolTable, this.mNumSymbols, this.mAddInterceptFeature);
            this.mInputVectorList.add(vector);
            this.mOutputCategoryList.add(outputCategoryId);
        }

        int[] categories() {
            int[] inputs = new int[this.mOutputCategoryList.size()];
            for (int i = 0; i < inputs.length; ++i) {
                inputs[i] = this.mOutputCategoryList.get(i);
            }
            return inputs;
        }

        Vector[] inputs() {
            return this.mInputVectorList.toArray(EMPTY_VECTOR_ARRAY);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class Externalizer<G>
    extends AbstractExternalizable {
        static final long serialVersionUID = -2003123148721825458L;
        final LogisticRegressionClassifier<G> mClassifier;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegressionClassifier<G> classifier) {
            this.mClassifier = classifier;
        }

        @Override
        public void writeExternal(ObjectOutput objOut) throws IOException {
            objOut.writeObject(((LogisticRegressionClassifier)this.mClassifier).mModel);
            objOut.writeObject(((LogisticRegressionClassifier)this.mClassifier).mFeatureExtractor);
            objOut.writeBoolean(((LogisticRegressionClassifier)this.mClassifier).mAddInterceptFeature);
            objOut.writeObject(((LogisticRegressionClassifier)this.mClassifier).mFeatureSymbolTable);
            objOut.writeInt(((LogisticRegressionClassifier)this.mClassifier).mCategorySymbols.length);
            for (int i = 0; i < ((LogisticRegressionClassifier)this.mClassifier).mCategorySymbols.length; ++i) {
                objOut.writeUTF(((LogisticRegressionClassifier)this.mClassifier).mCategorySymbols[i]);
            }
        }

        @Override
        public Object read(ObjectInput objIn) throws IOException, ClassNotFoundException {
            LogisticRegression model = (LogisticRegression)objIn.readObject();
            FeatureExtractor featureExtractor = (FeatureExtractor)objIn.readObject();
            boolean addInterceptFeature = objIn.readBoolean();
            SymbolTable featureSymbolTable = (SymbolTable)objIn.readObject();
            int numSymbols = objIn.readInt();
            String[] categorySymbols = new String[numSymbols];
            for (int i = 0; i < categorySymbols.length; ++i) {
                categorySymbols[i] = objIn.readUTF();
            }
            return new LogisticRegressionClassifier(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class FeatureCounter<H>
    implements ObjectHandler<Classified<H>> {
        private final FeatureExtractor<? super H> mFeatureExtractor;
        private final ObjectToCounterMap<String> mFeatureCounter;

        FeatureCounter(FeatureExtractor<? super H> featureExtractor, ObjectToCounterMap<String> featureCounter) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureCounter = featureCounter;
        }

        @Override
        public void handle(Classified<H> classified) {
            H h = classified.getObject();
            Map<String, Number> featureMap = this.mFeatureExtractor.features(h);
            for (String feature : featureMap.keySet()) {
                this.mFeatureCounter.increment(feature);
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class RegressionHandlerAdapter<F>
    implements ObjectHandler<LogisticRegression> {
        private final ObjectHandler<LogisticRegressionClassifier<F>> mClassifierHandler;
        private final FeatureExtractor<? super F> mFeatureExtractor;
        private final boolean mAddInterceptFeature;
        private final SymbolTable mFeatureSymbolTable;
        private final String[] mCategorySymbols;

        public RegressionHandlerAdapter(ObjectHandler<LogisticRegressionClassifier<F>> handler, FeatureExtractor<? super F> featureExtractor, boolean addInterceptFeature, SymbolTable featureSymbolTable, String[] categorySymbols) {
            this.mClassifierHandler = handler;
            this.mFeatureExtractor = featureExtractor;
            this.mAddInterceptFeature = addInterceptFeature;
            this.mFeatureSymbolTable = featureSymbolTable;
            this.mCategorySymbols = categorySymbols;
        }

        @Override
        public void handle(LogisticRegression regressionModel) {
            this.mClassifierHandler.handle(new LogisticRegressionClassifier<F>(regressionModel, this.mFeatureExtractor, this.mAddInterceptFeature, this.mFeatureSymbolTable, this.mCategorySymbols));
        }
    }
}

