package org.dkpro.tc.examples.deeplearning.dl4j.sequence;

import it.unimi.dsi.fastutil.objects.Object2IntLinkedOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.dkpro.tc.examples.deeplearning.dl4j.sequence.BinaryWordVectorSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/sequence/Vectorize.class */
public class Vectorize {
    private Object2IntMap<String> tagset;

    public Vectorize() {
        this.tagset = new Object2IntLinkedOpenHashMap();
    }

    public Vectorize(String[] strArr) {
        this();
        for (int i = 0; i < strArr.length; i++) {
            this.tagset.put(strArr[i], i);
        }
    }

    public DataSet vectorize(List<List<String>> list, List<List<String>> list2, BinaryWordVectorSerializer.BinaryVectorizer binaryVectorizer, int i, int i2, boolean z) throws IOException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new EmbeddingsFeature(binaryVectorizer));
        int sum = arrayList.stream().mapToInt(feature -> {
            return feature.size();
        }).sum();
        int asInt = list.stream().mapToInt(list3 -> {
            return list3.size();
        }).max().getAsInt();
        if (asInt > i) {
            asInt = i;
        }
        INDArray create = Nd4j.create(new int[]{list.size(), sum, asInt});
        INDArray create2 = Nd4j.create(new int[]{list.size(), i2, asInt});
        INDArray zeros = Nd4j.zeros(list.size(), asInt);
        INDArray zeros2 = Nd4j.zeros(list.size(), asInt);
        for (int i3 = 0; i3 < list.size(); i3++) {
            List<String> list4 = list.get(i3);
            List<String> list5 = list2.get(i3);
            for (int i4 = 0; i4 < Math.min(list4.size(), asInt); i4++) {
                create.put(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i4)}, Nd4j.create(binaryVectorizer.vectorize(list4.get(i4))));
                zeros.putScalar(new int[]{i3, i4}, 1.0d);
                String str = list5.get(i4);
                if (!this.tagset.containsKey(str)) {
                    this.tagset.put(str, this.tagset.size());
                }
                create2.putScalar(i3, this.tagset.getInt(str), i4, 1.0d);
                zeros2.putScalar(new int[]{i3, i4}, 1.0d);
            }
        }
        return new DataSet(create, create2, zeros, zeros2);
    }

    public String[] getTagset() {
        return (String[]) this.tagset.keySet().toArray(new String[this.tagset.size()]);
    }
}
