package org.openimaj.pgm.vb.lda.mle;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math.special.Gamma;
import org.openimaj.math.util.MathUtils;
import org.openimaj.pgm.util.Corpus;
import org.openimaj.pgm.util.Document;
import org.openimaj.pgm.vb.lda.mle.LDABetaInitStrategy;
import org.openimaj.util.array.SparseIntArray;

/* loaded from: input_file:org/openimaj/pgm/vb/lda/mle/LDALearner.class */
public class LDALearner {
    private int ntopics;
    private Map<LDAConfig, Object> config = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/openimaj/pgm/vb/lda/mle/LDALearner$LDAConfig.class */
    public enum LDAConfig {
        MAX_ITERATIONS { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.1
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public Integer defaultValue() {
                return 10;
            }
        },
        ALPHA { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.2
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public Double defaultValue() {
                return Double.valueOf(0.3d);
            }
        },
        VAR_MAX_ITERATIONS { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.3
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public Integer defaultValue() {
                return 10;
            }
        },
        INIT_STRATEGY { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.4
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public LDABetaInitStrategy defaultValue() {
                return new LDABetaInitStrategy.RandomBetaInit();
            }
        },
        EM_CONVERGED { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.5
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public Double defaultValue() {
                return Double.valueOf(1.0E-5d);
            }
        },
        VAR_EM_CONVERGED { // from class: org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig.6
            @Override // org.openimaj.pgm.vb.lda.mle.LDALearner.LDAConfig
            public Double defaultValue() {
                return Double.valueOf(1.0E-5d);
            }
        };

        public abstract Object defaultValue();
    }

    public LDALearner(int i) {
        this.ntopics = i;
    }

    public <T> T getConfig(LDAConfig lDAConfig) {
        T t = (T) this.config.get(lDAConfig);
        return t == null ? (T) lDAConfig.defaultValue() : t;
    }

    public void estimate(Corpus corpus) {
        performEM(corpus);
    }

    private void performEM(Corpus corpus) {
        double doubleValue = ((Double) getConfig(LDAConfig.ALPHA)).doubleValue();
        LDABetaInitStrategy lDABetaInitStrategy = (LDABetaInitStrategy) getConfig(LDAConfig.INIT_STRATEGY);
        LDAModel lDAModel = new LDAModel(this.ntopics);
        lDAModel.prepare(corpus);
        lDAModel.setAlpha(doubleValue);
        lDABetaInitStrategy.initModel(lDAModel, corpus);
        LDAVariationlState lDAVariationlState = new LDAVariationlState(lDAModel);
        while (modelConverged(lDAVariationlState.state)) {
            LDAModel newInstance = lDAVariationlState.state.newInstance();
            newInstance.setAlpha(doubleValue);
            for (Document document : corpus.getDocuments()) {
                lDAVariationlState.prepare(document);
                performE(document, lDAVariationlState);
                performM(document, lDAVariationlState, newInstance);
                newInstance.likelihood += lDAVariationlState.likelihood;
            }
            newInstance.iteration++;
            lDAVariationlState.state = newInstance;
        }
    }

    private LDAVariationlState performE(Document document, LDAVariationlState lDAVariationlState) {
        lDAVariationlState.prepare(document);
        while (!variationalStateConverged(lDAVariationlState)) {
            int i = 0;
            for (SparseIntArray.Entry entry : document.getVector().entries()) {
                double d = 0.0d;
                int i2 = entry.index;
                int i3 = entry.value;
                int i4 = 0;
                while (i4 < lDAVariationlState.phi.length) {
                    lDAVariationlState.oldphi[i4] = lDAVariationlState.phi[i][i4];
                    if (lDAVariationlState.state.topicWord[i4][i] > 0.0d) {
                        lDAVariationlState.phi[i][i4] = (Math.log(lDAVariationlState.state.topicWord[i4][i2]) - Math.log(lDAVariationlState.state.topicTotal[i4])) + Gamma.digamma(lDAVariationlState.varGamma[i4]);
                    } else {
                        lDAVariationlState.phi[i][i4] = Gamma.digamma(lDAVariationlState.varGamma[i4]) - 100.0d;
                    }
                    d = i4 == 0 ? lDAVariationlState.phi[i][i4] : MathUtils.logSum(d, lDAVariationlState.phi[i][i4]);
                    i4++;
                }
                for (int i5 = 0; i5 < lDAVariationlState.phi.length; i5++) {
                    lDAVariationlState.phi[i][i5] = Math.exp(lDAVariationlState.phi[i][i5] - d);
                    double[] dArr = lDAVariationlState.varGamma;
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + (i3 * (lDAVariationlState.phi[i][i5] - lDAVariationlState.oldphi[i5]));
                }
                i++;
            }
            lDAVariationlState.oldLikelihood = lDAVariationlState.likelihood;
            lDAVariationlState.likelihood = computeLikelihood(document, lDAVariationlState);
            lDAVariationlState.iteration++;
        }
        return lDAVariationlState;
    }

