package ciir.umass.edu.learning;

import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/CoorAscent.class */
public class CoorAscent extends Ranker {
    public static int nRestart = 2;
    public static int nMaxIteration = 25;
    public static double stepBase = 0.05d;
    public static double stepScale = 2.0d;
    public static double tolerance = 0.001d;
    public static boolean regularized = false;
    public static double slack = 0.001d;
    protected double[] weight;
    protected int current_feature;
    protected double weight_change;

    public CoorAscent() {
        this.weight = null;
        this.current_feature = -1;
        this.weight_change = -1.0d;
    }

    public CoorAscent(List<RankList> list, int[] iArr) {
        super(list, iArr);
        this.weight = null;
        this.current_feature = -1;
        this.weight_change = -1.0d;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void init() {
        PRINT("Initializing... ");
        this.weight = new double[this.features.length];
        for (int i = 0; i < this.weight.length; i++) {
            this.weight[i] = 1.0f / this.features.length;
        }
        PRINTLN("[Done]");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void learn() {
        double[] dArr = new double[this.weight.length];
        copy(this.weight, dArr);
        double[] dArr2 = null;
        double d = 0.0d;
        int[] iArr = {1, -1};
        PRINTLN("---------------------------");
        PRINTLN("Training starts...");
        PRINTLN("---------------------------");
        for (int i = 0; i < nRestart; i++) {
            PRINTLN("[+] Random restart #" + (i + 1) + "/" + nRestart + "...");
            int i2 = 0;
            for (int i3 = 0; i3 < this.weight.length; i3++) {
                this.weight[i3] = 1.0f / this.features.length;
            }
            this.current_feature = -1;
            double score = this.scorer.score(rank(this.samples));
            double d2 = score;
            double[] dArr3 = new double[this.weight.length];
            copy(this.weight, dArr3);
            do {
                if ((this.weight.length <= 1 || i2 >= this.weight.length - 1) && (this.weight.length != 1 || i2 != 0)) {
                    break;
                }
                PRINTLN("Shuffling features' order... [Done.]");
                PRINTLN("Optimizing weight vector... ");
                PRINTLN("------------------------------");
                PRINTLN(new int[]{7, 8, 7}, new String[]{"Feature", "weight", this.scorer.name()});
                PRINTLN("------------------------------");
                int[] shuffledFeatures = getShuffledFeatures();
                for (int i4 = 0; i4 < shuffledFeatures.length; i4++) {
                    this.current_feature = shuffledFeatures[i4];
                    double d3 = this.weight[shuffledFeatures[i4]];
                    double d4 = d3;
                    boolean z = false;
                    for (int i5 : iArr) {
                        double d5 = 0.001d;
                        if (d3 != 0.0d && 0.001d > 0.5d * Math.abs(d3)) {
                            d5 = stepBase * Math.abs(d3);
                        }
                        double d6 = d5;
                        for (int i6 = 0; i6 < nMaxIteration; i6++) {
                            double d7 = d3 + (d6 * i5);
                            this.weight_change = d7 - this.weight[shuffledFeatures[i4]];
                            this.weight[shuffledFeatures[i4]] = d7;
                            double score2 = this.scorer.score(rank(this.samples));
                            if (regularized) {
                                score2 -= slack * getDistance(this.weight, dArr);
                            }
                            if (score2 > d2) {
                                d2 = score2;
                                d4 = this.weight[shuffledFeatures[i4]];
                                z = true;
                                PRINTLN(new int[]{7, 8, 7}, new String[]{new StringBuilder(String.valueOf(this.features[shuffledFeatures[i4]])).toString(), new StringBuilder(String.valueOf(String.valueOf(d4 > 0.0d ? "+" : "") + SimpleMath.round(d4, 4))).toString(), new StringBuilder(String.valueOf(SimpleMath.round(d2, 4))).toString()});
                            }
                            d5 *= stepScale;
                            d6 += d5;
                        }
                        if (z) {
                            break;
                        }
                        this.weight_change = d3 - this.weight[shuffledFeatures[i4]];
                        updateCached();
                        this.weight[shuffledFeatures[i4]] = d3;
                    }
                    if (z) {
                        this.weight_change = d4 - this.weight[shuffledFeatures[i4]];
                        updateCached();
                        this.weight[shuffledFeatures[i4]] = d4;
                        i2 = 0;
                        scaleCached(normalize(this.weight));
                        copy(this.weight, dArr3);
                    } else {
                        i2++;
                        this.weight_change = d3 - this.weight[shuffledFeatures[i4]];
                        updateCached();
                        this.weight[shuffledFeatures[i4]] = d3;
                    }
                }
                PRINTLN("------------------------------");
            } while (d2 - score >= tolerance);
            if (dArr2 == null || d2 > d) {
                d = d2;
                dArr2 = dArr3;
            }
        }
        copy(dArr2, this.weight);
        this.current_feature = -1;
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(rank(this.samples)), 4);
        PRINTLN("---------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(String.valueOf(this.scorer.name()) + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(String.valueOf(this.scorer.name()) + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("---------------------------------");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public RankList rank(RankList rankList) {
        double[] dArr = new double[rankList.size()];
        if (this.current_feature == -1) {
            for (int i = 0; i < rankList.size(); i++) {
                for (int i2 = 0; i2 < this.features.length; i2++) {
                    int i3 = i;
                    dArr[i3] = dArr[i3] + (this.weight[i2] * rankList.get(i).getFeatureValue(this.features[i2]));
                }
                rankList.get(i).setCached(dArr[i]);
            }
        } else {
            for (int i4 = 0; i4 < rankList.size(); i4++) {
                dArr[i4] = rankList.get(i4).getCached() + (this.weight_change * rankList.get(i4).getFeatureValue(this.features[this.current_feature]));
                rankList.get(i4).setCached(dArr[i4]);
            }
        }
        return new RankList(rankList, MergeSorter.sort(dArr, false));
    }

    @Override // ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.features.length; i++) {
            d += this.weight[i] * dataPoint.getFeatureValue(this.features[i]);
        }
        return d;
    }

