package org.dkpro.tc.core.task.deep;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.io.FileUtils;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService;
import org.dkpro.lab.task.Discriminator;
import org.dkpro.lab.task.impl.ExecutableTaskBase;
import org.dkpro.tc.core.DeepLearningConstants;

/* loaded from: input_file:org/dkpro/tc/core/task/deep/EmbeddingTask.class */
public class EmbeddingTask extends ExecutableTaskBase {
    public static final String OUTPUT_KEY = "output";
    public static final String INPUT_MAPPING = "mappingInput";

    @Discriminator(name = DeepLearningConstants.DIM_PRETRAINED_EMBEDDINGS)
    private String embedding;

    @Discriminator(name = DeepLearningConstants.DIM_VECTORIZE_TO_INTEGER)
    private boolean integerVectorization;
    int lenVec = -1;

    public void execute(TaskContext taskContext) throws Exception {
        if (this.embedding == null) {
            taskContext.getFolder("output", StorageService.AccessMode.READWRITE);
        } else if (this.integerVectorization) {
            integerPreparation(taskContext);
        } else {
            wordPreparation(taskContext);
        }
    }

    private void wordPreparation(TaskContext taskContext) throws Exception {
        Set<String> loadVocabulary = loadVocabulary(taskContext);
        BufferedReader embeddingReader = getEmbeddingReader();
        Throwable th = null;
        try {
            BufferedWriter prunedEmbeddingWriter = getPrunedEmbeddingWriter(taskContext);
            Throwable th2 = null;
            while (true) {
                try {
                    try {
                        String readLine = embeddingReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        if (!readLine.trim().isEmpty()) {
                            int indexOf = readLine.indexOf(" ");
                            String substring = readLine.substring(0, indexOf);
                            if (loadVocabulary.contains(substring)) {
                                prunedEmbeddingWriter.write(readLine + "\n");
                                loadVocabulary.remove(substring);
                            }
                            if (this.lenVec < 0) {
                                this.lenVec = readLine.substring(indexOf + 1).split(" ").length;
                            }
                        }
                    } catch (Throwable th3) {
                        th2 = th3;
                        throw th3;
                    }
                } catch (Throwable th4) {
                    if (prunedEmbeddingWriter != null) {
                        if (th2 != null) {
                            try {
                                prunedEmbeddingWriter.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            prunedEmbeddingWriter.close();
                        }
                    }
                    throw th4;
                }
            }
            Iterator<String> it = loadVocabulary.iterator();
            while (it.hasNext()) {
                prunedEmbeddingWriter.write(it.next() + " " + randomVector(this.lenVec) + "\n");
            }
            if (prunedEmbeddingWriter != null) {
                if (0 != 0) {
                    try {
                        prunedEmbeddingWriter.close();
                    } catch (Throwable th6) {
                        th2.addSuppressed(th6);
                    }
                } else {
                    prunedEmbeddingWriter.close();
                }
            }
            if (embeddingReader != null) {
                if (0 == 0) {
                    embeddingReader.close();
                    return;
                }
                try {
                    embeddingReader.close();
                } catch (Throwable th7) {
                    th.addSuppressed(th7);
                }
            }
        } catch (Throwable th8) {
            if (embeddingReader != null) {
                if (0 != 0) {
                    try {
                        embeddingReader.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    embeddingReader.close();
                }
            }
            throw th8;
        }
    }

    private Set<String> loadVocabulary(TaskContext taskContext) throws IOException {
        List<String> readLines = FileUtils.readLines(new File(taskContext.getFolder("mappingInput", StorageService.AccessMode.READONLY), DeepLearningConstants.FILENAME_VOCABULARY), StandardCharsets.UTF_8);
        HashSet hashSet = new HashSet();
        for (String str : readLines) {
            if (!str.isEmpty()) {
                hashSet.add(str);
            }
        }
        return hashSet;
    }

    private void integerPreparation(TaskContext taskContext) throws Exception {
        Map<String, String> loadWord2IntegerMap = loadWord2IntegerMap(taskContext);
        BufferedReader embeddingReader = getEmbeddingReader();
        Throwable th = null;
        try {
            BufferedWriter prunedEmbeddingWriter = getPrunedEmbeddingWriter(taskContext);
            Throwable th2 = null;
            while (true) {
                try {
                    try {
                        String readLine = embeddingReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        if (!readLine.trim().isEmpty()) {
                            int indexOf = readLine.indexOf(" ");
                            String substring = readLine.substring(0, indexOf);
                            String substring2 = readLine.substring(indexOf + 1);
                            if (loadWord2IntegerMap.containsKey(substring)) {
                                prunedEmbeddingWriter.write(loadWord2IntegerMap.get(substring) + " " + substring2 + "\n");
                                loadWord2IntegerMap.remove(substring);
                            }
                            if (this.lenVec < 0) {
                                this.lenVec = substring2.split(" ").length;
                            }
                        }
                    } catch (Throwable th3) {
                        th2 = th3;
                        throw th3;
                    }
                } catch (Throwable th4) {
                    if (prunedEmbeddingWriter != null) {
                        if (th2 != null) {
                            try {
                                prunedEmbeddingWriter.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            prunedEmbeddingWriter.close();
                        }
                    }
                    throw th4;
                }
            }
            Iterator<Map.Entry<String, String>> it = loadWord2IntegerMap.entrySet().iterator();
            while (it.hasNext()) {
                prunedEmbeddingWriter.write(it.next().getValue() + " " + randomVector(this.lenVec) + "\n");
            }
            if (prunedEmbeddingWriter != null) {
                if (0 != 0) {
                    try {
                        prunedEmbeddingWriter.close();
                    } catch (Throwable th6) {
                        th2.addSuppressed(th6);
                    }
                } else {
                    prunedEmbeddingWriter.close();
                }
            }
            if (embeddingReader != null) {
                if (0 == 0) {
                    embeddingReader.close();
                    return;
                }
                try {
                    embeddingReader.close();
                } catch (Throwable th7) {
                    th.addSuppressed(th7);
                }
            }
        } catch (Throwable th8) {
            if (embeddingReader != null) {
                if (0 != 0) {
                    try {
                        embeddingReader.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    embeddingReader.close();
                }
            }
            throw th8;
        }
    }

    private BufferedReader getEmbeddingReader() throws Exception {
        return new BufferedReader(new InputStreamReader(new FileInputStream(new File(this.embedding)), StandardCharsets.UTF_8));
    }

    private BufferedWriter getPrunedEmbeddingWriter(TaskContext taskContext) throws Exception {
        return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(new File(taskContext.getFolder("output", StorageService.AccessMode.READWRITE), DeepLearningConstants.FILENAME_PRUNED_EMBEDDING)), StandardCharsets.UTF_8));
    }

    private Map<String, String> loadWord2IntegerMap(TaskContext taskContext) throws IOException {
        List<String> readLines = FileUtils.readLines(new File(taskContext.getFolder("mappingInput", StorageService.AccessMode.READONLY), DeepLearningConstants.FILENAME_INSTANCE_MAPPING), StandardCharsets.UTF_8);
        HashMap hashMap = new HashMap();
        for (String str : readLines) {
            if (!str.isEmpty()) {
                String[] split = str.split("\t");
                hashMap.put(split[0], split[1]);
            }
        }
        return hashMap;
    }

    public static String randomVector(int i, long j) {
        Random random = new Random(j);
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < i; i2++) {
            sb.append(String.format(Locale.US, "%.5f", Float.valueOf((random.nextFloat() - 0.5f) / i)));
            if (i2 + 1 < i) {
                sb.append(" ");
            }
        }
        return sb.toString();
    }

    public static String randomVector(int i) {
        return randomVector(i, 123456789L);
    }
}
