/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import java.util.Random;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.sgd.Gradient;
import org.apache.mahout.classifier.sgd.RankingGradient;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;

public class MixedGradient
implements Gradient {
    private final double alpha;
    private final RankingGradient rank;
    private final Gradient basic;
    private final Random random = RandomUtils.getRandom();
    private boolean hasZero;
    private boolean hasOne;

    public MixedGradient(double alpha, int window) {
        this.alpha = alpha;
        this.rank = new RankingGradient(window);
        this.basic = this.rank.getBaseGradient();
    }

    @Override
    public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
        if (this.random.nextDouble() < this.alpha) {
            if (!this.hasZero || !this.hasOne) {
                throw new IllegalStateException();
            }
            return this.rank.apply(groupKey, actual, instance, classifier);
        }
        this.hasZero |= actual == 0;
        this.hasOne |= actual == 1;
        this.rank.addToHistory(actual, instance);
        return this.basic.apply(groupKey, actual, instance, classifier);
    }
}

