/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.ScoredClassification;
import com.aliasi.classify.ScoredClassifier;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class TfIdfClassifierTrainer<E>
implements ObjectHandler<Classified<E>>,
Compilable,
Serializable {
    static final long serialVersionUID = -2793388723202924633L;
    final FeatureExtractor<? super E> mFeatureExtractor;
    final Map<Integer, ObjectToDoubleMap<Integer>> mFeatureToCategoryCount;
    final MapSymbolTable mFeatureSymbolTable;
    final MapSymbolTable mCategorySymbolTable;

    public TfIdfClassifierTrainer(FeatureExtractor<? super E> featureExtractor) {
        this(featureExtractor, new HashMap<Integer, ObjectToDoubleMap<Integer>>(), new MapSymbolTable(), new MapSymbolTable());
    }

    TfIdfClassifierTrainer(FeatureExtractor<? super E> featureExtractor, Map<Integer, ObjectToDoubleMap<Integer>> featureToCategoryCount, MapSymbolTable featureSymbolTable, MapSymbolTable categorySymbolTable) {
        this.mFeatureExtractor = featureExtractor;
        this.mFeatureToCategoryCount = featureToCategoryCount;
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mCategorySymbolTable = categorySymbolTable;
    }

    public double idf(String feature) {
        Integer featureId = this.mFeatureSymbolTable.symbolToIDInteger(feature);
        if (featureId == null) {
            return 0.0;
        }
        ObjectToDoubleMap<Integer> otd = this.mFeatureToCategoryCount.get(featureId);
        int df = otd.size();
        int numDocs = this.mCategorySymbolTable.numSymbols();
        return TfIdfClassifierTrainer.idf(df, numDocs);
    }

    public double tfIdf(String feature, String category) {
        Integer featureId = this.mFeatureSymbolTable.symbolToIDInteger(feature);
        if (featureId == null) {
            return 0.0;
        }
        ObjectToDoubleMap<Integer> otd = this.mFeatureToCategoryCount.get(featureId);
        Integer categoryId = this.mCategorySymbolTable.symbolToIDInteger(category);
        if (categoryId == null) {
            return 0.0;
        }
        double count = otd.getValue(categoryId);
        if (count == 0.0) {
            return 0.0;
        }
        int df = otd.size();
        int numDocs = this.mCategorySymbolTable.numSymbols();
        double idf = TfIdfClassifierTrainer.idf(df, numDocs);
        double tf = TfIdfClassifierTrainer.tf(count);
        return tf * idf;
    }

    public double tf(String feature, String category) {
        Integer featureId = this.mFeatureSymbolTable.symbolToIDInteger(feature);
        if (featureId == null) {
            return 0.0;
        }
        ObjectToDoubleMap<Integer> otd = this.mFeatureToCategoryCount.get(featureId);
        Integer categoryId = this.mCategorySymbolTable.symbolToIDInteger(category);
        if (categoryId == null) {
            return 0.0;
        }
        double count = otd.getValue(categoryId);
        return TfIdfClassifierTrainer.tf(count);
    }

    public FeatureExtractor<? super E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public Set<String> categories() {
        return this.mCategorySymbolTable.symbolSet();
    }

    void handle(E input, Classification classification) {
        String category = classification.bestCategory();
        int categoryId = this.mCategorySymbolTable.getOrAddSymbol(category);
        Map<String, Number> featureVector = this.mFeatureExtractor.features(input);
        for (Map.Entry<String, Number> entry : featureVector.entrySet()) {
            String feature = entry.getKey();
            double value = entry.getValue().doubleValue();
            int featureId = this.mFeatureSymbolTable.getOrAddSymbol(feature);
            ObjectToDoubleMap<Integer> categoryCounts = this.mFeatureToCategoryCount.get(featureId);
            if (categoryCounts == null) {
                categoryCounts = new ObjectToDoubleMap();
                this.mFeatureToCategoryCount.put(featureId, categoryCounts);
            }
            categoryCounts.increment(categoryId, value);
        }
    }

    @Override
    public void handle(Classified<E> classified) {
        this.handle(classified.getObject(), classified.getClassification());
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    static double idf(double docFrequency, double numDocs) {
        return Math.log(numDocs / docFrequency);
    }

    static double tf(double count) {
        return Math.sqrt(count);
    }

    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = -4757808688956812832L;
        final TfIdfClassifierTrainer<F> mTrainer;

        public Serializer() {
            this(null);
        }

        public Serializer(TfIdfClassifierTrainer<F> trainer) {
            this.mTrainer = trainer;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            AbstractExternalizable.serializeOrCompile(this.mTrainer.mFeatureExtractor, out);
            out.writeObject(this.mTrainer.mFeatureToCategoryCount);
            out.writeObject(this.mTrainer.mFeatureSymbolTable);
            out.writeObject(this.mTrainer.mCategorySymbolTable);
        }

        @Override
        public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException {
            FeatureExtractor featureExtractor = (FeatureExtractor)objIn.readObject();
            Map featureToCategoryCount = (Map)objIn.readObject();
            MapSymbolTable featureSymbolTable = (MapSymbolTable)objIn.readObject();
            MapSymbolTable categorySymbolTable = (MapSymbolTable)objIn.readObject();
            return new TfIdfClassifierTrainer(featureExtractor, featureToCategoryCount, featureSymbolTable, categorySymbolTable);
        }
    }

    static class TfIdfClassifier<G>
    implements ScoredClassifier<G> {
        final FeatureExtractor<? super G> mFeatureExtractor;
        final MapSymbolTable mFeatureSymbolTable;
        final String[] mCategories;
        final float[] mFeatureIdfs;
        final int[] mFeatureOffsets;
        final int[] mCategoryIds;
        final float[] mTfIdfs;

        TfIdfClassifier(FeatureExtractor<? super G> featureExtractor, MapSymbolTable featureSymbolTable, String[] categories, float[] featureIdfs, int[] featureOffsets, int[] categoryIds, float[] tfIdfs) {
            this.mFeatureExtractor = featureExtractor;
            this.mFeatureSymbolTable = featureSymbolTable;
            this.mCategories = categories;
            this.mFeatureIdfs = featureIdfs;
            this.mFeatureOffsets = featureOffsets;
            this.mCategoryIds = categoryIds;
            this.mTfIdfs = tfIdfs;
        }

        public String toString() {
            int i;
            StringBuilder sb = new StringBuilder();
            sb.append("TfIdfClassifierTrainer.TfIdfClassifier\n");
            sb.append("Feature Symbol Table\n  ");
            sb.append(this.mFeatureSymbolTable.toString());
            sb.append("\n");
            sb.append("Categories\n");
            for (i = 0; i < this.mCategories.length; ++i) {
                sb.append("  " + i + "=" + this.mCategories[i] + "\n");
            }
            sb.append("Index  Feature IDF  offset\n");
            for (i = 0; i < this.mFeatureIdfs.length; ++i) {
                sb.append("  " + i + "  " + this.mFeatureSymbolTable.idToSymbol(i) + "   " + this.mFeatureIdfs[i] + "   " + this.mFeatureOffsets[i] + "\n");
            }
            sb.append("Index  CategoryID  TF-IDF\n");
            for (i = 0; i < this.mCategoryIds.length; ++i) {
                sb.append("  " + i + "   " + this.mCategoryIds[i] + "    " + this.mTfIdfs[i] + "\n");
            }
            return sb.toString();
        }

        @Override
        public ScoredClassification classify(G in) {
            Map<String, Number> featureVector = this.mFeatureExtractor.features(in);
            double[] scores = new double[this.mCategories.length];
            double inputLengthSquared = 0.0;
            for (Map.Entry<String, Number> featureValue : featureVector.entrySet()) {
                String feature = featureValue.getKey();
                int featureId = this.mFeatureSymbolTable.symbolToID(feature);
                if (featureId == -1) continue;
                double inputTf = TfIdfClassifierTrainer.tf(featureValue.getValue().doubleValue());
                double inputIdf = this.mFeatureIdfs[featureId];
                double inputTfIdf = inputTf * inputIdf;
                inputLengthSquared += inputTfIdf * inputTfIdf;
                for (int offset = this.mFeatureOffsets[featureId]; offset < this.mFeatureOffsets[featureId + 1]; ++offset) {
                    int categoryId = this.mCategoryIds[offset];
                    double docNormedTfIdf = this.mTfIdfs[offset];
                    int n = categoryId;
                    scores[n] = scores[n] + docNormedTfIdf * inputTfIdf;
                }
            }
            double inputLength = Math.sqrt(inputLengthSquared);
            ArrayList<ScoredObject<String>> catScores = new ArrayList<ScoredObject<String>>(this.mCategories.length);
            for (int i = 0; i < scores.length; ++i) {
                double score = scores[i] / inputLength;
                catScores.add(new ScoredObject<String>(this.mCategories[i], score));
            }
            return ScoredClassification.create(catScores);
        }
    }

    static class Externalizer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = 5578122239615646843L;
        final TfIdfClassifierTrainer<F> mTrainer;

        public Externalizer() {
            this(null);
        }

        public Externalizer(TfIdfClassifierTrainer<F> trainer) {
            this.mTrainer = trainer;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            int i;
            AbstractExternalizable.compileOrSerialize(this.mTrainer.mFeatureExtractor, out);
            int numFeatures = this.mTrainer.mFeatureSymbolTable.numSymbols();
            out.writeObject(this.mTrainer.mFeatureSymbolTable);
            int numCats = this.mTrainer.mCategorySymbolTable.numSymbols();
            double numCatsD = numCats;
            out.writeInt(numCats);
            for (i = 0; i < numCats; ++i) {
                out.writeUTF(this.mTrainer.mCategorySymbolTable.idToSymbol(i));
            }
            for (i = 0; i < this.mTrainer.mFeatureSymbolTable.numSymbols(); ++i) {
                int docFrequency = this.mTrainer.mFeatureToCategoryCount.get(i).size();
                float idf = (float)TfIdfClassifierTrainer.idf(docFrequency, numCatsD);
                out.writeFloat(idf);
            }
            int nextFeatureOffset = 0;
            for (int i2 = 0; i2 < numFeatures; ++i2) {
                out.writeInt(nextFeatureOffset);
                int featureSize = this.mTrainer.mFeatureToCategoryCount.get(i2).size();
                nextFeatureOffset += featureSize;
            }
            out.writeInt(nextFeatureOffset);
            double[] catLengths = new double[numCats];
            for (Map.Entry<Integer, ObjectToDoubleMap<Integer>> entry : this.mTrainer.mFeatureToCategoryCount.entrySet()) {
                ObjectToDoubleMap<Integer> categoryCounts = entry.getValue();
                double idf = TfIdfClassifierTrainer.idf(categoryCounts.size(), numCatsD);
                for (Map.Entry categoryCount : categoryCounts.entrySet()) {
                    int catId = (Integer)categoryCount.getKey();
                    double count = (Double)categoryCount.getValue();
                    double tfIdf = TfIdfClassifierTrainer.tf(count) * idf;
                    int n = catId;
                    catLengths[n] = catLengths[n] + tfIdf * tfIdf;
                }
            }
            for (int i3 = 0; i3 < catLengths.length; ++i3) {
                catLengths[i3] = Math.sqrt(catLengths[i3]);
            }
            for (int featureId = 0; featureId < numFeatures; ++featureId) {
                ObjectToDoubleMap<Integer> categoryCounts = this.mTrainer.mFeatureToCategoryCount.get(featureId);
                double idf = TfIdfClassifierTrainer.idf(categoryCounts.size(), numCatsD);
                for (Map.Entry categoryCount : categoryCounts.entrySet()) {
                    int catId = (Integer)categoryCount.getKey();
                    double count = (Double)categoryCount.getValue();
                    float tfIdf = (float)(TfIdfClassifierTrainer.tf(count) * idf / catLengths[catId]);
                    out.writeInt(catId);
                    out.writeFloat(tfIdf);
                }
            }
        }

        @Override
        public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException {
            int catIdTfIdfArraySize;
            FeatureExtractor featureExtractor = (FeatureExtractor)objIn.readObject();
            MapSymbolTable featureSymbolTable = (MapSymbolTable)objIn.readObject();
            int numFeatures = featureSymbolTable.numSymbols();
            int numCategories = objIn.readInt();
            String[] categories = new String[numCategories];
            for (int i = 0; i < numCategories; ++i) {
                categories[i] = objIn.readUTF();
            }
            float[] featureIdfs = new float[featureSymbolTable.numSymbols()];
            for (int i = 0; i < featureIdfs.length; ++i) {
                featureIdfs[i] = objIn.readFloat();
            }
            int[] featureOffsets = new int[numFeatures + 1];
            for (int i = 0; i < numFeatures; ++i) {
                featureOffsets[i] = objIn.readInt();
            }
            featureOffsets[featureOffsets.length - 1] = catIdTfIdfArraySize = objIn.readInt();
            int[] catIds = new int[catIdTfIdfArraySize];
            float[] normedTfIdfs = new float[catIdTfIdfArraySize];
            for (int i = 0; i < catIdTfIdfArraySize; ++i) {
                catIds[i] = objIn.readInt();
                normedTfIdfs[i] = objIn.readFloat();
            }
            return new TfIdfClassifier(featureExtractor, featureSymbolTable, categories, featureIdfs, featureOffsets, catIds, normedTfIdfs);
        }
    }
}