    @Override // ciir.umass.edu.learning.Ranker
    /* renamed from: clone */
    public Ranker m5clone() {
        return new CoorAscent();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String toString() {
        String str = "";
        int i = 0;
        while (i < this.weight.length) {
            str = String.valueOf(str) + this.features[i] + ":" + this.weight[i] + (i == this.weight.length - 1 ? "" : " ");
            i++;
        }
        return str;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String model() {
        return String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf(String.valueOf("## " + name() + "\n") + "## Restart = " + nRestart + "\n") + "## MaxIteration = " + nMaxIteration + "\n") + "## StepBase = " + stepBase + "\n") + "## StepScale = " + stepScale + "\n") + "## Tolerance = " + tolerance + "\n") + "## Regularized = " + regularized + "\n") + "## Slack = " + slack + "\n") + toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void load(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "ASCII"));
            KeyValuePair keyValuePair = null;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (trim.length() != 0 && trim.indexOf("##") != 0) {
                    keyValuePair = new KeyValuePair(trim);
                    break;
                }
            }
            bufferedReader.close();
            List<String> keys = keyValuePair.keys();
            List<String> values = keyValuePair.values();
            this.weight = new double[keys.size()];
            this.features = new int[keys.size()];
            for (int i = 0; i < keys.size(); i++) {
                this.features[i] = Integer.parseInt(keys.get(i));
                this.weight[i] = Double.parseDouble(values.get(i));
            }
        } catch (Exception e) {
            System.out.println("Error in CoorAscent::load(): " + e.toString());
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of random restarts: " + nRestart);
        PRINTLN("No. of iterations to search in each direction: " + nMaxIteration);
        PRINTLN("Tolerance: " + tolerance);
        if (regularized) {
            PRINTLN("Reg. param: " + slack);
        } else {
            PRINTLN("Regularization: No");
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String name() {
        return "Coordinate Ascent";
    }

    private void updateCached() {
        for (int i = 0; i < this.samples.size(); i++) {
            RankList rankList = this.samples.get(i);
            for (int i2 = 0; i2 < rankList.size(); i2++) {
                rankList.get(i2).setCached(rankList.get(i2).getCached() + (this.weight_change * rankList.get(i2).getFeatureValue(this.features[this.current_feature])));
            }
        }
    }

    private void scaleCached(double d) {
        for (int i = 0; i < this.samples.size(); i++) {
            RankList rankList = this.samples.get(i);
            for (int i2 = 0; i2 < rankList.size(); i2++) {
                rankList.get(i2).setCached(rankList.get(i2).getCached() / d);
            }
        }
    }

    private int[] getShuffledFeatures() {
        int[] iArr = new int[this.features.length];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.features.length; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.shuffle(arrayList);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            iArr[i2] = ((Integer) arrayList.get(i2)).intValue();
        }
        return iArr;
    }

    private double getDistance(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += Math.abs(dArr[i]);
            d2 += Math.abs(dArr2[i]);
        }
        double d3 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d4 = (dArr[i2] / d) - (dArr2[i2] / d2);
            d3 += d4 * d4;
        }
        return Math.sqrt(d3);
    }

    private double normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.abs(d2);
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
        return d;
    }

    public void copyModel(CoorAscent coorAscent) {
        this.weight = new double[this.features.length];
        if (coorAscent.weight.length != this.weight.length) {
            System.out.println("These two models use different feature set!!");
            System.exit(1);
        }
        copy(coorAscent.weight, this.weight);
        PRINTLN("Model loaded.");
    }

    public double distance(CoorAscent coorAscent) {
        return getDistance(this.weight, coorAscent.weight);
    }
}
