package org.nd4j.linalg.api.ops.executioner;

import java.util.Arrays;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.complex.LinearViewComplexNDArray;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.LinearViewNDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.class */
public class DefaultOpExecutioner implements OpExecutioner {
    protected OpExecutioner.ExecutionMode executionMode = OpExecutioner.ExecutionMode.JAVA;

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op) {
        checkOp(op);
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }
        if (op instanceof TransformOp) {
            final TransformOp transformOp = (TransformOp) op;
            if (!op.x().getClass().equals(transformOp.z().getClass()) && !(op.x() instanceof LinearViewNDArray) && !(transformOp.z() instanceof LinearViewNDArray)) {
                throw new IllegalArgumentException("Illegal operation. Origin and output ndarray must be same types. op.x was " + op.x().getClass().getName() + " while t.z was " + transformOp.z().getClass().getName());
            }
            if (op.y() != null && Shape.opIsWholeBufferWithMatchingStrides(op)) {
                for (int i = 0; i < op.n(); i++) {
                    op.z().data().put(i, op.op(op.x().data().getDouble(i), op.y().data().getDouble(i)));
                }
            } else if (op.y() != null && Shape.opIsWholeBufferWithMatchingStrides(op)) {
                int stride = op.x().ordering() == 'f' ? op.x().stride(-1) : op.x().stride(0);
                int stride2 = op.y().ordering() == 'f' ? op.y().stride(-1) : op.y().stride(0);
                int stride3 = op.z().ordering() == 'f' ? op.z().stride(-1) : op.z().stride(0);
                for (int i2 = 0; i2 < op.n(); i2++) {
                    op.z().data().put(i2 * stride3, op.op(op.x().data().getDouble(i2 * stride), op.y().data().getDouble(i2 * stride2)));
                }
            } else if (Shape.opIsWholeBufferWithMatchingStrides(op)) {
                for (int i3 = 0; i3 < op.n(); i3++) {
                    op.z().data().put(i3, op.op(op.x().data().getDouble(i3)));
                }
            } else if (op.y() == null) {
                NdIndexIterator ndIndexIterator = new NdIndexIterator(op.x().shape());
                for (int i4 = 0; i4 < op.n(); i4++) {
                    apply(transformOp, ndIndexIterator.next());
                }
            } else if (Arrays.equals(op.x().shape(), op.y().shape())) {
                Shape.iterate(op.x(), new CoordinateFunction() { // from class: org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.1
                    @Override // org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction
                    public void process(int[]... iArr) {
                        DefaultOpExecutioner.this.apply(transformOp, iArr[0], iArr[0]);
                    }
                });
            } else {
                Shape.iterate(op.x(), op.y(), new CoordinateFunction() { // from class: org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.2
                    @Override // org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction
                    public void process(int[]... iArr) {
                        DefaultOpExecutioner.this.apply(transformOp, iArr[0], iArr[1]);
                    }
                });
            }
        } else if (op instanceof Accumulation) {
            Accumulation accumulation = (Accumulation) op;
            if (op.y() != null && Shape.opIsWholeBufferWithMatchingStrides(op)) {
                for (int i5 = 0; i5 < op.n(); i5++) {
                    accumulation.update(Double.valueOf(op.op(op.x().data().getDouble(i5), op.y().data().getDouble(i5))));
                }
            } else if (Shape.opIsWholeBufferWithMatchingStrides(op)) {
                for (int i6 = 0; i6 < op.n(); i6++) {
                    accumulation.update(Double.valueOf(op.op(op.x().data().getDouble(i6))));
                }
            } else if (op.x() instanceof IComplexNDArray) {
                for (int i7 = 0; i7 < op.n(); i7++) {
                    apply(accumulation, i7);
                }
            } else {
                INDArray reshape = op.x().reshape(1, op.x().length());
                for (int i8 = 0; i8 < op.n(); i8++) {
                    accumulation.update(Double.valueOf(op.op(reshape.getDouble(0, i8))));
                }
            }
        } else if (op instanceof ScalarOp) {
            ScalarOp scalarOp = (ScalarOp) op;
            if (op.isPassThrough()) {
                return scalarOp;
            }
            INDArray z = op.z();
            INDArray x = op.x();
            if (Shape.opIsWholeBufferWithMatchingStrides(op)) {
                for (int i9 = 0; i9 < op.n(); i9++) {
                    z.data().put(i9, op.op(x.data().getDouble(i9)));
                }
            } else if (op.x() instanceof IComplexNDArray) {
                IComplexNDArray iComplexNDArray = (IComplexNDArray) op.z();
                for (int i10 = 0; i10 < op.n(); i10++) {
                    iComplexNDArray.putScalar(i10, op.op(((IComplexNDArray) op.x()).getComplex(i10)));
                }
            } else {
                for (int i11 = 0; i11 < op.n(); i11++) {
                    z.putScalar(i11, op.op(x.getDouble(i11)));
                }
            }
        }
        return op;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return execAndReturn((TransformOp) op);
        }
        if (op instanceof ScalarOp) {
            return execAndReturn((ScalarOp) op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(execAndReturn((Accumulation) op).currentResult());
        }
        throw new IllegalArgumentException("Illegal type of op " + op.getClass());
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllRows(Op op) {
        if (op.x().isVector()) {
            op.setX(op.x());
            if (op.y() != null) {
                op.setY(op.y());
            }
            op.setZ(op.z());
            exec(op);
            return;
        }
        if (!op.x().isMatrix()) {
            INDArray x = op.x();
            INDArray z = op.z();
            for (int i = 0; i < x.slices(); i++) {
                INDArray slice = x.slice(i);
                INDArray slice2 = z.slice(i);
                op.setX(slice);
                op.setZ(slice2);
                iterateOverAllRows(op);
            }
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i2 = 0; i2 < iComplexNDArray.rows(); i2++) {
                IComplexNDArray slice3 = iComplexNDArray.slice(i2);
                IComplexNDArray slice4 = iComplexNDArray2.slice(i2);
                op.setX(slice3.dup());
                op.setZ(slice4.dup());
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.slice(i2));
                }
                exec(op);
                iComplexNDArray2.slice(i2).assign(op.z());
            }
            return;
        }
        INDArray x2 = op.x();
        INDArray z2 = op.z();
        INDArray y = op.y();
        for (int i3 = 0; i3 < x2.rows(); i3++) {
            INDArray row = x2.getRow(i3);
            INDArray row2 = z2.getRow(i3);
            op.setX(row.dup());
            op.setZ(row2.dup());
            if (y != null) {
                op.setY(y.getRow(i3).dup());
            }
            exec(op);
            row2.assign(op.z());
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            exec(op);
            return;
        }
        if (op.x().isMatrix() || op.x().isColumnVector()) {
            exec(op, 1);
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i = 0; i < op.x().slices(); i++) {
                op.setX(iComplexNDArray.getColumn(i));
                op.setZ(iComplexNDArray2.getColumn(i));
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.getColumn(i));
                }
                iterateOverAllColumns(op);
            }
            return;
        }
        INDArray x = op.x();
        INDArray z = op.z();
        INDArray y = op.y();
        for (int i2 = 0; i2 < op.x().slices(); i2++) {
            op.setX(x.getColumn(i2));
            op.setZ(z.getColumn(i2));
            if (y != null) {
                op.setY(y.getColumn(i2));
            }
            iterateOverAllColumns(op);
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(TransformOp transformOp) {
        return ((TransformOp) exec(transformOp)).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Accumulation accumulation) {
        return (Accumulation) exec(accumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp) {
        return exec(scalarOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op, int... iArr) {
        if (iArr.length == op.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(iArr);
            return op;
        }
        if (iArr.length == 1) {
            return exec(op, iArr[0]);
        }
        if (op instanceof Accumulation) {
            return exec((Accumulation) op);
        }
        for (int i = 0; i < op.x().tensorssAlongDimension(iArr); i++) {
            Op opForDimension = op.opForDimension(i, iArr);
            exec(opForDimension);
            if (op instanceof TransformOp) {
                ((TransformOp) op).z().tensorAlongDimension(i, iArr).assign(((TransformOp) opForDimension).z());
            }
        }
        return op;
    }

    protected Op exec(Op op, int i) {
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }
        if (op instanceof Accumulation) {
            return exec((Accumulation) op);
        }
        for (int i2 = 0; i2 < op.x().vectorsAlongDimension(i); i2++) {
            Op opForDimension = op.opForDimension(i2, i);
            exec(opForDimension);
            if (op instanceof TransformOp) {
                ((TransformOp) op).z().vectorAlongDimension(i2, i).assign(((TransformOp) opForDimension).z());
            }
        }
        return op;
    }

    protected void checkOp(Op op) {
        if (op.x() instanceof LinearViewNDArray) {
            return;
        }
        if (op.y() == null || !(op.y() instanceof LinearViewNDArray)) {
            if (op.z() == null || !(op.z() instanceof LinearViewNDArray)) {
                if (op.x() == null || !(op.x() instanceof LinearViewComplexNDArray)) {
                    if (op.y() == null || !(op.y() instanceof LinearViewComplexNDArray)) {
                        if (op.z() == null || !(op.z() instanceof LinearViewComplexNDArray)) {
                            if (op.x() == null || !op.x().isScalar()) {
                                if ((op.y() == null || !op.y().isScalar()) && op.z() != null && op.z().isScalar()) {
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Accumulation accumulation, int... iArr) {
        INDArray iNDArray;
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (accumulation.isPassThrough()) {
            accumulation.exec(iArr);
            return accumulation.z();
        }
        if (iArr[0] == Integer.MAX_VALUE) {
            return accumulation.x() instanceof IComplexNDArray ? Nd4j.scalar(execAndReturn(accumulation).currentResultComplex()) : Nd4j.scalar(execAndReturn(accumulation).currentResult().doubleValue());
        }
        int[] removeIndex = ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation instanceof IComplexNDArray) {
            IComplexNDArray createComplex = Nd4j.createComplex(removeIndex);
            for (int i = 0; i < accumulation.x().tensorssAlongDimension(iArr); i++) {
                createComplex.putScalar(i, execAndReturn((Accumulation) accumulation.opForDimension(i, iArr)).currentResultComplex());
            }
            if (createComplex.ordering() == 'c') {
                createComplex.setStride(ArrayUtil.reverseCopy(createComplex.stride()));
            }
            return createComplex;
        }
        INDArray create = Nd4j.create(removeIndex);
        INDArray x = accumulation.x();
        int i2 = iArr[0];
        if (iArr[0] != x.rank() - 1) {
            int[] iArr2 = new int[x.rank()];
            for (int i3 = 0; i3 < i2; i3++) {
                iArr2[i3] = i3;
            }
            for (int i4 = i2; i4 < x.rank() - 1; i4++) {
                iArr2[i4] = i4 + 1;
            }
            iArr2[iArr2.length - 1] = i2;
            iNDArray = x.permute(iArr2);
        } else {
            iNDArray = x;
        }
        int size = iNDArray.size(-1);
        int length = x.length() / size;
        int[] iArr3 = {1, size};
        int[] strides = Nd4j.getStrides(iArr3, accumulation.x().ordering());
        if (iNDArray.isMatrix() && iArr.length == 1 && iArr[0] == 0) {
            for (int i5 = 0; i5 < accumulation.x().tensorssAlongDimension(iArr); i5++) {
                create.putScalar(i5, execAndReturn((Accumulation) accumulation.opForDimension(i5, iArr)).currentResult().doubleValue());
            }
            return create;
        }
        int i6 = 0;
        int i7 = 0;
        while (i7 < create.length()) {
            accumulation.setX(Nd4j.create(x.data(), iArr3, strides, i6));
            create.putScalar(i7, execAndReturn(accumulation).currentResult().doubleValue());
            accumulation.setCurrentResult(accumulation.zero());
            i7++;
            i6 += size;
        }
        return create;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(final TransformOp transformOp, int... iArr) {
        if (iArr.length == transformOp.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (iArr.length == 1) {
            return execAndReturnVector(transformOp, iArr[0]);
        }
        Shape.iterate(transformOp.x(), new CoordinateFunction() { // from class: org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.3
            @Override // org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction
            public void process(int[]... iArr2) {
                DefaultOpExecutioner.this.apply(transformOp, iArr2[0], iArr2[0]);
            }
        });
        return transformOp.z();
    }

    protected INDArray execAndReturnVector(TransformOp transformOp, int i) {
        if (transformOp.isPassThrough()) {
            transformOp.exec(i);
            return transformOp.z();
        }
        for (int i2 = 0; i2 < transformOp.x().vectorsAlongDimension(i); i2++) {
            Op opForDimension = transformOp.opForDimension(i2, i);
            exec(opForDimension);
            transformOp.z().vectorAlongDimension(i2, i).assign(opForDimension.z());
        }
        return transformOp.z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp, int... iArr) {
        return exec(scalarOp, iArr).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ExecutionMode executionMode() {
        return this.executionMode;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void apply(TransformOp transformOp, int[] iArr, int[] iArr2) {
        if (transformOp.isPassThrough()) {
            return;
        }
        if (transformOp.y() == null) {
            if (!(transformOp.x() instanceof IComplexNDArray)) {
                transformOp.z().putScalar(iArr, transformOp.op(transformOp.x().getDouble(iArr)));
                return;
            }
            IComplexNDArray iComplexNDArray = (IComplexNDArray) transformOp.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) transformOp.z();
            if (transformOp.y() instanceof IComplexNDArray) {
                iComplexNDArray2.putScalar(iArr, transformOp.op(iComplexNDArray.getComplex(iArr)));
                return;
            } else {
                iComplexNDArray2.putScalar(iArr, transformOp.op(iComplexNDArray.getComplex(iArr)));
                return;
            }
        }
        if (!(transformOp.x() instanceof IComplexNDArray)) {
            transformOp.z().putScalar(iArr, transformOp.op(transformOp.x().getDouble(iArr), transformOp.y().getDouble(iArr2)));
            return;
        }
        IComplexNDArray iComplexNDArray3 = (IComplexNDArray) transformOp.x();
        IComplexNDArray iComplexNDArray4 = (IComplexNDArray) transformOp.z();
        IComplexNumber complex = iComplexNDArray3.getComplex(iArr);
        if (transformOp.y() instanceof IComplexNDArray) {
            iComplexNDArray4.putScalar(iArr, transformOp.op(complex, ((IComplexNDArray) transformOp.y()).getComplex(iArr)));
        } else {
            iComplexNDArray4.putScalar(iArr, transformOp.op(complex, transformOp.y().getDouble(iArr)));
        }
    }

    private void apply(TransformOp transformOp, int[] iArr) {
        if (transformOp.isPassThrough()) {
            return;
        }
        if (transformOp.y() == null) {
            if (!(transformOp.x() instanceof IComplexNDArray)) {
                transformOp.z().putScalar(iArr, transformOp.op(transformOp.x().getDouble(iArr)));
                return;
            }
            IComplexNDArray iComplexNDArray = (IComplexNDArray) transformOp.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) transformOp.z();
            if (transformOp.y() instanceof IComplexNDArray) {
                iComplexNDArray2.putScalar(iArr, transformOp.op(iComplexNDArray.getComplex(iArr)));
                return;
            } else {
                iComplexNDArray2.putScalar(iArr, transformOp.op(iComplexNDArray.getComplex(iArr)));
                return;
            }
        }
        if (!(transformOp.x() instanceof IComplexNDArray)) {
            transformOp.z().putScalar(iArr, transformOp.op(transformOp.x().getDouble(iArr), transformOp.y().getDouble(iArr)));
            return;
        }
        IComplexNDArray iComplexNDArray3 = (IComplexNDArray) transformOp.x();
        IComplexNDArray iComplexNDArray4 = (IComplexNDArray) transformOp.z();
        IComplexNumber complex = iComplexNDArray3.getComplex(iArr);
        if (transformOp.y() instanceof IComplexNDArray) {
            iComplexNDArray4.putScalar(iArr, transformOp.op(complex, ((IComplexNDArray) transformOp.y()).getComplex(iArr)));
        } else {
            iComplexNDArray4.putScalar(iArr, transformOp.op(complex, transformOp.y().getDouble(iArr)));
        }
    }

    private void apply(Accumulation accumulation, int i) {
        if (accumulation.isPassThrough()) {
            return;
        }
        if (accumulation.y() == null) {
            if (accumulation.x() instanceof IComplexNDArray) {
                accumulation.update(accumulation.op(((IComplexNDArray) accumulation.x()).getComplex(i)));
                return;
            } else {
                accumulation.update(Double.valueOf(accumulation.op(accumulation.x().getDouble(i))));
                return;
            }
        }
        if (!(accumulation.x() instanceof IComplexNDArray)) {
            accumulation.update(Double.valueOf(accumulation.op(accumulation.x().getDouble(i), accumulation.y().getDouble(i))));
            return;
        }
        IComplexNDArray iComplexNDArray = (IComplexNDArray) accumulation.x();
        IComplexNDArray iComplexNDArray2 = (IComplexNDArray) accumulation.y();
        IComplexNumber complex = iComplexNDArray.getComplex(i);
        if (accumulation.y() instanceof IComplexNDArray) {
            accumulation.update(accumulation.op(complex, iComplexNDArray2.getComplex(i)));
        } else {
            accumulation.update(accumulation.op(complex, accumulation.y().getDouble(i)));
        }
    }
}
