package edu.stanford.nlp.CLclassify;

import edu.stanford.nlp.CLling.BasicDatum;
import edu.stanford.nlp.CLling.Datum;
import edu.stanford.nlp.CLling.RVFDatum;
import edu.stanford.nlp.CLstats.ClassicCounter;
import edu.stanford.nlp.CLstats.GeneralizedCounter;
import edu.stanford.nlp.CLutil.FileLines;
import edu.stanford.nlp.CLutil.Index;
import edu.stanford.nlp.CLutil.Pair;
import edu.stanford.nlp.CLutil.ScoredComparator;
import edu.stanford.nlp.CLutil.ScoredObject;
import java.io.BufferedReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/CLclassify/Dataset.class */
public class Dataset extends GeneralDataset {
    private static int line1 = 0;

    public Dataset() {
        this(10);
    }

    public Dataset(int i, Index index, Index index2) {
        initialize(i);
        this.labelIndex = index2;
        this.featureIndex = index;
    }

    public Dataset(int i) {
        initialize(i);
    }

    public Dataset(Index index, int[] iArr, Index index2, int[][] iArr2) {
        this(index, iArr, index2, iArr2, iArr2.length);
    }

    public Dataset(Index index, int[] iArr, Index index2, int[][] iArr2, int i) {
        this.labelIndex = index;
        this.labels = iArr;
        this.featureIndex = index2;
        this.data = iArr2;
        this.size = i;
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [int[], int[][], java.lang.Object] */
    /* JADX WARN: Type inference failed for: r0v7, types: [int[], int[][], java.lang.Object] */
    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public Pair<GeneralDataset, GeneralDataset> split(double d) {
        int size = (int) (d * size());
        int size2 = size() - size;
        ?? r0 = new int[size];
        int[] iArr = new int[size];
        ?? r02 = new int[size2];
        int[] iArr2 = new int[size2];
        System.arraycopy(this.data, 0, r0, 0, size);
        System.arraycopy(this.labels, 0, iArr, 0, size);
        System.arraycopy(this.data, size, r02, 0, size2);
        System.arraycopy(this.labels, size, iArr2, 0, size2);
        if (!(this instanceof WeightedDataset)) {
            return new Pair<>(new Dataset(this.labelIndex, iArr2, this.featureIndex, r02, size2), new Dataset(this.labelIndex, iArr, this.featureIndex, r0, size));
        }
        float[] fArr = new float[size2];
        float[] fArr2 = new float[size];
        WeightedDataset weightedDataset = (WeightedDataset) this;
        System.arraycopy(weightedDataset.weights, 0, fArr2, 0, size);
        System.arraycopy(weightedDataset.weights, size, fArr, 0, size2);
        return new Pair<>(new WeightedDataset(this.labelIndex, iArr2, this.featureIndex, r02, size2, fArr), new WeightedDataset(this.labelIndex, iArr, this.featureIndex, r0, size, fArr2));
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [int[], int[][], java.lang.Object] */
    /* JADX WARN: Type inference failed for: r0v6, types: [int[], int[][], java.lang.Object] */
    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public Pair<GeneralDataset, GeneralDataset> split(int i, int i2) {
        int i3 = i2 - i;
        int size = size() - i3;
        ?? r0 = new int[i3];
        int[] iArr = new int[i3];
        ?? r02 = new int[size];
        int[] iArr2 = new int[size];
        System.arraycopy(this.data, i, r0, 0, i3);
        System.arraycopy(this.labels, i, iArr, 0, i3);
        System.arraycopy(this.data, 0, r02, 0, i);
        System.arraycopy(this.data, i2, r02, i, size() - i2);
        System.arraycopy(this.labels, 0, iArr2, 0, i);
        System.arraycopy(this.labels, i2, iArr2, i, size() - i2);
        if (!(this instanceof WeightedDataset)) {
            return new Pair<>(new Dataset(this.labelIndex, iArr2, this.featureIndex, r02, size), new Dataset(this.labelIndex, iArr, this.featureIndex, r0, i3));
        }
        float[] fArr = new float[size];
        float[] fArr2 = new float[i3];
        WeightedDataset weightedDataset = (WeightedDataset) this;
        System.arraycopy(weightedDataset.weights, i, fArr2, 0, i3);
        System.arraycopy(weightedDataset.weights, 0, fArr, 0, i);
        System.arraycopy(weightedDataset.weights, i2, fArr, i, size() - i2);
        return new Pair<>(new WeightedDataset(this.labelIndex, iArr2, this.featureIndex, r02, size, fArr), new WeightedDataset(this.labelIndex, iArr, this.featureIndex, r0, i3, fArr2));
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [int[], int[][]] */
    public Dataset getRandomSubDataset(double d, int i) {
        int size = (int) (d * size());
        HashSet hashSet = new HashSet();
        Random random = new Random();
        int size2 = size();
        while (hashSet.size() < size) {
            hashSet.add(Integer.valueOf(random.nextInt(size2)));
        }
        ?? r0 = new int[size];
        int[] iArr = new int[size];
        int i2 = 0;
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            r0[i2] = this.data[intValue];
            iArr[i2] = this.labels[intValue];
            i2++;
        }
        return new Dataset(this.labelIndex, iArr, this.featureIndex, r0);
    }

    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public double[][] getValuesArray() {
        return (double[][]) null;
    }

    public static Dataset readSVMLightFormat(String str) {
        return readSVMLightFormat(str, new Index(), new Index());
    }

    public static Dataset readSVMLightFormat(String str, List<String> list) {
        return readSVMLightFormat(str, new Index(), new Index(), list);
    }

    public static Dataset readSVMLightFormat(String str, Index index, Index index2) {
        return readSVMLightFormat(str, index, index2, null);
    }

    public static Dataset readSVMLightFormat(String str, Index index, Index index2, List<String> list) {
        BufferedReader bufferedReader = null;
        try {
            try {
                Dataset dataset = new Dataset(10, index, index2);
                Iterator<String> it = new FileLines(str).iterator();
                while (it.hasNext()) {
                    String next = it.next();
                    if (list != null) {
                        list.add(next);
                    }
                    dataset.add(svmLightLineToDatum(next));
                }
                return dataset;
            } catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException();
            }
        } finally {
            if (0 != 0) {
                try {
                    bufferedReader.close();
                } catch (Exception e2) {
                }
            }
        }
    }

    public static Datum svmLightLineToDatum(String str) {
        line1++;
        String[] split = str.split("\\s+");
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < split.length; i++) {
            String[] split2 = split[i].split(":");
            if (split2.length != 2) {
                System.err.println("Dataset error: line " + line1);
            }
            int parseDouble = (int) Double.parseDouble(split2[1]);
            for (int i2 = 0; i2 < parseDouble; i2++) {
                arrayList.add(new Integer(split2[0]));
            }
        }
        arrayList.add("###");
        return new BasicDatum(arrayList, split[0]);
    }

