package com.aliasi.classify;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.matrix.EuclideanDistance;
import com.aliasi.matrix.Vector;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Compilable;
import com.aliasi.util.Distance;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Proximity;
import com.aliasi.util.ScoredObject;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/aliasi/classify/KnnClassifier.class */
public class KnnClassifier<E> implements ScoredClassifier<E>, ObjectHandler<Classified<E>>, Compilable, Serializable {
    static final long serialVersionUID = 5692985587478284405L;
    final FeatureExtractor<? super E> mFeatureExtractor;
    final int mK;
    final Proximity<Vector> mProximity;
    final boolean mWeightByProximity;
    final List<Integer> mTrainingCategories;
    final List<Vector> mTrainingVectors;
    final MapSymbolTable mFeatureSymbolTable;
    final MapSymbolTable mCategorySymbolTable;

    /* loaded from: input_file:com/aliasi/classify/KnnClassifier$ProximityWrapper.class */
    static class ProximityWrapper implements Proximity<Vector>, Serializable {
        static final long serialVersionUID = -1410999733708772109L;
        Distance<Vector> mDistance;

        public ProximityWrapper() {
        }

        public ProximityWrapper(Distance<Vector> distance) {
            this.mDistance = distance;
        }

        @Override // com.aliasi.util.Proximity
        public double proximity(Vector vector, Vector vector2) {
            double distance = this.mDistance.distance(vector, vector2);
            if (distance < 0.0d) {
                return Double.MAX_VALUE;
            }
            return 1.0d / (1.0d + distance);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/aliasi/classify/KnnClassifier$Serializer.class */
    public static class Serializer<F> extends AbstractExternalizable {
        static final long serialVersionUID = 4951969636521202268L;
        final KnnClassifier<F> mClassifier;

        public Serializer() {
            this(null);
        }

        public Serializer(KnnClassifier<F> knnClassifier) {
            this.mClassifier = knnClassifier;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mFeatureExtractor, objectOutput);
            objectOutput.writeInt(this.mClassifier.mK);
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mProximity, objectOutput);
            objectOutput.writeBoolean(this.mClassifier.mWeightByProximity);
            int size = this.mClassifier.mTrainingCategories.size();
            objectOutput.writeInt(size);
            for (int i = 0; i < size; i++) {
                objectOutput.writeInt(this.mClassifier.mTrainingCategories.get(i).intValue());
            }
            for (int i2 = 0; i2 < size; i2++) {
                AbstractExternalizable.serializeOrCompile(this.mClassifier.mTrainingVectors.get(i2), objectOutput);
            }
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mFeatureSymbolTable, objectOutput);
            AbstractExternalizable.serializeOrCompile(this.mClassifier.mCategorySymbolTable, objectOutput);
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            FeatureExtractor featureExtractor = (FeatureExtractor) objectInput.readObject();
            int readInt = objectInput.readInt();
            Proximity proximity = (Proximity) objectInput.readObject();
            boolean readBoolean = objectInput.readBoolean();
            int readInt2 = objectInput.readInt();
            ArrayList arrayList = new ArrayList(readInt2);
            for (int i = 0; i < readInt2; i++) {
                arrayList.add(Integer.valueOf(objectInput.readInt()));
            }
            ArrayList arrayList2 = new ArrayList(readInt2);
            for (int i2 = 0; i2 < readInt2; i2++) {
                arrayList2.add((Vector) objectInput.readObject());
            }
            return new KnnClassifier(featureExtractor, readInt, proximity, readBoolean, arrayList, arrayList2, (MapSymbolTable) objectInput.readObject(), (MapSymbolTable) objectInput.readObject());
        }
    }

    /* loaded from: input_file:com/aliasi/classify/KnnClassifier$TrainingInstance.class */
    static class TrainingInstance {
        final String mCategory;
        final Vector mVector;

        TrainingInstance(String str, Vector vector) {
            this.mCategory = str;
            this.mVector = vector;
        }
    }

    public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int i) {
        this(featureExtractor, i, EuclideanDistance.DISTANCE);
    }

    public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int i, Distance<Vector> distance) {
        this(featureExtractor, i, new ProximityWrapper(distance), false);
    }

    public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int i, Proximity<Vector> proximity, boolean z) {
        this(featureExtractor, i, proximity, z, new ArrayList(), new ArrayList(), new MapSymbolTable(), new MapSymbolTable());
    }

    KnnClassifier(FeatureExtractor<? super E> featureExtractor, int i, Proximity<Vector> proximity, boolean z, List<Integer> list, List<Vector> list2, MapSymbolTable mapSymbolTable, MapSymbolTable mapSymbolTable2) {
        this.mFeatureExtractor = featureExtractor;
        this.mK = i;
        this.mProximity = proximity;
        this.mWeightByProximity = z;
        this.mTrainingCategories = list;
        this.mTrainingVectors = list2;
        this.mFeatureSymbolTable = mapSymbolTable;
        this.mCategorySymbolTable = mapSymbolTable2;
    }

    public FeatureExtractor<? super E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public Proximity<Vector> proximity() {
        return this.mProximity;
    }

    public List<String> categories() {
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = this.mTrainingCategories.iterator();
        while (it.hasNext()) {
            arrayList.add(this.mCategorySymbolTable.idToSymbol(it.next()));
        }
        return arrayList;
    }

    public boolean weightByProximity() {
        return this.mWeightByProximity;
    }

    public int k() {
        return this.mK;
    }

    void handle(E e, Classification classification) {
        String bestCategory = classification.bestCategory();
        Vector vectorAddSymbols = Features.toVectorAddSymbols(this.mFeatureExtractor.features(e), this.mFeatureSymbolTable, 2147483646, false);
        this.mTrainingCategories.add(this.mCategorySymbolTable.getOrAddSymbolInteger(bestCategory));
        this.mTrainingVectors.add(vectorAddSymbols);
    }

    @Override // com.aliasi.corpus.ObjectHandler
    public void handle(Classified<E> classified) {
        handle(classified.getObject(), classified.getClassification());
    }

    @Override // com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public ScoredClassification classify(E e) {
        Vector vector = Features.toVector(this.mFeatureExtractor.features(e), this.mFeatureSymbolTable, 2147483646, false);
        BoundedPriorityQueue boundedPriorityQueue = new BoundedPriorityQueue(ScoredObject.comparator(), this.mK);
        for (int i = 0; i < this.mTrainingCategories.size(); i++) {
            boundedPriorityQueue.offer(new ScoredObject(this.mTrainingCategories.get(i), this.mProximity.proximity(vector, this.mTrainingVectors.get(i))));
        }
        int numSymbols = this.mCategorySymbolTable.numSymbols();
        double[] dArr = new double[numSymbols];
        Iterator<E> it = boundedPriorityQueue.iterator();
        while (it.hasNext()) {
            ScoredObject scoredObject = (ScoredObject) it.next();
            int intValue = ((Integer) scoredObject.getObject()).intValue();
            dArr[intValue] = dArr[intValue] + (this.mWeightByProximity ? scoredObject.score() : 1.0d);
        }
        ArrayList arrayList = new ArrayList(numSymbols);
        for (int i2 = 0; i2 < numSymbols; i2++) {
            arrayList.add(new ScoredObject(this.mCategorySymbolTable.idToSymbol(i2), dArr[i2]));
        }
        return ScoredClassification.create(arrayList);
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(writeReplace());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ RankedClassification classify(Object obj) {
        return classify((KnnClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ Classification classify(Object obj) {
        return classify((KnnClassifier<E>) obj);
    }
}
