/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.continuous;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jtem.numericalMethods.algebra.linear.Determinant;
import de.jtem.numericalMethods.algebra.linear.Inversion;
import java.text.NumberFormat;
import java.util.Arrays;

public class MultivariateGaussian
extends AbstractTrainableStatisticalModel {
    private int dimension;
    private boolean dimensionAlongPositions;
    private double[] mean;
    private double[][] precMat;
    private double[] tempValues;
    private double[][] precTemp;
    private boolean isInitialized;
    private double[] priorMean;
    private double[][] priorPrecMat;
    private double essMu;
    private double essPrec;
    private boolean meanFixed;
    private boolean precisionFixed;

    public MultivariateGaussian(int dimension, boolean dimensionAlongPositions) throws CloneNotSupportedException {
        this(dimension, dimensionAlongPositions, new double[dimension], MultivariateGaussian.getPrecMat(dimension, 0.0), 0.0, dimension, false, false);
    }

    public MultivariateGaussian(int dimension, boolean dimensionAlongPositions, double[] priorMean, double priorPrec, double ess) throws CloneNotSupportedException {
        this(dimension, dimensionAlongPositions, priorMean, MultivariateGaussian.getPrecMat(dimension, priorPrec), ess, ess, false, false);
    }

    private static double[][] getPrecMat(int dimension, double priorPrec) {
        double[][] mat = new double[dimension][dimension];
        int i = 0;
        while (i < dimension) {
            mat[i][i] = priorPrec;
            ++i;
        }
        return mat;
    }

    public MultivariateGaussian(int dimension, boolean dimensionsAlongPositions, double[] priorMean, double[][] priorPrecMat, double essMu, double essPrec, boolean meanFixed, boolean precFixed) throws CloneNotSupportedException {
        super(new AlphabetContainer((Alphabet)new ContinuousAlphabet()), dimensionsAlongPositions ? dimension : 0);
        this.dimension = dimension;
        this.dimensionAlongPositions = dimensionsAlongPositions;
        this.mean = new double[dimension];
        this.tempValues = new double[dimension];
        this.precMat = new double[dimension][dimension];
        this.precTemp = new double[dimension][dimension];
        this.priorMean = (double[])priorMean.clone();
        this.priorPrecMat = (double[][])ArrayHandler.clone((Cloneable[])priorPrecMat);
        this.essMu = essMu;
        this.essPrec = essPrec;
        this.meanFixed = meanFixed;
        this.precisionFixed = precFixed;
    }

    public MultivariateGaussian(StringBuffer stringBuff) throws NonParsableException {
        super(stringBuff);
    }

    @Override
    public MultivariateGaussian clone() throws CloneNotSupportedException {
        MultivariateGaussian clone = (MultivariateGaussian)super.clone();
        clone.mean = (double[])this.mean.clone();
        clone.precMat = (double[][])ArrayHandler.clone((Cloneable[])this.precMat);
        clone.precTemp = (double[][])ArrayHandler.clone((Cloneable[])this.precTemp);
        clone.priorMean = (double[])this.priorMean.clone();
        clone.priorPrecMat = (double[][])ArrayHandler.clone((Cloneable[])this.priorPrecMat);
        clone.tempValues = (double[])this.tempValues.clone();
        return clone;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer sb = new StringBuffer();
        XMLParser.appendObjectWithTags(sb, this.dimension, "dimension");
        XMLParser.appendObjectWithTags(sb, this.dimensionAlongPositions, "dimensionAlongPositions");
        XMLParser.appendObjectWithTags(sb, this.essMu, "essMu");
        XMLParser.appendObjectWithTags(sb, this.essPrec, "essPrec");
        XMLParser.appendObjectWithTags(sb, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(sb, this.mean, "mean");
        XMLParser.appendObjectWithTags(sb, this.meanFixed, "meanFixed");
        XMLParser.appendObjectWithTags(sb, this.precisionFixed, "precisionFixed");
        XMLParser.appendObjectWithTags(sb, this.precMat, "precMat");
        XMLParser.appendObjectWithTags(sb, this.precTemp, "precTemp");
        XMLParser.appendObjectWithTags(sb, this.priorMean, "priorMean");
        XMLParser.appendObjectWithTags(sb, this.priorPrecMat, "priorPrecMat");
        XMLParser.addTags(sb, "MultivariateGaussianDiffSM");
        return sb;
    }

    @Override
    public void train(DataSet data, double[] weights) throws Exception {
        int k;
        int j;
        double w;
        Sequence seq;
        int i;
        double n = 0.0;
        if (this.meanFixed) {
            this.mean = (double[])this.priorMean.clone();
        } else {
            n = this.essMu;
            i = 0;
            while (i < this.mean.length) {
                this.mean[i] = this.essMu * this.priorMean[i];
                ++i;
            }
            i = 0;
            while (i < data.getNumberOfElements()) {
                seq = data.getElementAt(i);
                double d = w = weights == null ? 1.0 : weights[i];
                if (this.dimensionAlongPositions) {
                    j = 0;
                    while (j < seq.getLength()) {
                        int n2 = j;
                        this.mean[n2] = this.mean[n2] + seq.continuousVal(j) * w;
                        ++j;
                    }
                    n += w;
                } else if (this.dimension == 1) {
                    j = 0;
                    while (j < seq.getLength()) {
                        this.mean[0] = this.mean[0] + seq.continuousVal(j) * w;
                        n += w;
                        ++j;
                    }
                } else {
                    if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                        throw new Exception();
                    }
                    j = 0;
                    while (j < seq.getLength()) {
                        ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, j);
                        k = 0;
                        while (k < this.tempValues.length) {
                            int n3 = k;
                            this.mean[n3] = this.mean[n3] + this.tempValues[k] * w;
                            ++k;
                        }
                        n += w;
                        ++j;
                    }
                }
                ++i;
            }
            i = 0;
            while (i < this.mean.length) {
                int n4 = i++;
                this.mean[n4] = this.mean[n4] / n;
            }
        }
        if (this.precisionFixed) {
            this.precMat = (double[][])ArrayHandler.clone((Cloneable[])this.priorPrecMat);
        } else {
            i = 0;
            while (i < this.precMat.length) {
                Arrays.fill(this.precMat[i], 0.0);
                ++i;
            }
            n = 0.0;
            i = 0;
            while (i < data.getNumberOfElements()) {
                seq = data.getElementAt(i);
                double d = w = weights == null ? 1.0 : weights[i];
                if (this.dimensionAlongPositions) {
                    j = 0;
                    while (j < seq.getLength()) {
                        k = 0;
                        while (k < seq.getLength()) {
                            double[] dArray = this.precMat[j];
                            int n5 = k;
                            dArray[n5] = dArray[n5] + (seq.continuousVal(j) - this.mean[j]) * (seq.continuousVal(k) - this.mean[k]) * w;
                            ++k;
                        }
                        ++j;
                    }
                    n += w;
                } else if (this.dimension == 1) {
                    j = 0;
                    while (j < seq.getLength()) {
                        double[] dArray = this.precMat[0];
                        dArray[0] = dArray[0] + (seq.continuousVal(j) - this.mean[0]) * (seq.continuousVal(j) - this.mean[0]) * w;
                        n += w;
                        ++j;
                    }
                } else {
                    if (((MultiDimensionalSequence)seq).getNumberOfSequences() != this.dimension) {
                        throw new Exception();
                    }
                    j = 0;
                    while (j < seq.getLength()) {
                        ((MultiDimensionalSequence)seq).fillContainer(this.tempValues, j);
                        k = 0;
                        while (k < this.tempValues.length) {
                            int m = 0;
                            while (m < this.tempValues.length) {
                                double[] dArray = this.precMat[k];
                                int n6 = m;
                                dArray[n6] = dArray[n6] + (this.tempValues[k] - this.mean[k]) * (this.tempValues[m] - this.mean[m]) * w;
                                ++m;
                            }
                            ++k;
                        }
                        n += w;
                        ++j;
                    }
                }
                ++i;
            }
            i = 0;
            while (i < this.mean.length) {
                int j2 = 0;
                while (j2 < this.mean.length) {
                    this.precTemp[i][j2] = (this.mean[i] - this.priorMean[i]) * (this.mean[j2] - this.priorMean[j2]);
                    ++j2;
                }
                ++i;
            }
            i = 0;
            while (i < this.precMat.length) {
                int j3 = 0;
                while (j3 < this.precMat[i].length) {
                    this.precMat[i][j3] = (this.priorPrecMat[i][j3] + this.precMat[i][j3] + this.essMu * this.precTemp[i][j3]) / (this.essPrec - (double)this.dimension + n);
                    ++j3;
                }
                ++i;
            }
            Inversion.compute(this.precMat, this.precTemp);
            double[][] temp = this.precMat;
            this.precMat = this.precTemp;
            this.precTemp = temp;
        }
        this.isInitialized = true;
    }

    public double getLogProbFor(double[] values) {
        double val = 0.0;
        int i = 0;
        while (i < values.length) {
            int j = 0;
            while (j < values.length) {
                val -= (values[i] - this.mean[i]) * this.precMat[i][j] * (values[j] - this.mean[j]);
                ++j;
            }
            ++i;
        }
        val += Math.log(Determinant.compute(this.precMat)) - (double)this.dimension * Math.log(Math.PI * 2);
        return val *= 0.5;
    }

    @Override
    public double getLogProbFor(Sequence sequence, int startpos, int endpos) throws Exception {
        if (this.dimensionAlongPositions) {
            if (endpos - startpos + 1 != this.dimension) {
                throw new Exception(String.valueOf(endpos) + " - " + startpos + " != " + this.dimension);
            }
            int i = startpos;
            while (i <= endpos) {
                this.tempValues[i - startpos] = sequence.continuousVal(i);
                ++i;
            }
            return this.getLogProbFor(this.tempValues);
        }
        double val = 0.0;
        if (this.dimension == 1) {
            int i = startpos;
            while (i <= endpos) {
                this.tempValues[0] = sequence.continuousVal(i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        } else {
            if (((MultiDimensionalSequence)sequence).getNumberOfSequences() != this.dimension) {
                throw new Exception();
            }
            int i = startpos;
            while (i <= endpos) {
                ((MultiDimensionalSequence)sequence).fillContainer(this.tempValues, i);
                val += this.getLogProbFor(this.tempValues);
                ++i;
            }
        }
        return val;
    }

    @Override
    public double getLogPriorTerm() throws Exception {
        double lp = 0.0;
        if (!this.meanFixed) {
            int i = 0;
            while (i < this.mean.length) {
                int j = 0;
                while (j < this.mean.length) {
                    lp += (this.priorMean[i] - this.mean[i]) * this.precMat[i][j] * (this.priorMean[j] - this.mean[j]);
                    ++j;
                }
                ++i;
            }
            lp *= -this.essMu / 2.0;
        }
        double det = Math.log(Determinant.compute(this.precMat));
        if (!this.precisionFixed) {
            lp += (this.essPrec - (double)this.dimension - 1.0) / 2.0 * det;
            int i = 0;
            while (i < this.mean.length) {
                int k = 0;
                while (k < this.mean.length) {
                    lp -= 0.5 * this.priorPrecMat[i][k] * this.precMat[k][i];
                    ++k;
                }
                ++i;
            }
        }
        if (!this.meanFixed) {
            lp += 0.5 * det;
        }
        return lp;
    }

    @Override
    public String getInstanceName() {
        return null;
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "MultivariateGaussianDiffSM");
        this.dimension = XMLParser.extractObjectForTags(xml, "dimension", Integer.TYPE);
        this.dimensionAlongPositions = XMLParser.extractObjectForTags(xml, "dimensionAlongPositions", Boolean.TYPE);
        this.essMu = XMLParser.extractObjectForTags(xml, "essMu", Double.TYPE);
        this.essPrec = XMLParser.extractObjectForTags(xml, "essPrec", Double.TYPE);
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.mean = (double[])XMLParser.extractObjectForTags(xml, "mean");
        this.meanFixed = XMLParser.extractObjectForTags(xml, "meanFixed", Boolean.TYPE);
        this.precisionFixed = XMLParser.extractObjectForTags(xml, "precisionFixed", Boolean.TYPE);
        this.precMat = (double[][])XMLParser.extractObjectForTags(xml, "precMat");
        this.precTemp = (double[][])XMLParser.extractObjectForTags(xml, "precTemp");
        this.priorMean = (double[])XMLParser.extractObjectForTags(xml, "priorMean");
        this.priorPrecMat = (double[][])XMLParser.extractObjectForTags(xml, "priorPrecMat");
        this.tempValues = new double[this.dimension];
        this.alphabets = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        this.length = this.dimensionAlongPositions ? this.dimension : 0;
    }

    @Override
    public String toString(NumberFormat nf) {
        String str = String.valueOf(Arrays.toString(this.mean)) + "\n\n";
        int i = 0;
        while (i < this.precMat.length) {
            str = String.valueOf(str) + Arrays.toString(this.precMat[i]);
            ++i;
        }
        return str;
    }
}

