package ciir.umass.edu.learning.boosting;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/boosting/AdaRank.class */
public class AdaRank extends Ranker {
    public static int nIteration = 500;
    public static double tolerance = 0.002d;
    public static boolean trainWithEnqueue = true;
    public static int maxSelCount = 5;
    protected Hashtable<Integer, Integer> usedFeatures;
    protected double[] sweight;
    protected List<WeakRanker> rankers;
    protected List<Double> rweight;
    protected List<WeakRanker> bestModelRankers;
    protected List<Double> bestModelWeights;
    int lastFeature;
    int lastFeatureConsecutiveCount;
    boolean performanceChanged;
    List<Integer> featureQueue;
    protected double[] backupSampleWeight;
    protected double backupTrainScore;
    protected double lastTrainedScore;

    public AdaRank() {
        this.usedFeatures = new Hashtable<>();
        this.sweight = null;
        this.rankers = null;
        this.rweight = null;
        this.bestModelRankers = null;
        this.bestModelWeights = null;
        this.lastFeature = -1;
        this.lastFeatureConsecutiveCount = 0;
        this.performanceChanged = false;
        this.featureQueue = null;
        this.backupSampleWeight = null;
        this.backupTrainScore = 0.0d;
        this.lastTrainedScore = -1.0d;
    }

    public AdaRank(List<RankList> list, int[] iArr) {
        super(list, iArr);
        this.usedFeatures = new Hashtable<>();
        this.sweight = null;
        this.rankers = null;
        this.rweight = null;
        this.bestModelRankers = null;
        this.bestModelWeights = null;
        this.lastFeature = -1;
        this.lastFeatureConsecutiveCount = 0;
        this.performanceChanged = false;
        this.featureQueue = null;
        this.backupSampleWeight = null;
        this.backupTrainScore = 0.0d;
        this.lastTrainedScore = -1.0d;
    }

    private void updateBestModelOnValidation() {
        this.bestModelRankers.clear();
        this.bestModelRankers.addAll(this.rankers);
        this.bestModelWeights.clear();
        this.bestModelWeights.addAll(this.rweight);
    }

