package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLclassify.LogPrior;
import edu.stanford.nlp.CLling.Datum;
import edu.stanford.nlp.CLling.RVFDatum;
import edu.stanford.nlp.CLoptimization.QNMinimizer;
import edu.stanford.nlp.CLstats.ClassicCounter;
import edu.stanford.nlp.CLstats.Counter;
import edu.stanford.nlp.CLutil.FileLines;
import edu.stanford.nlp.CLutil.Index;
import edu.stanford.nlp.CLutil.StringUtils;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Properties;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/LogisticClassifier.class */
public class LogisticClassifier implements Classifier, Serializable, RVFClassifier {
    private double[] weights;
    private Index featureIndex;
    private Object[] classes;
    private LogPrior prior;
    private boolean biased;

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        Iterator it = this.featureIndex.iterator();
        while (it.hasNext()) {
            Object next = it.next();
            System.err.println(this.classes[1] + " / " + next + " = " + this.weights[this.featureIndex.indexOf(next)]);
        }
        return stringBuffer.toString();
    }

    public Counter weightsAsCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator it = this.featureIndex.iterator();
        while (it.hasNext()) {
            Object next = it.next();
            classicCounter.incrementCount(this.classes[1] + " / " + next, this.weights[this.featureIndex.indexOf(next)]);
        }
        return classicCounter;
    }

    public LogisticClassifier() {
        this(new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public LogisticClassifier(boolean z) {
        this(new LogPrior(LogPrior.LogPriorType.QUADRATIC), z);
    }

    public LogisticClassifier(LogPrior logPrior) {
        this.classes = new Object[2];
        this.biased = false;
        this.prior = logPrior;
    }

    public LogisticClassifier(LogPrior logPrior, boolean z) {
        this.classes = new Object[2];
        this.biased = false;
        this.prior = logPrior;
        this.biased = z;
    }

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public Collection labels() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.classes[0]);
        linkedList.add(this.classes[1]);
        return linkedList;
    }

    public Object classOf(Collection collection) {
        return scoreOf(collection) > 0.0d ? this.classes[1] : this.classes[0];
    }

    public double scoreOf(Collection collection) {
        double d = 0.0d;
        Iterator it = collection.iterator();
        while (it.hasNext()) {
            int indexOf = this.featureIndex.indexOf(it.next());
            if (indexOf >= 0) {
                d += this.weights[indexOf];
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public ClassicCounter scoresOf(Datum datum) {
        double scoreOf = scoreOf(datum.asFeatures());
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount((ClassicCounter) this.classes[0], scoreOf);
        classicCounter.setCount((ClassicCounter) this.classes[1], 1.0d - scoreOf);
        return classicCounter;
    }

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public Object classOf(Datum datum) {
        return classOf(datum.asFeatures());
    }

    public Object classOf(ClassicCounter classicCounter) {
        return scoreOf(classicCounter) > 0.0d ? this.classes[1] : this.classes[0];
    }

    public double scoreOf(ClassicCounter classicCounter) {
        double d = 0.0d;
        for (Object obj : classicCounter.keySet()) {
            int indexOf = this.featureIndex.indexOf(obj);
            if (indexOf >= 0) {
                d += this.weights[indexOf] * classicCounter.getCount(obj);
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.CLclassify.RVFClassifier
    public Object classOf(RVFDatum rVFDatum) {
        return classOf(rVFDatum.asFeaturesCounter());
    }

    @Override // edu.stanford.nlp.CLclassify.RVFClassifier
    public ClassicCounter scoresOf(RVFDatum rVFDatum) {
        double scoreOf = scoreOf(rVFDatum.asFeaturesCounter());
        ClassicCounter classicCounter = new ClassicCounter();
        System.out.println(this.classes[0] + ": " + scoreOf + " ; " + this.classes[1] + ": " + (1.0d - scoreOf));
        classicCounter.setCount((ClassicCounter) this.classes[0], scoreOf);
        classicCounter.setCount((ClassicCounter) this.classes[1], 1.0d - scoreOf);
        return classicCounter;
    }

    public double probabilityOf(Datum datum) {
        return probabilityOf(datum.asFeatures(), datum.label());
    }

    public double probabilityOf(Collection collection, Object obj) {
        return 1.0d / (1.0d + Math.exp(((short) (obj.equals(this.classes[0]) ? 1 : -1)) * scoreOf(collection)));
    }

    public double probabilityOf(RVFDatum rVFDatum) {
        return probabilityOf(rVFDatum.asFeaturesCounter(), rVFDatum.label());
    }

    public double probabilityOf(ClassicCounter classicCounter, Object obj) {
        return 1.0d / (1.0d + Math.exp(((short) (obj.equals(this.classes[0]) ? 1 : -1)) * scoreOf(classicCounter)));
    }

    public void train(GeneralDataset generalDataset) {
        if (generalDataset.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        if (this.biased) {
            BiasedLogisticObjectiveFunction biasedLogisticObjectiveFunction = new BiasedLogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), this.prior);
            this.weights = new QNMinimizer(biasedLogisticObjectiveFunction).minimize((QNMinimizer) biasedLogisticObjectiveFunction, 1.0E-4d, new double[generalDataset.numFeatureTypes()]);
        } else {
            LogisticObjectiveFunction logisticObjectiveFunction = new LogisticObjectiveFunction(generalDataset.numFeatureTypes(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), this.prior);
            this.weights = new QNMinimizer(logisticObjectiveFunction).minimize((QNMinimizer) logisticObjectiveFunction, 1.0E-4d, new double[generalDataset.numFeatureTypes()]);
        }
        this.featureIndex = generalDataset.featureIndex;
        this.classes[0] = generalDataset.labelIndex.get(0);
        this.classes[1] = generalDataset.labelIndex.get(1);
    }

    public static void main(String[] strArr) throws Exception {
        Properties argsToProperties = StringUtils.argsToProperties(strArr);
        Dataset dataset = new Dataset();
        Iterator<String> it = new FileLines(argsToProperties.getProperty("trainFile")).iterator();
        while (it.hasNext()) {
            String[] split = it.next().split("\\s+");
            LinkedList linkedList = new LinkedList();
            String str = split[0];
            for (int i = 1; i < split.length; i++) {
                linkedList.add(split[i]);
            }
            dataset.add(linkedList, str);
        }
        dataset.summaryStatistics();
        LogisticClassifier logisticClassifier = new LogisticClassifier();
        if (argsToProperties.getProperty("biased", "false").equals("true")) {
            logisticClassifier.biased = true;
        }
        logisticClassifier.train(dataset);
        Iterator<String> it2 = new FileLines(argsToProperties.getProperty("testFile")).iterator();
        while (it2.hasNext()) {
            String next = it2.next();
            String[] split2 = next.split("\\s+");
            LinkedList linkedList2 = new LinkedList();
            String str2 = split2[0];
            for (int i2 = 1; i2 < split2.length; i2++) {
                linkedList2.add(split2[i2]);
            }
            System.out.println(((String) logisticClassifier.classOf(linkedList2)) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + next);
        }
    }
}
