package edu.stanford.nlp.maxent;

import edu.stanford.nlp.maxent.iis.LambdaSolve;
import edu.stanford.nlp.optimization.CGMinimizer;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.util.ReflectionLoading;
import java.util.Arrays;

/* loaded from: input_file:edu/stanford/nlp/maxent/CGRunner.class */
public class CGRunner {
    private static final boolean SAVE_LAMBDAS_REGULARLY = false;
    private final LambdaSolve prob;
    private final String filename;
    private final double tol;
    private final boolean useGaussianPrior;
    private final double priorSigmaS;
    private final double[] sigmaSquareds;
    private static final double DEFAULT_TOLERANCE = 1.0E-4d;
    private static final double DEFAULT_SIGMASQUARED = 0.5d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/maxent/CGRunner$LikelihoodFunction.class */
    public static final class LikelihoodFunction implements DiffFunction {
        private final LambdaSolve model;
        private final double tol;
        private final boolean useGaussianPrior;
        private final double[] sigmaSquareds;
        private int valueAtCalls;
        private double likelihood;

        public LikelihoodFunction(LambdaSolve lambdaSolve, double d, boolean z, double d2, double[] dArr) {
            this.model = lambdaSolve;
            this.tol = d;
            this.useGaussianPrior = z;
            if (!z) {
                this.sigmaSquareds = null;
                return;
            }
            this.sigmaSquareds = new double[this.model.lambda.length];
            if (dArr != null) {
                System.arraycopy(dArr, 0, this.sigmaSquareds, 0, dArr.length);
            } else {
                Arrays.fill(this.sigmaSquareds, d2);
            }
        }

        @Override // edu.stanford.nlp.optimization.Function
        public int domainDimension() {
            return this.model.lambda.length;
        }

        public double likelihood() {
            return this.likelihood;
        }

        public int numCalls() {
            return this.valueAtCalls;
        }

        @Override // edu.stanford.nlp.optimization.Function
        public double valueAt(double[] dArr) {
            this.valueAtCalls++;
            this.model.lambda = dArr;
            double logLikelihoodScratch = this.model.logLikelihoodScratch();
            if (this.useGaussianPrior) {
                for (int i = 0; i < dArr.length; i++) {
                    logLikelihoodScratch += (dArr[i] * dArr[i]) / (this.sigmaSquareds[i] + this.sigmaSquareds[i]);
                }
            }
            this.likelihood = logLikelihoodScratch;
            return logLikelihoodScratch;
        }

        @Override // edu.stanford.nlp.optimization.DiffFunction
        public double[] derivativeAt(double[] dArr) {
            boolean z = true;
            int i = 0;
            while (true) {
                if (i >= dArr.length) {
                    break;
                }
                if (Math.abs(dArr[i] - this.model.lambda[i]) > this.tol) {
                    z = false;
                    break;
                }
                i++;
            }
            if (!z) {
                System.err.println("derivativeAt: call with different value");
                valueAt(dArr);
            }
            double[] derivatives = this.model.getDerivatives();
            if (this.useGaussianPrior) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    int i3 = i2;
                    derivatives[i3] = derivatives[i3] + (dArr[i2] / this.sigmaSquareds[i2]);
                }
            }
            return derivatives;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/maxent/CGRunner$MonitorFunction.class */
    public static final class MonitorFunction implements Function {
        private final LambdaSolve model;
        private final LikelihoodFunction lf;
        private final String filename;
        private int iterations;

        public MonitorFunction(LambdaSolve lambdaSolve, LikelihoodFunction likelihoodFunction, String str) {
            this.model = lambdaSolve;
            this.lf = likelihoodFunction;
            this.filename = str;
        }

        @Override // edu.stanford.nlp.optimization.Function
        public double valueAt(double[] dArr) {
            double likelihood = this.lf.likelihood();
            System.err.println();
            System.err.print(reportMonitoring(likelihood));
            if (this.iterations > 0 && this.iterations % 30 == 0) {
                this.model.checkCorrectness();
            }
            this.iterations++;
            return 42.0d;
        }

        public String reportMonitoring(double d) {
            return "Iter. " + this.iterations + ": neg. log cond. likelihood = " + d + " [" + this.lf.numCalls() + " calls to valueAt]";
        }

        @Override // edu.stanford.nlp.optimization.Function
        public int domainDimension() {
            return this.lf.domainDimension();
        }
    }

    public CGRunner(LambdaSolve lambdaSolve, String str) {
        this(lambdaSolve, str, DEFAULT_SIGMASQUARED);
    }

    public CGRunner(LambdaSolve lambdaSolve, String str, double d) {
        this(lambdaSolve, str, DEFAULT_TOLERANCE, d);
    }

    public CGRunner(LambdaSolve lambdaSolve, String str, double d, double d2) {
        this.prob = lambdaSolve;
        this.filename = str;
        this.tol = d;
        this.useGaussianPrior = (d2 == 0.0d || d2 == Double.POSITIVE_INFINITY) ? false : true;
        this.priorSigmaS = d2;
        this.sigmaSquareds = null;
    }

    public CGRunner(LambdaSolve lambdaSolve, String str, double d, double[] dArr) {
        this.prob = lambdaSolve;
        this.filename = str;
        this.tol = d;
        this.useGaussianPrior = dArr != null;
        this.sigmaSquareds = dArr;
        this.priorSigmaS = -1.0d;
    }

    public void solve() {
        solveQN();
    }

    public void solveQN() {
        LikelihoodFunction likelihoodFunction = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        MonitorFunction monitorFunction = new MonitorFunction(this.prob, likelihoodFunction, this.filename);
        double[] minimize = new QNMinimizer(monitorFunction, 10).minimize((QNMinimizer) likelihoodFunction, this.tol, new double[likelihoodFunction.domainDimension()]);
        this.prob.lambda = minimize;
        monitorFunction.reportMonitoring(likelihoodFunction.valueAt(minimize));
        System.err.println("after optimization value is " + likelihoodFunction.valueAt(minimize));
    }

    public void solveCG() {
        LikelihoodFunction likelihoodFunction = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        MonitorFunction monitorFunction = new MonitorFunction(this.prob, likelihoodFunction, this.filename);
        double[] minimize = new CGMinimizer(monitorFunction).minimize((CGMinimizer) likelihoodFunction, this.tol, new double[likelihoodFunction.domainDimension()]);
        this.prob.lambda = minimize;
        monitorFunction.reportMonitoring(likelihoodFunction.valueAt(minimize));
        System.err.println("after optimization value is " + likelihoodFunction.valueAt(minimize));
    }

    public void solveL1(double d) {
        LikelihoodFunction likelihoodFunction = new LikelihoodFunction(this.prob, this.tol, this.useGaussianPrior, this.priorSigmaS, this.sigmaSquareds);
        double[] minimize = ((Minimizer) ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", Double.valueOf(d))).minimize(likelihoodFunction, this.tol, new double[likelihoodFunction.domainDimension()]);
        this.prob.lambda = minimize;
        System.err.println("after optimization value is " + likelihoodFunction.valueAt(minimize));
    }
}