    public ClassicCounter getFeatureCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < this.data.length; i++) {
            Iterator it = new HashSet(((BasicDatum) getDatum(i)).asFeatures()).iterator();
            while (it.hasNext()) {
                classicCounter.incrementCount(it.next(), 1.0d);
            }
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public void add(Datum datum) {
        add(datum.asFeatures(), datum.label());
    }

    public void add(Collection collection, Object obj) {
        ensureSize();
        addLabel(obj);
        addFeatures(collection);
        this.size++;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r0v13, types: [int[], int[][], java.lang.Object] */
    public void ensureSize() {
        if (this.labels.length == this.size) {
            int[] iArr = new int[this.size * 2];
            System.arraycopy(this.labels, 0, iArr, 0, this.size);
            this.labels = iArr;
            ?? r0 = new int[this.size * 2];
            System.arraycopy(this.data, 0, r0, 0, this.size);
            this.data = r0;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addLabel(Object obj) {
        this.labelIndex.add(obj);
        this.labels[this.size] = this.labelIndex.indexOf(obj);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addFeatures(Collection collection) {
        int[] iArr = new int[collection.size()];
        int i = 0;
        for (Object obj : collection) {
            this.featureIndex.add(obj);
            if (this.featureIndex.indexOf(obj) >= 0) {
                iArr[i] = this.featureIndex.indexOf(obj);
                i++;
            }
        }
        this.data[this.size] = new int[i];
        System.arraycopy(iArr, 0, this.data[this.size], 0, i);
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [int[], int[][]] */
    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    protected void initialize(int i) {
        this.labelIndex = new Index();
        this.featureIndex = new Index();
        this.labels = new int[i];
        this.data = new int[i];
        this.size = 0;
    }

    public Datum getDatum(int i) {
        return new BasicDatum(this.featureIndex.objects(this.data[i]), this.labelIndex.get(this.labels[i]));
    }

    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public RVFDatum getRVFDatum(int i) {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator it = this.featureIndex.objects(this.data[i]).iterator();
        while (it.hasNext()) {
            classicCounter.incrementCount(it.next(), 1.0d);
        }
        return new RVFDatum(classicCounter, this.labelIndex.get(this.labels[i]));
    }

    @Override // edu.stanford.nlp.CLclassify.GeneralDataset
    public void summaryStatistics() {
        System.err.println(toSummaryStatistics());
    }

    public String toSummaryStatistics() {
        StringBuilder sb = new StringBuilder();
        sb.append("numDatums: ").append(this.size).append("\n");
        sb.append("numLabels: ").append(this.labelIndex.size()).append(" [");
        Iterator it = this.labelIndex.iterator();
        while (it.hasNext()) {
            sb.append(it.next());
            if (it.hasNext()) {
                sb.append(", ");
            }
        }
        sb.append("]\n");
        sb.append("numFeatures Phi(X) types): ").append(this.featureIndex.size()).append("\n");
        return sb.toString();
    }

    public void applyFeatureCountThreshold(List<Pair<Pattern, Integer>> list) {
        float[] featureCounts = getFeatureCounts();
        Index index = new Index();
        Iterator it = this.featureIndex.iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            Iterator<Pair<Pattern, Integer>> it2 = list.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    index.add(str);
                    break;
                }
                if (it2.next().first().matcher(str).matches()) {
                    if (featureCounts[this.featureIndex.indexOf(str)] >= r0.second.intValue()) {
                        index.add(str);
                    }
                }
            }
        }
        int[] iArr = new int[this.featureIndex.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = index.indexOf(this.featureIndex.get(i));
        }
        this.featureIndex = null;
        for (int i2 = 0; i2 < this.size; i2++) {
            ArrayList arrayList = new ArrayList(this.data[i2].length);
            for (int i3 = 0; i3 < this.data[i2].length; i3++) {
                if (iArr[this.data[i2][i3]] >= 0) {
                    arrayList.add(Integer.valueOf(iArr[this.data[i2][i3]]));
                }
            }
            this.data[i2] = new int[arrayList.size()];
            for (int i4 = 0; i4 < this.data[i2].length; i4++) {
                this.data[i2][i4] = ((Integer) arrayList.get(i4)).intValue();
            }
        }
        this.featureIndex = index;
    }

    public void printFullFeatureMatrix(PrintWriter printWriter) {
        for (int i = 0; i < this.featureIndex.size(); i++) {
            printWriter.print(LinearClassifier.TEXT_SERIALIZATION_DELIMITER + this.featureIndex.get(i));
        }
        printWriter.println();
        for (int i2 = 0; i2 < this.labels.length; i2++) {
            printWriter.print(this.labelIndex.get(i2));
            HashSet hashSet = new HashSet();
            for (int i3 = 0; i3 < this.data[i2].length; i3++) {
                hashSet.add(new Integer(this.data[i2][i3]));
            }
            for (int i4 = 0; i4 < this.featureIndex.size(); i4++) {
                if (hashSet.contains(new Integer(i4))) {
                    printWriter.print(LinearClassifier.TEXT_SERIALIZATION_DELIMITER + "1");
                } else {
                    printWriter.print(LinearClassifier.TEXT_SERIALIZATION_DELIMITER + "0");
                }
            }
        }
    }

    public void printSparseFeatureMatrix() {
        printSparseFeatureMatrix(new PrintWriter((OutputStream) System.out, true));
    }

    public void printSparseFeatureMatrix(PrintWriter printWriter) {
        for (int i = 0; i < this.size; i++) {
            printWriter.print(this.labelIndex.get(this.labels[i]));
            for (int i2 : this.data[i]) {
                printWriter.print(LinearClassifier.TEXT_SERIALIZATION_DELIMITER + this.featureIndex.get(i2));
            }
            printWriter.println();
        }
    }

    public static void main(String[] strArr) {
        Dataset dataset = new Dataset();
        dataset.add(new BasicDatum(Arrays.asList("fever", "cough", "congestion"), "cold"));
        dataset.add(new BasicDatum(Arrays.asList("fever", "cough", "nausea"), "flu"));
        dataset.add(new BasicDatum(Arrays.asList("cough", "congestion"), "cold"));
        dataset.summaryStatistics();
        dataset.applyFeatureCountThreshold(2);
        LinearClassifier linearClassifier = (LinearClassifier) new LinearClassifierFactory().trainClassifier(dataset);
        BasicDatum basicDatum = new BasicDatum(Arrays.asList("cough", "fever"));
        System.out.println(linearClassifier.classOf(basicDatum));
        System.out.println(linearClassifier.probabilityOf(basicDatum));
    }

    public void changeLabelIndex(Index index) {
        this.labels = trimToSize(this.labels);
        for (int i = 0; i < this.labels.length; i++) {
            this.labels[i] = index.indexOf(this.labelIndex.get(this.labels[i]));
        }
        this.labelIndex = index;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    public void changeFeatureIndex(Index index) {
        this.data = trimToSize(this.data);
        this.labels = trimToSize(this.labels);
        ?? r0 = new int[this.data.length];
        for (int i = 0; i < this.data.length; i++) {
            int[] iArr = new int[this.data[i].length];
            int i2 = 0;
            for (int i3 = 0; i3 < this.data[i].length; i3++) {
                int indexOf = index.indexOf(this.featureIndex.get(this.data[i][i3]));
                if (indexOf >= 0) {
                    int i4 = i2;
                    i2++;
                    iArr[i4] = indexOf;
                }
            }
            r0[i] = new int[i2];
            System.arraycopy(iArr, 0, r0[i], 0, i2);
        }
        this.data = r0;
        this.featureIndex = index;
    }

    public void selectFeaturesBinaryInformationGain(int i) {
        double[] informationGains = getInformationGains();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < informationGains.length; i2++) {
            arrayList.add(new ScoredObject(this.featureIndex.get(i2), informationGains[i2]));
        }
        Collections.sort(arrayList, ScoredComparator.DESCENDING_COMPARATOR);
        Index index = new Index();
        for (int i3 = 0; i3 < arrayList.size() && i3 < i; i3++) {
            index.add(((ScoredObject) arrayList.get(i3)).object());
        }
        for (int i4 = 0; i4 < this.size; i4++) {
            int[] iArr = new int[this.data[i4].length];
            int i5 = 0;
            for (int i6 = 0; i6 < this.data[i4].length; i6++) {
                int indexOf = index.indexOf(this.featureIndex.get(this.data[i4][i6]));
                if (indexOf != -1) {
                    int i7 = i5;
                    i5++;
                    iArr[i7] = indexOf;
                }
            }
            int[] iArr2 = new int[i5];
            System.arraycopy(iArr, 0, iArr2, 0, i5);
            this.data[i4] = iArr2;
        }
        this.featureIndex = index;
    }

    public double[] getInformationGains() {
        this.data = trimToSize(this.data);
        this.labels = trimToSize(this.labels);
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter classicCounter2 = new ClassicCounter();
        GeneralizedCounter generalizedCounter = new GeneralizedCounter(2);
        for (int i = 0; i < this.labels.length; i++) {
            classicCounter2.incrementCount(this.labelIndex.get(this.labels[i]));
            boolean[] zArr = new boolean[this.featureIndex.size()];
            for (int i2 = 0; i2 < this.data[i].length; i2++) {
                zArr[this.data[i][i2]] = true;
            }
            for (int i3 = 0; i3 < zArr.length; i3++) {
                if (zArr[i3]) {
                    classicCounter.incrementCount(this.featureIndex.get(i3));
                    generalizedCounter.incrementCount2D(this.featureIndex.get(i3), this.labelIndex.get(this.labels[i]), 1.0d);
                }
            }
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < this.labelIndex.size(); i4++) {
            double count = classicCounter2.getCount(this.labelIndex.get(i4)) / size();
            d -= count * (Math.log(count) / Math.log(2.0d));
        }
        double[] dArr = new double[this.featureIndex.size()];
        Arrays.fill(dArr, d);
        for (int i5 = 0; i5 < this.featureIndex.size(); i5++) {
            Object obj = this.featureIndex.get(i5);
            double count2 = classicCounter.getCount(obj);
            double size = size() - count2;
            double size2 = count2 / size();
            double d2 = 1.0d - size2;
            if (count2 == 0.0d) {
                dArr[i5] = 0.0d;
            } else if (size == 0.0d) {
                dArr[i5] = 0.0d;
            } else {
                double d3 = 0.0d;
                double d4 = 0.0d;
                for (int i6 = 0; i6 < this.labelIndex.size(); i6++) {
                    double count3 = generalizedCounter.getCount(obj, this.labelIndex.get(i6));
                    double size3 = size() - count3;
                    double d5 = count3 / count2;
                    double d6 = size3 / size;
                    if (count3 != 0.0d) {
                        d3 += d5 * (Math.log(d5) / Math.log(2.0d));
                    }
                    if (size3 != 0.0d) {
                        d4 += d6 * (Math.log(d6) / Math.log(2.0d));
                    }
                }
                dArr[i5] = (size2 * d3) + (d2 * d4);
            }
        }
        return dArr;
    }

    public String toString() {
        return "Dataset of size " + this.size;
    }

    public String toSummaryString() {
        PrintWriter printWriter = new PrintWriter(new StringWriter());
        printWriter.println("Number of data points: " + size());
        printWriter.println("Number of active feature tokens: " + numFeatureTokens());
        printWriter.println("Number of active feature types:" + numFeatureTypes());
        return printWriter.toString();
    }

    public static void printSVMLightFormat(PrintWriter printWriter, ClassicCounter<Integer> classicCounter, int i) {
        Integer[] numArr = (Integer[]) classicCounter.keySet().toArray(new Integer[0]);
        Arrays.sort(numArr);
        StringBuilder sb = new StringBuilder();
        for (Integer num : numArr) {
            int intValue = num.intValue();
            sb.append((intValue + 1) + ":" + classicCounter.getCount(Integer.valueOf(intValue)) + " ");
        }
        printWriter.println(i + " " + sb.toString());
    }
}
