package edu.stanford.nlp.CLoptimization;

import edu.stanford.nlp.CLmath.ArrayMath;
import edu.stanford.nlp.CLoptimization.AbstractStochasticCachingDiffFunction;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/CLoptimization/StochasticDiffFunctionTester.class */
public class StochasticDiffFunctionTester {
    protected int testBatchSize;
    protected int numBatches;
    protected AbstractStochasticCachingDiffFunction thisFunc;
    double[] approxGrad;
    double[] fullGrad;
    double[] diff;
    double[] Hv;
    double[] HvFD;
    double[] v;
    double[] curGrad;
    double[] gradFD;
    double diffNorm;
    double diffValue;
    double fullValue;
    double approxValue;
    double diffGrad;
    double maxGradDiff = 0.0d;
    double maxHvDiff = 0.0d;
    Random generator;
    private static double EPS = 1.0E-8d;
    private static boolean quiet = false;
    private static NumberFormat nf = new DecimalFormat("00.0");

    public StochasticDiffFunctionTester(Function function) {
        if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
            System.err.println("Attempt to test non stochastic function using StochasticDiffFunctionTester");
            throw new UnsupportedOperationException();
        }
        this.thisFunc = (AbstractStochasticCachingDiffFunction) function;
        this.generator = new Random(System.currentTimeMillis());
        this.testBatchSize = (int) getTestBatchSize(this.thisFunc.dataDimension());
        if (this.testBatchSize < 0 || this.testBatchSize > this.thisFunc.dataDimension() || this.thisFunc.dataDimension() % this.testBatchSize != 0) {
            System.err.println("Invalid testBatchSize found, testing aborted.  Data size: " + this.thisFunc.dataDimension() + " batchSize: " + this.testBatchSize);
            System.exit(1);
        }
        this.numBatches = this.thisFunc.dataDimension() / this.testBatchSize;
        sayln("StochasticDiffFunctionTester created with:");
        sayln("   data dimension  = " + this.thisFunc.dataDimension());
        sayln("   batch size = " + this.testBatchSize);
        sayln("   number of batches = " + this.numBatches);
    }

    private void sayln(String str) {
        if (quiet) {
            return;
        }
        System.err.println(str);
    }

    private static long[] primeFactors(long j) {
        long[] jArr = new long[64];
        long abs = Math.abs(j);
        short s = 0;
        if (abs > 0) {
            while (abs % 2 == 0) {
                s = (short) (s + 1);
                jArr[s] = 2;
                abs /= 2;
            }
            while (abs % 3 == 0) {
                s = (short) (s + 1);
                jArr[s] = 3;
                abs /= 3;
            }
            for (int i = 5; i * i <= abs; i += 6) {
                for (int i2 = i; i2 <= i + 2; i2 += 2) {
                    while (abs % i2 == 0) {
                        s = (short) (s + 1);
                        jArr[s] = i2;
                        abs /= i2;
                    }
                }
            }
            if (abs > 1) {
                s = (short) (s + 1);
                jArr[s] = abs;
            }
        }
        jArr[0] = s;
        return jArr;
    }

    private static long getTestBatchSize(long j) {
        long j2 = 1;
        long[] primeFactors = primeFactors(j);
        long j3 = primeFactors[0];
        if (j3 == 0) {
            System.err.println("Attempt to test function on data of prime dimension.  This would involve a batchSize of 1 and may take a very long time.");
            System.exit(1);
        } else if (j3 == 2) {
            j2 = (int) primeFactors[1];
        } else {
            for (int i = 1; i < j3; i++) {
                j2 *= primeFactors[i];
            }
        }
        return j2;
    }

    public boolean testSumOfBatches(double[] dArr, double d) {
        boolean z;
        System.err.println("Making sure that the sum of stochastic gradients equals the full gradient");
        AbstractStochasticCachingDiffFunction.SamplingMethod samplingMethod = this.thisFunc.sampleMethod;
        StochasticCalculateMethods stochasticCalculateMethods = this.thisFunc.method;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        if (this.thisFunc.method == StochasticCalculateMethods.NoneSpecified) {
            System.err.println("No calculate method has been specified");
        }
        this.approxValue = 0.0d;
        this.approxGrad = new double[dArr.length];
        this.curGrad = new double[dArr.length];
        this.fullGrad = new double[dArr.length];
        for (int i = 0; i < this.numBatches; i++) {
            this.approxValue += this.thisFunc.valueAt(dArr, this.v, this.testBatchSize);
            this.thisFunc.returnPreviousValues = true;
            System.arraycopy(this.thisFunc.derivativeAt(dArr, this.v, this.testBatchSize), 0, this.curGrad, 0, this.curGrad.length);
            this.approxGrad = ArrayMath.pairwiseAdd(this.approxGrad, this.curGrad);
            System.err.printf("%5.1f percent complete  %6.2f \n", Double.valueOf((100.0d * i) / this.numBatches), Double.valueOf(ArrayMath.norm(this.approxGrad)));
        }
        System.err.println("About to calculate the full derivative and value");
        System.arraycopy(this.thisFunc.derivativeAt(dArr), 0, this.fullGrad, 0, this.fullGrad.length);
        this.thisFunc.returnPreviousValues = true;
        this.fullValue = this.thisFunc.valueAt(dArr);
        this.diff = new double[dArr.length];
        double[] pairwiseSubtract = ArrayMath.pairwiseSubtract(this.fullGrad, this.approxGrad);
        this.diff = pairwiseSubtract;
        if (ArrayMath.norm_inf(pairwiseSubtract) < d) {
            sayln("");
            sayln("Success: sum of batch gradients equals full gradient");
        } else {
            this.diffNorm = ArrayMath.norm(this.diff);
            sayln("");
            sayln("Failure: sum of batch gradients minus full gradient has norm " + this.diffNorm);
        }
        if (Math.abs(this.approxValue - this.fullValue) < d) {
            sayln("");
            sayln("Success: sum of batch values equals full value");
            z = true;
        } else {
            sayln("");
            sayln("Failure: sum of batch values minus full value has norm " + Math.abs(this.approxValue - this.fullValue));
            z = false;
        }
        this.thisFunc.sampleMethod = samplingMethod;
        this.thisFunc.method = stochasticCalculateMethods;
        return z;
    }

    public boolean testDerivatives(double[] dArr, double d) {
        boolean z;
        System.err.println("Making sure that the stochastic derivatives are ok.");
        AbstractStochasticCachingDiffFunction.SamplingMethod samplingMethod = this.thisFunc.sampleMethod;
        StochasticCalculateMethods stochasticCalculateMethods = this.thisFunc.method;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        if (this.thisFunc.method == StochasticCalculateMethods.NoneSpecified) {
            System.err.println("No calculate method has been specified");
        } else if (!this.thisFunc.method.calculatesHessianVectorProduct()) {
        }
        this.approxValue = 0.0d;
        this.approxGrad = new double[dArr.length];
        this.curGrad = new double[dArr.length];
        this.Hv = new double[dArr.length];
        for (int i = 0; i < this.numBatches; i++) {
            System.err.printf("%5.1f percent complete\n", Double.valueOf((100.0d * i) / this.numBatches));
            this.thisFunc.method = stochasticCalculateMethods;
            System.arraycopy(this.thisFunc.HdotVAt(dArr, this.v, this.testBatchSize), 0, this.Hv, 0, this.Hv.length);
            this.thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference;
            System.arraycopy(this.thisFunc.derivativeAt(dArr, this.v, this.testBatchSize), 0, this.gradFD, 0, this.gradFD.length);
            this.thisFunc.recalculatePrevBatch = true;
            System.arraycopy(this.thisFunc.HdotVAt(dArr, this.v, this.gradFD, this.testBatchSize), 0, this.HvFD, 0, this.HvFD.length);
            double norm_inf = ArrayMath.norm_inf(ArrayMath.pairwiseSubtract(this.Hv, this.HvFD));
            if (norm_inf > this.maxHvDiff) {
                this.maxHvDiff = norm_inf;
            }
        }
        if (this.maxHvDiff < d) {
            sayln("");
            sayln("Success: Hessian approximations lined up");
            z = true;
        } else {
            sayln("");
            sayln("Failure: Hessian approximation at somepoint was off by " + this.maxHvDiff);
            z = false;
        }
        this.thisFunc.sampleMethod = samplingMethod;
        this.thisFunc.method = stochasticCalculateMethods;
        return z;
    }

    public double testConditionNumber(int i) {
        double d = 0.0d;
        double d2 = 0.0d;
        double[] dArr = new double[this.thisFunc.domainDimension()];
        double[] dArr2 = new double[dArr.length];
        this.gradFD = new double[dArr.length];
        this.HvFD = new double[dArr.length];
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        this.thisFunc.method = StochasticCalculateMethods.ExternalFiniteDifference;
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = this.generator.nextDouble();
            }
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                dArr2[i4] = this.generator.nextDouble();
            }
            System.err.println("Evaluating Hessian Product");
            System.arraycopy(this.thisFunc.derivativeAt(dArr2, dArr, this.testBatchSize), 0, this.gradFD, 0, this.gradFD.length);
            this.thisFunc.recalculatePrevBatch = true;
            System.arraycopy(this.thisFunc.HdotVAt(dArr2, dArr, this.gradFD, this.testBatchSize), 0, this.HvFD, 0, this.HvFD.length);
            double innerProduct = ArrayMath.innerProduct(dArr, this.HvFD);
            if (Math.abs(innerProduct) > d) {
                d = Math.abs(innerProduct);
            }
            if (Math.abs(innerProduct) < d2) {
                d2 = Math.abs(innerProduct);
            }
            if (innerProduct < 0.0d) {
                z = true;
            }
            if (innerProduct > 0.0d) {
                z2 = true;
            }
            if (innerProduct == 0.0d) {
                z3 = true;
            }
            System.err.println("It:" + i2 + "  C:" + (d / d2) + "N:" + z + "P:" + z2 + "S:" + z3);
        }
        System.out.println("Condition Number of: " + (d / d2));
        System.out.println("Is negative: " + z);
        System.out.println("Is positive: " + z2);
        System.out.println("Is semi:     " + z3);
        return d / d2;
    }

    public double[] getVariance(double[] dArr) {
        return getVariance(dArr, this.testBatchSize);
    }

    public double[] getVariance(double[] dArr, int i) {
        double[] dArr2 = new double[4];
        double[] dArr3 = new double[this.thisFunc.domainDimension()];
        double[] dArr4 = new double[dArr.length];
        double[] dArr5 = new double[dArr.length];
        new ArrayList();
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
        System.arraycopy(this.thisFunc.derivativeAt(dArr, dArr, this.thisFunc.dataDimension()), 0, dArr5, 0, dArr5.length);
        System.arraycopy(this.thisFunc.HdotVAt(dArr, dArr, dArr5, this.thisFunc.dataDimension()), 0, dArr3, 0, dArr3.length);
        double norm = ArrayMath.norm(dArr3);
        double dataDimension = this.thisFunc.dataDimension() / i;
        this.thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.RandomWithReplacment;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        int i2 = 0;
        System.err.println(dArr3[4] + "  " + dArr[4]);
        for (int i3 = 0; i3 < 100; i3++) {
            System.arraycopy(this.thisFunc.derivativeAt(dArr, dArr, i), 0, dArr5, 0, dArr5.length);
            System.arraycopy(this.thisFunc.HdotVAt(dArr, dArr, dArr5, i), 0, dArr4, 0, dArr4.length);
            ArrayMath.multiplyInPlace(dArr4, dataDimension);
            double norm2 = ArrayMath.norm(dArr4);
            double innerProduct = ArrayMath.innerProduct(dArr4, dArr3) / (norm2 * norm);
            double d5 = norm2 / norm;
            i2++;
            double d6 = innerProduct - d;
            d += d6 / i2;
            d3 += d6 * (innerProduct - d);
            double d7 = d5 - d2;
            d2 += d7 / i2;
            d4 += d7 * (d5 - d2);
        }
        dArr2[0] = d;
        dArr2[1] = d3 / (i2 - 1);
        dArr2[2] = d2;
        dArr2[3] = d4 / (i2 - 1);
        return dArr2;
    }

    public void testVariance(double[] dArr) {
        int[] iArr = {10, 20, 35, 50, 75, 150, 300, 500, 750, 1000, 5000, 10000};
        PrintWriter printWriter = null;
        DecimalFormat decimalFormat = new DecimalFormat("0.000E0");
        try {
            printWriter = new PrintWriter((OutputStream) new FileOutputStream("var.out"), true);
        } catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (int i : iArr) {
            double[] variance = getVariance(dArr, i);
            printWriter.println(i + "," + decimalFormat.format(variance[0]) + "," + decimalFormat.format(variance[1]) + "," + decimalFormat.format(variance[2]) + "," + decimalFormat.format(variance[3]));
            System.err.println("Batch size of: " + i + "   " + variance[0] + "," + decimalFormat.format(variance[1]) + "," + decimalFormat.format(variance[2]) + "," + decimalFormat.format(variance[3]));
        }
        printWriter.close();
    }

    public void listToFile(List<double[]> list, String str) {
        PrintWriter printWriter = null;
        DecimalFormat decimalFormat = new DecimalFormat("0.000E0");
        try {
            printWriter = new PrintWriter((OutputStream) new FileOutputStream(str), true);
        } catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (double[] dArr : list) {
            for (double d : dArr) {
                printWriter.print(decimalFormat.format(d) + "  ");
            }
            printWriter.println("");
        }
        printWriter.close();
    }

    public void arrayToFile(double[] dArr, String str) {
        PrintWriter printWriter = null;
        DecimalFormat decimalFormat = new DecimalFormat("0.000E0");
        try {
            printWriter = new PrintWriter((OutputStream) new FileOutputStream(str), true);
        } catch (IOException e) {
            System.err.println("Caught IOException outputing List to file: " + e.getMessage());
            System.exit(1);
        }
        for (double d : dArr) {
            printWriter.print(decimalFormat.format(d) + "  ");
        }
        printWriter.close();
    }
}
