package com.aliasi.stats;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Locale;

/* loaded from: input_file:com/aliasi/stats/LogisticRegression.class */
public class LogisticRegression implements Compilable, Serializable {
    static final long serialVersionUID = -8585743596322227589L;
    private final Vector[] mWeightVectors;

    /* loaded from: input_file:com/aliasi/stats/LogisticRegression$Externalizer.class */
    static class Externalizer extends AbstractExternalizable {
        static final long serialVersionUID = -2256261505231943102L;
        final LogisticRegression mRegression;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegression logisticRegression) {
            this.mRegression = logisticRegression;
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            int length = this.mRegression.mWeightVectors.length + 1;
            objectOutput.writeInt(length);
            int numDimensions = this.mRegression.mWeightVectors[0].numDimensions();
            objectOutput.writeInt(numDimensions);
            for (int i = 0; i < length - 1; i++) {
                Vector vector = this.mRegression.mWeightVectors[i];
                for (int i2 = 0; i2 < numDimensions; i2++) {
                    objectOutput.writeDouble(vector.value(i2));
                }
            }
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws IOException {
            int readInt = objectInput.readInt();
            int readInt2 = objectInput.readInt();
            Vector[] vectorArr = new Vector[readInt - 1];
            for (int i = 0; i < vectorArr.length; i++) {
                DenseVector denseVector = new DenseVector(readInt2);
                vectorArr[i] = denseVector;
                for (int i2 = 0; i2 < readInt2; i2++) {
                    denseVector.setValue(i2, objectInput.readDouble());
                }
            }
            return new LogisticRegression(vectorArr);
        }
    }

    public LogisticRegression(Vector[] vectorArr) {
        if (vectorArr.length < 1) {
            throw new IllegalArgumentException("Require at least one weight vector.");
        }
        int numDimensions = vectorArr[0].numDimensions();
        for (int i = 1; i < vectorArr.length; i++) {
            if (numDimensions != vectorArr[i].numDimensions()) {
                throw new IllegalArgumentException("All weight vectors must be same dimensionality. Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + i + "]=" + vectorArr[i].numDimensions());
            }
        }
        this.mWeightVectors = vectorArr;
    }

    public LogisticRegression(Vector vector) {
        this.mWeightVectors = new Vector[]{vector};
    }

    public int numInputDimensions() {
        return this.mWeightVectors[0].numDimensions();
    }

    public int numOutcomes() {
        return this.mWeightVectors.length + 1;
    }

    public Vector[] weightVectors() {
        Vector[] vectorArr = new Vector[this.mWeightVectors.length];
        for (int i = 0; i < vectorArr.length; i++) {
            vectorArr[i] = Matrices.unmodifiableVector(this.mWeightVectors[i]);
        }
        return vectorArr;
    }

    public double[] classify(Vector vector) {
        double[] dArr = new double[numOutcomes()];
        classify(vector, dArr);
        return dArr;
    }

    public void classify(Vector vector, double[] dArr) {
        if (numInputDimensions() != vector.numDimensions()) {
            throw new IllegalArgumentException("Vector and classifer must be of same dimensionality. Regression model this.numInputDimensions()=" + numInputDimensions() + " Vector x.numDimensions()=" + vector.numDimensions());
        }
        int length = dArr.length - 1;
        dArr[length] = 0.0d;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            dArr[i] = vector.dotProduct(this.mWeightVectors[i]);
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = Math.exp(dArr[i2] - d);
            d2 += dArr[i2];
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Externalizer(this);
    }

    @Deprecated
    public static LogisticRegression estimate(Vector[] vectorArr, int[] iArr, RegressionPrior regressionPrior, AnnealingSchedule annealingSchedule, double d, int i, int i2, PrintWriter printWriter) {
        return estimate(vectorArr, iArr, regressionPrior, annealingSchedule, printWriter == null ? null : Reporters.writer(printWriter).setLevel(LogLevel.DEBUG), d, i, i2);
    }

    public static LogisticRegression estimate(Vector[] vectorArr, int[] iArr, RegressionPrior regressionPrior, AnnealingSchedule annealingSchedule, Reporter reporter, double d, int i, int i2) {
        return estimate(vectorArr, iArr, regressionPrior, Math.max(1, iArr.length / 50), null, annealingSchedule, d, 10, i, i2, null, reporter);
    }

