package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RANKER_TYPE;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.RankerFactory;
import ciir.umass.edu.learning.Sampler;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/tree/RFRanker.class */
public class RFRanker extends Ranker {
    public static int nBag = 300;
    public static float subSamplingRate = 1.0f;
    public static float featureSamplingRate = 0.3f;
    public static RANKER_TYPE rType = RANKER_TYPE.MART;
    public static int nTrees = 1;
    public static int nTreeLeaves = 100;
    public static float learningRate = 0.1f;
    public static int nThreshold = 256;
    public static int minLeafSupport = 1;
    protected Ensemble[] ensembles;

    public RFRanker() {
        this.ensembles = null;
    }

    public RFRanker(List<RankList> list, int[] iArr) {
        super(list, iArr);
        this.ensembles = null;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void init() {
        PRINT("Initializing... ");
        this.ensembles = new Ensemble[nBag];
        LambdaMART.nTrees = nTrees;
        LambdaMART.nTreeLeaves = nTreeLeaves;
        LambdaMART.learningRate = learningRate;
        LambdaMART.nThreshold = nThreshold;
        LambdaMART.minLeafSupport = minLeafSupport;
        LambdaMART.nRoundToStopEarly = -1;
        FeatureHistogram.samplingRate = featureSamplingRate;
        PRINTLN("[Done]");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void learn() {
        RankerFactory rankerFactory = new RankerFactory();
        PRINTLN("------------------------------------");
        PRINTLN("Training starts...");
        PRINTLN("------------------------------------");
        PRINTLN(new int[]{9, 9, 11}, new String[]{"bag", String.valueOf(this.scorer.name()) + "-B", String.valueOf(this.scorer.name()) + "-OOB"});
        PRINTLN("------------------------------------");
        for (int i = 0; i < nBag; i++) {
            System.gc();
            LambdaMART lambdaMART = (LambdaMART) rankerFactory.createRanker(rType, new Sampler().doSampling(this.samples, subSamplingRate, true), this.features);
            Ranker.verbose = false;
            lambdaMART.init();
            lambdaMART.set(this.scorer);
            lambdaMART.learn();
            Ranker.verbose = true;
            PRINTLN(new int[]{9, 9}, new String[]{"b[" + (i + 1) + "]", new StringBuilder(String.valueOf(SimpleMath.round(lambdaMART.getScoreOnTrainingData(), 4))).toString()});
            this.ensembles[i] = lambdaMART.getEnsemble();
        }
        this.scoreOnTrainingData = this.scorer.score(rank(this.samples));
        PRINTLN("------------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(String.valueOf(this.scorer.name()) + " on training data: " + SimpleMath.round(this.scoreOnTrainingData, 4));
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(String.valueOf(this.scorer.name()) + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("------------------------------------");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.ensembles.length; i++) {
            d += this.ensembles[i].eval(dataPoint);
        }
        return d / this.ensembles.length;
    }

    @Override // ciir.umass.edu.learning.Ranker
    /* renamed from: clone */
    public Ranker m5clone() {
        return new RFRanker();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String toString() {
        String str = "";
        for (int i = 0; i < nBag; i++) {
            str = String.valueOf(str) + this.ensembles[i].toString() + "\n";
        }
        return str;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String model() {
        return String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("## " + name() + "\n") + "## No. of bags = " + nBag + "\n") + "## Sub-sampling = " + subSamplingRate + "\n") + "## Feature-sampling = " + featureSamplingRate + "\n") + "## No. of trees = " + nTrees + "\n") + "## No. of leaves = " + nTreeLeaves + "\n") + "## No. of threshold candidates = " + nThreshold + "\n") + "## Learning rate = " + learningRate + "\n") + "\n") + toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void load(InputStream inputStream) {
        try {
            String str = "";
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "ASCII"));
            ArrayList arrayList = new ArrayList();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (trim.length() != 0 && trim.indexOf("##") != 0) {
                    str = String.valueOf(str) + trim;
                    if (trim.indexOf("</ensemble>") != -1) {
                        arrayList.add(new Ensemble(str));
                        str = "";
                    }
                }
            }
            bufferedReader.close();
            this.ensembles = new Ensemble[arrayList.size()];
            for (int i = 0; i < arrayList.size(); i++) {
                this.ensembles[i] = (Ensemble) arrayList.get(i);
            }
        } catch (Exception e) {
            System.out.println("Error in RFRanker::load(): " + e.toString());
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of bags: " + nBag);
        PRINTLN("Sub-sampling: " + subSamplingRate);
        PRINTLN("Feature-sampling: " + featureSamplingRate);
        PRINTLN("No. of trees: " + nTrees);
        PRINTLN("No. of leaves: " + nTreeLeaves);
        PRINTLN("No. of threshold candidates: " + nThreshold);
        PRINTLN("Learning rate: " + learningRate);
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String name() {
        return "Random Forests";
    }

    public Ensemble[] getEnsembles() {
        return this.ensembles;
    }
}
