package org.dkpro.tc.examples.deeplearning.dl4j.document;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/document/NewsIterator.class */
public class NewsIterator implements DataSetIterator {
    private final WordVectors wordVectors;
    private final int batchSize;
    private final int vectorSize;
    private int maxLength;
    private final String dataDirectory;
    private final List<Pair<String, List<String>>> categoryData;
    private int cursor;
    private int totalNews;
    private int newsPosition;
    private final List<String> labels;
    private int currCategory;

    /* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/document/NewsIterator$Builder.class */
    public static class Builder {
        private String dataDirectory;
        private WordVectors wordVectors;
        private int batchSize;

        public Builder dataDirectory(String str) {
            this.dataDirectory = str;
            return this;
        }

        public Builder wordVectors(WordVectors wordVectors) {
            this.wordVectors = wordVectors;
            return this;
        }

        public Builder batchSize(int i) {
            this.batchSize = i;
            return this;
        }

        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            return this;
        }

        public NewsIterator build() throws Exception {
            return new NewsIterator(this.dataDirectory, this.wordVectors, this.batchSize);
        }

        public String toString() {
            return "org.deeplearning4j.examples.recurrent.ProcessNews.NewsIterator.Builder(dataDirectory=" + this.dataDirectory + ", wordVectors=" + this.wordVectors + ", batchSize=" + this.batchSize + ")";
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NewsIterator(String str, WordVectors wordVectors, int i) throws Exception {
        this.categoryData = new ArrayList();
        this.cursor = 0;
        this.totalNews = 0;
        this.newsPosition = 0;
        this.currCategory = 0;
        this.dataDirectory = str;
        this.batchSize = i;
        this.vectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
        this.wordVectors = wordVectors;
        populateData();
        this.labels = new ArrayList();
        for (int i2 = 0; i2 < this.categoryData.size(); i2++) {
            this.labels.add(this.categoryData.get(i2).getKey());
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public DataSet next(int i) {
        if (this.cursor >= this.totalNews) {
            throw new NoSuchElementException();
        }
        try {
            return nextDataSet(i);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private DataSet nextDataSet(int i) throws IOException {
        ArrayList arrayList = new ArrayList(i);
        int[] iArr = new int[i];
        int i2 = 0;
        while (i2 < i && this.cursor < totalExamples()) {
            if (this.currCategory < this.categoryData.size()) {
                arrayList.add(((List) this.categoryData.get(this.currCategory).getValue()).get(this.newsPosition));
                iArr[i2] = Integer.parseInt((String) this.categoryData.get(this.currCategory).getKey());
                this.currCategory++;
                this.cursor++;
            } else {
                this.currCategory = 0;
                this.newsPosition++;
                i2--;
            }
            i2++;
        }
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        this.maxLength = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            List asList = Arrays.asList(((String) it.next()).replaceAll(" 0", "").split(" "));
            this.maxLength = this.maxLength < asList.size() ? asList.size() : this.maxLength;
            arrayList2.add(asList);
        }
        INDArray create = Nd4j.create(new int[]{arrayList.size(), this.vectorSize, this.maxLength});
        INDArray create2 = Nd4j.create(new int[]{arrayList.size(), this.categoryData.size(), this.maxLength});
        INDArray zeros = Nd4j.zeros(arrayList.size(), this.maxLength);
        INDArray zeros2 = Nd4j.zeros(arrayList.size(), this.maxLength);
        int[] iArr2 = new int[2];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            List list = (List) arrayList2.get(i3);
            iArr2[0] = i3;
            for (int i4 = 0; i4 < list.size() && i4 < this.maxLength; i4++) {
                create.put(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i4)}, this.wordVectors.getWordVectorMatrix((String) list.get(i4)));
                iArr2[1] = i4;
                zeros.putScalar(iArr2, 1.0d);
            }
            int i5 = iArr[i3] - 1;
            int min = Math.min(list.size(), this.maxLength);
            create2.putScalar(new int[]{i3, i5, min - 1}, 1.0d);
            zeros2.putScalar(new int[]{i3, min - 1}, 1.0d);
        }
        return new DataSet(create, create2, zeros, zeros2);
    }

    public INDArray loadFeaturesFromFile(File file, int i) throws IOException {
        return loadFeaturesFromString(FileUtils.readFileToString(file), i);
    }

    public INDArray loadFeaturesFromString(String str, int i) {
        ArrayList<String> arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (String str2 : arrayList) {
            if (this.wordVectors.hasWord(str2)) {
                arrayList2.add(str2);
            }
        }
        INDArray create = Nd4j.create(new int[]{1, this.vectorSize, Math.max(i, arrayList2.size())});
        for (int i2 = 0; i2 < arrayList.size() && i2 < i; i2++) {
            create.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i2)}, this.wordVectors.getWordVectorMatrix((String) arrayList.get(i2)));
        }
        return create;
    }

    private void populateData() throws Exception {
        List readLines = FileUtils.readLines(new File(this.dataDirectory, "instanceVectors.txt"), "utf-8");
        String replaceAll = ((String) FileUtils.readLines(new File(this.dataDirectory, "outcomeVectors.txt"), "utf-8").get(0)).replaceAll(" ", "");
        HashMap hashMap = new HashMap();
        for (int i = 0; i < readLines.size(); i++) {
            String str = (String) readLines.get(i);
            String substring = str.substring(1, str.length() - 1);
            String str2 = replaceAll.charAt(i) + "";
            List list = (List) hashMap.get(str2);
            if (list == null) {
                list = new ArrayList();
            }
            list.add(substring);
            hashMap.put(str2, list);
            this.totalNews++;
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            this.categoryData.add(Pair.of(entry.getKey(), entry.getValue()));
        }
    }

    public int totalExamples() {
        return this.totalNews;
    }

    public int inputColumns() {
        return this.vectorSize;
    }

    public int totalOutcomes() {
        return this.categoryData.size();
    }

    public void reset() {
        this.cursor = 0;
        this.newsPosition = 0;
        this.currCategory = 0;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public int batch() {
        return this.batchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        return totalExamples();
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        throw new UnsupportedOperationException();
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public boolean hasNext() {
        return this.cursor < numExamples();
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m0next() {
        return next(this.batchSize);
    }

    public void remove() {
    }

    public DataSetPreProcessor getPreProcessor() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public int getMaxLength() {
        return this.maxLength;
    }
}
