package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLclassify.CrossValidator;
import edu.stanford.nlp.CLclassify.LogPrior;
import edu.stanford.nlp.CLoptimization.CGMinimizer;
import edu.stanford.nlp.CLoptimization.GoldenSectionLineSearch;
import edu.stanford.nlp.CLoptimization.HybridMinimizer;
import edu.stanford.nlp.CLoptimization.LineSearcher;
import edu.stanford.nlp.CLoptimization.Minimizer;
import edu.stanford.nlp.CLoptimization.QNMinimizer;
import edu.stanford.nlp.CLoptimization.SGDMinimizer;
import edu.stanford.nlp.CLoptimization.SGDToQNMinimizer;
import edu.stanford.nlp.CLoptimization.SMDMinimizer;
import edu.stanford.nlp.CLoptimization.SQNMinimizer;
import edu.stanford.nlp.CLoptimization.StochasticCalculateMethods;
import edu.stanford.nlp.CLsequences.SeqClassifierFlags;
import edu.stanford.nlp.CLstats.MultiClassAccuracyStats;
import edu.stanford.nlp.CLstats.Scorer;
import edu.stanford.nlp.CLutil.ArrayUtils;
import edu.stanford.nlp.CLutil.Function;
import edu.stanford.nlp.CLutil.Index;
import edu.stanford.nlp.CLutil.Pair;
import edu.stanford.nlp.CLutil.Timing;
import edu.stanford.nlp.CLutil.Triple;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/LinearClassifierFactory.class */
public class LinearClassifierFactory extends AbstractLinearClassifierFactory {
    private double TOL;
    private int mem;
    private boolean verbose;
    private LogPrior logPrior;
    private Minimizer minimizer;
    private boolean useSum;
    private boolean tuneSigmaHeldOut;
    private boolean tuneSigmaCV;
    private boolean resetWeight;
    private int folds;
    private double min;
    private double max;
    private boolean retrainFromScratchAfterSigmaTuning;
    protected static double[] sigmasToTry = {0.5d, 1.0d, 2.0d, 4.0d, 10.0d, 20.0d, 100.0d};
    private LineSearcher heldOutSearcher;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/CLclassify/LinearClassifierFactory$NegativeScorer.class */
    public class NegativeScorer implements Function<Double, Double> {
        public double[] weights = null;
        GeneralDataset trainSet;
        GeneralDataset devSet;
        Scorer scorer;
        Timing timer;

        public NegativeScorer(GeneralDataset generalDataset, GeneralDataset generalDataset2, Scorer scorer, Timing timing) {
            this.trainSet = generalDataset;
            this.devSet = generalDataset2;
            this.scorer = scorer;
            this.timer = timing;
        }

        @Override // edu.stanford.nlp.CLutil.Function
        public Double apply(Double d) {
            LinearClassifierFactory.this.setSigma(d.doubleValue());
            double[][] trainWeights = LinearClassifierFactory.this.trainWeights(this.trainSet, this.weights, true);
            this.weights = ArrayUtils.flatten(trainWeights);
            double score = this.scorer.score(new LinearClassifier(trainWeights, this.trainSet.featureIndex, this.trainSet.labelIndex), this.devSet);
            System.err.print("##sigma = " + LinearClassifierFactory.this.getSigma() + " ");
            System.err.println("-> average Score: " + score);
            System.err.println("##time elapsed: " + this.timer.stop() + " milliseconds.");
            this.timer.restart();
            return Double.valueOf(-score);
        }
    }

    public double[][] adaptWeights(double[][] dArr, GeneralDataset generalDataset) {
        System.err.println("adaptWeights in LinearClassifierFactory. increase weight dim only");
        double[][] dArr2 = new double[generalDataset.featureIndex.size()][generalDataset.labelIndex.size()];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        AdaptedGaussianPriorObjectiveFunction adaptedGaussianPriorObjectiveFunction = new AdaptedGaussianPriorObjectiveFunction(generalDataset, this.logPrior, dArr2);
        return adaptedGaussianPriorObjectiveFunction.to2D(this.minimizer.minimize(adaptedGaussianPriorObjectiveFunction, this.TOL, adaptedGaussianPriorObjectiveFunction.initial()));
    }

    @Override // edu.stanford.nlp.CLclassify.AbstractLinearClassifierFactory
    public double[][] trainWeights(GeneralDataset generalDataset) {
        return trainWeights(generalDataset, null);
    }

    public double[][] trainWeights(GeneralDataset generalDataset, double[] dArr) {
        return trainWeights(generalDataset, dArr, false);
    }

