#include "clustddp.h"

int nclust(const int *x, const int n) {
  int z[n];
  int isunique;
  int i,j,count = 1;
  z[0] = x[0];
  if(n > 1) {
    for(i=1;i<n;i++) {
      isunique = 1;
      for(j=0;j<count;j++) { 
        if(z[j] == x[i]) isunique = 0;
      }
      if(isunique) {
        z[count] = x[i];
        count++;
      }
      else { }
    }
  }
  return count;
}

/** bait clustering DP **/
int updateBeta(MODEL *model, DATA *data, const gsl_rng *r) {
  int j,k;
  double a,b;
  double betap[model->nb];
  double m[model->nb];
  double mrevsum[model->nb];
  for(k=0;k<model->nb;k++) m[k] = 0.0;
  for(j=0;j<data->nbait;j++) m[model->dp.z[j]] += 1.0;
  mrevsum[model->nb-1] = m[model->nb-1];
  for(k=model->nb-2;k>=0;k--) mrevsum[k] = mrevsum[k+1] + m[k];
  for(k=0;k<model->nb-1;k++) {
    a = 1.0 + m[k];
    b = model->dp.alpha + mrevsum[k+1];
    betap[k] = gsl_ran_beta(r, a, b);
    /* if(checknum(betap[k])) return 1; */
  }
  betap[model->nb-1] = 1.0;
  for(k=0;k<model->nb;k++) {
    model->dp.beta[k] = 0.0;
    for(j=0;j<k;j++) model->dp.beta[k] += log(1.0 - betap[j]);
    model->dp.beta[k] += log(betap[k]);
    if(checknum(model->dp.beta[k])) return 1;
  }
  for(k=0;k<model->nb;k++) model->dp.beta[k] = exp(model->dp.beta[k]);   
  return 0;
}

void printp(const double *p, const int n) {
  int i;
  for(i=0;i<n;i++) fprintf(stderr, "%.3f\t", p[i]);
  fprintf(stderr, "\n");
}


int updateZ2(MODEL *model, DATA *data, const gsl_rng *r) {
  int i,j,k,z;
  double maxp;
  double loglikNew_j[model->nb];
  double p[model->nb];
  int id[model->nb];
  double p_minus_j;
  double H, T;
  H = model->H[model->curLadder];
  T = model->T[model->curLadder];

  for(j=0;j<data->nbait;j++) {
    for(k=0;k<model->nb;k++) id[k] = 1;
    for(k=0;k<data->nbait;k++) id[model->dp.z[k]] = 0;
    for(k=0;k<data->ninter[j];k++) id[model->dp.z[data->inter[j][k]]] = 1;

    p_minus_j = model->curlik - loglikBait(model,data,j,r);
    z = model->dp.z[j];
    for(i=0;i<data->nprey;i++) {
      p_minus_j -= log(model->hdp.pi[i][model->hdp.y[z][i]]);
    }

    for(k=0;k<model->nb;k++) loglikNew_j[k] = 0.0;
    if(loglikBaitSample(model,data,j,loglikNew_j,r,id)) { 
      fprintf(stderr, "Error in bait-loglikelihood\n"); return 1; 
    }
    for(k=0;k<model->nb;k++) {
      if(id[k]) {
        for(i=0;i<data->nprey;i++) loglikNew_j[k] += log(model->hdp.pi[i][model->hdp.y[k][i]]); 
      }
    }

    for(k=0;k<model->nb;k++) p[k] = p_minus_j + loglikNew_j[k];
    maxp = vec_max_cond(p, model->nb, id);
    for(k=0;k<model->nb;k++) p[k] -= maxp;
    for(k=0;k<model->nb;k++) p[k] += log(model->dp.beta[k]);
    for(k=0;k<model->nb;k++) p[k] = id[k] ? exp(p[k]) : 0.0;
    model->dp.z[j] = ranMultinom(r, p, model->nb);
    
    if(model->dp.z[j] < 0 || model->dp.z[j] >= model->nb) return 1;
    if(checknum(model->dp.z[j])) return 1;
    updateBeta(model, data, r);
    model->curlik = loglik(model, data, r) + logPrior(model, data);
  }
  return 0;
}

