package edu.stanford.nlp.dcoref;

import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.SemanticHeadFinder;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.tregex.TregexMatcher;
import edu.stanford.nlp.trees.tregex.TregexPattern;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/dcoref/MentionExtractor.class */
public class MentionExtractor {
    protected HeadFinder headFinder = new SemanticHeadFinder();
    protected String currentDocumentID;
    protected Dictionaries dictionaries;
    public static final boolean VERBOSE = false;

    public MentionExtractor(Dictionaries dictionaries) {
        this.dictionaries = dictionaries;
    }

    public List<List<Mention>> nextDoc() {
        return null;
    }

    public List<List<Mention>> arrange(List<List<CoreLabel>> list, List<Tree> list2, List<List<Mention>> list3) {
        return arrange(list, list2, list3, null, false);
    }

    private int getHeadIndex(Tree tree) {
        return ((Integer) ((CoreLabel) tree.headTerminal(this.headFinder).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue();
    }

    private String treeToKey(Tree tree) {
        return String.valueOf(Integer.toString(getHeadIndex(tree))) + ":" + tree.toString();
    }

    public List<List<Mention>> arrange(List<List<CoreLabel>> list, List<Tree> list2, List<List<Mention>> list3, LexicalizedParser lexicalizedParser, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<CoreLabel> list4 = list.get(i);
            Tree tree = list2.get(i);
            List<Mention> list5 = list3.get(i);
            if (z) {
                mergeLabels(tree, list4);
            }
            if (lexicalizedParser == null) {
                assignPos(tree, list4);
            }
            for (Mention mention : list5) {
                mention.contextParseTree = tree;
                mention.sentenceWords = list4;
                mention.originalSpan = new ArrayList(mention.sentenceWords.subList(mention.originalStartIndex, mention.originalEndIndex));
                if (mention.originalHeadStartIndex != -1) {
                    mention.headIndex = mention.originalHeadStartIndex;
                    mention.headWord = mention.sentenceWords.get(mention.headIndex);
                } else if (mention.originalSpan.size() == 1) {
                    mention.headWord = mention.originalSpan.get(0);
                    mention.headIndex = mention.originalStartIndex;
                } else {
                    Tree findExactMatch = findExactMatch(tree, ((Integer) mention.originalSpan.get(0).get(CoreAnnotations.IndexAnnotation.class)).intValue(), ((Integer) mention.originalSpan.get(mention.originalSpan.size() - 1).get(CoreAnnotations.IndexAnnotation.class)).intValue());
                    if (findExactMatch == null) {
                        findExactMatch = lexicalizedParser.apply((Object) mention.originalSpan);
                    }
                    if (findExactMatch != null) {
                        mention.headWord = mention.originalSpan.get(Math.min(mention.originalSpan.size() - 1, findExactMatch.getLeaves().indexOf(findExactMatch.headTerminal(this.headFinder))));
                        mention.headIndex = -1;
                        int i2 = 0;
                        while (true) {
                            if (i2 >= mention.sentenceWords.size()) {
                                break;
                            }
                            if (mention.sentenceWords.get(i2) == mention.headWord) {
                                mention.headIndex = i2;
                                break;
                            }
                            i2++;
                        }
                    } else {
                        mention.headIndex = findRightMost(mention.originalSpan);
                        mention.headWord = mention.originalSpan.get(mention.headIndex);
                    }
                    if (mention.headIndex < 0) {
                        throw new RuntimeException("Could not find mention head for mention: " + mention.originalSpan);
                    }
                }
            }
            if (lexicalizedParser != null) {
                assignPos(tree, list4);
            }
            HashMap hashMap = new HashMap();
            for (Mention mention2 : list5) {
                Tree tree2 = tree.getLeaves().get(mention2.headIndex);
                if (tree2 == null) {
                    throw new RuntimeException("Missing head tree for a mention!");
                }
                Tree tree3 = tree2;
                while (true) {
                    Tree parent = tree3.parent(tree);
                    tree3 = parent;
                    if (parent == null) {
                        break;
                    }
                    if (tree3.headTerminal(this.headFinder) != tree2 || !tree3.value().equals("NP")) {
                        if (mention2.mentionSubTree != null) {
                            break;
                        }
                    } else {
                        mention2.mentionSubTree = tree3;
                    }
                }
                if (mention2.mentionSubTree == null) {
                    mention2.mentionSubTree = tree2;
                }
                List list6 = (List) hashMap.get(treeToKey(mention2.mentionSubTree));
                if (list6 == null) {
                    list6 = new ArrayList();
                    hashMap.put(treeToKey(mention2.mentionSubTree), list6);
                }
                list6.add(mention2);
                mention2.process(this.dictionaries);
            }
            ArrayList arrayList2 = new ArrayList();
            arrayList.add(arrayList2);
            Iterator<Tree> it = tree.preOrderNodeList().iterator();
            while (it.hasNext()) {
                List list7 = (List) hashMap.get(treeToKey(it.next()));
                if (list7 != null) {
                    Iterator it2 = list7.iterator();
                    while (it2.hasNext()) {
                        arrayList2.add((Mention) it2.next());
                    }
                }
            }
            HashSet hashSet = new HashSet();
            findAppositions(tree, hashSet);
            markMentionRelation(arrayList2, hashSet, "APPOSITION");
            HashSet hashSet2 = new HashSet();
            findPredicateNominatives(tree, hashSet2);
            markMentionRelation(arrayList2, hashSet2, "PREDICATE_NOMINATIVE");
            HashSet hashSet3 = new HashSet();
            findRelativePronouns(tree, hashSet3);
            markMentionRelation(arrayList2, hashSet3, "RELATIVE_PRONOUN");
        }
        return arrayList;
    }

    static void printSpan(PrintStream printStream, List<CoreLabel> list) {
        printStream.printf("(%d, %d)", list.get(0).get(CoreAnnotations.IndexAnnotation.class), list.get(list.size() - 1).get(CoreAnnotations.IndexAnnotation.class));
        printStream.print("[");
        for (int i = 0; i < list.size(); i++) {
            if (i > 0) {
                printStream.print(" ");
            }
            printStream.print((String) list.get(i).get(CoreAnnotations.TextAnnotation.class));
        }
        printStream.println("]");
    }

    public static void mergeLabels(Tree tree, List<CoreLabel> list) {
        int i = 0;
        for (Tree tree2 : tree.getLeaves()) {
            int i2 = i;
            i++;
            CoreLabel coreLabel = list.get(i2);
            coreLabel.set(CoreAnnotations.ValueAnnotation.class, tree2.value());
            tree2.setLabel(coreLabel);
        }
        tree.indexLeaves();
    }

    static void assignPos(Tree tree, List<CoreLabel> list) {
        int i = 0;
        Iterator<Tree> it = tree.getLeaves().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            list.get(i2).set(CoreAnnotations.PartOfSpeechAnnotation.class, it.next().parent(tree).value());
        }
    }

