package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLclassify.LogPrior;
import edu.stanford.nlp.CLoptimization.AbstractCachingDiffFunction;
import java.util.Arrays;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/BiasedLogisticObjectiveFunction.class */
public class BiasedLogisticObjectiveFunction extends AbstractCachingDiffFunction {
    private final int numFeatures;
    private final int[][] data;
    private final double[][] dataValues;
    private final int[] labels;
    protected float[] dataweights;
    private final LogPrior prior;
    double probCorrect;

    @Override // edu.stanford.nlp.CLoptimization.AbstractCachingDiffFunction, edu.stanford.nlp.CLoptimization.Function
    public int domainDimension() {
        return this.numFeatures;
    }

    @Override // edu.stanford.nlp.CLoptimization.AbstractCachingDiffFunction
    protected void calculate(double[] dArr) {
        double d;
        if (this.dataValues != null) {
            throw new RuntimeException();
        }
        this.value = 0.0d;
        Arrays.fill(this.derivative, 0.0d);
        for (int i = 0; i < this.data.length; i++) {
            int[] iArr = this.data[i];
            double d2 = 0.0d;
            for (int i2 : iArr) {
                d2 += dArr[i2];
            }
            if (this.dataweights != null) {
                throw new RuntimeException();
            }
            if (this.labels[i] == 1) {
                double exp = 1.0d / (1.0d + Math.exp(-d2));
                this.value -= Math.log(exp);
                d = exp - 1.0d;
            } else {
                double exp2 = 1.0d / (1.0d + Math.exp(-d2));
                double d3 = ((1.0d - this.probCorrect) * exp2) + (this.probCorrect * (1.0d - exp2));
                this.value -= Math.log(d3);
                d = (-((exp2 * (1.0d - exp2)) * (1.0d - (2.0d * this.probCorrect)))) / d3;
            }
            for (int i3 : iArr) {
                double[] dArr2 = this.derivative;
                dArr2[i3] = dArr2[i3] + d;
            }
        }
        this.value += this.prior.compute(dArr, this.derivative);
    }

    protected void calculateRVF(double[] dArr) {
        double exp;
        double d;
        this.value = 0.0d;
        Arrays.fill(this.derivative, 0.0d);
        for (int i = 0; i < this.data.length; i++) {
            int[] iArr = this.data[i];
            double[] dArr2 = this.dataValues[i];
            double d2 = 0.0d;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                d2 += dArr[iArr[i2]] * dArr2[iArr[i2]];
            }
            if (this.labels[i] == 0) {
                exp = Math.exp(d2);
                d = 1.0d / (1.0d + (1.0d / exp));
            } else {
                exp = Math.exp(-d2);
                d = (-1.0d) / (1.0d + (1.0d / exp));
            }
            if (this.dataweights == null) {
                this.value += Math.log(1.0d + exp);
            } else {
                this.value += Math.log(1.0d + exp) * this.dataweights[i];
                d *= this.dataweights[i];
            }
            for (int i3 = 0; i3 < iArr.length; i3++) {
                double[] dArr3 = this.derivative;
                int i4 = iArr[i3];
                dArr3[i4] = dArr3[i4] + (dArr2[iArr[i3]] * d);
            }
        }
        this.value += this.prior.compute(dArr, this.derivative);
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, int[] iArr2) {
        this(i, iArr, iArr2, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, int[] iArr2, LogPrior logPrior) {
        this(i, iArr, iArr2, logPrior, (float[]) null);
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, int[] iArr2, float[] fArr) {
        this(i, iArr, iArr2, new LogPrior(LogPrior.LogPriorType.QUADRATIC), fArr);
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, int[] iArr2, LogPrior logPrior, float[] fArr) {
        this(i, iArr, (double[][]) null, iArr2, logPrior, fArr);
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, double[][] dArr, int[] iArr2) {
        this(i, iArr, dArr, iArr2, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, double[][] dArr, int[] iArr2, LogPrior logPrior) {
        this(i, iArr, dArr, iArr2, logPrior, null);
    }

    public BiasedLogisticObjectiveFunction(int i, int[][] iArr, double[][] dArr, int[] iArr2, LogPrior logPrior, float[] fArr) {
        this.dataweights = null;
        this.probCorrect = 0.7d;
        this.numFeatures = i;
        this.data = iArr;
        this.labels = iArr2;
        this.prior = logPrior;
        this.dataweights = fArr;
        this.dataValues = dArr;
    }
}
