package org.dkpro.tc.examples.deeplearning.dl4j.sequence;

import java.io.BufferedOutputStream;
import java.io.DataInput;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Locale;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/sequence/BinaryWordVectorSerializer.class */
public class BinaryWordVectorSerializer {
    public static final String UNK = "-=*>UNKNOWN TOKEN<*=-";

    /* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/sequence/BinaryWordVectorSerializer$BinaryVectorizer.class */
    public static class BinaryVectorizer {
        private final Header header;
        public final String[] words;
        private final FloatBuffer[] parts;
        private final int maxVectorsPerPartition;
        private Locale locale;
        private float[] unk;

        BinaryVectorizer(Header header, RandomAccessFile randomAccessFile, String[] strArr, long j, float[] fArr) throws IOException {
            this.header = header;
            this.words = strArr;
            this.unk = fArr;
            this.locale = Locale.forLanguageTag(this.header.locale);
            this.maxVectorsPerPartition = Integer.MAX_VALUE / (this.header.vectorLength * 4);
            int i = this.maxVectorsPerPartition * this.header.vectorLength * 4;
            int length = strArr.length / this.maxVectorsPerPartition;
            length = strArr.length % i > 0 ? length + 1 : length;
            this.parts = new FloatBuffer[length];
            FileChannel channel = randomAccessFile.getChannel();
            for (int i2 = 0; i2 < length; i2++) {
                long j2 = j + (i2 * i);
                long j3 = i;
                if (i2 == length - 1) {
                    j3 = (strArr.length % this.maxVectorsPerPartition) * this.header.vectorLength * 4;
                }
                this.parts[i2] = channel.map(FileChannel.MapMode.READ_ONLY, j2, j3).asFloatBuffer();
            }
        }

        public int getVectorSize() {
            return this.header.vectorLength;
        }

        public boolean contains(String str) {
            String str2 = str;
            if (this.header.caseless) {
                str2 = str2.toLowerCase(this.locale);
            }
            return Arrays.binarySearch(this.words, str2) >= 0;
        }

        public float[] vectorize(String str) throws IOException {
            String str2 = str;
            if (this.header.caseless) {
                str2 = str2.toLowerCase(this.locale);
            }
            int binarySearch = Arrays.binarySearch(this.words, str2);
            if (binarySearch < 0) {
                return this.unk;
            }
            FloatBuffer floatBuffer = this.parts[binarySearch / this.maxVectorsPerPartition];
            floatBuffer.position((binarySearch % this.maxVectorsPerPartition) * this.header.vectorLength);
            float[] fArr = new float[this.header.vectorLength];
            floatBuffer.get(fArr);
            return fArr;
        }