void baitClusteringDP(MODEL *model, DATA *data, const gsl_rng *r) {
  if(updateBeta(model, data, r)) fprintf(stderr, "baitClusterDP -- beta\n"); 
  /* if(updateZ(model, data, r)) fprintf(stderr, "baitClusterDP -- Z\n"); */
  if(updateZ2(model, data, r)) fprintf(stderr, "baitClusterDP -- Z\n"); 
}

/** prey clustering HDP **/
int updateEtaPi(MODEL *model, DATA *data, const gsl_rng *r) {
  int i,j,k,l;
  double a,b;
  double etap[model->np];
  double pip[model->np];
  double m[model->np];
  double mrevsum[model->np];
  double mi[model->np];
  double mirevsum[model->np];
  for(k=0;k<model->np;k++) {
    m[k] = 0.0;
    mrevsum[k] = 0.0;
  }
  for(i=0;i<data->nprey;i++) {
    for(k=0;k<model->np;k++) mi[k] = 0.0;
    for(j=0;j<data->nbait;j++) mi[model->hdp.y[model->dp.z[j]][i]] += 1.0;  
    mirevsum[model->np-1] = mi[model->np-1];
    for(k=model->np-2;k>=0;k--) mirevsum[k] = mirevsum[k+1] + mi[k];
    for(k=0;k<model->np-1;k++) {
      a = 1.0 + mi[k]; 
      b = model->hdp.gamma + mirevsum[k+1];
      pip[k] = gsl_ran_beta(r, a, b);
      if(checknum(pip[k])) return 1;
    }
    pip[model->np-1] = 1.0;
    for(k=0;k<model->np;k++) {
      model->hdp.pi[i][k] = 0.0;
      for(l=0;l<k;l++) model->hdp.pi[i][k] += log(1.0 - pip[l]);
      model->hdp.pi[i][k] += log(pip[k]);
      model->hdp.pi[i][k] = exp(model->hdp.pi[i][k]);
      if(checknum(model->hdp.pi[i][k])) return 1;
    }
    if(data->use[i]) {
      for(k=0;k<model->np;k++) {
        m[k] += mi[k];
        mrevsum[k] += mirevsum[k];
      }
    }
  }
  
  /* sample etap, eta */
  for(k=0;k<model->np-1;k++) {
    a = 1.0 + m[k];
    b = model->hdp.rho + mrevsum[k+1];
    etap[k] = gsl_ran_beta(r, a, b);
  }
  etap[model->np-1] = 1.0;
  for(k=0;k<model->np;k++) {
    model->hdp.eta[k] = 0.0;
    for(l=0;l<k;l++) model->hdp.eta[k] += log(1.0 - etap[l]);
    model->hdp.eta[k] += log(etap[k]);
    model->hdp.eta[k] = exp(model->hdp.eta[k]);
    if(checknum(model->hdp.eta[k])) return 1;
  }
  return 0;
}

