/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.df.split;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.split.IgSplit;
import org.apache.mahout.classifier.df.split.Split;

public class RegressionSplit
extends IgSplit {
    @Override
    public Split computeSplit(Data data, int attr) {
        if (data.getDataset().isNumerical(attr)) {
            return RegressionSplit.numericalSplit(data, attr);
        }
        return RegressionSplit.categoricalSplit(data, attr);
    }

    private static Split categoricalSplit(Data data, int attr) {
        FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
        double[] sk = new double[data.getDataset().nbValues(attr)];
        for (int i = 0; i < ra.length; ++i) {
            ra[i] = new FullRunningAverage();
        }
        FullRunningAverage totalRa = new FullRunningAverage();
        double totalSk = 0.0;
        for (int i = 0; i < data.size(); ++i) {
            double mk;
            Instance instance = data.get(i);
            int value = (int)instance.get(attr);
            double xk = data.getDataset().getLabel(instance);
            if (ra[value].getCount() == 0) {
                ra[value].addDatum(xk);
                sk[value] = 0.0;
            } else {
                mk = ra[value].getAverage();
                ra[value].addDatum(xk);
                int n = value;
                sk[n] = sk[n] + (xk - mk) * (xk - ra[value].getAverage());
            }
            if (i == 0) {
                totalRa.addDatum(xk);
                totalSk = 0.0;
                continue;
            }
            mk = totalRa.getAverage();
            totalRa.addDatum(xk);
            totalSk += (xk - mk) * (xk - totalRa.getAverage());
        }
        double ig = totalSk;
        for (double aSk : sk) {
            ig -= aSk;
        }
        return new Split(attr, ig);
    }

    private static Split numericalSplit(Data data, int attr) {
        FullRunningAverage[] ra = new FullRunningAverage[2];
        for (int i = 0; i < ra.length; ++i) {
            ra[i] = new FullRunningAverage();
        }
        Instance[] instances = new Instance[data.size()];
        for (int i = 0; i < data.size(); ++i) {
            instances[i] = data.get(i);
        }
        Arrays.sort(instances, new InstanceComparator(attr));
        double[] sk = new double[2];
        for (Instance instance : instances) {
            double xk = data.getDataset().getLabel(instance);
            if (ra[1].getCount() == 0) {
                ra[1].addDatum(xk);
                sk[1] = 0.0;
                continue;
            }
            double mk = ra[1].getAverage();
            ra[1].addDatum(xk);
            sk[1] = sk[1] + (xk - mk) * (xk - ra[1].getAverage());
        }
        double totalSk = sk[1];
        double split = Double.NaN;
        double preSplit = Double.NaN;
        double bestVal = Double.MAX_VALUE;
        double bestSk = 0.0;
        for (Instance instance : instances) {
            double mk;
            double curVal;
            double xk = data.getDataset().getLabel(instance);
            if (instance.get(attr) > preSplit && (curVal = sk[0] / (double)ra[0].getCount() + sk[1] / (double)ra[1].getCount()) < bestVal) {
                bestVal = curVal;
                bestSk = sk[0] + sk[1];
                split = (instance.get(attr) + preSplit) / 2.0;
            }
            if (ra[0].getCount() == 0) {
                ra[0].addDatum(xk);
                sk[0] = 0.0;
            } else {
                mk = ra[0].getAverage();
                ra[0].addDatum(xk);
                sk[0] = sk[0] + (xk - mk) * (xk - ra[0].getAverage());
            }
            mk = ra[1].getAverage();
            ra[1].removeDatum(xk);
            sk[1] = sk[1] - (xk - mk) * (xk - ra[1].getAverage());
            preSplit = instance.get(attr);
        }
        double ig = totalSk - bestSk;
        return new Split(attr, ig, split);
    }

    private static class InstanceComparator
    implements Comparator<Instance>,
    Serializable {
        private final int attr;

        InstanceComparator(int attr) {
            this.attr = attr;
        }

        @Override
        public int compare(Instance arg0, Instance arg1) {
            return Double.compare(arg0.get(this.attr), arg1.get(this.attr));
        }
    }
}

