package org.openimaj.ml.annotation.linear;

import de.bwaldvogel.liblinear.DenseLinear;
import de.bwaldvogel.liblinear.DenseProblem;
import de.bwaldvogel.liblinear.Feature;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.AnnotatedObject;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;
import org.openimaj.ml.annotation.svm.SVMAnnotator;
import org.openimaj.ml.annotation.utils.AnnotatedListHelper;
import org.openimaj.ml.annotation.utils.LiblinearHelper;

@Reference(type = ReferenceType.Article, author = {"Fan, Rong-En", "Chang, Kai-Wei", "Hsieh, Cho-Jui", "Wang, Xiang-Rui", "Lin, Chih-Jen"}, title = "LIBLINEAR: A Library for Large Linear Classification", year = "2008", journal = "J. Mach. Learn. Res.", pages = {"1871", "", "1874"}, url = "http://dl.acm.org/citation.cfm?id=1390681.1442794", month = "june", publisher = "JMLR.org", volume = "9", customData = {"date", "6/1/2008", "issn", "1532-4435", "numpages", "4", "acmid", "1442794"})
/* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator.class */
public class LiblinearAnnotator<OBJECT, ANNOTATION> extends BatchAnnotator<OBJECT, ANNOTATION> {
    InternalModel<OBJECT, ANNOTATION> internal;

    /* renamed from: org.openimaj.ml.annotation.linear.LiblinearAnnotator$1, reason: invalid class name */
    /* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$openimaj$ml$annotation$linear$LiblinearAnnotator$Mode = new int[Mode.values().length];

        static {
            try {
                $SwitchMap$org$openimaj$ml$annotation$linear$LiblinearAnnotator$Mode[Mode.MULTICLASS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$openimaj$ml$annotation$linear$LiblinearAnnotator$Mode[Mode.MULTILABEL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator$InternalModel.class */
    public static abstract class InternalModel<OBJECT, ANNOTATION> {
        ArrayList<ANNOTATION> annotationsList;
        FeatureExtractor<? extends FeatureVector, OBJECT> extractor;
        boolean dense;
        double bias = -1.0d;
        boolean estimateProbabilities = true;

        InternalModel() {
        }

        public abstract void train(List<? extends Annotated<OBJECT, ANNOTATION>> list);

        public abstract void train(GroupedDataset<ANNOTATION, ListDataset<OBJECT>, OBJECT> groupedDataset);

        public abstract List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object);

        Feature[] computeFeature(OBJECT object) {
            return LiblinearHelper.convert((FeatureVector) this.extractor.extractFeature(object), this.bias);
        }

        double[] computeFeatureDense(OBJECT object) {
            return LiblinearHelper.convertDense((FeatureVector) this.extractor.extractFeature(object), this.bias);
        }