    public double[][] trainWeights(GeneralDataset generalDataset, double[] dArr, boolean z) {
        double[] dArr2 = null;
        if (!z) {
            if (this.tuneSigmaHeldOut) {
                dArr2 = heldOutSetSigma(generalDataset);
            } else if (this.tuneSigmaCV) {
                crossValidateSetSigma(generalDataset, this.folds);
            }
        }
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, this.logPrior);
        if (dArr == null && dArr2 != null && !this.retrainFromScratchAfterSigmaTuning) {
            dArr = dArr2;
        }
        if (dArr == null) {
            dArr = logConditionalObjectiveFunction.initial();
        }
        return logConditionalObjectiveFunction.to2D(this.minimizer.minimize(logConditionalObjectiveFunction, this.TOL, dArr));
    }

    public Classifier trainClassifierSemiSup(GeneralDataset generalDataset, GeneralDataset generalDataset2, double[][] dArr, double[] dArr2) {
        return new LinearClassifier(trainWeightsSemiSup(generalDataset, generalDataset2, dArr, dArr2), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public double[][] trainWeightsSemiSup(GeneralDataset generalDataset, GeneralDataset generalDataset2, double[][] dArr, double[] dArr2) {
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, new LogPrior(LogPrior.LogPriorType.NULL));
        SemiSupervisedLogConditionalObjectiveFunction semiSupervisedLogConditionalObjectiveFunction = new SemiSupervisedLogConditionalObjectiveFunction(logConditionalObjectiveFunction, new BiasedLogConditionalObjectiveFunction(generalDataset2, dArr, new LogPrior(LogPrior.LogPriorType.NULL)), this.logPrior);
        if (dArr2 == null) {
            dArr2 = logConditionalObjectiveFunction.initial();
        }
        return logConditionalObjectiveFunction.to2D(this.minimizer.minimize(semiSupervisedLogConditionalObjectiveFunction, this.TOL, dArr2));
    }

    public Classifier trainClassifierV(GeneralDataset generalDataset, GeneralDataset generalDataset2, double d, double d2, boolean z) {
        this.labelIndex = generalDataset.labelIndex();
        this.featureIndex = generalDataset.featureIndex();
        this.min = d;
        this.max = d2;
        heldOutSetSigma(generalDataset, generalDataset2);
        return new LinearClassifier(trainWeights(generalDataset), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public Classifier trainClassifierV(GeneralDataset generalDataset, double d, double d2, boolean z) {
        this.labelIndex = generalDataset.labelIndex();
        this.featureIndex = generalDataset.featureIndex();
        this.tuneSigmaHeldOut = true;
        this.min = d;
        this.max = d2;
        heldOutSetSigma(generalDataset);
        return new LinearClassifier(trainWeights(generalDataset), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public LinearClassifierFactory() {
        this(new QNMinimizer(15));
    }

    public LinearClassifierFactory(Minimizer minimizer) {
        this(minimizer, false);
    }

    public LinearClassifierFactory(boolean z) {
        this(new QNMinimizer(15), z);
    }

    public LinearClassifierFactory(double d) {
        this((Minimizer) new QNMinimizer(15), d, false);
    }

    public LinearClassifierFactory(Minimizer minimizer, boolean z) {
        this(minimizer, 1.0E-4d, z);
    }

    public LinearClassifierFactory(Minimizer minimizer, double d, boolean z) {
        this(minimizer, d, z, 1.0d);
    }

    public LinearClassifierFactory(double d, boolean z, double d2) {
        this(new QNMinimizer(15), d, z, d2);
    }

    public LinearClassifierFactory(Minimizer minimizer, double d, boolean z, double d2) {
        this(minimizer, d, z, LogPrior.LogPriorType.QUADRATIC.ordinal(), d2);
    }

    public LinearClassifierFactory(Minimizer minimizer, double d, boolean z, int i, double d2) {
        this(minimizer, d, z, i, d2, 0.0d);
    }

    public LinearClassifierFactory(double d, boolean z, int i, double d2, double d3) {
        this(new QNMinimizer(15), d, z, new LogPrior(i, d2, d3));
    }

    public LinearClassifierFactory(double d, boolean z, int i, double d2, double d3, int i2) {
        this(new QNMinimizer(i2), d, z, new LogPrior(i, d2, d3));
    }

    public LinearClassifierFactory(Minimizer minimizer, double d, boolean z, int i, double d2, double d3) {
        this(minimizer, d, z, new LogPrior(i, d2, d3));
    }

    public LinearClassifierFactory(Minimizer minimizer, double d, boolean z, LogPrior logPrior) {
        this.mem = 15;
        this.verbose = false;
        this.useSum = false;
        this.tuneSigmaHeldOut = false;
        this.tuneSigmaCV = false;
        this.resetWeight = true;
        this.min = 0.1d;
        this.max = 10.0d;
        this.retrainFromScratchAfterSigmaTuning = false;
        this.heldOutSearcher = null;
        this.minimizer = minimizer;
        this.TOL = d;
        this.useSum = z;
        this.logPrior = logPrior;
    }

    public void setTol(double d) {
        this.TOL = d;
    }

    public void setPrior(LogPrior logPrior) {
        this.logPrior = logPrior;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public void setMinimizer(Minimizer minimizer) {
        this.minimizer = minimizer;
    }

    public void setEpsilon(double d) {
        this.logPrior.setEpsilon(d);
    }

    public void setSigma(double d) {
        this.logPrior.setSigma(d);
    }

    public double getSigma() {
        return this.logPrior.getSigma();
    }

    public void useQuasiNewton() {
        this.minimizer = new QNMinimizer(this.mem);
    }

    public void useQuasiNewton(boolean z) {
        this.minimizer = new QNMinimizer(this.mem, z);
    }

    public void useStochasticQN(double d, int i) {
        this.minimizer = new SQNMinimizer(this.mem, d, i, false);
    }

    public void useStochasticMetaDescent() {
        useStochasticMetaDescent(0.1d, 15, StochasticCalculateMethods.ExternalFiniteDifference, 20);
    }

    public void useStochasticMetaDescent(double d, int i, StochasticCalculateMethods stochasticCalculateMethods, int i2) {
        this.minimizer = new SMDMinimizer(d, i, stochasticCalculateMethods, i2);
    }

    public void useStochasticGradientDescent() {
        useStochasticGradientDescent(0.1d, 15);
    }

    public void useStochasticGradientDescent(double d, int i) {
        this.minimizer = new SGDMinimizer(d, i);
    }

    public void useStochasticGradientDescentToQuasiNewton(SeqClassifierFlags seqClassifierFlags) {
        this.minimizer = new SGDToQNMinimizer(seqClassifierFlags);
    }

    public void useHybridMinimizer() {
        useHybridMinimizer(0.1d, 15, StochasticCalculateMethods.ExternalFiniteDifference, 0);
    }

    public void useHybridMinimizer(double d, int i, StochasticCalculateMethods stochasticCalculateMethods, int i2) {
        this.minimizer = new HybridMinimizer(new SMDMinimizer(d, i, stochasticCalculateMethods, i2), new QNMinimizer(this.mem), i2);
    }

    public void setMem(int i) {
        this.mem = i;
    }

    public void useConjugateGradientAscent(boolean z) {
        this.verbose = z;
        useConjugateGradientAscent();
    }

    public void useConjugateGradientAscent() {
        this.minimizer = new CGMinimizer(!this.verbose);
    }

    public void setUseSum(boolean z) {
        this.useSum = z;
    }

    public void setTuneSigmaHeldOut() {
        this.tuneSigmaHeldOut = true;
        this.tuneSigmaCV = false;
    }

    public void setTuneSigmaCV(int i) {
        this.tuneSigmaCV = true;
        this.tuneSigmaHeldOut = false;
        this.folds = i;
    }

    public void resetWeight() {
        this.resetWeight = true;
    }

    public void crossValidateSetSigma(GeneralDataset generalDataset) {
        crossValidateSetSigma(generalDataset, 5);
    }

    public void crossValidateSetSigma(GeneralDataset generalDataset, int i) {
        System.err.println("##you are here.");
        crossValidateSetSigma(generalDataset, i, new MultiClassAccuracyStats(2), new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset generalDataset, int i, Scorer scorer) {
        crossValidateSetSigma(generalDataset, i, scorer, new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset generalDataset, int i, LineSearcher lineSearcher) {
        crossValidateSetSigma(generalDataset, i, new MultiClassAccuracyStats(2), lineSearcher);
    }

    public void crossValidateSetSigma(GeneralDataset generalDataset, int i, final Scorer scorer, LineSearcher lineSearcher) {
        System.err.println("##in Cross Validate, folds = " + i);
        System.err.println("##Scorer is " + scorer);
        this.featureIndex = generalDataset.featureIndex;
        this.labelIndex = generalDataset.labelIndex;
        final CrossValidator crossValidator = new CrossValidator(generalDataset, i);
        final Function<Triple<GeneralDataset, GeneralDataset, CrossValidator.SavedState>, Double> function = new Function<Triple<GeneralDataset, GeneralDataset, CrossValidator.SavedState>, Double>() { // from class: edu.stanford.nlp.CLclassify.LinearClassifierFactory.1
            @Override // edu.stanford.nlp.CLutil.Function
            public Double apply(Triple<GeneralDataset, GeneralDataset, CrossValidator.SavedState> triple) {
                GeneralDataset first = triple.first();
                GeneralDataset second = triple.second();
                double[][] trainWeights = LinearClassifierFactory.this.trainWeights(first, (double[]) triple.third().state, true);
                triple.third().state = ArrayUtils.flatten(trainWeights);
                double score = scorer.score(new LinearClassifier(trainWeights, first.featureIndex, first.labelIndex), second);
                System.out.print(".");
                return Double.valueOf(score);
            }
        };
        double minimize = lineSearcher.minimize(new Function<Double, Double>() { // from class: edu.stanford.nlp.CLclassify.LinearClassifierFactory.2
            @Override // edu.stanford.nlp.CLutil.Function
            public Double apply(Double d) {
                LinearClassifierFactory.this.setSigma(d.doubleValue());
                Double valueOf = Double.valueOf(crossValidator.computeAverage(function));
                System.err.print("##sigma = " + LinearClassifierFactory.this.getSigma() + " ");
                System.err.println("-> average Score: " + valueOf);
                return Double.valueOf(-valueOf.doubleValue());
            }
        });
        System.err.println("##best sigma: " + minimize);
        setSigma(minimize);
    }

    public void setHeldOutSearcher(LineSearcher lineSearcher) {
        this.heldOutSearcher = lineSearcher;
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset) {
        Pair<GeneralDataset, GeneralDataset> split = generalDataset.split(0.3d);
        return heldOutSetSigma(split.first(), split.second());
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset, Scorer scorer) {
        Pair<GeneralDataset, GeneralDataset> split = generalDataset.split(0.3d);
        return heldOutSetSigma(split.first(), split.second(), scorer);
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset, GeneralDataset generalDataset2) {
        return heldOutSetSigma(generalDataset, generalDataset2, new MultiClassAccuracyStats(2), this.heldOutSearcher == null ? new GoldenSectionLineSearch(true, 0.01d, this.min, this.max) : this.heldOutSearcher);
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset, GeneralDataset generalDataset2, Scorer scorer) {
        return heldOutSetSigma(generalDataset, generalDataset2, scorer, new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset, GeneralDataset generalDataset2, LineSearcher lineSearcher) {
        return heldOutSetSigma(generalDataset, generalDataset2, new MultiClassAccuracyStats(2), lineSearcher);
    }

    public double[] heldOutSetSigma(GeneralDataset generalDataset, GeneralDataset generalDataset2, Scorer scorer, LineSearcher lineSearcher) {
        this.featureIndex = generalDataset.featureIndex;
        this.labelIndex = generalDataset.labelIndex;
        Timing timing = new Timing();
        NegativeScorer negativeScorer = new NegativeScorer(generalDataset, generalDataset2, scorer, timing);
        timing.start();
        double minimize = lineSearcher.minimize(negativeScorer);
        System.err.println("##best sigma: " + minimize);
        setSigma(minimize);
        return ArrayUtils.flatten(trainWeights(generalDataset, negativeScorer.weights, true));
    }

    public void setRetrainFromScratchAfterSigmaTuning(boolean z) {
        this.retrainFromScratchAfterSigmaTuning = z;
    }

    public Classifier trainClassifier(GeneralDataset generalDataset, float[] fArr, LogPrior logPrior) {
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, fArr, this.logPrior);
        return new LinearClassifier(logConditionalObjectiveFunction.to2D(this.minimizer.minimize(logConditionalObjectiveFunction, this.TOL, logConditionalObjectiveFunction.initial())), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    @Override // edu.stanford.nlp.CLclassify.AbstractLinearClassifierFactory
    public Classifier trainClassifier(GeneralDataset generalDataset) {
        return trainClassifier(generalDataset, null);
    }

    public Classifier trainClassifier(GeneralDataset generalDataset, double[] dArr) {
        return new LinearClassifier(trainWeights(generalDataset, dArr, false), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public Classifier loadFromFilename(String str) {
        try {
            File file = new File(str);
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            Index<String> loadFromReader = Index.loadFromReader(bufferedReader);
            Index<String> loadFromReader2 = Index.loadFromReader(bufferedReader);
            double[][] dArr = new double[loadFromReader2.size()][loadFromReader.size()];
            int i = 1;
            for (String readLine = bufferedReader.readLine(); readLine != null && readLine.length() > 0; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER);
                if (split.length != 3) {
                    throw new Exception("Error: incorrect number of tokens in weight specifier, line=" + i + " in file " + file.getAbsolutePath());
                }
                i++;
                dArr[Integer.valueOf(split[0]).intValue()][Integer.valueOf(split[1]).intValue()] = Double.valueOf(split[2]).doubleValue();
            }
            double[] dArr2 = new double[Integer.valueOf(bufferedReader.readLine()).intValue()];
            int i2 = 0;
            while (true) {
                String readLine2 = bufferedReader.readLine();
                if (readLine2 == null) {
                    return new LinearClassifier(dArr, loadFromReader2, loadFromReader);
                }
                int i3 = i2;
                i2++;
                dArr2[i3] = Double.valueOf(readLine2.trim()).doubleValue();
            }
        } catch (Exception e) {
            System.err.println("Error in LinearClassifierFactory, loading from file=" + str);
            e.printStackTrace();
            return null;
        }
    }
}
