/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.pgm.vb.lda.mle;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math.special.Gamma;
import org.openimaj.math.util.MathUtils;
import org.openimaj.pgm.util.Corpus;
import org.openimaj.pgm.util.Document;
import org.openimaj.pgm.vb.lda.mle.LDABetaInitStrategy;
import org.openimaj.pgm.vb.lda.mle.LDAModel;
import org.openimaj.pgm.vb.lda.mle.LDAVariationlState;
import org.openimaj.util.array.SparseIntArray;

public class LDALearner {
    private int ntopics;
    private Map<LDAConfig, Object> config = new HashMap<LDAConfig, Object>();

    public LDALearner(int ntopics) {
        this.ntopics = ntopics;
    }

    public <T> T getConfig(LDAConfig key) {
        Object val = this.config.get((Object)key);
        if (val == null) {
            return (T)key.defaultValue();
        }
        return (T)val;
    }

    public void estimate(Corpus corpus) {
        this.performEM(corpus);
    }

    private void performEM(Corpus corpus) {
        double initialAlpha = (Double)this.getConfig(LDAConfig.ALPHA);
        LDABetaInitStrategy initStrat = (LDABetaInitStrategy)this.getConfig(LDAConfig.INIT_STRATEGY);
        LDAModel state = new LDAModel(this.ntopics);
        state.prepare(corpus);
        state.setAlpha(initialAlpha);
        initStrat.initModel(state, corpus);
        LDAVariationlState vstate = new LDAVariationlState(state);
        while (this.modelConverged(vstate.state)) {
            LDAModel nextState = vstate.state.newInstance();
            nextState.setAlpha(initialAlpha);
            for (Document doc : corpus.getDocuments()) {
                vstate.prepare(doc);
                this.performE(doc, vstate);
                this.performM(doc, vstate, nextState);
                nextState.likelihood += vstate.likelihood;
            }
            ++nextState.iteration;
            vstate.state = nextState;
        }
    }

    private LDAVariationlState performE(Document doc, LDAVariationlState vstate) {
        vstate.prepare(doc);
        while (!this.variationalStateConverged(vstate)) {
            int docWordIndex = 0;
            for (SparseIntArray.Entry wordCount : doc.getVector().entries()) {
                int topicIndex;
                double phiSum = 0.0;
                int word = wordCount.index;
                int count = wordCount.value;
                for (topicIndex = 0; topicIndex < vstate.phi.length; ++topicIndex) {
                    vstate.oldphi[topicIndex] = vstate.phi[docWordIndex][topicIndex];
                    if (vstate.state.topicWord[topicIndex][docWordIndex] > 0.0) {
                        double logBeta = Math.log(vstate.state.topicWord[topicIndex][word]) - Math.log(vstate.state.topicTotal[topicIndex]);
                        vstate.phi[docWordIndex][topicIndex] = logBeta + Gamma.digamma((double)vstate.varGamma[topicIndex]);
                    } else {
                        vstate.phi[docWordIndex][topicIndex] = Gamma.digamma((double)vstate.varGamma[topicIndex]) - 100.0;
                    }
                    phiSum = topicIndex == 0 ? vstate.phi[docWordIndex][topicIndex] : MathUtils.logSum((double)phiSum, (double)vstate.phi[docWordIndex][topicIndex]);
                }
                for (topicIndex = 0; topicIndex < vstate.phi.length; ++topicIndex) {
                    vstate.phi[docWordIndex][topicIndex] = Math.exp(vstate.phi[docWordIndex][topicIndex] - phiSum);
                    int n = topicIndex;
                    vstate.varGamma[n] = vstate.varGamma[n] + (double)count * (vstate.phi[docWordIndex][topicIndex] - vstate.oldphi[topicIndex]);
                }
                ++docWordIndex;
            }
            vstate.oldLikelihood = vstate.likelihood;
            vstate.likelihood = this.computeLikelihood(doc, vstate);
            ++vstate.iteration;
        }
        return vstate;
    }

