package com.aliasi.classify;

import com.aliasi.coref.Matcher;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.Features;
import com.aliasi.lm.CompiledNGramProcessLM;
import com.aliasi.matrix.KernelFunction;
import com.aliasi.matrix.Vector;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Arrays;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/* loaded from: input_file:com/aliasi/classify/PerceptronClassifier.class */
public class PerceptronClassifier<E> implements ScoredClassifier<E>, Serializable {
    static final long serialVersionUID = 8752291174601085455L;
    final FeatureExtractor<? super E> mFeatureExtractor;
    final MapSymbolTable mSymbolTable;
    final KernelFunction mKernelFunction;
    final Vector[] mBasisVectors;
    final int[] mBasisWeights;
    final String mAcceptCategory;
    final String mRejectCategory;
    static final Vector[] EMPTY_SPARSE_FLOAT_VECTOR_ARRAY = new Vector[0];
    static final int INITIAL_BASIS_SIZE = 32768;

    /* loaded from: input_file:com/aliasi/classify/PerceptronClassifier$CorpusCollector.class */
    class CorpusCollector implements ObjectHandler<Classified<E>> {
        final List<Vector> mInputFeatureVectorList = new ArrayList();
        final List<Boolean> mInputAcceptList = new ArrayList();

        CorpusCollector() {
        }

        @Override // com.aliasi.corpus.ObjectHandler
        public void handle(Classified<E> classified) {
            E object = classified.getObject();
            Classification classification = classified.getClassification();
            this.mInputFeatureVectorList.add(Features.toVectorAddSymbols(PerceptronClassifier.this.mFeatureExtractor.features(object), PerceptronClassifier.this.mSymbolTable, Integer.MAX_VALUE, false));
            this.mInputAcceptList.add(PerceptronClassifier.this.mAcceptCategory.equals(classification.bestCategory()) ? Boolean.TRUE : Boolean.FALSE);
        }

        Vector[] featureVectors() {
            return (Vector[]) this.mInputFeatureVectorList.toArray(PerceptronClassifier.EMPTY_SPARSE_FLOAT_VECTOR_ARRAY);
        }

        boolean[] polarities() {
            boolean[] zArr = new boolean[this.mInputAcceptList.size()];
            for (int i = 0; i < zArr.length; i++) {
                zArr[i] = this.mInputAcceptList.get(i).booleanValue();
            }
            return zArr;
        }
    }

    /* loaded from: input_file:com/aliasi/classify/PerceptronClassifier$Externalizer.class */
    static class Externalizer<F> extends AbstractExternalizable {
        static final long serialVersionUID = -1901362811305741506L;
        final PerceptronClassifier<F> mClassifier;

        public Externalizer() {
            this(null);
        }

        public Externalizer(PerceptronClassifier<F> perceptronClassifier) {
            this.mClassifier = perceptronClassifier;
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            FeatureExtractor featureExtractor = (FeatureExtractor) objectInput.readObject();
            KernelFunction kernelFunction = (KernelFunction) objectInput.readObject();
            MapSymbolTable mapSymbolTable = (MapSymbolTable) objectInput.readObject();
            int readInt = objectInput.readInt();
            Vector[] vectorArr = new Vector[readInt];
            for (int i = 0; i < readInt; i++) {
                vectorArr[i] = (Vector) objectInput.readObject();
            }
            int[] iArr = new int[readInt];
            for (int i2 = 0; i2 < readInt; i2++) {
                iArr[i2] = objectInput.readInt();
            }
            return new PerceptronClassifier(featureExtractor, kernelFunction, mapSymbolTable, vectorArr, iArr, objectInput.readUTF(), objectInput.readUTF());
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            AbstractExternalizable.compileOrSerialize(this.mClassifier.mFeatureExtractor, objectOutput);
            AbstractExternalizable.compileOrSerialize(this.mClassifier.mKernelFunction, objectOutput);
            objectOutput.writeObject(this.mClassifier.mSymbolTable);
            objectOutput.writeInt(this.mClassifier.mBasisVectors.length);
            for (int i = 0; i < this.mClassifier.mBasisVectors.length; i++) {
                objectOutput.writeObject(this.mClassifier.mBasisVectors[i]);
            }
            for (int i2 = 0; i2 < this.mClassifier.mBasisWeights.length; i2++) {
                objectOutput.writeInt(this.mClassifier.mBasisWeights[i2]);
            }
            objectOutput.writeUTF(this.mClassifier.mAcceptCategory);
            objectOutput.writeUTF(this.mClassifier.mRejectCategory);
        }
    }

    PerceptronClassifier(FeatureExtractor<? super E> featureExtractor, KernelFunction kernelFunction, MapSymbolTable mapSymbolTable, Vector[] vectorArr, int[] iArr, String str, String str2) {
        this.mFeatureExtractor = featureExtractor;
        this.mKernelFunction = kernelFunction;
        this.mBasisVectors = vectorArr;
        this.mBasisWeights = iArr;
        this.mAcceptCategory = str;
        this.mRejectCategory = str2;
        this.mSymbolTable = mapSymbolTable;
    }