        public static BinaryVectorizer load(Path path) throws IOException {
            RandomAccessFile randomAccessFile = new RandomAccessFile(path.toFile(), "rw");
            Header read = Header.read(randomAccessFile);
            String[] strArr = new String[read.wordCount];
            for (int i = 0; i < read.wordCount; i++) {
                strArr[i] = randomAccessFile.readUTF();
            }
            System.out.println("Loaded " + strArr.length);
            byte[] bArr = new byte[read.vectorLength * 4];
            randomAccessFile.readFully(bArr);
            ByteBuffer wrap = ByteBuffer.wrap(bArr);
            float[] fArr = new float[read.vectorLength];
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = wrap.getFloat(i2 * 4);
            }
            return new BinaryVectorizer(read, randomAccessFile, strArr, randomAccessFile.getFilePointer(), fArr);
        }
    }

    /* loaded from: input_file:org/dkpro/tc/examples/deeplearning/dl4j/sequence/BinaryWordVectorSerializer$Header.class */
    public static class Header {
        private static final String MAGIC = "dl4jw2v";
        private int version = 1;
        private int wordCount;
        private int vectorLength;
        private boolean caseless;
        private String locale;

        public static Header read(DataInput dataInput) throws IOException {
            byte[] bArr = new byte[MAGIC.length()];
            dataInput.readFully(bArr);
            if (!MAGIC.equals(new String(bArr, StandardCharsets.US_ASCII))) {
                throw new IOException("The file you provided is either not a DL4J binary word vectors file or corrupted.");
            }
            Header header = new Header();
            header.version = dataInput.readByte();
            if (1 != header.version) {
                throw new IOException("Not supported file format version.");
            }
            header.wordCount = dataInput.readInt();
            header.vectorLength = dataInput.readInt();
            header.caseless = dataInput.readBoolean();
            header.locale = dataInput.readUTF();
            return header;
        }

        public void write(OutputStream outputStream) throws IOException {
            DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
            dataOutputStream.write(MAGIC.getBytes(StandardCharsets.US_ASCII));
            dataOutputStream.writeByte(this.version);
            dataOutputStream.writeInt(this.wordCount);
            dataOutputStream.writeInt(this.vectorLength);
            dataOutputStream.writeBoolean(this.caseless);
            dataOutputStream.writeUTF(this.locale);
            dataOutputStream.flush();
        }
    }

    public static void convertWordVectorsToBinary(WordVectors wordVectors, Path path) throws IOException {
        convertWordVectorsToBinary(wordVectors, false, Locale.US, path);
    }

    public static void verify(WordVectors wordVectors, Path path) throws IOException {
        BinaryVectorizer load = BinaryVectorizer.load(path);
        if (load.contains(UNK)) {
            System.out.printf("Unknown word is contained in vocabulary!%n", new Object[0]);
        }
        float[] asFloat = makeUnk(wordVectors.lookupTable().layerSize()).data().asFloat();
        float[] asFloat2 = makeUnk(wordVectors.lookupTable().layerSize()).data().asFloat();
        float[] vectorize = load.vectorize(UNK);
        if (!Arrays.equals(asFloat, asFloat2)) {
            System.out.printf("Unstable generated unknown word%n", new Object[0]);
        }
        if (!Arrays.equals(asFloat, vectorize)) {
            System.out.printf("Vectors differ for unknown word%n", new Object[0]);
        }
        for (String str : wordVectors.vocab().words()) {
            if (!Arrays.equals(ArrayUtil.toFloats(wordVectors.getWordVector(str)), load.vectorize(str))) {
                System.out.printf("Vectors differ for word [%s]%n", str);
            }
        }
    }

    private static INDArray makeUnk(int i) {
        return Nd4j.rand(1, i, 12345L).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(i));
    }

    public static void convertWordVectorsToBinary(WordVectors wordVectors, boolean z, Locale locale, Path path) throws IOException {
        Header header = new Header();
        header.version = 1;
        header.wordCount = wordVectors.vocab().words().size();
        header.vectorLength = wordVectors.lookupTable().layerSize();
        header.caseless = z;
        header.locale = locale.toString();
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(path.toFile())));
        Throwable th = null;
        try {
            header.write(dataOutputStream);
            System.out.println("Sorting data...");
            String[] strArr = (String[]) wordVectors.vocab().words().toArray(new String[wordVectors.vocab().words().size()]);
            Arrays.sort(strArr);
            System.out.println("Writing strings...");
            for (String str : strArr) {
                dataOutputStream.writeUTF(str);
            }
            System.out.println("Writing UNK vector...");
            float[] asFloat = makeUnk(header.vectorLength).data().asFloat();
            ByteBuffer allocate = ByteBuffer.allocate(asFloat.length * 4);
            allocate.asFloatBuffer().put(asFloat);
            dataOutputStream.write(allocate.array());
            System.out.println("Writing vectors...");
            for (String str2 : strArr) {
                float[] floats = ArrayUtil.toFloats(wordVectors.getWordVector(str2));
                ByteBuffer allocate2 = ByteBuffer.allocate(floats.length * 4);
                allocate2.asFloatBuffer().put(floats);
                dataOutputStream.write(allocate2.array());
            }
            if (dataOutputStream != null) {
                if (0 == 0) {
                    dataOutputStream.close();
                    return;
                }
                try {
                    dataOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (dataOutputStream != null) {
                if (0 != 0) {
                    try {
                        dataOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    dataOutputStream.close();
                }
            }
            throw th3;
        }
    }
}
