package org.deeplearning4j.spark.text.functions;

import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.spark.text.accumulators.MaxPerPartitionAccumulator;

/* loaded from: input_file:org/deeplearning4j/spark/text/functions/CountCumSum.class */
public class CountCumSum {
    private JavaSparkContext sc;
    private JavaRDD<AtomicLong> sentenceCountRDD;
    private JavaRDD<AtomicLong> foldWithinPartitionRDD;
    private Broadcast<Counter<Integer>> broadcastedMaxPerPartitionCounter;
    private JavaRDD<Long> cumSumRDD;

    public CountCumSum(JavaRDD<AtomicLong> javaRDD) {
        this.sentenceCountRDD = javaRDD;
        this.sc = new JavaSparkContext(javaRDD.context());
    }

    public JavaRDD<Long> getCumSumRDD() {
        if (this.cumSumRDD != null) {
            return this.cumSumRDD;
        }
        throw new IllegalAccessError("Cumulative Sum list not defined. Call buildCumSum() first.");
    }

    public void actionForMapPartition(JavaRDD javaRDD) {
        javaRDD.foreachPartition(new MapPerPartitionVoidFunction());
    }

    public void cumSumWithinPartition() {
        Accumulator accumulator = this.sc.accumulator(new Counter(), new MaxPerPartitionAccumulator());
        this.foldWithinPartitionRDD = this.sentenceCountRDD.mapPartitionsWithIndex(new FoldWithinPartitionFunction(accumulator), true).cache();
        actionForMapPartition(this.foldWithinPartitionRDD);
        this.broadcastedMaxPerPartitionCounter = this.sc.broadcast(accumulator.value());
    }

    public void cumSumBetweenPartition() {
        this.cumSumRDD = this.foldWithinPartitionRDD.mapPartitionsWithIndex(new FoldBetweenPartitionFunction(this.broadcastedMaxPerPartitionCounter), true).setName("cumSumRDD").cache();
        this.foldWithinPartitionRDD.unpersist();
    }

    public JavaRDD<Long> buildCumSum() {
        cumSumWithinPartition();
        cumSumBetweenPartition();
        return getCumSumRDD();
    }
}