int updateY(MODEL *model, DATA *data, const gsl_rng *r) {
  int i,j,k,l,id,m,id2,pos;
  int map[model->nb][data->nbait];  /* mark -1 in map[][0] if there is no member bait */
  int nmap[model->nb];
  double prob[model->np]; 
  double maxprob;
  double loglikNew_ki[model->np];
  double p_minus_ki;

  int candidate[data->nprey];
  /* int cur0, cur1, selectid, hit, cont; */
  double mhratio;  

  /* first identify the bait clustering map (from cluster to sample) */
  for(k=0;k<model->nb;k++) {
    nmap[k] = 0;
    for(j=0;j<data->nbait;j++) map[k][j] = -1;
  }
  for(j=0;j<data->nbait;j++) {
    k = model->dp.z[j];
    map[k][nmap[k]] = j;
    (nmap[k])++;
  }

  /* sample ranMultinom, make probabilities in log scale */
  for(k=0;k<model->nb;k++) {
    if(nmap[k] > 0) {
      for(i=0;i<data->nprey;i++) {
        id = model->hdp.y[k][i];
        p_minus_ki = model->curlik;
        for(j=0;j<nmap[k];j++) {
          pos = map[k][j];
          for(m=0;m<data->n_b2ip[pos];m++) {
            id2 = data->b2ip[pos][m];
            p_minus_ki -= logGaussian(data->d[id2][i], model->hdp.omega[id], model->sigma2[i], r);
          }
          p_minus_ki -= log(model->hdp.pi[i][id]);
        }
 
        /* likelihood calculation */
        for(l=0;l<model->np;l++) {
          loglikNew_ki[l] = 0.0;
          for(j=0;j<nmap[k];j++) {
            pos = map[k][j];
            for(m=0;m<data->n_b2ip[pos];m++) {
              id2 = data->b2ip[pos][m];
              loglikNew_ki[l] += logGaussian(data->d[id2][i], model->hdp.omega[l], model->sigma2[i], r);
            }
            loglikNew_ki[l] += log(model->hdp.pi[i][l]);
          }
        }

        /* truncation for p here */
        for(l=0;l<model->np;l++) prob[l] = p_minus_ki + loglikNew_ki[l];
 
        /* probability calculation */
        maxprob = vec_max(prob, model->np);
        for(l=0;l<model->np;l++) prob[l] -= maxprob;
        for(l=0;l<model->np;l++) prob[l] = exp(prob[l]);

        /* model->hdp.y[k][i] = vec_max_index(prob, model->np); */
        model->hdp.y[k][i] = ranMultinom(r, prob, model->np); 
        if(checknum(model->hdp.y[k][i])) return 1;

        /* update curlik */
        model->curlik = p_minus_ki + loglikNew_ki[model->hdp.y[k][i]];
      }
    }
    else {
      for(i=0;i<data->nprey;i++) {
        for(l=0;l<model->np;l++) prob[l] = model->hdp.pi[i][l];
        candidate[i] = ranMultinom(r, prob, model->np); 
        mhratio = 0.0;
        mhratio += log(model->hdp.pi[i][candidate[i]]);
        mhratio -= log(model->hdp.pi[i][model->hdp.y[k][i]]); 
        mhratio = GSL_MIN(1.0,exp(mhratio));
        if(gsl_ran_flat(r,0.0,1.0) < mhratio) model->hdp.y[k][i] = candidate[i];
      }
    }
  }

  model->curlik = loglik(model, data, r) + logPrior(model, data);
  
  return 0;
}


void sampleBase(const gsl_rng *r, DATA *data, double *omega, double lambda, double nu, double a, double b, double mv) {
  double tmp;
  tmp = gsl_ran_gaussian(r, sqrt(nu)) + lambda;
  omega[0] = tmp;
  tmp = 1.0 / gsl_ran_gamma(r, a, 1.0/b);
  if(tmp >= mv) omega[1] = tmp;
  else omega[1] = mv;
}

void updateBase(const gsl_rng *r, DATA *data, double *omega, double *propose, double a, double b, double mv) {
  double tmp;
  if(gsl_ran_flat(r,0.0,1.0) < 0.5) {
    tmp = gsl_ran_gaussian(r, 1.0) + omega[0];
    propose[0] = tmp;
    propose[1] = omega[1];
  }
  else {
    propose[0] = omega[0];
    tmp = gsl_ran_gaussian(r, 0.5) + omega[1];
    if(tmp >= mv) propose[1] = tmp;
    else propose[1] = mv;
  }
}

/* Prior evaluate */ 

double logBaseEvaluate(double *omega, double lambda, double nu, double a, double b) {
  double mu, v;
  double res;
  mu = omega[0];
  v = omega[1];
  res = -.5 * log(nu * 2.0 * M_PI) + pow(mu - lambda, 2.0) / (2.0 * nu);
  res += a * log(b) - gsl_sf_lngamma(a) - (a + 1.0) * log(v) - b / v; 
  return res;
}