    public static LogisticRegression estimate(Vector[] vectorArr, int[] iArr, RegressionPrior regressionPrior, int i, LogisticRegression logisticRegression, AnnealingSchedule annealingSchedule, double d, int i2, int i3, int i4, ObjectHandler<LogisticRegression> objectHandler, Reporter reporter) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("Logistic Regression Estimation");
        boolean z = !Double.isNaN(d);
        reporter.info("Monitoring convergence=" + z);
        if (d < 0.0d) {
            throw new IllegalArgumentException("Min improvement should be Double.NaN to turn off convergence or >= 0.0 otherwise. Found minImprovement=" + d);
        }
        if (vectorArr.length < 1) {
            reporter.fatal("Require at least one training instance.");
            throw new IllegalArgumentException("Require at least one training instance.");
        }
        if (vectorArr.length != iArr.length) {
            String str = "Require same number of training instances as outcomes. Found xs.length=" + vectorArr.length + " cs.length=" + iArr.length;
            reporter.fatal(str);
            throw new IllegalArgumentException(str);
        }
        int length = vectorArr.length;
        int max = Math.max(iArr);
        int i5 = max + 1;
        int numDimensions = vectorArr[0].numDimensions();
        regressionPrior.verifyNumberOfDimensions(numDimensions);
        for (int i6 = 1; i6 < vectorArr.length; i6++) {
            if (vectorArr[i6].numDimensions() != numDimensions) {
                String str2 = "Number of dimensions must match for all input vectors. Found xs[0].numDimensions()=" + numDimensions + " xs[" + i6 + "].numDimensions()=" + vectorArr[i6].numDimensions();
                reporter.fatal(str2);
                throw new IllegalArgumentException(str2);
            }
        }
        DenseVector[] denseVectorArr = new DenseVector[max];
        if (logisticRegression == null) {
            for (int i7 = 0; i7 < max; i7++) {
                denseVectorArr[i7] = new DenseVector(numDimensions);
            }
        } else {
            Vector[] weightVectors = logisticRegression.weightVectors();
            for (int i8 = 0; i8 < denseVectorArr.length; i8++) {
                denseVectorArr[i8] = new DenseVector(weightVectors[i8]);
            }
        }
        LogisticRegression logisticRegression2 = new LogisticRegression(denseVectorArr);
        boolean z2 = (regressionPrior == null || regressionPrior.isUniform()) ? false : true;
        reporter.info("Number of dimensions=" + numDimensions);
        reporter.info("Number of Outcomes=" + i5);
        reporter.info("Number of Parameters=" + ((i5 - 1) * numDimensions));
        reporter.info("Number of Training Instances=" + iArr.length);
        reporter.info("Prior=" + regressionPrior);
        reporter.info("Annealing Schedule=" + annealingSchedule);
        reporter.info("Minimum Epochs=" + i3);
        reporter.info("Maximum Epochs=" + i4);
        reporter.info("Minimum Improvement Per Period=" + d);
        reporter.info("Has Informative Prior=" + z2);
        double d2 = -8.988465674311579E307d;
        double[] dArr = new double[i2];
        Arrays.fill(dArr, Double.POSITIVE_INFINITY);
        int i9 = 0;
        double d3 = Double.NEGATIVE_INFINITY;
        int i10 = 0;
        while (true) {
            if (i10 >= i4) {
                break;
            }
            DenseVector[] copy = copy(denseVectorArr);
            double learningRate = annealingSchedule.learningRate(i10);
            double[] dArr2 = new double[i5];
            for (int i11 = 0; i11 < length; i11++) {
                if (i11 % (length / 10) == 0 && reporter.isDebugEnabled()) {
                    reporter.debug("          epoch " + i10 + " is " + ((100 * i11) / length) + "% complete");
                }
                Vector vector = vectorArr[i11];
                int i12 = iArr[i11];
                if (z2 && i11 > 0 && i11 % i == 0) {
                    adjustWeightsWithPrior(denseVectorArr, regressionPrior, (learningRate * i) / length);
                }
                logisticRegression2.classify(vector, dArr2);
                for (int i13 = 0; i13 < max; i13++) {
                    adjustWeightsWithConditionalProbs(denseVectorArr[i13], dArr2[i13], learningRate, vector, i13, i12);
                }
            }
            reporter.debug("catching up regularizations at end of epoch");
            int i14 = length % i;
            if (i14 == 0) {
                i14 = i;
            }
            if (z2) {
                adjustWeightsWithPrior(denseVectorArr, regressionPrior, (learningRate * i14) / length);
            }
            if (objectHandler != null) {
                reporter.debug("handling regression for epoch");
                objectHandler.handle(logisticRegression2);
            }
            if (z) {
                reporter.debug("computing log likelihood");
                double log2Likelihood = log2Likelihood(vectorArr, iArr, logisticRegression2);
                double log2Prior = regressionPrior.log2Prior(denseVectorArr);
                double log2Prior2 = log2Likelihood + regressionPrior.log2Prior(denseVectorArr);
                if (log2Prior2 > d3) {
                    d3 = log2Prior2;
                }
                if (reporter.isInfoEnabled()) {
                    Formatter formatter = null;
                    try {
                        try {
                            formatter = new Formatter(Locale.ENGLISH);
                            formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", Integer.valueOf(i10), Double.valueOf(learningRate), Double.valueOf(log2Likelihood), Double.valueOf(log2Prior), Double.valueOf(log2Prior2), Double.valueOf(d3));
                            reporter.info(formatter.toString());
                            if (formatter != null) {
                                formatter.close();
                            }
                        } catch (IllegalFormatException e) {
                            reporter.warn("Illegal format in Logistic Regression");
                            if (formatter != null) {
                                formatter.close();
                            }
                        }
                    } catch (Throwable th) {
                        if (formatter != null) {
                            formatter.close();
                        }
                        throw th;
                    }
                }
                if (annealingSchedule.receivedError(i10, learningRate, -log2Prior2)) {
                    double relativeAbsoluteDifference = Math.relativeAbsoluteDifference(d2, log2Prior2);
                    dArr[i9] = relativeAbsoluteDifference;
                    i9++;
                    if (i9 == dArr.length) {
                        i9 = 0;
                    }
                    double mean = Statistics.mean(dArr);
                    reporter.debug("relativeAbsDiff=" + relativeAbsoluteDifference + " rollingAvg=" + mean);
                    d2 = log2Prior2;
                    if (mean < d) {
                        reporter.info("Converged with Rolling Average Absolute Difference=" + mean);
                        break;
                    }
                } else {
                    reporter.info("Annealing rejected update at learningRate=" + learningRate + " error=" + (-log2Prior2));
                    denseVectorArr = copy;
                    logisticRegression2 = new LogisticRegression(denseVectorArr);
                }
            } else {
                reporter.info("Unmonitored Epoch=" + i10);
            }
            i10++;
        }
        return logisticRegression2;
    }

    public static double log2Likelihood(Vector[] vectorArr, int[] iArr, LogisticRegression logisticRegression) {
        if (vectorArr.length != iArr.length) {
            throw new IllegalArgumentException("Inputs and categories must be same length. Found inputs.length=" + vectorArr.length + " cats.length=" + iArr.length);
        }
        int length = vectorArr.length;
        double d = 0.0d;
        double[] dArr = new double[logisticRegression.numOutcomes()];
        for (int i = 0; i < length; i++) {
            logisticRegression.classify(vectorArr[i], dArr);
            d += Math.log2(dArr[iArr[i]]);
        }
        return d;
    }

    private static void adjustWeightsWithPrior(DenseVector[] denseVectorArr, RegressionPrior regressionPrior, double d) {
        for (DenseVector denseVector : denseVectorArr) {
            int numDimensions = denseVector.numDimensions();
            for (int i = 0; i < numDimensions; i++) {
                double value = denseVector.value(i);
                if (value != regressionPrior.mode(i)) {
                    double gradient = regressionPrior.gradient(value, i) * d;
                    if (gradient != 0.0d) {
                        double d2 = value - gradient;
                        double mode = regressionPrior.mode(i);
                        if (value > mode) {
                            if (d2 < mode) {
                                d2 = mode;
                            }
                        } else if (d2 > mode) {
                            d2 = mode;
                        }
                        denseVector.setValue(i, d2);
                    }
                }
            }
        }
    }

    private static void adjustWeightsWithConditionalProbs(DenseVector denseVector, double d, double d2, Vector vector, int i, int i2) {
        double d3 = i == i2 ? d - 1.0d : d;
        if (d3 == 0.0d) {
            return;
        }
        denseVector.increment((-d2) * d3, vector);
    }

    private static DenseVector[] copy(DenseVector[] denseVectorArr) {
        DenseVector[] denseVectorArr2 = new DenseVector[denseVectorArr.length];
        for (int i = 0; i < denseVectorArr.length; i++) {
            denseVectorArr2[i] = new DenseVector(denseVectorArr[i]);
        }
        return denseVectorArr2;
    }
}