    static boolean inside(int i, Mention mention) {
        return i >= mention.originalStartIndex && i < mention.originalEndIndex;
    }

    private void findTreePattern(Tree tree, String str, Set<Pair<Integer, Integer>> set) {
        try {
            TregexMatcher matcher = TregexPattern.compile(str).matcher(tree);
            while (matcher.find()) {
                addFoundPair(matcher.getNode("m1"), matcher.getNode("m2"), matcher.getMatch(), set);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
        }
    }

    private void addFoundPair(Tree tree, Tree tree2, Tree tree3, Set<Pair<Integer, Integer>> set) {
        Tree headTerminal = tree.headTerminal(this.headFinder);
        set.add(new Pair<>(Integer.valueOf(((Integer) ((CoreMap) tree2.headTerminal(this.headFinder).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue() - 1), Integer.valueOf(((Integer) ((CoreMap) headTerminal.label()).get(CoreAnnotations.IndexAnnotation.class)).intValue() - 1)));
    }

    private void findAppositions(Tree tree, Set<Pair<Integer, Integer>> set) {
        findTreePattern(tree, "/^NP(?:-TMP|-ADV)?$/=m1 < (NP=m2 $- /^,$/ $-- NP=m3 !$ CC|CONJP)", set);
        findTreePattern(tree, "/^NP(?:-TMP|-ADV)?$/=m1 < (PRN=m2 < (NP < /^NNS?|CD$/ $-- /^-LRB-$/ $+ /^-RRB-$/))", set);
    }

    private void findPredicateNominatives(Tree tree, Set<Pair<Integer, Integer>> set) {
        findTreePattern(tree, "S < (NP=m1 $.. (VP < ((/VB/ < /^(am|are|is|was|were|'m|'re|'s|be)$/) $.. NP=m2)))", set);
        findTreePattern(tree, "S < (NP=m1 $.. (VP < (VP < ((/VB/ < /^(be|been|being)$/) $.. NP=m2))))", set);
    }

    private void findRelativePronouns(Tree tree, Set<Pair<Integer, Integer>> set) {
        findTreePattern(tree, "NP < (NP=m1 $.. (SBAR < (WHNP < WP|WDT=m2)))", set);
    }

    private void markMentionRelation(List<Mention> list, Set<Pair<Integer, Integer>> set, String str) {
        for (int i = 0; i < list.size(); i++) {
            Mention mention = list.get(i);
            for (int i2 = i + 1; i2 < list.size(); i2++) {
                Mention mention2 = list.get(i2);
                for (Pair<Integer, Integer> pair : set) {
                    if ((pair.first.intValue() == mention.headIndex && pair.second.intValue() == mention2.headIndex) || (pair.first.intValue() == mention2.headIndex && pair.second.intValue() == mention.headIndex)) {
                        if (str.equals("APPOSITION")) {
                            mention2.addApposition(mention);
                        } else if (str.equals("PREDICATE_NOMINATIVE")) {
                            mention2.addPredicateNominatives(mention);
                        } else if (str.equals("RELATIVE_PRONOUN")) {
                            mention2.addRelativePronoun(mention);
                        } else {
                            System.err.println("check flag in markMentionRelation (dcoref/MentionExtractor.java)");
                        }
                    }
                }
            }
        }
    }

    private static int findRightMost(List<CoreLabel> list) {
        int size = list.size();
        for (int i = 0; i < list.size(); i++) {
            String tag = list.get(i).tag();
            if (tag.equals(",") || tag.equals("IN")) {
                size = i;
                break;
            }
        }
        for (int i2 = size - 1; i2 >= 0; i2--) {
            String tag2 = list.get(i2).tag();
            if (tag2.startsWith("NN") || tag2.equals("CD") || tag2.equals("%")) {
                return i2;
            }
        }
        return list.size() - 1;
    }

    public static Tree findExactMatch(Tree tree, int i, int i2) {
        List<Tree> leaves = tree.getLeaves();
        int intValue = ((Integer) ((CoreMap) leaves.get(0).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue();
        int intValue2 = ((Integer) ((CoreMap) leaves.get(leaves.size() - 1).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue();
        if (intValue == i && intValue2 == i2) {
            return tree;
        }
        for (Tree tree2 : tree.children()) {
            Tree findExactMatch = findExactMatch(tree2, i, i2);
            if (findExactMatch != null) {
                return findExactMatch;
            }
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Tree findBestMatch(Tree tree, int i, int i2) {
        Pair pair = new Pair(null, Double.valueOf(Double.MIN_VALUE));
        findBestMatch(tree, i, i2, pair);
        return (Tree) pair.first;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v18, types: [java.lang.Double, T2] */
    private static void findBestMatch(Tree tree, int i, int i2, Pair<Tree, Double> pair) {
        List<Tree> leaves = tree.getLeaves();
        int intValue = ((Integer) ((CoreMap) leaves.get(0).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue();
        int intValue2 = ((Integer) ((CoreMap) leaves.get(leaves.size() - 1).label()).get(CoreAnnotations.IndexAnnotation.class)).intValue();
        if (intValue == i && intValue2 == i2 && 1.0d > pair.second.doubleValue()) {
            pair.first = tree;
            pair.second = Double.valueOf(1.0d);
        }
        for (Tree tree2 : tree.children()) {
            findBestMatch(tree2, i, i2, pair);
        }
    }
}
