package ciir.umass.edu.learning.neuralnet;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/neuralnet/ListNet.class */
public class ListNet extends RankNet {
    public static int nIteration = 1500;
    public static double learningRate = 1.0E-5d;
    public static int nHiddenLayer = 0;

    public ListNet() {
    }

    public ListNet(List<RankList> list, int[] iArr) {
        super(list, iArr);
    }

    protected float[] feedForward(RankList rankList) {
        float[] fArr = new float[rankList.size()];
        for (int i = 0; i < rankList.size(); i++) {
            addInput(rankList.get(i));
            propagate(i);
            fArr[i] = rankList.get(i).getLabel();
        }
        return fArr;
    }

    protected void backPropagate(float[] fArr) {
        PropParameter propParameter = new PropParameter(fArr);
        this.outputLayer.computeDelta(propParameter);
        this.outputLayer.updateWeight(propParameter);
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet
    protected void estimateLoss() {
        this.error = 0.0d;
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.samples.size(); i++) {
            RankList rankList = this.samples.get(i);
            double[] dArr = new double[rankList.size()];
            double d3 = 0.0d;
            for (int i2 = 0; i2 < rankList.size(); i2++) {
                dArr[i2] = eval(rankList.get(i2));
                d += Math.exp(rankList.get(i2).getLabel());
                d2 += Math.exp(dArr[i2]);
            }
            for (int i3 = 0; i3 < rankList.size(); i3++) {
                d3 += (-(Math.exp(rankList.get(i3).getLabel()) / d)) * SimpleMath.logBase2(Math.exp(dArr[i3]) / d2);
            }
            this.error += d3 / rankList.size();
        }
        this.lastError = this.error;
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public void init() {
        PRINT("Initializing... ");
        setInputOutput(this.features.length, 1, 1);
        wire();
        if (this.validationSamples != null) {
            for (int i = 0; i < this.layers.size(); i++) {
                this.bestModelOnValidation.add(new ArrayList());
            }
        }
        Neuron.learningRate = learningRate;
        PRINTLN("[Done]");
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public void learn() {
        PRINTLN("-----------------------------------------");
        PRINTLN("Training starts...");
        PRINTLN("--------------------------------------------------");
        PRINTLN(new int[]{7, 14, 9, 9}, new String[]{"#epoch", "C.E. Loss", String.valueOf(this.scorer.name()) + "-T", String.valueOf(this.scorer.name()) + "-V"});
        PRINTLN("--------------------------------------------------");
        for (int i = 1; i <= nIteration; i++) {
            for (int i2 = 0; i2 < this.samples.size(); i2++) {
                backPropagate(feedForward(this.samples.get(i2)));
                clearNeuronOutputs();
            }
            PRINT(new int[]{7, 14}, new String[]{new StringBuilder(String.valueOf(i)).toString(), new StringBuilder(String.valueOf(SimpleMath.round(this.error, 6))).toString()});
            if (i % 1 == 0) {
                this.scoreOnTrainingData = this.scorer.score(rank(this.samples));
                PRINT(new int[]{9}, new String[]{new StringBuilder(String.valueOf(SimpleMath.round(this.scoreOnTrainingData, 4))).toString()});
                if (this.validationSamples != null) {
                    double score = this.scorer.score(rank(this.validationSamples));
                    if (score > this.bestScoreOnValidationData) {
                        this.bestScoreOnValidationData = score;
                        saveBestModelOnValidation();
                    }
                    PRINT(new int[]{9}, new String[]{new StringBuilder(String.valueOf(SimpleMath.round(score, 4))).toString()});
                }
            }
            PRINTLN("");
        }
        if (this.validationSamples != null) {
            restoreBestModelOnValidation();
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(rank(this.samples)), 4);
        PRINTLN("--------------------------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(String.valueOf(this.scorer.name()) + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(String.valueOf(this.scorer.name()) + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("---------------------------------");
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        return super.eval(dataPoint);
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    /* renamed from: clone */
    public Ranker m5clone() {
        return new ListNet();
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public String toString() {
        return super.toString();
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public String model() {
        String str = String.valueOf(String.valueOf("## " + name() + "\n") + "## Epochs = " + nIteration + "\n") + "## No. of features = " + this.features.length + "\n";
        int i = 0;
        while (i < this.features.length) {
            str = String.valueOf(str) + this.features[i] + (i == this.features.length - 1 ? "" : " ");
            i++;
        }
        return String.valueOf(String.valueOf(String.valueOf(str) + "\n") + "0\n") + toString();
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public void load(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream, "ASCII"));
            ArrayList arrayList = new ArrayList();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (trim.length() != 0 && trim.indexOf("##") != 0) {
                    arrayList.add(trim);
                }
            }
            bufferedReader.close();
            String[] split = ((String) arrayList.get(0)).split(" ");
            this.features = new int[split.length];
            for (int i = 0; i < split.length; i++) {
                this.features[i] = Integer.parseInt(split[i]);
            }
            int parseInt = Integer.parseInt((String) arrayList.get(1));
            int[] iArr = new int[parseInt];
            int i2 = 2;
            while (i2 < 2 + parseInt) {
                iArr[i2 - 2] = Integer.parseInt((String) arrayList.get(i2));
                i2++;
            }
            setInputOutput(this.features.length, 1);
            for (int i3 = 0; i3 < parseInt; i3++) {
                addHiddenLayer(iArr[i3]);
            }
            wire();
            while (i2 < arrayList.size()) {
                String[] split2 = ((String) arrayList.get(i2)).split(" ");
                Neuron neuron = this.layers.get(Integer.parseInt(split2[0])).get(Integer.parseInt(split2[1]));
                for (int i4 = 0; i4 < neuron.getOutLinks().size(); i4++) {
                    neuron.getOutLinks().get(i4).setWeight(Double.parseDouble(split2[i4 + 2]));
                }
                i2++;
            }
        } catch (Exception e) {
            System.out.println("Error in ListNet::load(): " + e.toString());
        }
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of epochs: " + nIteration);
        PRINTLN("Learning rate: " + learningRate);
    }

    @Override // ciir.umass.edu.learning.neuralnet.RankNet, ciir.umass.edu.learning.Ranker
    public String name() {
        return "ListNet";
    }
}
