package com.aliasi.classify;

import com.aliasi.stats.Statistics;
import com.aliasi.util.Math;
import com.aliasi.util.Pair;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.Arrays;

/* loaded from: input_file:com/aliasi/classify/ConditionalClassification.class */
public class ConditionalClassification extends ScoredClassification {
    private final double[] mConditionalProbs;
    private static final double TOLERANCE = 0.01d;

    public ConditionalClassification(String[] strArr, double[] dArr) {
        this(strArr, dArr, dArr, TOLERANCE);
    }

    public ConditionalClassification(String[] strArr, double[] dArr, double[] dArr2) {
        this(strArr, dArr, dArr2, TOLERANCE);
    }

    public ConditionalClassification(String[] strArr, double[] dArr, double d) {
        this(strArr, dArr, dArr, d);
    }

    public ConditionalClassification(String[] strArr, double[] dArr, double[] dArr2, double d) {
        super(strArr, dArr);
        this.mConditionalProbs = dArr2;
        if (d < 0.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Tolerance must be a positive number. Found tolerance=" + d);
        }
        for (int i = 0; i < dArr2.length; i++) {
            if (dArr2[i] < 0.0d || dArr2[i] > 1.0d) {
                throw new IllegalArgumentException("Conditional probabilities must be  between 0.0 and 1.0. Found conditionalProbs[" + i + "]=" + dArr2[i]);
            }
        }
        double sum = Math.sum(dArr2);
        if (sum < 1.0d - d || sum > 1.0d + d) {
            throw new IllegalArgumentException("Conditional probabilities must sum to 1.0. Acceptable tolerance=" + d + " Found sum=" + sum);
        }
    }

    public double conditionalProbability(int i) {
        if (i < 0 || i > this.mConditionalProbs.length - 1) {
            throw new IllegalArgumentException("Require rank in range 0.." + (this.mConditionalProbs.length - 1) + " Found rank=" + i);
        }
        return this.mConditionalProbs[i];
    }

    public double conditionalProbability(String str) {
        for (int i = 0; i < size(); i++) {
            if (category(i).equals(str)) {
                return conditionalProbability(i);
            }
        }
        String str2 = str + " is not a valid category for this classification.  Valid categories are:";
        for (int i2 = 0; i2 < size(); i2++) {
            str2 = str2 + Strings.SINGLE_SPACE_STRING + category(i2) + ",";
        }
        throw new IllegalArgumentException(str2.substring(0, str2.length() - 1));
    }

    @Override // com.aliasi.classify.ScoredClassification, com.aliasi.classify.RankedClassification, com.aliasi.classify.Classification
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Rank  Category  Score  P(Category|Input)\n");
        for (int i = 0; i < size(); i++) {
            sb.append(i + "=" + category(i) + Strings.SINGLE_SPACE_STRING + score(i) + Strings.SINGLE_SPACE_STRING + conditionalProbability(i) + '\n');
        }
        return sb.toString();
    }

    public static ConditionalClassification createLogProbs(String[] strArr, double[] dArr) {
        verifyLengths(strArr, dArr);
        verifyLogProbs(dArr);
        Pair<String[], double[]> sort = sort(strArr, logJointToConditional(dArr));
        return new ConditionalClassification(sort.a(), sort.b());
    }

    public static ConditionalClassification createProbs(String[] strArr, double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d || Double.isInfinite(dArr[i]) || Double.isNaN(dArr[i])) {
                throw new IllegalArgumentException("Probability ratios must be non-negative and finite. Found probabilityRatios[" + i + "]=" + dArr[i]);
            }
        }
        if (Math.sum(dArr) == 0.0d) {
            double[] dArr2 = new double[dArr.length];
            Arrays.fill(dArr2, 1.0d / dArr.length);
            return new ConditionalClassification(strArr, dArr2);
        }
        double[] dArr3 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr3[i2] = Math.log2(dArr[i2]);
        }
        return createLogProbs(strArr, dArr3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void verifyLogProbs(double[] dArr) {
        for (double d : dArr) {
            if (Double.isNaN(d) || d > 0.0d) {
                throw new IllegalArgumentException("Log probs must be non-positive numbers. Found x=" + d);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void verifyLengths(String[] strArr, double[] dArr) {
        if (strArr.length != dArr.length) {
            throw new IllegalArgumentException("Arrays must be same length. Found categories.length=" + strArr.length + " logProbabilities.length=" + dArr.length);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Pair<String[], double[]> sort(String[] strArr, double[] dArr) {
        verifyLengths(strArr, dArr);
        ScoredObject[] scoredObjectArr = new ScoredObject[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            scoredObjectArr[i] = new ScoredObject(strArr[i], dArr[i]);
        }
        String[] strArr2 = new String[scoredObjectArr.length];
        double[] dArr2 = new double[strArr.length];
        Arrays.sort(scoredObjectArr, ScoredObject.reverseComparator());
        for (int i2 = 0; i2 < scoredObjectArr.length; i2++) {
            strArr2[i2] = (String) scoredObjectArr[i2].getObject();
            dArr2[i2] = scoredObjectArr[i2].score();
        }
        return new Pair<>(strArr2, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] logJointToConditional(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > 0.0d && dArr[i] < 1.0E-10d) {
                dArr[i] = 0.0d;
            }
            if (dArr[i] > 0.0d || Double.isNaN(dArr[i])) {
                StringBuilder sb = new StringBuilder();
                sb.append("Joint probs must be zero or negative. Found log2JointProbs[" + i + "]=" + dArr[i]);
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    sb.append("\nlogJointProbs[" + i2 + "]=" + dArr[i2]);
                }
                throw new IllegalArgumentException(sb.toString());
            }
        }
        double maximum = Math.maximum(dArr);
        double[] dArr2 = new double[dArr.length];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr2[i3] = Math.pow(2.0d, dArr[i3] - maximum);
            if (dArr2[i3] == Double.POSITIVE_INFINITY) {
                dArr2[i3] = 3.4028234663852886E38d;
            } else if (dArr2[i3] == Double.NEGATIVE_INFINITY || Double.isNaN(dArr2[i3])) {
                dArr2[i3] = 0.0d;
            }
        }
        return Statistics.normalize(dArr2);
    }
}
