package edu.stanford.nlp.sequences;

import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.StringUtils;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/sequences/SequenceGibbsSampler.class */
public class SequenceGibbsSampler implements BestSequenceFinder {
    private static Random random = new Random();
    public static int verbose = 0;
    private List document;
    private int numSamples;
    private int sampleInterval;
    private SequenceListener listener;
    public boolean returnLastFoundSequence;

    public static int[] copy(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
        return iArr2;
    }

    public static int[] getRandomSequence(SequenceModel sequenceModel) {
        int[] iArr = new int[sequenceModel.length()];
        for (int i = 0; i < iArr.length; i++) {
            int[] possibleValues = sequenceModel.getPossibleValues(i);
            iArr[i] = possibleValues[random.nextInt(possibleValues.length)];
        }
        return iArr;
    }

    @Override // edu.stanford.nlp.sequences.BestSequenceFinder
    public int[] bestSequence(SequenceModel sequenceModel) {
        return findBestUsingSampling(sequenceModel, this.numSamples, this.sampleInterval, getRandomSequence(sequenceModel));
    }

    public int[] findBestUsingSampling(SequenceModel sequenceModel, int i, int i2, int[] iArr) {
        List<int[]> collectSamples = collectSamples(sequenceModel, i, i2, iArr);
        int[] iArr2 = (int[]) null;
        double d = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < collectSamples.size(); i3++) {
            int[] iArr3 = collectSamples.get(i3);
            double scoreOf = sequenceModel.scoreOf(iArr3);
            if (scoreOf > d) {
                iArr2 = iArr3;
                d = scoreOf;
                System.err.println("found new best (" + d + ")");
                System.err.println(ArrayMath.toString(iArr2));
            }
        }
        return iArr2;
    }

    public int[] findBestUsingAnnealing(SequenceModel sequenceModel, CoolingSchedule coolingSchedule) {
        return findBestUsingAnnealing(sequenceModel, coolingSchedule, getRandomSequence(sequenceModel));
    }

    public int[] findBestUsingAnnealing(SequenceModel sequenceModel, CoolingSchedule coolingSchedule, int[] iArr) {
        if (verbose > 0) {
            System.err.println("Doing annealing");
        }
        this.listener.setInitialSequence(iArr);
        ArrayList arrayList = new ArrayList();
        int[] iArr2 = iArr;
        int[] iArr3 = (int[]) null;
        double d = Double.NEGATIVE_INFINITY;
        if (!this.returnLastFoundSequence) {
            sequenceModel.scoreOf(iArr2);
        }
        for (int i = 0; i < coolingSchedule.numIterations(); i++) {
            iArr2 = copy(iArr2);
            sampleSequenceForward(sequenceModel, iArr2, coolingSchedule.getTemperature(i));
            arrayList.add(iArr2);
            if (this.returnLastFoundSequence) {
                iArr3 = iArr2;
            } else {
                double scoreOf = sequenceModel.scoreOf(iArr2);
                if (scoreOf > d) {
                    iArr3 = iArr2;
                    d = scoreOf;
                }
            }
            if (verbose > 0) {
                System.err.print(".");
            }
        }
        if (verbose > 1) {
            System.err.println();
            printSamples(arrayList, System.err);
        }
        if (verbose > 0) {
            System.err.println("done.");
        }
        return iArr3;
    }

    public List<int[]> collectSamples(SequenceModel sequenceModel, int i, int i2) {
        return collectSamples(sequenceModel, i, i2, getRandomSequence(sequenceModel));
    }

    public List<int[]> collectSamples(SequenceModel sequenceModel, int i, int i2, int[] iArr) {
        if (verbose > 0) {
            System.err.print("Collecting samples");
        }
        this.listener.setInitialSequence(iArr);
        ArrayList arrayList = new ArrayList();
        int[] iArr2 = iArr;
        for (int i3 = 0; i3 < i; i3++) {
            iArr2 = copy(iArr2);
            sampleSequenceRepeatedly(sequenceModel, iArr2, i2);
            arrayList.add(iArr2);
            if (verbose > 0) {
                System.err.print(".");
            }
            System.err.flush();
        }
        if (verbose > 1) {
            System.err.println();
            printSamples(arrayList, System.err);
        }
        if (verbose > 0) {
            System.err.println("done.");
        }
        return arrayList;
    }

    public void sampleSequenceRepeatedly(SequenceModel sequenceModel, int[] iArr, int i) {
        int[] copy = copy(iArr);
        this.listener.setInitialSequence(copy);
        for (int i2 = 0; i2 < i; i2++) {
            sampleSequenceForward(sequenceModel, copy);
        }
    }

    public void sampleSequenceRepeatedly(SequenceModel sequenceModel, int i) {
        sampleSequenceRepeatedly(sequenceModel, getRandomSequence(sequenceModel), i);
    }

    public void sampleSequenceForward(SequenceModel sequenceModel, int[] iArr) {
        sampleSequenceForward(sequenceModel, iArr, 1.0d);
    }

    public void sampleSequenceForward(SequenceModel sequenceModel, int[] iArr, double d) {
        for (int i = 0; i < iArr.length; i++) {
            samplePosition(sequenceModel, iArr, i, d);
        }
    }

    public void sampleSequenceBackward(SequenceModel sequenceModel, int[] iArr) {
        sampleSequenceBackward(sequenceModel, iArr, 1.0d);
    }

    public void sampleSequenceBackward(SequenceModel sequenceModel, int[] iArr, double d) {
        for (int length = iArr.length - 1; length >= 0; length--) {
            samplePosition(sequenceModel, iArr, length, d);
        }
    }

    public double samplePosition(SequenceModel sequenceModel, int[] iArr, int i) {
        return samplePosition(sequenceModel, iArr, i, 1.0d);
    }

    public double samplePosition(SequenceModel sequenceModel, int[] iArr, int i, double d) {
        double[] scoresOf = sequenceModel.scoresOf(iArr, i);
        if (d != 1.0d) {
            if (d == 0.0d) {
                int argmax = ArrayMath.argmax(scoresOf);
                Arrays.fill(scoresOf, Double.NEGATIVE_INFINITY);
                scoresOf[argmax] = 0.0d;
            } else {
                ArrayMath.multiplyInPlace(scoresOf, 1.0d / d);
            }
        }
        ArrayMath.logNormalize(scoresOf);
        ArrayMath.expInPlace(scoresOf);
        int i2 = iArr[i];
        int sampleFromDistribution = ArrayMath.sampleFromDistribution(scoresOf, random);
        iArr[i] = sampleFromDistribution;
        this.listener.updateSequenceElement(iArr, i, i2);
        return scoresOf[sampleFromDistribution];
    }

    public void printSamples(List list, PrintStream printStream) {
        for (int i = 0; i < this.document.size(); i++) {
            HasWord hasWord = (HasWord) this.document.get(i);
            printStream.print(StringUtils.padOrTrim(hasWord != null ? hasWord.word() : "null", 10));
            for (int i2 = 0; i2 < list.size(); i2++) {
                printStream.print(" " + StringUtils.padLeft(((int[]) list.get(i2))[i], 2));
            }
            printStream.println();
        }
    }

    public SequenceGibbsSampler(int i, int i2, SequenceListener sequenceListener, List list, boolean z) {
        this.returnLastFoundSequence = false;
        this.numSamples = i;
        this.sampleInterval = i2;
        this.listener = sequenceListener;
        this.document = list;
        this.returnLastFoundSequence = z;
    }

    public SequenceGibbsSampler(int i, int i2, SequenceListener sequenceListener, List list) {
        this(i, i2, sequenceListener, list, false);
    }

    public SequenceGibbsSampler(int i, int i2, SequenceListener sequenceListener) {
        this(i, i2, sequenceListener, null);
    }
}
