package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.DenseVectorFactoryMTJ;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.openimaj.io.IOUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/experiments/sinabill/AustrianWordExperiments.class */
public class AustrianWordExperiments extends BilinearExperiment {
    public static void main(String[] strArr) throws IOException {
        new AustrianWordExperiments().performExperiment();
    }

    @Override // org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment
    public void performExperiment() throws IOException {
        BilinearLearnerParameters bilinearLearnerParameters = new BilinearLearnerParameters();
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_U, Double.valueOf(2.0E-4d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_W, Double.valueOf(0.002d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.LAMBDA, Double.valueOf(0.001d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_TOL, Double.valueOf(0.05d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_MAXITER, 5);
        bilinearLearnerParameters.put(BilinearLearnerParameters.BIAS, true);
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_BIAS, Double.valueOf(0.05d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
        BillMatlabFileDataGenerator billMatlabFileDataGenerator = new BillMatlabFileDataGenerator(new File(MATLAB_DATA()), 98, true);
        prepareExperimentLog(bilinearLearnerParameters);
        File file = new File("/Users/ss/Dropbox/TrendMiner/deliverables/year2-18month/Austrian Data/streamingExperiments/experiment_1365684128359/fold_0_learner");
        this.logger.debug("Fold: 0");
        BilinearSparseOnlineLearner bilinearSparseOnlineLearner = new BilinearSparseOnlineLearner(bilinearLearnerParameters);
        bilinearSparseOnlineLearner.reinitParams();
        billMatlabFileDataGenerator.setFold(0, BillMatlabFileDataGenerator.Mode.TEST);
        this.logger.debug("...training");
        billMatlabFileDataGenerator.setFold(0, BillMatlabFileDataGenerator.Mode.TRAINING);
        int i = 0;
        if (file.exists()) {
            bilinearSparseOnlineLearner = (BilinearSparseOnlineLearner) IOUtils.read(file, BilinearSparseOnlineLearner.class);
        } else {
            while (true) {
                Pair<Matrix> mo9generate = billMatlabFileDataGenerator.mo9generate();
                if (mo9generate == null) {
                    break;
                }
                int i2 = i;
                i++;
                this.logger.debug("...trying item " + i2);
                bilinearSparseOnlineLearner.process((Matrix) mo9generate.firstObject(), (Matrix) mo9generate.secondObject());
            }
            System.out.println("Writing W and U to: " + file);
            IOUtils.writeBinary(file, bilinearSparseOnlineLearner);
        }
        Matrix w = bilinearSparseOnlineLearner.getW();
        int numColumns = w.getNumColumns();
        for (int i3 = 0; i3 < numColumns; i3++) {
            System.out.println("Top 20 words for: " + billMatlabFileDataGenerator.getTasks()[i3]);
            double[] array = new DenseVectorFactoryMTJ().copyVector(w.getColumn(i3)).getArray();
            Integer[] integerRange = ArrayIndexComparator.integerRange(array);
            Arrays.sort(integerRange, new ArrayIndexComparator(array));
            for (int length = array.length - 1; length >= array.length - 20; length--) {
                System.out.printf("%s: %1.5f\n", billMatlabFileDataGenerator.getVocabulary().get(integerRange[length]), Double.valueOf(array[integerRange[length].intValue()]));
            }
        }
    }
}
