/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.matrix.Vector;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;

public abstract class RegressionPrior
implements Serializable {
    static final long serialVersionUID = 2955531646832969891L;
    static final double SQRT_2 = java.lang.Math.sqrt(2.0);
    private static final RegressionPrior NONINFORMATIVE_PRIOR = new NoninformativeRegressionPrior();
    static final double sqrt2 = java.lang.Math.sqrt(2.0);
    static final double log2Sqrt2Over2 = Math.log2(sqrt2 / 2.0);
    static final double log2Sqrt2Pi = Math.log2(java.lang.Math.sqrt(java.lang.Math.PI * 2));
    static final double log21OverPi = -Math.log2(java.lang.Math.PI);

    private RegressionPrior() {
    }

    public boolean isUniform() {
        return false;
    }

    public double mode(int dimension) {
        return 0.0;
    }

    public abstract double gradient(double var1, int var3);

    public abstract double log2Prior(double var1, int var3);

    public double log2Prior(Vector beta) {
        int numDimensions = beta.numDimensions();
        this.verifyNumberOfDimensions(numDimensions);
        double log2Prior = 0.0;
        for (int i = 0; i < numDimensions; ++i) {
            log2Prior += this.log2Prior(beta.value(i), i);
        }
        return log2Prior;
    }

    public double log2Prior(Vector[] betas) {
        double log2Prior = 0.0;
        for (Vector beta : betas) {
            log2Prior += this.log2Prior(beta);
        }
        return log2Prior;
    }

    void verifyNumberOfDimensions(int ignoreMeNumDimensions) {
    }

    public static RegressionPrior noninformative() {
        return NONINFORMATIVE_PRIOR;
    }

    public static RegressionPrior gaussian(double priorVariance, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorVariance);
        return new VariableGaussianRegressionPrior(priorVariance, noninformativeIntercept);
    }

    public static RegressionPrior gaussian(double[] priorVariances) {
        RegressionPrior.verifyPriorVariances(priorVariances);
        return new GaussianRegressionPrior(priorVariances);
    }

    public static RegressionPrior laplace(double priorVariance, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorVariance);
        return new VariableLaplaceRegressionPrior(priorVariance, noninformativeIntercept);
    }

    public static RegressionPrior laplace(double[] priorVariances) {
        RegressionPrior.verifyPriorVariances(priorVariances);
        return new LaplaceRegressionPrior(priorVariances);
    }

    public static RegressionPrior cauchy(double priorSquaredScale, boolean noninformativeIntercept) {
        RegressionPrior.verifyPriorVariance(priorSquaredScale);
        return new VariableCauchyRegressionPrior(priorSquaredScale, noninformativeIntercept);
    }

    public static RegressionPrior cauchy(double[] priorSquaredScales) {
        RegressionPrior.verifyPriorVariances(priorSquaredScales);
        return new CauchyRegressionPrior(priorSquaredScales);
    }

    public static RegressionPrior logInterpolated(double alpha, RegressionPrior prior1, RegressionPrior prior2) {
        if (Double.isNaN(alpha) || alpha < 0.0 || alpha > 1.0) {
            String msg = "Weight of first prior must be between 0 and 1 inclusive. Found alpha=" + alpha;
            throw new IllegalArgumentException(msg);
        }
        return new LogInterpolatedRegressionPrior(alpha, prior1, prior2);
    }

    public static RegressionPrior elasticNet(double laplaceWeight, double scale, boolean noninformativeIntercept) {
        if (Double.isInfinite(scale) || !(scale > 0.0)) {
            String msg = "Scale parameter must be finite and positive. Found scale=" + scale;
            throw new IllegalArgumentException(msg);
        }
        return RegressionPrior.logInterpolated(laplaceWeight, RegressionPrior.laplace(1.0 / java.lang.Math.sqrt(scale), noninformativeIntercept), RegressionPrior.gaussian(SQRT_2 / scale, noninformativeIntercept));
    }

    public static RegressionPrior shiftMeans(double[] shifts, RegressionPrior prior) {
        return new ShiftMeans(shifts, prior);
    }

    static void verifyPriorVariance(double priorVariance) {
        if (priorVariance < 0.0 || Double.isNaN(priorVariance) || priorVariance == Double.NEGATIVE_INFINITY) {
            String msg = "Prior variance must be a non-negative number. Found priorVariance=" + priorVariance;
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyPriorVariances(double[] priorVariances) {
        for (int i = 0; i < priorVariances.length; ++i) {
            if (!(priorVariances[i] < 0.0) && !Double.isNaN(priorVariances[i]) && priorVariances[i] != Double.NEGATIVE_INFINITY) continue;
            String msg = "Prior variances must be non-negative numbers. Found priorVariances[" + i + "]=" + priorVariances[i];
            throw new IllegalArgumentException(msg);
        }
    }

    static class ShiftMeans
    extends RegressionPrior {
        static final long serialVersionUID = 5159543505446681732L;
        private final double[] mMeans;
        private final RegressionPrior mPrior;

        ShiftMeans(double[] means, RegressionPrior prior) {
            this.mPrior = prior;
            this.mMeans = means;
        }

        @Override
        public double mode(int i) {
            return this.mMeans[i] + this.mPrior.mode(i);
        }

        @Override
        public boolean isUniform() {
            return this.mPrior.isUniform();
        }

        @Override
        public double log2Prior(double betaI, int i) {
            return this.mPrior.log2Prior(betaI - this.mMeans[i], i);
        }

        @Override
        public double gradient(double betaI, int i) {
            return this.mPrior.gradient(betaI - this.mMeans[i], i);
        }

        public String toString() {
            return "ShiftMeans(means=...,prior=" + this.mPrior + ")";
        }

        static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = -777157399350907424L;
            final ShiftMeans mPrior;

            public Serializer() {
                this(null);
            }

            public Serializer(ShiftMeans prior) {
                this.mPrior = prior;
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                Serializer.writeDoubles(this.mPrior.mMeans, out);
                out.writeObject(this.mPrior.mPrior);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double[] means = Serializer.readDoubles(in);
                RegressionPrior prior = (RegressionPrior)in.readObject();
                return new ShiftMeans(means, prior);
            }
        }
    }

    static class LogInterpolatedRegressionPrior
    extends RegressionPrior {
        static final long serialVersionUID = 1052451778773339516L;
        private final double mAlpha;
        private final RegressionPrior mPrior1;
        private final RegressionPrior mPrior2;

        LogInterpolatedRegressionPrior(double alpha, RegressionPrior prior1, RegressionPrior prior2) {
            this.mAlpha = alpha;
            this.mPrior1 = prior1;
            this.mPrior2 = prior2;
        }

        @Override
        public double gradient(double beta, int dimension) {
            return this.mAlpha * this.mPrior1.gradient(beta, dimension) + (1.0 - this.mAlpha) * this.mPrior2.gradient(beta, dimension);
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            return this.mAlpha * this.mPrior1.log2Prior(beta, dimension) + (1.0 - this.mAlpha) * this.mPrior2.log2Prior(beta, dimension);
        }

        public String toString() {
            return "LogInterpolatedRegressionPrior(alpha=" + this.mAlpha + ", prior1=" + this.mPrior1 + ", prior2=" + this.mPrior2 + ")";
        }

        Object writeReplace() {
            return new Serializer(this);
        }

        static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 1071183663202516816L;
            final LogInterpolatedRegressionPrior mPrior;

            public Serializer() {
                this(null);
            }

            public Serializer(LogInterpolatedRegressionPrior prior) {
                this.mPrior = prior;
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeDouble(this.mPrior.mAlpha);
                out.writeObject(this.mPrior.mPrior1);
                out.writeObject(this.mPrior.mPrior2);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double alpha = in.readDouble();
                RegressionPrior prior1 = (RegressionPrior)in.readObject();
                RegressionPrior prior2 = (RegressionPrior)in.readObject();
                return new LogInterpolatedRegressionPrior(alpha, prior1, prior2);
            }
        }
    }

    static class VariableCauchyRegressionPrior
    extends VariableRegressionPrior {
        static final long serialVersionUID = 3368658136325392652L;

        VariableCauchyRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
        }

        @Override
        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept ? 0.0 : 2.0 * beta / (beta * beta + this.mPriorVariance);
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return log21OverPi + 0.5 * Math.log2(this.mPriorVariance) - Math.log2(beta * beta + this.mPriorVariance);
        }

        public String toString() {
            return this.toString("CauchyRegressionPrior", "Scale");
        }

        public Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = -7209096281888148303L;
            final VariableCauchyRegressionPrior mPrior;

            public Serializer(VariableCauchyRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeDouble(this.mPrior.mPriorVariance);
                out.writeBoolean(this.mPrior.mNoninformativeIntercept);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double priorScale = in.readDouble();
                boolean noninformativeIntercept = in.readBoolean();
                return new VariableCauchyRegressionPrior(priorScale, noninformativeIntercept);
            }
        }
    }

    static class VariableLaplaceRegressionPrior
    extends VariableRegressionPrior
    implements Serializable {
        static final long serialVersionUID = -4286001162222250623L;
        final double mPositiveGradient;
        final double mNegativeGradient;
        final double mPriorIntercept;
        final double mPriorCoefficient;

        VariableLaplaceRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
            this.mPositiveGradient = java.lang.Math.sqrt(2.0 / priorVariance);
            this.mNegativeGradient = -this.mPositiveGradient;
            this.mPriorIntercept = log2Sqrt2Over2 - 0.5 * Math.log2(priorVariance);
            this.mPriorCoefficient = -sqrt2 / java.lang.Math.sqrt(priorVariance);
        }

        @Override
        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept || beta == 0.0 ? 0.0 : (beta > 0.0 ? this.mPositiveGradient : this.mNegativeGradient);
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return this.mPriorIntercept + this.mPriorCoefficient * java.lang.Math.abs(beta);
        }

        public String toString() {
            return this.toString("LaplaceRegressionPrior", "Variance");
        }

        private Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 2321796089407881776L;
            final VariableLaplaceRegressionPrior mPrior;

            public Serializer(VariableLaplaceRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeDouble(this.mPrior.mPriorVariance);
                out.writeBoolean(this.mPrior.mNoninformativeIntercept);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double priorVariance = in.readDouble();
                boolean noninformativeIntercept = in.readBoolean();
                return new VariableLaplaceRegressionPrior(priorVariance, noninformativeIntercept);
            }
        }
    }

    static class VariableGaussianRegressionPrior
    extends VariableRegressionPrior
    implements Serializable {
        static final long serialVersionUID = -7527207309328127863L;

        VariableGaussianRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            super(priorVariance, noninformativeIntercept);
        }

        @Override
        public double gradient(double beta, int dimension) {
            return dimension == 0 && this.mNoninformativeIntercept ? 0.0 : beta / this.mPriorVariance;
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            if (dimension == 0 && this.mNoninformativeIntercept) {
                return 0.0;
            }
            return -log2Sqrt2Pi - 0.5 * Math.log2(this.mPriorVariance) - beta * beta / (2.0 * this.mPriorVariance);
        }

        public String toString() {
            return this.toString("GaussianRegressionPrior", "Variance");
        }

        private Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 5979483825025936160L;
            final VariableGaussianRegressionPrior mPrior;

            public Serializer(VariableGaussianRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeDouble(this.mPrior.mPriorVariance);
                out.writeBoolean(this.mPrior.mNoninformativeIntercept);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double priorVariance = in.readDouble();
                boolean noninformativeIntercept = in.readBoolean();
                return new VariableGaussianRegressionPrior(priorVariance, noninformativeIntercept);
            }
        }
    }

    static abstract class VariableRegressionPrior
    extends RegressionPrior {
        static final long serialVersionUID = -7527207309328127863L;
        final double mPriorVariance;
        final boolean mNoninformativeIntercept;

        VariableRegressionPrior(double priorVariance, boolean noninformativeIntercept) {
            this.mPriorVariance = priorVariance;
            this.mNoninformativeIntercept = noninformativeIntercept;
        }

        public String toString(String priorName, String paramName) {
            return priorName + "(" + paramName + "=" + this.mPriorVariance + ", noninformativeIntercept=" + this.mNoninformativeIntercept + ")";
        }
    }

    static class CauchyRegressionPrior
    extends ArrayRegressionPrior
    implements Serializable {
        static final long serialVersionUID = 2351846943518745614L;

        CauchyRegressionPrior(double[] priorSquaredScales) {
            super(priorSquaredScales);
        }

        @Override
        public double gradient(double beta, int dimension) {
            return 2.0 * beta / (beta * beta + this.mValues[dimension]);
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            return log21OverPi + 0.5 * Math.log2(this.mValues[dimension]) - Math.log2(beta * beta + this.mValues[dimension] * this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("CauchyRegressionPrior", "Scale");
        }

        private Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 5202676106810759907L;
            final CauchyRegressionPrior mPrior;

            public Serializer(CauchyRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeInt(this.mPrior.mValues.length);
                for (int i = 0; i < this.mPrior.mValues.length; ++i) {
                    out.writeDouble(this.mPrior.mValues[i]);
                }
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                int numDimensions = in.readInt();
                double[] priorScales = new double[numDimensions];
                for (int i = 0; i < numDimensions; ++i) {
                    priorScales[i] = in.readDouble();
                }
                return new CauchyRegressionPrior(priorScales);
            }
        }
    }

    static class LaplaceRegressionPrior
    extends ArrayRegressionPrior
    implements Serializable {
        static final long serialVersionUID = 9120480132502062861L;

        LaplaceRegressionPrior(double[] priorVariances) {
            super(priorVariances);
        }

        @Override
        public double gradient(double beta, int dimension) {
            if (beta == 0.0) {
                return 0.0;
            }
            if (beta > 0.0) {
                return java.lang.Math.sqrt(2.0 / this.mValues[dimension]);
            }
            return -java.lang.Math.sqrt(2.0 / this.mValues[dimension]);
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            return log2Sqrt2Over2 - 0.5 * Math.log2(this.mValues[dimension]) - sqrt2 * java.lang.Math.abs(beta) / java.lang.Math.sqrt(this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("LaplaceRegressionPrior", "Variance");
        }

        private Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 7844951573062416091L;
            final LaplaceRegressionPrior mPrior;

            public Serializer(LaplaceRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeInt(this.mPrior.mValues.length);
                for (int i = 0; i < this.mPrior.mValues.length; ++i) {
                    out.writeDouble(this.mPrior.mValues[i]);
                }
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                int numDimensions = in.readInt();
                double[] priorVariances = new double[numDimensions];
                for (int i = 0; i < numDimensions; ++i) {
                    priorVariances[i] = in.readDouble();
                }
                return new LaplaceRegressionPrior(priorVariances);
            }
        }
    }

    static class GaussianRegressionPrior
    extends ArrayRegressionPrior
    implements Serializable {
        static final long serialVersionUID = 8257747607648390037L;

        GaussianRegressionPrior(double[] priorVariances) {
            super(priorVariances);
        }

        @Override
        public double gradient(double beta, int dimension) {
            return beta / this.mValues[dimension];
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            return -log2Sqrt2Pi - 0.5 * Math.log2(this.mValues[dimension]) - beta * beta / (2.0 * this.mValues[dimension]);
        }

        public String toString() {
            return this.toString("GaussianRegressionPrior", "Variance");
        }

        private Object writeReplace() {
            return new Serializer(this);
        }

        private static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = -1129377549371296060L;
            final GaussianRegressionPrior mPrior;

            public Serializer(GaussianRegressionPrior prior) {
                this.mPrior = prior;
            }

            public Serializer() {
                this(null);
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeInt(this.mPrior.mValues.length);
                for (int i = 0; i < this.mPrior.mValues.length; ++i) {
                    out.writeDouble(this.mPrior.mValues[i]);
                }
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                int numDimensions = in.readInt();
                double[] priorVariances = new double[numDimensions];
                for (int i = 0; i < numDimensions; ++i) {
                    priorVariances[i] = in.readDouble();
                }
                return new GaussianRegressionPrior(priorVariances);
            }
        }
    }

    static abstract class ArrayRegressionPrior
    extends RegressionPrior {
        static final long serialVersionUID = -1887383164794837169L;
        final double[] mValues;

        ArrayRegressionPrior(double[] values) {
            this.mValues = values;
        }

        @Override
        void verifyNumberOfDimensions(int numDimensions) {
            if (this.mValues.length != numDimensions) {
                String msg = "Prior and instances must match in number of dimensions. Found prior numDimensions=" + this.mValues.length + " instance numDimensions=" + numDimensions;
                throw new IllegalArgumentException(msg);
            }
        }

        public String toString(String priorName, String paramName) {
            StringBuilder sb = new StringBuilder();
            sb.append(priorName + "\n");
            sb.append("     dimensionality=" + this.mValues.length);
            for (int i = 0; i < this.mValues.length; ++i) {
                sb.append("     " + paramName + "[" + i + "]=" + this.mValues[i] + "\n");
            }
            return sb.toString();
        }
    }

    static class NoninformativeRegressionPrior
    extends RegressionPrior
    implements Serializable {
        static final long serialVersionUID = -582012445093979284L;

        NoninformativeRegressionPrior() {
        }

        @Override
        public double gradient(double beta, int dimension) {
            return 0.0;
        }

        @Override
        public double log2Prior(double beta, int dimension) {
            return 0.0;
        }

        @Override
        public double log2Prior(Vector beta) {
            return 0.0;
        }

        @Override
        public double log2Prior(Vector[] betas) {
            return 0.0;
        }

        public String toString() {
            return "NoninformativeRegressionPrior";
        }

        @Override
        public boolean isUniform() {
            return true;
        }
    }
}