        void computeProbabilities(double[] dArr) {
            if (this.estimateProbabilities) {
                int size = this.annotationsList.size();
                int i = size == 2 ? 1 : size;
                for (int i2 = 0; i2 < i; i2++) {
                    dArr[i2] = 1.0d / (1.0d + Math.exp(-dArr[i2]));
                }
                if (size == 2) {
                    dArr[1] = 1.0d - dArr[0];
                    return;
                }
                double d = 0.0d;
                for (int i3 = 0; i3 < size; i3++) {
                    d += dArr[i3];
                }
                for (int i4 = 0; i4 < size; i4++) {
                    dArr[i4] = dArr[i4] / d;
                }
            }
        }
    }

    /* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator$Mode.class */
    public enum Mode {
        MULTICLASS,
        MULTILABEL
    }

    /* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator$Multiclass.class */
    static class Multiclass<OBJECT, ANNOTATION> extends InternalModel<OBJECT, ANNOTATION> {
        private Parameter parameter;
        private Model model;

        public Multiclass(SolverType solverType, double d, double d2, double d3, boolean z) {
            this.parameter = new Parameter(solverType, d, d2);
            this.dense = z;
            this.bias = d3;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v24, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [de.bwaldvogel.liblinear.Feature[], de.bwaldvogel.liblinear.Feature[][]] */
        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public void train(GroupedDataset<ANNOTATION, ListDataset<OBJECT>, OBJECT> groupedDataset) {
            this.annotationsList = new ArrayList<>(groupedDataset.getGroups());
            int numInstances = groupedDataset.numInstances();
            int length = ((FeatureVector) this.extractor.extractFeature(groupedDataset.getRandomInstance())).length();
            if (this.dense) {
                DenseProblem denseProblem = new DenseProblem();
                denseProblem.l = numInstances;
                denseProblem.n = length + 1;
                denseProblem.bias = this.bias;
                denseProblem.x = new double[numInstances];
                denseProblem.y = new double[numInstances];
                int i = 0;
                Iterator it = groupedDataset.getGroups().iterator();
                while (it.hasNext()) {
                    for (Object obj : (ListDataset) groupedDataset.get(it.next())) {
                        denseProblem.y[i] = this.annotationsList.indexOf(r0) + 1;
                        denseProblem.x[i] = computeFeatureDense(obj);
                        i++;
                    }
                }
                this.model = DenseLinear.train(denseProblem, this.parameter);
                return;
            }
            Problem problem = new Problem();
            problem.l = numInstances;
            problem.n = length;
            problem.bias = this.bias;
            problem.x = new Feature[numInstances];
            problem.y = new double[numInstances];
            int i2 = 0;
            Iterator it2 = groupedDataset.getGroups().iterator();
            while (it2.hasNext()) {
                for (Object obj2 : (ListDataset) groupedDataset.get(it2.next())) {
                    problem.y[i2] = this.annotationsList.indexOf(r0) + 1;
                    problem.x[i2] = computeFeature(obj2);
                    i2++;
                }
            }
            this.model = Linear.train(problem, this.parameter);
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [de.bwaldvogel.liblinear.Feature[], de.bwaldvogel.liblinear.Feature[][]] */
        /* JADX WARN: Type inference failed for: r1v29, types: [double[], double[][]] */
        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
            this.annotationsList = new ArrayList<>(new AnnotatedListHelper(list).getAnnotations());
            int size = list.size();
            int length = ((FeatureVector) this.extractor.extractFeature(list.get(0).getObject())).length();
            if (this.dense) {
                DenseProblem denseProblem = new DenseProblem();
                denseProblem.l = size;
                denseProblem.n = length;
                denseProblem.bias = this.bias;
                denseProblem.x = new double[size];
                denseProblem.y = new double[size];
                for (int i = 0; i < size; i++) {
                    Annotated<OBJECT, ANNOTATION> annotated = list.get(i);
                    if (annotated.getAnnotations().size() != 1) {
                        throw new IllegalArgumentException("A multiclass problem cannot have more than one class per instance");
                    }
                    denseProblem.y[i] = this.annotationsList.indexOf(annotated.getAnnotations().iterator().next()) + 1;
                    denseProblem.x[i] = computeFeatureDense(annotated.getObject());
                }
                this.model = DenseLinear.train(denseProblem, this.parameter);
                return;
            }
            Problem problem = new Problem();
            problem.l = size;
            problem.n = length;
            problem.bias = this.bias;
            problem.x = new Feature[size];
            problem.y = new double[size];
            for (int i2 = 0; i2 < size; i2++) {
                Annotated<OBJECT, ANNOTATION> annotated2 = list.get(i2);
                if (annotated2.getAnnotations().size() != 1) {
                    throw new IllegalArgumentException("A multiclass problem cannot have more than one class per instance");
                }
                problem.y[i2] = this.annotationsList.indexOf(annotated2.getAnnotations().iterator().next()) + 1;
                problem.x[i2] = computeFeature(annotated2.getObject());
            }
            this.model = Linear.train(problem, this.parameter);
        }

        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
            double predictValues;
            double d;
            if (this.dense) {
                double[] computeFeatureDense = computeFeatureDense(object);
                if (this.parameter.getSolverType().isLogisticRegressionSolver()) {
                    double[] dArr = new double[this.annotationsList.size()];
                    predictValues = DenseLinear.predictProbability(this.model, computeFeatureDense, dArr) - 1.0d;
                    d = dArr[(int) predictValues];
                } else {
                    double[] dArr2 = new double[this.annotationsList.size()];
                    predictValues = DenseLinear.predictValues(this.model, computeFeatureDense, dArr2) - 1.0d;
                    computeProbabilities(dArr2);
                    d = dArr2[(int) predictValues];
                }
            } else {
                Feature[] computeFeature = computeFeature(object);
                if (this.parameter.getSolverType().isLogisticRegressionSolver()) {
                    double[] dArr3 = new double[this.annotationsList.size()];
                    predictValues = Linear.predictProbability(this.model, computeFeature, dArr3) - 1.0d;
                    d = dArr3[(int) predictValues];
                } else {
                    double[] dArr4 = new double[this.annotationsList.size()];
                    predictValues = Linear.predictValues(this.model, computeFeature, dArr4) - 1.0d;
                    computeProbabilities(dArr4);
                    d = dArr4[(int) predictValues];
                }
            }
            ArrayList arrayList = new ArrayList(1);
            arrayList.add(new ScoredAnnotation(this.annotationsList.get((int) predictValues), (float) d));
            return arrayList;
        }
    }

    /* loaded from: input_file:org/openimaj/ml/annotation/linear/LiblinearAnnotator$Multilabel.class */
    static class Multilabel<OBJECT, ANNOTATION> extends InternalModel<OBJECT, ANNOTATION> {
        private Parameter parameter;
        private Model[] models;
        private static final int NEGATIVE_CLASS = 1;
        private static final int POSTIVE_CLASS = 2;

        public Multilabel(SolverType solverType, double d, double d2, double d3, boolean z) {
            this.parameter = new Parameter(solverType, d, d2);
            this.dense = z;
            this.bias = d3;
        }

        /* JADX WARN: Type inference failed for: r1v25, types: [de.bwaldvogel.liblinear.Feature[], de.bwaldvogel.liblinear.Feature[][]] */
        /* JADX WARN: Type inference failed for: r1v47, types: [double[], double[][]] */
        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
            AnnotatedListHelper annotatedListHelper = new AnnotatedListHelper(list);
            this.annotationsList = new ArrayList<>(annotatedListHelper.getAnnotations());
            int length = ((FeatureVector) this.extractor.extractFeature(list.get(0).getObject())).length();
            this.models = new Model[this.annotationsList.size()];
            for (int i = 0; i < this.annotationsList.size(); i++) {
                ANNOTATION annotation = this.annotationsList.get(i);
                List extractFeatures = annotatedListHelper.extractFeatures(annotation, this.extractor);
                List extractFeaturesExclude = annotatedListHelper.extractFeaturesExclude(annotation, this.extractor);
                if (this.dense) {
                    DenseProblem denseProblem = new DenseProblem();
                    denseProblem.l = extractFeatures.size() + extractFeaturesExclude.size();
                    denseProblem.n = length;
                    denseProblem.bias = this.bias;
                    denseProblem.x = new double[denseProblem.l];
                    denseProblem.y = new double[denseProblem.l];
                    for (int i2 = 0; i2 < extractFeaturesExclude.size(); i2++) {
                        denseProblem.x[i2] = LiblinearHelper.convertDense((FeatureVector) extractFeaturesExclude.get(i2), this.bias);
                        denseProblem.y[i2] = 1.0d;
                    }
                    int size = extractFeaturesExclude.size();
                    for (int i3 = 0; i3 < extractFeatures.size(); i3++) {
                        denseProblem.x[size] = LiblinearHelper.convertDense((FeatureVector) extractFeatures.get(i3), this.bias);
                        denseProblem.y[size] = 2.0d;
                        size++;
                    }
                    this.models[i] = DenseLinear.train(denseProblem, this.parameter);
                } else {
                    Problem problem = new Problem();
                    problem.l = extractFeatures.size() + extractFeaturesExclude.size();
                    problem.n = length;
                    problem.bias = this.bias;
                    problem.x = new Feature[problem.l];
                    problem.y = new double[problem.l];
                    for (int i4 = 0; i4 < extractFeaturesExclude.size(); i4++) {
                        problem.x[i4] = LiblinearHelper.convert((FeatureVector) extractFeaturesExclude.get(i4), this.bias);
                        problem.y[i4] = 1.0d;
                    }
                    int size2 = extractFeaturesExclude.size();
                    for (int i5 = 0; i5 < extractFeatures.size(); i5++) {
                        problem.x[size2] = LiblinearHelper.convert((FeatureVector) extractFeatures.get(i5), this.bias);
                        problem.y[size2] = 2.0d;
                        size2++;
                    }
                    this.models[i] = Linear.train(problem, this.parameter);
                }
            }
        }

        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
            double predictValues;
            double d;
            double predictValues2;
            double d2;
            ArrayList arrayList = new ArrayList();
            if (this.dense) {
                double[] computeFeatureDense = computeFeatureDense(object);
                for (int i = 0; i < this.annotationsList.size(); i++) {
                    if (this.parameter.getSolverType().isLogisticRegressionSolver()) {
                        double[] dArr = new double[this.annotationsList.size()];
                        predictValues2 = DenseLinear.predictProbability(this.models[i], computeFeatureDense, dArr);
                        d2 = dArr[((int) predictValues2) - 1];
                    } else {
                        double[] dArr2 = new double[POSTIVE_CLASS];
                        predictValues2 = DenseLinear.predictValues(this.models[i], computeFeatureDense, dArr2);
                        computeProbabilities(dArr2);
                        d2 = dArr2[((int) predictValues2) - 1];
                    }
                    if (predictValues2 == 2.0d) {
                        arrayList.add(new ScoredAnnotation(this.annotationsList.get(i), (float) d2));
                    }
                }
            } else {
                Feature[] computeFeature = computeFeature(object);
                for (int i2 = 0; i2 < this.annotationsList.size(); i2++) {
                    if (this.parameter.getSolverType().isLogisticRegressionSolver()) {
                        double[] dArr3 = new double[this.annotationsList.size()];
                        predictValues = Linear.predictProbability(this.models[i2], computeFeature, dArr3);
                        d = dArr3[((int) predictValues) - 1];
                    } else {
                        double[] dArr4 = new double[POSTIVE_CLASS];
                        predictValues = Linear.predictValues(this.models[i2], computeFeature, dArr4);
                        computeProbabilities(dArr4);
                        d = dArr4[((int) predictValues) - 1];
                    }
                    if (predictValues == 2.0d) {
                        arrayList.add(new ScoredAnnotation(this.annotationsList.get(i2), (float) d));
                    }
                }
            }
            return arrayList;
        }

        @Override // org.openimaj.ml.annotation.linear.LiblinearAnnotator.InternalModel
        public void train(GroupedDataset<ANNOTATION, ListDataset<OBJECT>, OBJECT> groupedDataset) {
            train(AnnotatedObject.createList(groupedDataset));
        }
    }

    public LiblinearAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> featureExtractor, Mode mode, SolverType solverType, double d, double d2) {
        this(featureExtractor, mode, solverType, d, d2, -1.0d, false);
    }

    public LiblinearAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> featureExtractor, Mode mode, SolverType solverType, double d, double d2, double d3, boolean z) {
        switch (AnonymousClass1.$SwitchMap$org$openimaj$ml$annotation$linear$LiblinearAnnotator$Mode[mode.ordinal()]) {
            case SVMAnnotator.POSITIVE_CLASS /* 1 */:
                this.internal = new Multiclass(solverType, d, d2, d3, z);
                break;
            case 2:
                this.internal = new Multilabel(solverType, d, d2, d3, z);
                break;
            default:
                throw new RuntimeException("Unhandled mode");
        }
        this.internal.extractor = featureExtractor;
    }

    @Override // org.openimaj.ml.training.BatchTrainer
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> list) {
        this.internal.train(list);
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public Set<ANNOTATION> getAnnotations() {
        return new HashSet(this.internal.annotationsList);
    }

    @Override // org.openimaj.ml.annotation.Annotator
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        return this.internal.annotate(object);
    }

    @Override // org.openimaj.ml.annotation.BatchAnnotator
    public void train(GroupedDataset<ANNOTATION, ListDataset<OBJECT>, OBJECT> groupedDataset) {
        this.internal.train(groupedDataset);
    }
}