int updateOmega(MODEL *model, DATA *data, const gsl_rng *r) {
  int i,j,k,m,id2;
  int accept;
  double coin, mhratio;  
  double propose[_DIM_];
  double lik_res, lik_prev, lik_new;
  double likNew, likPrev;
  int inUse[model->np];

  /* Metropolis-Hastings */
  for(k=0;k<model->np;k++) inUse[k] = 0;
  for(i=0;i<data->nprey;i++) {
    for(j=0;j<data->nbait;j++) {
      inUse[model->hdp.y[model->dp.z[j]][i]] = 1;
    }
  }

  for(k=0;k<model->np;k++) {
    if(inUse[k]) {
      lik_prev = 0.0;
      for(i=0;i<data->nprey;i++) {
        for(j=0;j<data->nbait;j++) {
          if(model->hdp.y[model->dp.z[j]][i] == k) {
            for(m=0;m<data->n_b2ip[j];m++) {
              id2 = data->b2ip[j][m];
              lik_prev += logGaussian(data->d[id2][i], model->hdp.omega[k], model->sigma2[i], r);
            }
          }
        }
      }
      lik_res = model->curlik - lik_prev;

      accept = 0;
      mhratio = 0.0;

      updateBase(r, data, model->hdp.omega[k], propose, model->a, model->b, model->minvar);

      lik_new = 0.0;
      for(i=0;i<data->nprey;i++) {
        for(j=0;j<data->nbait;j++) {
          if(model->hdp.y[model->dp.z[j]][i] == k) {
            for(m=0;m<data->n_b2ip[j];m++) {
              id2 = data->b2ip[j][m];
              lik_new += logGaussian(data->d[id2][i], propose, model->sigma2[i], r);
            }
          }
        }
      }

      likNew = lik_res + lik_new;
      likPrev = lik_res + lik_prev;

      mhratio = likNew - likPrev;
      mhratio += logBaseEvaluate(propose, model->lambda, model->nu, model->a, model->b);
      mhratio -= logBaseEvaluate(model->hdp.omega[k], model->lambda, model->nu, model->a, model->b);
      mhratio = exp(mhratio);
      if(mhratio > 1.0) mhratio = 1.0;
      coin = gsl_ran_flat(r, 0.0, 1.0);
      if(coin <= mhratio) accept = 1;
      if(accept) {
        for(i=0;i<_DIM_;i++) model->hdp.omega[k][i] = propose[i];
        model->curlik = loglik(model, data, r) + logPrior(model, data);
      }
    }
    else {
      sampleBase(r, data, model->hdp.omega[k], model->lambda, model->nu, model->a, model->b, model->minvar);
    }
  }
  return 0;
}

void preyClusteringHDP(MODEL *model, DATA *data, const gsl_rng *r) {
  if(updateOmega(model, data, r)) fprintf(stderr, "preyClusterHDP -- Omega\n");
  if(updateEtaPi(model, data, r)) fprintf(stderr, "preyClusterHDP -- Eta or Pi\n");
  if(updateY(model, data, r)) fprintf(stderr, "preyClusterHDP -- Y\n");

}

void imputeMiss(MODEL *model, DATA *data, const gsl_rng *r) {
  int i, j;
  for(j=0;j<data->nIP;j++) {
    for(i=0;i<data->nprey;i++) {
      if(data->zero[j][i]) data->d[j][i] = gsl_ran_gaussian(r, 0.1) - 1.0;
    }
  }
}

void updateSigma2(MODEL *model, DATA *data, const gsl_rng *r) {
  int i,j,k,id,id2;
  double shape, rate;
  for(i=0;i<data->nprey;i++) {
    /* inverse gamma prior */
    shape = ((double) model->nbait / 2);
    rate = ((double) model->nbait);
    for(j=0;j<data->nbait;j++) {
      id = model->hdp.y[model->dp.z[j]][i];
      for(k=0;k<data->n_b2ip[j];k++) {
        id2 = data->b2ip[j][k];
        shape += .5;
        rate += .5 * pow(data->d[id2][i] - model->hdp.omega[id][0], 2.0);
      }
    }
    model->sigma2[i] = 1.0 / gsl_ran_gamma(r, shape, 1.0/rate);
    /* fprintf(stderr, "%.2f\t", model->sigma2[i]);  */
  }  
}

/** main function **/
void blockedGibbs(MODEL *model, DATA *data, const gsl_rng *r, int upvar) {
  int i;
  /* In EE sampler, update from the higher-order chain to the lower-order chain */
  /* imputeMiss(model, data, r);  */
  baitClusteringDP(model, data, r); 
  for(i=0;i<2;i++) preyClusteringHDP(model, data, r);
  /* updateSigma2(model, data, r); */
  /* varSelect(&(model->dp), &(model->hdp), data, r); */
  model->curlik = loglik(model, data, r) + logPrior(model, data);
}





