package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLutil.Function;
import edu.stanford.nlp.CLutil.Pair;
import edu.stanford.nlp.CLutil.Triple;
import java.util.Iterator;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/CrossValidator.class */
public class CrossValidator {
    private GeneralDataset originalTrainData;
    private int kfold;
    private int foldSize;
    private SavedState[] savedStates;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/CLclassify/CrossValidator$CrossValidationIterator.class */
    public class CrossValidationIterator implements Iterator<Triple<GeneralDataset, GeneralDataset, SavedState>> {
        int iter = 0;

        CrossValidationIterator() {
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.iter < CrossValidator.this.kfold;
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new RuntimeException("CrossValidationIterator doesn't support remove()");
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public Triple<GeneralDataset, GeneralDataset, SavedState> next() {
            if (this.iter == CrossValidator.this.kfold) {
                return null;
            }
            Pair<GeneralDataset, GeneralDataset> split = CrossValidator.this.originalTrainData.split((CrossValidator.this.originalTrainData.size() * this.iter) / CrossValidator.this.kfold, (CrossValidator.this.originalTrainData.size() * (this.iter + 1)) / CrossValidator.this.kfold);
            GeneralDataset first = split.first();
            GeneralDataset second = split.second();
            SavedState[] savedStateArr = CrossValidator.this.savedStates;
            int i = this.iter;
            this.iter = i + 1;
            return new Triple<>(first, second, savedStateArr[i]);
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/CLclassify/CrossValidator$SavedState.class */
    public static class SavedState {
        public Object state;
    }

    public CrossValidator(GeneralDataset generalDataset) {
        this(generalDataset, 5);
    }

    public CrossValidator(GeneralDataset generalDataset, int i) {
        this.originalTrainData = generalDataset;
        this.kfold = i;
        this.foldSize = (int) (this.originalTrainData.size() / i);
        this.savedStates = new SavedState[i];
        for (int i2 = 0; i2 < this.savedStates.length; i2++) {
            this.savedStates[i2] = new SavedState();
        }
    }

    private Iterator<Triple<GeneralDataset, GeneralDataset, SavedState>> iterator() {
        return new CrossValidationIterator();
    }

    public double computeAverage(Function<Triple<GeneralDataset, GeneralDataset, SavedState>, Double> function) {
        double d = 0.0d;
        Iterator<Triple<GeneralDataset, GeneralDataset, SavedState>> it = iterator();
        while (it.hasNext()) {
            d += function.apply(it.next()).doubleValue();
        }
        return d / this.kfold;
    }

    public static void main(String[] strArr) {
        Iterator<Triple<GeneralDataset, GeneralDataset, SavedState>> it = new CrossValidator(Dataset.readSVMLightFormat(strArr[0])).iterator();
        if (it.hasNext()) {
            it.next();
        }
    }
}