    private WeakRanker learnWeakRanker() {
        double d = -1.0d;
        WeakRanker weakRanker = null;
        for (int i = 0; i < this.features.length; i++) {
            int i2 = this.features[i];
            if (!this.featureQueue.contains(Integer.valueOf(i2)) && this.usedFeatures.get(Integer.valueOf(i2)) == null) {
                WeakRanker weakRanker2 = new WeakRanker(i2);
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.samples.size(); i3++) {
                    d2 += this.scorer.score(weakRanker2.rank(this.samples.get(i3))) * this.sweight[i3];
                }
                if (d < d2) {
                    d = d2;
                    weakRanker = weakRanker2;
                }
            }
        }
        return weakRanker;
    }

    private int learn(int i, boolean z) {
        int i2 = i;
        while (true) {
            if (i2 > nIteration) {
                break;
            }
            PRINT(new int[]{7}, new String[]{new StringBuilder(String.valueOf(i2)).toString()});
            WeakRanker learnWeakRanker = learnWeakRanker();
            if (learnWeakRanker == null) {
                break;
            }
            if (z) {
                if (learnWeakRanker.getFID() == this.lastFeature) {
                    this.featureQueue.add(Integer.valueOf(this.lastFeature));
                    this.rankers.remove(this.rankers.size() - 1);
                    this.rweight.remove(this.rweight.size() - 1);
                    copy(this.backupSampleWeight, this.sweight);
                    this.bestScoreOnValidationData = 0.0d;
                    this.lastTrainedScore = this.backupTrainScore;
                    PRINTLN(new int[]{8, 9, 9, 9}, new String[]{new StringBuilder(String.valueOf(learnWeakRanker.getFID())).toString(), "", "", "ROLLBACK"});
                    i2++;
                } else {
                    this.lastFeature = learnWeakRanker.getFID();
                    copy(this.sweight, this.backupSampleWeight);
                    this.backupTrainScore = this.lastTrainedScore;
                }
            }
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.samples.size(); i3++) {
                double score = this.scorer.score(learnWeakRanker.rank(this.samples.get(i3)));
                d += this.sweight[i3] * (1.0d + score);
                d2 += this.sweight[i3] * (1.0d - score);
            }
            this.rankers.add(learnWeakRanker);
            double ln = 0.5d * SimpleMath.ln(d / d2);
            this.rweight.add(Double.valueOf(ln));
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i4 = 0; i4 < this.samples.size(); i4++) {
                double score2 = this.scorer.score(rank(this.samples.get(i4)));
                d4 += Math.exp((-ln) * score2);
                d3 += score2;
            }
            double size = d3 / this.samples.size();
            double d5 = (size + tolerance) - this.lastTrainedScore;
            String str = d5 > 0.0d ? "OK" : "DAMN";
            if (!z) {
                if (size != this.lastTrainedScore) {
                    this.performanceChanged = true;
                    this.lastFeatureConsecutiveCount = 0;
                    this.usedFeatures.clear();
                } else {
                    this.performanceChanged = false;
                    if (this.lastFeature == learnWeakRanker.getFID()) {
                        this.lastFeatureConsecutiveCount++;
                        if (this.lastFeatureConsecutiveCount == maxSelCount) {
                            str = "F. REM.";
                            this.lastFeatureConsecutiveCount = 0;
                            this.usedFeatures.put(Integer.valueOf(this.lastFeature), 1);
                        }
                    } else {
                        this.lastFeatureConsecutiveCount = 0;
                        this.usedFeatures.clear();
                    }
                }
                this.lastFeature = learnWeakRanker.getFID();
            }
            PRINT(new int[]{8, 9}, new String[]{new StringBuilder(String.valueOf(learnWeakRanker.getFID())).toString(), new StringBuilder(String.valueOf(SimpleMath.round(size, 4))).toString()});
            if (i2 % 1 != 0 || this.validationSamples == null) {
                PRINT(new int[]{9, 9}, new String[]{"", str});
            } else {
                double score3 = this.scorer.score(rank(this.validationSamples));
                if (score3 > this.bestScoreOnValidationData) {
                    this.bestScoreOnValidationData = score3;
                    updateBestModelOnValidation();
                }
                PRINT(new int[]{9, 9}, new String[]{new StringBuilder(String.valueOf(SimpleMath.round(score3, 4))).toString(), str});
            }
            PRINTLN("");
            if (d5 <= 0.0d) {
                this.rankers.remove(this.rankers.size() - 1);
                this.rweight.remove(this.rweight.size() - 1);
                break;
            }
            this.lastTrainedScore = size;
            for (int i5 = 0; i5 < this.sweight.length; i5++) {
                double[] dArr = this.sweight;
                int i6 = i5;
                dArr[i6] = dArr[i6] * (Math.exp((-ln) * this.scorer.score(rank(this.samples.get(i5)))) / d4);
            }
            i2++;
        }
        return i2;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void init() {
        PRINT("Initializing... ");
        this.usedFeatures.clear();
        this.sweight = new double[this.samples.size()];
        for (int i = 0; i < this.sweight.length; i++) {
            this.sweight[i] = 1.0f / this.samples.size();
        }
        this.backupSampleWeight = new double[this.sweight.length];
        copy(this.sweight, this.backupSampleWeight);
        this.lastTrainedScore = -1.0d;
        this.rankers = new ArrayList();
        this.rweight = new ArrayList();
        this.featureQueue = new ArrayList();
        this.bestScoreOnValidationData = 0.0d;
        this.bestModelRankers = new ArrayList();
        this.bestModelWeights = new ArrayList();
        PRINTLN("[Done]");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void learn() {
        PRINTLN("---------------------------");
        PRINTLN("Training starts...");
        PRINTLN("--------------------------------------------------------");
        PRINTLN(new int[]{7, 8, 9, 9, 9}, new String[]{"#iter", "Sel. F.", String.valueOf(this.scorer.name()) + "-T", String.valueOf(this.scorer.name()) + "-V", "Status"});
        PRINTLN("--------------------------------------------------------");
        if (trainWithEnqueue) {
            int learn = learn(1, true);
            for (int size = this.featureQueue.size() - 1; size >= 0; size--) {
                this.featureQueue.remove(size);
                learn = learn(learn, false);
            }
        } else {
            learn(1, false);
        }
        if (this.validationSamples != null && this.bestModelRankers.size() > 0) {
            this.rankers.clear();
            this.rweight.clear();
            this.rankers.addAll(this.bestModelRankers);
            this.rweight.addAll(this.bestModelWeights);
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(rank(this.samples)), 4);
        PRINTLN("--------------------------------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(String.valueOf(this.scorer.name()) + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(String.valueOf(this.scorer.name()) + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("---------------------------------");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.rankers.size(); i++) {
            d += this.rweight.get(i).doubleValue() * dataPoint.getFeatureValue(this.rankers.get(i).getFID());
        }
        return d;
    }

    @Override // ciir.umass.edu.learning.Ranker
    /* renamed from: clone */
    public Ranker m5clone() {
        return new AdaRank();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String toString() {
        String str = "";
        int i = 0;
        while (i < this.rankers.size()) {
            str = String.valueOf(str) + this.rankers.get(i).getFID() + ":" + this.rweight.get(i) + (i == this.rankers.size() - 1 ? "" : " ");
            i++;
        }
        return str;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String model() {
        return String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("## " + name() + "\n") + "## Iteration = " + nIteration + "\n") + "## Train with enqueue: " + (trainWithEnqueue ? "Yes" : "No") + "\n") + "## Tolerance = " + tolerance + "\n") + "## Max consecutive selection count = " + maxSelCount + "\n") + toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void load(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "ASCII"));
            KeyValuePair keyValuePair = null;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (trim.length() != 0 && trim.indexOf("##") != 0) {
                    keyValuePair = new KeyValuePair(trim);
                    break;
                }
            }
            bufferedReader.close();
            List<String> keys = keyValuePair.keys();
            List<String> values = keyValuePair.values();
            this.rweight = new ArrayList();
            this.rankers = new ArrayList();
            this.features = new int[keys.size()];
            for (int i = 0; i < keys.size(); i++) {
                this.features[i] = Integer.parseInt(keys.get(i));
                this.rankers.add(new WeakRanker(this.features[i]));
                this.rweight.add(Double.valueOf(Double.parseDouble(values.get(i))));
            }
        } catch (Exception e) {
            System.out.println("Error in AdaRank::load(): " + e.toString());
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of rounds: " + nIteration);
        PRINTLN("Train with 'enequeue': " + (trainWithEnqueue ? "Yes" : "No"));
        PRINTLN("Tolerance: " + tolerance);
        PRINTLN("Max Sel. Count: " + maxSelCount);
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String name() {
        return "AdaRank";
    }
}