    private boolean modelConverged(LDAModel model) {
        double EM_CONVERGED = (Double)this.getConfig(LDAConfig.EM_CONVERGED);
        int MAX_ITER = (Integer)this.getConfig(LDAConfig.MAX_ITERATIONS);
        double converged = (model.likelihood - model.oldLikelihood) / model.oldLikelihood;
        boolean liklihoodSettled = converged < EM_CONVERGED || model.iteration <= 2;
        boolean maxIterExceeded = model.iteration > MAX_ITER;
        return liklihoodSettled || maxIterExceeded;
    }

    private boolean variationalStateConverged(LDAVariationlState vstate) {
        double EM_CONVERGED = (Double)this.getConfig(LDAConfig.VAR_EM_CONVERGED);
        int MAX_ITER = (Integer)this.getConfig(LDAConfig.VAR_MAX_ITERATIONS);
        double converged = (vstate.likelihood - vstate.oldLikelihood) / vstate.oldLikelihood;
        boolean liklihoodSettled = converged < EM_CONVERGED || vstate.iteration <= 2;
        boolean maxIterExceeded = vstate.iteration > MAX_ITER;
        return liklihoodSettled || maxIterExceeded;
    }

    private void performM(Document d, LDAVariationlState vstate, LDAModel nextState) {
        for (SparseIntArray.Entry entry : d.values.entries()) {
            for (int topicIndex = 0; topicIndex < this.ntopics; ++topicIndex) {
                int wordIndex = entry.index;
                int count = entry.value;
                nextState.incTopicWord(topicIndex, wordIndex, (double)count * vstate.phi[wordIndex][topicIndex]);
                nextState.incTopicTotal(topicIndex, count);
            }
        }
    }

    public double computeLikelihood(Document doc, LDAVariationlState vstate) {
        int topicIndex;
        double likelihood = 0.0;
        double sumVarGamma = 0.0;
        double sumDiGamma = 0.0;
        for (topicIndex = 0; topicIndex < this.ntopics; ++topicIndex) {
            sumVarGamma += vstate.varGamma[topicIndex];
            vstate.digamma[topicIndex] = Gamma.digamma((double)vstate.varGamma[topicIndex]);
            sumDiGamma += vstate.digamma[topicIndex];
        }
        likelihood += Gamma.logGamma((double)(vstate.state.alpha * (double)this.ntopics)) - Gamma.logGamma((double)vstate.state.alpha) * (double)this.ntopics + Gamma.logGamma((double)sumVarGamma);
        for (topicIndex = 0; topicIndex < this.ntopics; ++topicIndex) {
            double topicGammaDiff = vstate.digamma[topicIndex] - sumDiGamma;
            likelihood += Gamma.logGamma((double)vstate.varGamma[topicIndex]) - (vstate.varGamma[topicIndex] - 1.0) * topicGammaDiff;
            int wordIndex = 0;
            for (SparseIntArray.Entry wordCount : doc.getVector().entries()) {
                int word = wordCount.index;
                int count = wordCount.value;
                double logBeta = Math.log(vstate.state.topicWord[topicIndex][word]) - Math.log(vstate.state.topicTotal[topicIndex]);
                likelihood += (double)count * (vstate.phi[wordIndex][topicIndex] * (topicGammaDiff + (double)count * logBeta - Math.log(vstate.phi[wordIndex][topicIndex])));
                ++wordIndex;
            }
        }
        return likelihood;
    }

    static enum LDAConfig {
        MAX_ITERATIONS{

            @Override
            public Integer defaultValue() {
                return 10;
            }
        }
        ,
        ALPHA{

            @Override
            public Double defaultValue() {
                return 0.3;
            }
        }
        ,
        VAR_MAX_ITERATIONS{

            @Override
            public Integer defaultValue() {
                return 10;
            }
        }
        ,
        INIT_STRATEGY{

            @Override
            public LDABetaInitStrategy defaultValue() {
                return new LDABetaInitStrategy.RandomBetaInit();
            }
        }
        ,
        EM_CONVERGED{

            @Override
            public Double defaultValue() {
                return 1.0E-5;
            }
        }
        ,
        VAR_EM_CONVERGED{

            @Override
            public Double defaultValue() {
                return 1.0E-5;
            }
        };


        public abstract Object defaultValue();
    }
}

