/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.df.split;

import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.split.IgSplit;
import org.apache.mahout.classifier.df.split.Split;

public class OptIgSplit
extends IgSplit {
    private static final int MAX_NUMERIC_SPLITS = 16;

    @Override
    public Split computeSplit(Data data, int attr) {
        if (data.getDataset().isNumerical(attr)) {
            return OptIgSplit.numericalSplit(data, attr);
        }
        return OptIgSplit.categoricalSplit(data, attr);
    }

    private static Split categoricalSplit(Data data, int attr) {
        double[] values = (double[])data.values(attr).clone();
        double[] splitPoints = OptIgSplit.chooseCategoricalSplitPoints(values);
        int numLabels = data.getDataset().nblabels();
        int[][] counts = new int[splitPoints.length][numLabels];
        int[] countAll = new int[numLabels];
        OptIgSplit.computeFrequencies(data, attr, splitPoints, counts, countAll);
        int size = data.size();
        double hy = OptIgSplit.entropy(countAll, size);
        double hyx = 0.0;
        double invDataSize = 1.0 / (double)size;
        for (int index = 0; index < splitPoints.length; ++index) {
            size = DataUtils.sum(counts[index]);
            hyx += (double)size * invDataSize * OptIgSplit.entropy(counts[index], size);
        }
        double ig = hy - hyx;
        return new Split(attr, ig);
    }

    static void computeFrequencies(Data data, int attr, double[] splitPoints, int[][] counts, int[] countAll) {
        Dataset dataset = data.getDataset();
        for (int index = 0; index < data.size(); ++index) {
            int split;
            Instance instance = data.get(index);
            int label = (int)dataset.getLabel(instance);
            double value = instance.get(attr);
            for (split = 0; split < splitPoints.length && value > splitPoints[split]; ++split) {
            }
            if (split < splitPoints.length) {
                int[] nArray = counts[split];
                int n = label;
                nArray[n] = nArray[n] + 1;
            }
            int n = label;
            countAll[n] = countAll[n] + 1;
        }
    }

    static Split numericalSplit(Data data, int attr) {
        double[] values = (double[])data.values(attr).clone();
        Arrays.sort(values);
        double[] splitPoints = OptIgSplit.chooseNumericSplitPoints(values);
        int numLabels = data.getDataset().nblabels();
        int[][] counts = new int[splitPoints.length][numLabels];
        int[] countAll = new int[numLabels];
        int[] countLess = new int[numLabels];
        OptIgSplit.computeFrequencies(data, attr, splitPoints, counts, countAll);
        int size = data.size();
        double hy = OptIgSplit.entropy(countAll, size);
        double invDataSize = 1.0 / (double)size;
        int best = -1;
        double bestIg = -1.0;
        for (int index = 0; index < splitPoints.length; ++index) {
            double ig = hy;
            DataUtils.add(countLess, counts[index]);
            DataUtils.dec(countAll, counts[index]);
            size = DataUtils.sum(countLess);
            ig -= (double)size * invDataSize * OptIgSplit.entropy(countLess, size);
            size = DataUtils.sum(countAll);
            ig -= (double)size * invDataSize * OptIgSplit.entropy(countAll, size);
            if (!(ig > bestIg)) continue;
            bestIg = ig;
            best = index;
        }
        if (best == -1) {
            throw new IllegalStateException("no best split found !");
        }
        return new Split(attr, bestIg, splitPoints[best]);
    }

    private static double[] chooseNumericSplitPoints(double[] values) {
        if (values.length <= 1) {
            return values;
        }
        if (values.length <= 17) {
            double[] splitPoints = new double[values.length - 1];
            for (int i = 1; i < values.length; ++i) {
                splitPoints[i - 1] = (values[i] + values[i - 1]) / 2.0;
            }
            return splitPoints;
        }
        Percentile distribution = new Percentile();
        distribution.setData(values);
        double[] percentiles = new double[16];
        for (int i = 0; i < percentiles.length; ++i) {
            double p = 100.0 * (((double)i + 1.0) / 17.0);
            percentiles[i] = distribution.evaluate(p);
        }
        return percentiles;
    }

    private static double[] chooseCategoricalSplitPoints(double[] values) {
        TreeSet<Double> uniqueOrderedCategories = new TreeSet<Double>();
        for (double v : values) {
            uniqueOrderedCategories.add(v);
        }
        double[] uniqueValues = new double[uniqueOrderedCategories.size()];
        Iterator it = uniqueOrderedCategories.iterator();
        for (int i = 0; i < uniqueValues.length; ++i) {
            uniqueValues[i] = (Double)it.next();
        }
        return uniqueValues;
    }

    private static double entropy(int[] counts, int dataSize) {
        if (dataSize == 0) {
            return 0.0;
        }
        double entropy = 0.0;
        for (int count : counts) {
            if (count <= 0) continue;
            double p = (double)count / (double)dataSize;
            entropy -= p * Math.log(p);
        }
        return entropy / LOG2;
    }
}

