package de.tudarmstadt.ukp.inception.recommendation.api.evaluation;

import java.io.Serializable;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.Collector;

/* loaded from: input_file:de/tudarmstadt/ukp/inception/recommendation/api/evaluation/EvaluationResult.class */
public class EvaluationResult implements Serializable {
    private static final long serialVersionUID = 5842125748342833451L;
    private final int trainingSetSize;
    private final int testSetSize;
    private final double trainingDataRatio;
    private boolean skippedEvaluation;
    private String errorMsg;
    private final Set<String> ignoreLabels;
    private ConfusionMatrix confusionMatrix;

    /* loaded from: input_file:de/tudarmstadt/ukp/inception/recommendation/api/evaluation/EvaluationResult$EvaluationResultCollector.class */
    public static class EvaluationResultCollector implements Collector<LabelPair, ConfusionMatrix, EvaluationResult> {
        private final Set<String> ignoreLabels;
        private final int testSize;
        private final int trainSize;
        private final double trainDataRatio;

        public EvaluationResultCollector(int i, int i2, double d, String... strArr) {
            this.ignoreLabels = new HashSet();
            this.testSize = i2;
            this.trainSize = i;
            this.trainDataRatio = d;
            Collections.addAll(this.ignoreLabels, strArr);
        }

        public EvaluationResultCollector() {
            this.ignoreLabels = new HashSet();
            this.testSize = 0;
            this.trainSize = 0;
            this.trainDataRatio = 0.0d;
        }

        @Override // java.util.stream.Collector
        public Supplier<ConfusionMatrix> supplier() {
            return ConfusionMatrix::new;
        }

        @Override // java.util.stream.Collector
        public BiConsumer<ConfusionMatrix, LabelPair> accumulator() {
            return (confusionMatrix, labelPair) -> {
                confusionMatrix.incrementCounts(labelPair.getPredictedLabel(), labelPair.getGoldLabel());
            };
        }

        @Override // java.util.stream.Collector
        public BinaryOperator<ConfusionMatrix> combiner() {
            return (confusionMatrix, confusionMatrix2) -> {
                confusionMatrix.addMatrix(confusionMatrix2);
                return confusionMatrix;
            };
        }

        @Override // java.util.stream.Collector
        public Function<ConfusionMatrix, EvaluationResult> finisher() {
            return confusionMatrix -> {
                return new EvaluationResult(confusionMatrix, this.trainSize, this.testSize, this.trainDataRatio, this.ignoreLabels);
            };
        }

        @Override // java.util.stream.Collector
        public Set<Collector.Characteristics> characteristics() {
            return Collections.emptySet();
        }
    }

    public EvaluationResult() {
        this.ignoreLabels = new LinkedHashSet();
        this.confusionMatrix = new ConfusionMatrix();
        this.trainingSetSize = 0;
        this.testSetSize = 0;
        this.trainingDataRatio = 0.0d;
    }

    public EvaluationResult(ConfusionMatrix confusionMatrix, int i, int i2, double d, Set<String> set) {
        this.ignoreLabels = new LinkedHashSet();
        this.ignoreLabels.addAll(set);
        this.confusionMatrix = confusionMatrix;
        this.trainingSetSize = i;
        this.testSetSize = i2;
        this.trainingDataRatio = d;
    }

    public EvaluationResult(int i, int i2, double d) {
        this.ignoreLabels = new HashSet();
        this.confusionMatrix = new ConfusionMatrix();
        this.trainingSetSize = i;
        this.testSetSize = i2;
        this.trainingDataRatio = d;
    }

    public int getNumOfLabels() {
        Set<String> labels = this.confusionMatrix.getLabels();
        return this.ignoreLabels.isEmpty() ? labels.size() : Math.toIntExact(labels.stream().filter(str -> {
            return !this.ignoreLabels.contains(str);
        }).count());
    }

    public double computeAccuracyScore() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (String str : this.confusionMatrix.getLabels()) {
            if (!this.ignoreLabels.contains(str)) {
                d += this.confusionMatrix.getEntryCount(str, str);
            }
            d2 += countIgnoreLabelsAsGold(str);
        }
        double total = this.confusionMatrix.getTotal() - d2;
        if (total > 0.0d) {
            return d / total;
        }
        return 0.0d;
    }

    private double countIgnoreLabelsAsGold(String str) {
        double d = 0.0d;
        while (this.ignoreLabels.iterator().hasNext()) {
            d += this.confusionMatrix.getEntryCount(str, r0.next());
        }
        return d;
    }

    public double computePrecisionScore() {
        return calcMetricAverage((str, str2) -> {
            if (this.ignoreLabels.contains(str2)) {
                return 0.0d;
            }
            return this.confusionMatrix.getEntryCount(str, str2);
        });
    }

    public double computeRecallScore() {
        return calcMetricAverage((str, str2) -> {
            if (this.ignoreLabels.contains(str)) {
                return 0.0d;
            }
            return this.confusionMatrix.getEntryCount(str2, str);
        });
    }

    private double calcMetricAverage(ToDoubleBiFunction<String, String> toDoubleBiFunction) {
        double d = 0.0d;
        int numOfLabels = getNumOfLabels();
        if (numOfLabels > 0) {
            Set<String> labels = this.confusionMatrix.getLabels();
            for (String str : labels) {
                double entryCount = this.ignoreLabels.contains(str) ? 0.0d : this.confusionMatrix.getEntryCount(str, str);
                double d2 = 0.0d;
                Iterator<String> it = labels.iterator();
                while (it.hasNext()) {
                    d2 += toDoubleBiFunction.applyAsDouble(str, it.next());
                }
                d += calcClassMetric(str, entryCount, d2);
            }
            d /= numOfLabels;
        }
        return d;
    }

    private double calcClassMetric(String str, double d, double d2) {
        double d3 = 0.0d;
        if (d2 > 0.0d && !this.ignoreLabels.contains(str)) {
            d3 = d / d2;
        }
        return d3;
    }

    public double computeF1Score() {
        double computePrecisionScore = computePrecisionScore();
        double computeRecallScore = computeRecallScore();
        if (computePrecisionScore > 0.0d || computeRecallScore > 0.0d) {
            return ((2.0d * computePrecisionScore) * computeRecallScore) / (computePrecisionScore + computeRecallScore);
        }
        return 0.0d;
    }

    public int getTrainingSetSize() {
        return this.trainingSetSize;
    }

    public int getTestSetSize() {
        return this.testSetSize;
    }

    public double getTrainDataRatio() {
        return this.trainingDataRatio;
    }

    public void setEvaluationSkipped(boolean z) {
        this.skippedEvaluation = z;
    }

    public boolean isEvaluationSkipped() {
        return this.skippedEvaluation;
    }

    public Optional<String> getErrorMsg() {
        return Optional.ofNullable(this.errorMsg);
    }

    public void setErrorMsg(String str) {
        this.errorMsg = str;
    }

    public void setConfusionMatrix(ConfusionMatrix confusionMatrix) {
        this.confusionMatrix = confusionMatrix;
    }

    public static EvaluationResultCollector collector() {
        return new EvaluationResultCollector();
    }

    public static EvaluationResult skipped() {
        EvaluationResult evaluationResult = new EvaluationResult();
        evaluationResult.setEvaluationSkipped(true);
        return evaluationResult;
    }

    public static EvaluationResultCollector collector(int i, int i2, double d, String... strArr) {
        return new EvaluationResultCollector(i, i2, d, strArr);
    }
}
