package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLling.Datum;
import edu.stanford.nlp.CLling.RVFDatum;
import edu.stanford.nlp.CLstats.ClassicCounter;
import edu.stanford.nlp.CLstats.Counters;
import edu.stanford.nlp.CLutil.Pair;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/NaiveBayesClassifier.class */
public class NaiveBayesClassifier implements Classifier, RVFClassifier {
    ClassicCounter weights;
    ClassicCounter priors;
    Set features;
    private boolean addZeroValued;
    ClassicCounter priorZero;
    Set labels;
    private final Integer zero;

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public Collection labels() {
        return this.labels;
    }

    @Override // edu.stanford.nlp.CLclassify.RVFClassifier
    public Object classOf(RVFDatum rVFDatum) {
        return Counters.argmax(scoresOf(rVFDatum));
    }

    @Override // edu.stanford.nlp.CLclassify.RVFClassifier
    public ClassicCounter scoresOf(RVFDatum rVFDatum) {
        ClassicCounter classicCounter = new ClassicCounter();
        Counters.addInPlace(classicCounter, this.priors);
        if (this.addZeroValued) {
            Counters.addInPlace(classicCounter, this.priorZero);
        }
        for (Object obj : this.labels) {
            double d = 0.0d;
            ClassicCounter asFeaturesCounter = rVFDatum.asFeaturesCounter();
            for (Object obj2 : asFeaturesCounter.keySet()) {
                d += weight(obj, obj2, new Integer((int) asFeaturesCounter.getCount(obj2)));
                if (this.addZeroValued) {
                    d -= weight(obj, obj2, this.zero);
                }
            }
            classicCounter.incrementCount(obj, d);
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public Object classOf(Datum datum) {
        return classOf(new RVFDatum(datum));
    }

    @Override // edu.stanford.nlp.CLclassify.Classifier
    public ClassicCounter scoresOf(Datum datum) {
        return scoresOf(new RVFDatum(datum));
    }

    public NaiveBayesClassifier(ClassicCounter classicCounter, ClassicCounter classicCounter2, Set set, Set set2, boolean z) {
        this.zero = new Integer(0);
        this.weights = classicCounter;
        this.features = set2;
        this.priors = classicCounter2;
        this.labels = set;
        this.addZeroValued = z;
        if (this.addZeroValued) {
            initZeros();
        }
    }

    public float accuracy(Iterator it) {
        int i = 0;
        int i2 = 0;
        while (it.hasNext()) {
            RVFDatum rVFDatum = (RVFDatum) it.next();
            if (classOf(rVFDatum).equals(rVFDatum.label())) {
                i++;
            }
            i2++;
        }
        System.err.println("correct " + i + " out of " + i2);
        return i / i2;
    }

    public void print(PrintStream printStream) {
        printStream.println("priors ");
        printStream.println(this.priors.toString());
        printStream.println("weights ");
        printStream.println(this.weights.toString());
    }

    public void print() {
        print(System.out);
    }

    private double weight(Object obj, Object obj2, Object obj3) {
        return this.weights.getCount(new Pair(new Pair(obj, obj2), obj3));
    }

    public NaiveBayesClassifier(ClassicCounter classicCounter, ClassicCounter classicCounter2, Set set) {
        this(classicCounter, classicCounter2, set, null, false);
    }

    private void initZeros() {
        this.priorZero = new ClassicCounter();
        for (Object obj : this.labels) {
            double d = 0.0d;
            Iterator it = this.features.iterator();
            while (it.hasNext()) {
                d += weight(obj, it.next(), this.zero);
            }
            this.priorZero.setCount((ClassicCounter) obj, d);
        }
    }
}
