/*
 * Decompiled with CFR 0.152.
 */
package projects.crispr;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractVariableLengthDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import java.text.NumberFormat;
import java.util.Arrays;
import projects.crispr.PositionPrior;

public class CRISPRDiffSM
extends AbstractVariableLengthDiffSM {
    private double[] prePars;
    private AbstractVariableLengthDiffSM bgModel;
    private AbstractVariableLengthDiffSM pairingModel;
    private PositionPrior position;
    private double ess;

    public CRISPRDiffSM(AlphabetContainer alphabets, AbstractDifferentiableStatisticalModel bgModel, AbstractVariableLengthDiffSM pairingModel, PositionPrior position, double ess) throws IllegalArgumentException, CloneNotSupportedException {
        super(alphabets, 0);
        this.bgModel = (AbstractVariableLengthDiffSM)bgModel.clone();
        this.pairingModel = (AbstractVariableLengthDiffSM)pairingModel.clone();
        this.position = position.clone();
        this.ess = ess;
    }

    public CRISPRDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public CRISPRDiffSM clone() throws CloneNotSupportedException {
        CRISPRDiffSM clone = (CRISPRDiffSM)super.clone();
        clone.bgModel = (AbstractVariableLengthDiffSM)this.bgModel.clone();
        clone.pairingModel = (AbstractVariableLengthDiffSM)this.pairingModel.clone();
        clone.position = this.position.clone();
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        double lp = 0.0;
        lp += this.bgModel.getLogPriorTerm();
        lp += this.pairingModel.getLogPriorTerm();
        return lp += this.position.getLogPriorTerm();
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        this.bgModel.addGradientOfLogPriorTerm(grad, start);
        this.pairingModel.addGradientOfLogPriorTerm(grad, start += this.bgModel.getNumberOfParameters());
        this.position.addGradientOfLogPriorTerm(grad, start += this.pairingModel.getNumberOfParameters());
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.bgModel.initializeFunctionRandomly(freeParams);
        this.pairingModel.initializeFunctionRandomly(freeParams);
        this.position.initializeFunctionRandomly(freeParams);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, start, seq.getLength() - 1, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int end, IntList indices, DoubleList partialDer) {
        double ls = 0.0;
        double[] temp = new double[2];
        IntList pairingIndices = new IntList();
        IntList bgIndices = new IntList();
        IntList positionIndices = new IntList();
        DoubleList pairingDers = new DoubleList();
        DoubleList bgDers = new DoubleList();
        DoubleList positionDers = new DoubleList();
        int ders = partialDer.length();
        int i = start;
        while (i <= end) {
            pairingIndices.clear();
            bgIndices.clear();
            positionIndices.clear();
            pairingDers.clear();
            bgDers.clear();
            positionDers.clear();
            double pos = this.position.getLogScoreAndPartialDerivation(i, positionIndices, positionDers);
            double npos = Math.log1p(-Math.exp(pos));
            double fg = this.pairingModel.getLogScoreAndPartialDerivation(seq, i, i, pairingIndices, pairingDers);
            temp[0] = pos + fg;
            temp[1] = npos;
            double bg = 0.0;
            int order = this.bgModel.getMaximalMarkovOrder();
            if (order > i) {
                order = i;
            }
            if (order > 0 && i > 0) {
                bg += this.bgModel.getLogScoreAndPartialDerivation(seq, i - order, i, bgIndices, bgDers);
                int num = bgDers.length();
                bg -= this.bgModel.getLogScoreAndPartialDerivation(seq, i - order, i - 1, bgIndices, bgDers);
                bgDers.multiply(num, bgDers.length(), -1.0);
            } else {
                bg += this.bgModel.getLogScoreAndPartialDerivation(seq, i, i, bgIndices, bgDers);
            }
            temp[1] = temp[1] + bg;
            ls += Normalisation.logSumNormalisation(temp);
            int j = 0;
            while (j < bgDers.length()) {
                indices.add(bgIndices.get(j));
                partialDer.add(bgDers.get(j) * temp[1]);
                ++j;
            }
            int off = this.bgModel.getNumberOfParameters();
            int j2 = 0;
            while (j2 < pairingDers.length()) {
                indices.add(pairingIndices.get(j2) + off);
                partialDer.add(pairingDers.get(j2) * temp[0]);
                ++j2;
            }
            off += this.pairingModel.getNumberOfParameters();
            double t = -Math.exp(pos) / Math.exp(npos) * temp[1];
            if (temp[1] == 0.0 && Double.isInfinite(npos)) {
                t = -Math.exp(bg) / Math.exp(fg);
            }
            if (Double.isNaN(t) || Double.isInfinite(t)) {
                System.out.println("pos " + i + ": " + t + " " + pos + " " + npos + " " + Arrays.toString(temp));
            }
            int j3 = 0;
            while (j3 < positionDers.length()) {
                indices.add(positionIndices.get(j3) + off);
                partialDer.add(positionDers.get(j3) * temp[0]);
                indices.add(positionIndices.get(j3) + off);
                partialDer.add(positionDers.get(j3) * t);
                ++j3;
            }
            ++i;
        }
        i = ders;
        while (i < partialDer.length()) {
            double d = partialDer.get(i);
            if (Double.isInfinite(d) || Double.isNaN(d)) {
                try {
                    System.out.println(Arrays.toString(temp));
                    System.out.println(String.valueOf(i - ders) + " " + indices.get(i) + " " + d + " " + Arrays.toString(this.getCurrentParameterValues()) + " " + seq);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                System.exit(1);
            }
            ++i;
        }
        return ls;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return 40;
    }

    @Override
    public int getNumberOfParameters() {
        int num = this.bgModel.getNumberOfParameters();
        num += this.pairingModel.getNumberOfParameters();
        return num += this.position.getNumberOfParameters();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] params = new double[this.getNumberOfParameters()];
        double[] temp = this.bgModel.getCurrentParameterValues();
        System.arraycopy(temp, 0, params, 0, temp.length);
        int off = this.bgModel.getNumberOfParameters();
        temp = this.pairingModel.getCurrentParameterValues();
        System.arraycopy(temp, 0, params, off, temp.length);
        temp = this.position.getCurrentParameterValues();
        System.arraycopy(temp, 0, params, off += this.pairingModel.getNumberOfParameters(), temp.length);
        return params;
    }

    @Override
    public void setParameters(double[] params, int start) {
        try {
            this.prePars = this.getCurrentParameterValues();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.bgModel.setParameters(params, start);
        this.pairingModel.setParameters(params, start += this.bgModel.getNumberOfParameters());
        this.position.setParameters(params, start += this.pairingModel.getNumberOfParameters());
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogScoreFor(seq, start, seq.getLength() - 1);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        double ls = 0.0;
        double[] temp = new double[2];
        int i = start;
        while (i <= end) {
            double pos = this.position.getLogScore(i);
            double npos = Math.log1p(-Math.exp(pos));
            temp[0] = pos + this.pairingModel.getLogScoreFor(seq, i, i);
            temp[1] = npos;
            int order = this.bgModel.getMaximalMarkovOrder();
            if (order > i) {
                order = i;
            }
            temp[1] = order > 0 && i > 0 ? temp[1] + (this.bgModel.getLogScoreFor(seq, i - order, i) - this.bgModel.getLogScoreFor(seq, i - order, i - 1)) : temp[1] + this.bgModel.getLogScoreFor(seq, i, i);
            if (Double.isNaN(ls += Normalisation.getLogSum(temp)) || Double.isInfinite(ls)) {
                System.out.println(String.valueOf(Arrays.toString(temp)) + " " + pos + " " + npos + " " + this.pairingModel.getLogScoreFor(seq, i, i));
                try {
                    System.out.println(Arrays.toString(this.prePars));
                    System.out.println(Arrays.toString(this.getCurrentParameterValues()));
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                this.setParameters(this.prePars, 0);
                System.out.println(this);
                System.exit(1);
            }
            ++i;
        }
        return ls;
    }

    @Override
    public boolean isInitialized() {
        return this.bgModel.isInitialized() && this.pairingModel.isInitialized();
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        sb.append(this.bgModel.toString(nf));
        sb.append("\n");
        sb.append(this.pairingModel.toString(nf));
        sb.append("\n");
        sb.append(this.position.toString());
        return sb.toString();
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
    }
}

