/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
import org.openimaj.util.pair.Pair;

public class BillAustrianExperimentsNormalised
extends BilinearExperiment {
    public static void main(String[] args) throws IOException {
        BillAustrianExperimentsNormalised exp = new BillAustrianExperimentsNormalised();
        exp.performExperiment();
    }

    @Override
    public void performExperiment() throws IOException {
        Pair<Matrix> next;
        int j;
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        int INITIAL_TRAIN_NUMBER = 48;
        params.put("eta0u", 5.0);
        params.put("eta0w", 5.0);
        params.put("lambda_u", 5.0E-6);
        params.put("lambda_w", 5.0E-4);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("biaseta0", 0.1);
        params.put("winitstrat", new SparseZerosInitStrategy());
        params.put("uinitstrat", new SparseZerosInitStrategy());
        params.put("loss", new MatSquareLossFunction());
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(this.MATLAB_DATA("%s/user_vsr_for_polls_SINA.mat")), "user_vsr_for_polls_SINA", new File(this.MATLAB_DATA()), 98, false);
        this.prepareExperimentLog(params);
        BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
        learner.reinitParams();
        bmfdg.setFold(-1, null);
        this.logger.debug((Object)("... training initial " + INITIAL_TRAIN_NUMBER + " items"));
        for (j = 0; j < INITIAL_TRAIN_NUMBER && (next = bmfdg.generate()) != null; ++j) {
            this.logger.debug((Object)("...trying item " + j));
            learner.process((Matrix)next.firstObject(), (Matrix)next.secondObject());
            this.logger.debug((Object)("...done processing item " + j));
        }
        this.logger.debug((Object)"... testing 5, training 5...");
        int i = 0;
        while (true) {
            Pair<Matrix> next2;
            ArrayList<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>();
            for (int k = 0; k < 5 && (next2 = bmfdg.generate()) != null; ++k) {
                testpairs.add(next2);
            }
            if (testpairs.size() == 0) break;
            Matrix u = learner.getU();
            Matrix w = learner.getW();
            Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias());
            RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
            eval.setLearner(learner);
            double loss = ((BilinearEvaluator)eval).evaluate(testpairs);
            this.logger.debug((Object)String.format("Saving learner, Fold %d, Item %d", i, j));
            File learnerOut = new File(this.FOLD_ROOT(i), String.format("learner_%d", j));
            IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
            this.logger.debug((Object)("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)w)));
            this.logger.debug((Object)("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)u)));
            Boolean biasMode = (Boolean)learner.getParams().getTyped("bias");
            if (biasMode.booleanValue()) {
                this.logger.debug((Object)("Bias: " + CFMatrixUtils.diag((Matrix)bias)));
            }
            this.logger.debug((Object)String.format("... loss: %f", loss));
            for (Pair pair : testpairs) {
                this.logger.debug((Object)"...training with tests");
                this.logger.debug((Object)("...trying item " + j));
                learner.process((Matrix)pair.firstObject(), (Matrix)pair.secondObject());
                this.logger.debug((Object)("...done processing item " + j));
                ++j;
            }
            ++i;
        }
    }
}

