/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.projection;

import Jama.Matrix;
import gnu.trove.list.array.TDoubleArrayList;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.matrix.algorithm.pca.ThinSvdPrincipalComponentAnalysis;

public class LargeMarginDimensionalityReduction {
    protected int ndims;
    protected double wLearnRate = 0.25;
    protected double bLearnRate = 1.0;
    protected Matrix W;
    protected double b;

    public LargeMarginDimensionalityReduction(int ndims) {
        this.ndims = ndims;
    }

    public LargeMarginDimensionalityReduction(int ndims, double wLearnRate, double bLearnRate) {
        this.ndims = ndims;
    }

    public void initialise(double[][] datai, double[][] dataj, boolean[] same) {
        double[][] data = new double[2 * datai.length][];
        for (int i = 0; i < datai.length; ++i) {
            data[2 * i] = datai[i];
            data[2 * i + 1] = dataj[i];
        }
        ThinSvdPrincipalComponentAnalysis pca = new ThinSvdPrincipalComponentAnalysis(this.ndims);
        pca.learnBasis((double[][])data);
        double[] evs = pca.getEigenValues();
        double[] invStdDev = new double[this.ndims];
        for (int i = 0; i < this.ndims; ++i) {
            invStdDev[i] = 1.0 / Math.sqrt(evs[i]);
        }
        this.W = MatrixUtils.diag((double[])invStdDev).times(pca.getBasis().transpose());
        this.recomputeBias(datai, dataj, same);
    }

    public void recomputeBias(double[][] datai, double[][] dataj, boolean[] same) {
        TDoubleArrayList posDistances = new TDoubleArrayList();
        TDoubleArrayList negDistances = new TDoubleArrayList();
        for (int i = 0; i < datai.length; ++i) {
            Matrix diff = this.diff(datai[i], dataj[i]);
            Matrix diffProj = this.W.times(diff);
            double dist = this.sumsq(diffProj);
            if (same[i]) {
                posDistances.add(dist);
                continue;
            }
            negDistances.add(dist);
        }
        this.b = this.computeOptimal(posDistances, negDistances);
    }

    private double computeOptimal(TDoubleArrayList posDistances, TDoubleArrayList negDistances) {
        double acc;
        double thresh;
        int i;
        double bestAcc = 0.0;
        double bestThresh = -1.7976931348623157E308;
        for (i = 0; i < posDistances.size(); ++i) {
            thresh = posDistances.get(i);
            acc = this.computeAccuracy(posDistances, negDistances, thresh);
            if (!(acc > bestAcc)) continue;
            bestAcc = acc;
            bestThresh = thresh;
        }
        for (i = 0; i < negDistances.size(); ++i) {
            thresh = negDistances.get(i);
            acc = this.computeAccuracy(posDistances, negDistances, thresh);
            if (!(acc > bestAcc)) continue;
            bestAcc = acc;
            bestThresh = thresh;
        }
        return bestThresh;
    }

    private double computeAccuracy(TDoubleArrayList posDistances, TDoubleArrayList negDistances, double thresh) {
        int i;
        int correct = 0;
        for (i = 0; i < posDistances.size(); ++i) {
            if (!(posDistances.get(i) < thresh)) continue;
            ++correct;
        }
        for (i = 0; i < negDistances.size(); ++i) {
            if (!(negDistances.get(i) >= thresh)) continue;
            ++correct;
        }
        return (double)correct / (double)(posDistances.size() + negDistances.size());
    }

    private Matrix diff(double[] phii, double[] phij) {
        Matrix diff = new Matrix(phii.length, 1);
        double[][] diffv = diff.getArray();
        for (int i = 0; i < phii.length; ++i) {
            diffv[i][0] = phii[i] - phij[i];
        }
        return diff;
    }

    private double sumsq(Matrix diffProj) {
        double[][] v = diffProj.getArray();
        double sumsq = 0.0;
        for (int i = 0; i < v.length; ++i) {
            sumsq += v[i][0] * v[i][0];
        }
        return sumsq;
    }

    public boolean step(double[] phii, double[] phij, boolean same) {
        Matrix diff;
        Matrix diffProj;
        double sumsq;
        int yij = same ? 1 : -1;
        if ((double)yij * (this.b - (sumsq = this.sumsq(diffProj = this.W.times(diff = this.diff(phii, phij))))) > 1.0) {
            return false;
        }
        this.fastUpdate(diffProj, this.wLearnRate * (double)yij, diff);
        this.b += (double)yij * this.bLearnRate;
        return true;
    }

    private void fastUpdate(Matrix diffProj, double weight, Matrix diff) {
        double[][] dp = diffProj.getArray();
        double[][] d = diff.getArray();
        double[][] Wdata = this.W.getArray();
        for (int r = 0; r < Wdata.length; ++r) {
            for (int c = 0; c < Wdata.length; ++c) {
                double[] dArray = Wdata[r];
                int n = c;
                dArray[n] = dArray[n] - weight * dp[r][0] * d[c][0];
            }
        }
    }

    public Matrix getTransform() {
        return this.W;
    }

    public double getBias() {
        return this.b;
    }

    public double score(double[] phii, double[] phij) {
        Matrix diff = this.diff(phii, phij);
        Matrix diffProj = this.W.times(diff);
        return this.b - this.sumsq(diffProj);
    }

    public boolean classify(double[] phii, double[] phij) {
        return this.score(phii, phij) >= 0.0;
    }

    public double[] project(double[] in) {
        return this.W.times(new Matrix((double[][])new double[][]{in}).transpose()).getColumnPackedCopy();
    }

    public void setBias(double d) {
        this.b = d;
    }

    public void setTransform(Matrix proj) {
        this.W = proj;
    }
}

