package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/learning/Nesterovs.class */
public class Nesterovs implements Serializable, GradientUpdater {
    private double momentum;
    private INDArray v;
    private double lr;

    public Nesterovs(double d, double d2) {
        this.momentum = 0.5d;
        this.momentum = d;
        this.lr = d2;
    }

    public Nesterovs(double d) {
        this(d, 0.1d);
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public double getLr() {
        return this.lr;
    }

    public void setLr(double d) {
        this.lr = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.v == null) {
            this.v = Nd4j.zeros(iNDArray.shape());
        }
        INDArray iNDArray2 = this.v;
        this.v = iNDArray2.mul(Double.valueOf(this.momentum)).subi(iNDArray.mul(Double.valueOf(this.lr)));
        return iNDArray2.muli(Double.valueOf(this.momentum)).addi(this.v.mul(Double.valueOf((-this.momentum) - 1.0d)));
    }
}
