package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/tree/Split.class */
public class Split {
    private int featureID;
    private float threshold;
    private float avgLabel;
    private double sumLabel;
    private double sqSumLabel;
    private Split left;
    private Split right;
    private float deviance;
    private int[][] sortedSampleIDs;
    public int[] samples;
    public FeatureHistogram hist;

    public Split() {
        this.featureID = -1;
        this.threshold = 0.0f;
        this.avgLabel = 0.0f;
        this.sumLabel = 0.0d;
        this.sqSumLabel = 0.0d;
        this.left = null;
        this.right = null;
        this.deviance = 0.0f;
        this.sortedSampleIDs = null;
        this.samples = null;
        this.hist = null;
    }

    public Split(int i, float f, float f2) {
        this.featureID = -1;
        this.threshold = 0.0f;
        this.avgLabel = 0.0f;
        this.sumLabel = 0.0d;
        this.sqSumLabel = 0.0d;
        this.left = null;
        this.right = null;
        this.deviance = 0.0f;
        this.sortedSampleIDs = null;
        this.samples = null;
        this.hist = null;
        this.featureID = i;
        this.threshold = f;
        this.deviance = f2;
    }

    public Split(int[][] iArr, float f, double d, double d2) {
        this.featureID = -1;
        this.threshold = 0.0f;
        this.avgLabel = 0.0f;
        this.sumLabel = 0.0d;
        this.sqSumLabel = 0.0d;
        this.left = null;
        this.right = null;
        this.deviance = 0.0f;
        this.sortedSampleIDs = null;
        this.samples = null;
        this.hist = null;
        this.sortedSampleIDs = iArr;
        this.deviance = f;
        this.sumLabel = d;
        this.sqSumLabel = d2;
        this.avgLabel = (float) (d / iArr[0].length);
    }

    public Split(int[] iArr, FeatureHistogram featureHistogram, float f, double d) {
        this.featureID = -1;
        this.threshold = 0.0f;
        this.avgLabel = 0.0f;
        this.sumLabel = 0.0d;
        this.sqSumLabel = 0.0d;
        this.left = null;
        this.right = null;
        this.deviance = 0.0f;
        this.sortedSampleIDs = null;
        this.samples = null;
        this.hist = null;
        this.samples = iArr;
        this.hist = featureHistogram;
        this.deviance = f;
        this.sumLabel = d;
        this.avgLabel = (float) (d / iArr.length);
    }

    public void set(int i, float f, float f2) {
        this.featureID = i;
        this.threshold = f;
        this.deviance = f2;
    }

    public void setLeft(Split split) {
        this.left = split;
    }

    public void setRight(Split split) {
        this.right = split;
    }

    public void setOutput(float f) {
        this.avgLabel = f;
    }

    public Split getLeft() {
        return this.left;
    }

    public Split getRight() {
        return this.right;
    }

    public float getDeviance() {
        return this.deviance;
    }

    public float getOutput() {
        return this.avgLabel;
    }

    public List<Split> leaves() {
        ArrayList arrayList = new ArrayList();
        leaves(arrayList);
        return arrayList;
    }

    private void leaves(List<Split> list) {
        if (this.featureID == -1) {
            list.add(this);
        } else {
            this.left.leaves(list);
            this.right.leaves(list);
        }
    }

    public float eval(DataPoint dataPoint) {
        return this.featureID == -1 ? this.avgLabel : dataPoint.getFeatureValue(this.featureID) <= this.threshold ? this.left.eval(dataPoint) : this.right.eval(dataPoint);
    }

    public String toString() {
        return toString("");
    }

    public String toString(String str) {
        return String.valueOf(String.valueOf(String.valueOf(str) + "<split>\n") + getString(String.valueOf(str) + "\t")) + str + "</split>\n";
    }

    public String getString(String str) {
        return this.featureID == -1 ? String.valueOf("") + str + "<output> " + this.avgLabel + " </output>\n" : String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("") + str + "<feature> " + this.featureID + " </feature>\n") + str + "<threshold> " + this.threshold + " </threshold>\n") + str + "<split pos=\"left\">\n") + this.left.getString(String.valueOf(str) + "\t")) + str + "</split>\n") + str + "<split pos=\"right\">\n") + this.right.getString(String.valueOf(str) + "\t")) + str + "</split>\n";
    }

    public int[] getSamples() {
        return this.sortedSampleIDs != null ? this.sortedSampleIDs[0] : this.samples;
    }

    public int[][] getSampleSortedIndex() {
        return this.sortedSampleIDs;
    }

    public double getSumLabel() {
        return this.sumLabel;
    }

    public double getSqSumLabel() {
        return this.sqSumLabel;
    }

    public void clearSamples() {
        this.sortedSampleIDs = null;
        this.samples = null;
        this.hist = null;
    }
}