    private boolean modelConverged(LDAModel lDAModel) {
        return ((((lDAModel.likelihood - lDAModel.oldLikelihood) / lDAModel.oldLikelihood) > ((Double) getConfig(LDAConfig.EM_CONVERGED)).doubleValue() ? 1 : (((lDAModel.likelihood - lDAModel.oldLikelihood) / lDAModel.oldLikelihood) == ((Double) getConfig(LDAConfig.EM_CONVERGED)).doubleValue() ? 0 : -1)) < 0 || lDAModel.iteration <= 2) || (lDAModel.iteration > ((Integer) getConfig(LDAConfig.MAX_ITERATIONS)).intValue());
    }

    private boolean variationalStateConverged(LDAVariationlState lDAVariationlState) {
        return ((((lDAVariationlState.likelihood - lDAVariationlState.oldLikelihood) / lDAVariationlState.oldLikelihood) > ((Double) getConfig(LDAConfig.VAR_EM_CONVERGED)).doubleValue() ? 1 : (((lDAVariationlState.likelihood - lDAVariationlState.oldLikelihood) / lDAVariationlState.oldLikelihood) == ((Double) getConfig(LDAConfig.VAR_EM_CONVERGED)).doubleValue() ? 0 : -1)) < 0 || lDAVariationlState.iteration <= 2) || (lDAVariationlState.iteration > ((Integer) getConfig(LDAConfig.VAR_MAX_ITERATIONS)).intValue());
    }

    private void performM(Document document, LDAVariationlState lDAVariationlState, LDAModel lDAModel) {
        for (SparseIntArray.Entry entry : document.values.entries()) {
            for (int i = 0; i < this.ntopics; i++) {
                int i2 = entry.index;
                int i3 = entry.value;
                lDAModel.incTopicWord(i, i2, i3 * lDAVariationlState.phi[i2][i]);
                lDAModel.incTopicTotal(i, i3);
            }
        }
    }

    public double computeLikelihood(Document document, LDAVariationlState lDAVariationlState) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.ntopics; i++) {
            d += lDAVariationlState.varGamma[i];
            lDAVariationlState.digamma[i] = Gamma.digamma(lDAVariationlState.varGamma[i]);
            d2 += lDAVariationlState.digamma[i];
        }
        double logGamma = 0.0d + (Gamma.logGamma(lDAVariationlState.state.alpha * this.ntopics) - (Gamma.logGamma(lDAVariationlState.state.alpha) * this.ntopics)) + Gamma.logGamma(d);
        for (int i2 = 0; i2 < this.ntopics; i2++) {
            double d3 = lDAVariationlState.digamma[i2] - d2;
            logGamma += Gamma.logGamma(lDAVariationlState.varGamma[i2]) - ((lDAVariationlState.varGamma[i2] - 1.0d) * d3);
            int i3 = 0;
            for (SparseIntArray.Entry entry : document.getVector().entries()) {
                int i4 = entry.index;
                int i5 = entry.value;
                logGamma += i5 * lDAVariationlState.phi[i3][i2] * ((d3 + (i5 * (Math.log(lDAVariationlState.state.topicWord[i2][i4]) - Math.log(lDAVariationlState.state.topicTotal[i2])))) - Math.log(lDAVariationlState.phi[i3][i2]));
                i3++;
            }
        }
        return logGamma;
    }
}