    public PerceptronClassifier(Corpus<ObjectHandler<Classified<E>>> corpus, FeatureExtractor<? super E> featureExtractor, KernelFunction kernelFunction, String str, int i, String str2, String str3) throws IOException {
        this.mFeatureExtractor = featureExtractor;
        this.mKernelFunction = kernelFunction;
        this.mAcceptCategory = str2;
        this.mRejectCategory = str3;
        this.mSymbolTable = new MapSymbolTable();
        CorpusCollector corpusCollector = new CorpusCollector();
        corpus.visitCorpus(corpusCollector);
        Vector[] featureVectors = corpusCollector.featureVectors();
        boolean[] polarities = corpusCollector.polarities();
        int i2 = -1;
        int[] iArr = new int[INITIAL_BASIS_SIZE];
        int[] iArr2 = new int[INITIAL_BASIS_SIZE];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < featureVectors.length; i4++) {
                if ((prediction(featureVectors[i4], featureVectors, polarities, iArr, iArr2, i2) > 0.0d) != polarities[i4]) {
                    i2++;
                    if (i2 >= iArr.length) {
                        iArr = Arrays.reallocate(iArr);
                        iArr2 = Arrays.reallocate(iArr2);
                    }
                    iArr2[i2] = i4;
                    iArr[i2] = 1;
                } else if (i2 >= 0) {
                    int[] iArr3 = iArr;
                    int i5 = i2;
                    iArr3[i5] = iArr3[i5] + 1;
                }
            }
        }
        HashMap hashMap = new HashMap();
        int i6 = 0;
        for (int i7 = 0; i7 <= i2; i7++) {
            if (!hashMap.containsKey(Integer.valueOf(iArr2[i7]))) {
                int i8 = i6;
                i6++;
                hashMap.put(Integer.valueOf(iArr2[i7]), Integer.valueOf(i8));
            }
        }
        this.mBasisVectors = new Vector[hashMap.size()];
        this.mBasisWeights = new int[hashMap.size()];
        int i9 = 0;
        int i10 = i2 + 1;
        while (true) {
            i10--;
            if (i10 < 0) {
                return;
            }
            int i11 = iArr2[i10];
            int intValue = ((Integer) hashMap.get(Integer.valueOf(i11))).intValue();
            this.mBasisVectors[intValue] = featureVectors[i11];
            i9 += iArr[i10];
            if (polarities[i10]) {
                int[] iArr4 = this.mBasisWeights;
                iArr4[intValue] = iArr4[intValue] + i9;
            } else {
                int[] iArr5 = this.mBasisWeights;
                iArr5[intValue] = iArr5[intValue] - i9;
            }
        }
    }

    public KernelFunction kernelFunction() {
        return this.mKernelFunction;
    }

    public FeatureExtractor<? super E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Averaged Perceptron");
        sb.append("  Kernel Function=" + this.mKernelFunction + "\n");
        for (int i = 0; i < this.mBasisVectors.length; i++) {
            sb.append("  idx=" + i + Strings.SINGLE_SPACE_STRING + "vec=" + this.mBasisVectors[i] + " wgt=" + this.mBasisWeights[i] + "\n");
        }
        return sb.toString();
    }

    @Override // com.aliasi.classify.ScoredClassifier, com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public ScoredClassification classify(E e) {
        Vector vector = Features.toVector(this.mFeatureExtractor.features(e), this.mSymbolTable, Integer.MAX_VALUE, false);
        double d = 0.0d;
        int length = this.mBasisVectors.length;
        while (true) {
            length--;
            if (length < 0) {
                break;
            }
            d += this.mBasisWeights[length] * this.mKernelFunction.proximity(this.mBasisVectors[length], vector);
        }
        return d > 0.0d ? new ScoredClassification(new String[]{this.mAcceptCategory, this.mRejectCategory}, new double[]{d, -d}) : new ScoredClassification(new String[]{this.mRejectCategory, this.mAcceptCategory}, new double[]{-d, d});
    }

    double prediction(Vector vector, Vector[] vectorArr, boolean[] zArr, int[] iArr, int[] iArr2, int i) {
        double d = 0.0d;
        for (int i2 = i; i2 >= 0; i2--) {
            d += (zArr[i2] ? 1 : -1) * this.mKernelFunction.proximity(vector, vectorArr[iArr2[i2]]);
        }
        return d;
    }

    static double power(double d, int i) {
        switch (i) {
            case CompiledNGramProcessLM.ROOT_NODE_INDEX /* 0 */:
                return 1.0d;
            case 1:
                return d;
            case Matcher.MAX_DISTANCE_SCORE /* 2 */:
                return d * d;
            case 3:
                return d * d * d;
            case Matcher.MAX_SEMANTIC_SCORE /* 4 */:
                return d * d * d * d;
            default:
                return Math.pow(d, i);
        }
    }

    private Object writeReplace() {
        return new Externalizer(this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.RankedClassifier, com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ RankedClassification classify(Object obj) {
        return classify((PerceptronClassifier<E>) obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.aliasi.classify.BaseClassifier
    public /* bridge */ /* synthetic */ Classification classify(Object obj) {
        return classify((PerceptronClassifier<E>) obj);
    }
}
