package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/indexing/ShapeOffsetResolution.class */
public class ShapeOffsetResolution implements Serializable {
    private INDArray arr;
    private int[] offsets;
    private int[] shapes;
    private int[] strides;
    private int offset = -1;

    public ShapeOffsetResolution(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public void exec(INDArrayIndex... iNDArrayIndexArr) {
        INDArrayIndex[] resolve = NDArrayIndex.resolve(this.arr.shape(), iNDArrayIndexArr);
        int[] shape = this.arr.shape();
        int i = 0;
        int i2 = 0;
        boolean z = false;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        ArrayList arrayList7 = new ArrayList();
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        ArrayList arrayList8 = new ArrayList();
        for (int i6 = 0; i6 < resolve.length; i6++) {
            INDArrayIndex iNDArrayIndex = resolve[i6];
            if (iNDArrayIndex instanceof NDArrayIndexAll) {
                z = true;
                if (i6 < this.arr.rank() && this.arr.size(i6) == 1) {
                    arrayList.add(Integer.valueOf(i6));
                }
            }
            if (iNDArrayIndex instanceof PointIndex) {
                arrayList7.add(Integer.valueOf(iNDArrayIndex.offset()));
                arrayList6.add(Integer.valueOf(this.arr.stride(i5)));
                i3++;
                i4++;
                i5++;
            } else if (iNDArrayIndex instanceof NewAxis) {
                if (z) {
                    arrayList8.add(Integer.valueOf(i6));
                } else {
                    i2++;
                }
            } else if ((!(iNDArrayIndex instanceof IntervalIndex) || (iNDArrayIndex instanceof NDArrayIndexAll)) && !(iNDArrayIndex instanceof SpecifiedIndex)) {
                int i7 = i4;
                i4++;
                arrayList2.add(Integer.valueOf(shape[i7]));
                int i8 = i5;
                i5++;
                arrayList3.add(Integer.valueOf(this.arr.stride(i8)));
                arrayList4.add(Integer.valueOf(iNDArrayIndex.offset()));
            } else {
                if (iNDArrayIndex instanceof IntervalIndex) {
                    arrayList3.add(Integer.valueOf(this.arr.stride(i5) * iNDArrayIndex.stride()));
                    arrayList5.add(Integer.valueOf(iNDArrayIndex.stride()));
                    i++;
                } else {
                    arrayList3.add(Integer.valueOf(this.arr.stride(i5)));
                }
                arrayList2.add(Integer.valueOf(iNDArrayIndex.length()));
                if (iNDArrayIndex instanceof IntervalIndex) {
                    arrayList4.add(Integer.valueOf(iNDArrayIndex.offset()));
                } else {
                    arrayList4.add(Integer.valueOf(iNDArrayIndex.offset()));
                }
                i4++;
                i5++;
            }
        }
        while (i4 < shape.length) {
            if (Shape.isVector(shape)) {
                arrayList2.add(1);
                i4++;
            } else {
                int i9 = i4;
                i4++;
                arrayList2.add(Integer.valueOf(shape[i9]));
            }
        }
        int length = shape.length <= 2 ? shape.length : shape.length - i3;
        boolean z2 = (arrayList2.size() == arrayList3.size() || arrayList4.size() == arrayList2.size()) ? false : true;
        while (arrayList4.size() < length && z2) {
            arrayList4.add(0);
        }
        while (arrayList2.size() < 2) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                arrayList2.add(0, 1);
            } else {
                arrayList2.add(1);
            }
        }
        while (i5 < arrayList2.size()) {
            int i10 = i5;
            i5++;
            arrayList3.add(Integer.valueOf(this.arr.stride(i10)));
        }
        if (i2 > 0) {
            for (int i11 = 0; i11 < i2; i11++) {
                arrayList2.add(0, 1);
                arrayList3.add(0, 0);
                arrayList4.add(0, 0);
            }
        }
        int i12 = 0;
        for (int i13 = 0; i13 < arrayList8.size(); i13++) {
            arrayList2.add(((Integer) arrayList8.get(i13)).intValue() - i12, 1);
            arrayList3.add(((Integer) arrayList8.get(i13)).intValue() - i12, 0);
            i12++;
        }
        int size = arrayList4.size() - 1;
        while (arrayList4.size() > arrayList2.size()) {
            if (((Integer) arrayList4.get(size)).intValue() == 0) {
                arrayList4.remove(arrayList4.size() - 1);
            }
            size--;
        }
        if (arrayList3.size() < arrayList4.size()) {
            arrayList3.addAll(arrayList6);
        }
        while (arrayList4.size() < arrayList2.size()) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                arrayList4.add(0, 0);
            } else {
                arrayList4.add(0);
            }
        }
        if (Shape.isMatrix(shape) && (resolve[0] instanceof PointIndex) && (resolve[1] instanceof NDArrayIndexAll)) {
            Collections.reverse(arrayList2);
        }
        this.shapes = Ints.toArray(arrayList2);
        boolean isColumnVectorShape = Shape.isColumnVectorShape(this.shapes);
        while (arrayList3.size() < arrayList4.size()) {
            if (isColumnVectorShape) {
                arrayList3.add(Integer.valueOf(this.arr.elementStride()));
            } else {
                arrayList3.add(0, Integer.valueOf(this.arr.elementStride()));
            }
        }
        this.strides = Ints.toArray(arrayList3);
        this.offsets = Ints.toArray(arrayList4);
        if (i3 <= 0 || arrayList6.isEmpty()) {
            if (i <= 0 || this.arr.rank() <= 2) {
                this.offset = ArrayUtil.calcOffset(arrayList2, arrayList4, arrayList3);
                return;
            } else if (!z || this.arr.size(0) == 1) {
                this.offset = ArrayUtil.dotProduct(arrayList4, arrayList3) / i;
                return;
            } else {
                this.offset = ArrayUtil.dotProduct(arrayList4, arrayList3);
                return;
            }
        }
        if (i2 >= 1) {
            while (arrayList6.size() < arrayList4.size()) {
                arrayList6.add(1);
            }
            for (int i14 = 0; i14 < arrayList3.size(); i14++) {
                if (((Integer) arrayList3.get(i14)).intValue() == 0) {
                    arrayList6.set(i14, 0);
                }
            }
        }
        while (arrayList7.size() < arrayList6.size()) {
            arrayList7.add(0);
        }
        if (this.arr.isRowVector() && !arrayList5.isEmpty() && ((Integer) arrayList7.get(0)).intValue() == 0) {
            this.offset = resolve[1].offset();
        } else {
            this.offset = ArrayUtil.dotProduct(arrayList7, arrayList6);
        }
    }

    public INDArray getArr() {
        return this.arr;
    }

    public void setArr(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public int[] getOffsets() {
        return this.offsets;
    }

    public void setOffsets(int[] iArr) {
        this.offsets = iArr;
    }

    public int[] getShapes() {
        return this.shapes;
    }

    public void setShapes(int[] iArr) {
        this.shapes = iArr;
    }

    public int[] getStrides() {
        return this.strides;
    }

    public void setStrides(int[] iArr) {
        this.strides = iArr;
    }

    public int getOffset() {
        return this.offset;
    }

    public void setOffset(int i) {
        this.offset = i;
    }
}
