/*
 * TNGG: TNGG with marginal sampler for only 2 times
 */

#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <sys/types.h>
#include <assert.h>
#include "mex.h"

#include "slice-sampler.h"

//#define RECORD

#ifdef _MSC_VER
#define finite _finite
#define isnan _isnan
#endif

#ifdef	 __USE_ISOC99
/* INFINITY and NAN are defined by the ISO C99 standard */
#else
double my_infinity(void) {
  double zero = 0;
  return 1.0/zero;
}
double my_nan(void) {
  double zero = 0;
  return zero/zero;
}
#define INFINITY my_infinity()
#define NAN my_nan()
#endif

/********************
 * M:		concentration parameters for sources DP, 1 * nS
 * q:		subsampling rates from source s to group t, nDP * nS
 * n:		counts on sources, 1 * nS
 * n_t:		counts on groups, 1 * nDP
 * n_ts:	counts from group t to source s, nDP * nS
 * nj:		counts on sources for topics k, 1 * K
 * nj_t:	counts on groups for topics k, nDP * K
 * K_c:		#topics in each group, last one is #topics in total, 1 * (nDP + 1)
 * K_ct:	#topics in each source, 1 * nS
 * r:		indicators of inheritance from source s to group t, nDP * K
 * rr:		indicators of random jumps inherited from source s to group t, nDP * nS * K_e
 * K_t:		max #topics considered so far, 1 * 1
 * nDP:		#groups
 * nS:		#sources
 * JJ:		jumps in sources s, nS * K
 * K_e:		#random jumps in each sources
 * K_id:	indicator of which sources this topics belongs to, 1 * K
 */

double a, *M, **q, *sumzj, *JJ, gamma_a;
int *n, *n_t, *n_t_tot, **n_ts, *nj, **nj_t, *K_c, *K_ct, *K_id, K_t, nS, nDP, *tag_tr, u_idx, q_i, q_j, output;
double *u, qa, qb;
double lv = -20, rv = 20, qlv = 1e-10, qrv = 0.999999;
FILE* fid;

/*
 * data loading functions, adapted from Teh's HDP code
*/
#define max1(x1,x2) ( (x1) < (x2) ? (x2) : (x1) )

#define mxReadCellVectorDef(funcname,type) \
  type **funcname(const mxArray *mcell, type shift) { \
    mxArray *mvector; \
    double *mdouble; \
    int ii, jj; \
    type **result; \
    result = (type**)malloc(sizeof(type*)*mxGetNumberOfElements(mcell));  \
    for ( jj = 0 ; jj < mxGetNumberOfElements(mcell) ; jj++ ) {  \
      mvector = mxGetCell(mcell,jj);  \
      mdouble = mxGetPr(mvector);  \
      result[jj] = (type*)malloc(sizeof(type)*mxGetNumberOfElements(mvector));  \
      for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ )  \
        result[jj][ii] = (type)mdouble[ii] + shift;  \
    } \
    return result; \
  }
mxReadCellVectorDef(mxReadIntCellVector,int);
mxReadCellVectorDef(mxReadDoubleCellVector,double);

#define mxWriteCellVectorDef(funcname, type) \
  mxArray *funcname(int numcell,int *numentry,type **var,type shift, int del) { \
    mxArray *result, *mvector; \
    double *mdouble; \
    int ii, jj; \
    result = mxCreateCellMatrix(1,numcell); \
    for ( jj = 0 ; jj < numcell ; jj++) { \
      mvector = mxCreateDoubleMatrix(1,numentry[jj],mxREAL); \
      mxSetCell(result,jj,mvector); \
      mdouble = mxGetPr(mvector); \
      for ( ii = 0 ; ii < numentry[jj] ; ii++ ) \
        mdouble[ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del) \
    	free(var); \
    return result; \
  }
mxWriteCellVectorDef(mxWriteIntCellVector, int);
mxWriteCellVectorDef(mxWriteDoubleCellVector, double);

#define reallo(funcname, type) \
	type** funcname(type** var, int old_len, int add_len, int dim) { \
		int i, j; \
		type** result = (type**)malloc((old_len + add_len) * sizeof(type*)); \
		for(i = 0; i < old_len + add_len; i++){ \
			if(i < old_len){ \
				result[i] = var[i]; \
				var[i] = NULL; \
			}else{ \
				result[i] = (type*)malloc(dim * sizeof(type)); \
				for(j = 0; j < dim; j++){ \
					result[i][j] = (type)0; \
				} \
			} \
		} \
		free(var); \
		return result; \
	}
reallo(realloInt, int);
reallo(realloDouble, double);

#define mxReadVectorDef(funcname,type,str) \
  type *funcname(const mxArray *mvector,int number,type shift,type init) { \
    double *mdouble; \
    type *result;  \
    int ii; \
    number = max1(number,mxGetNumberOfElements(mvector)); \
    result = (type*) malloc(sizeof(type)*number); \
    mdouble = mxGetPr(mvector); \
    for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ ) \
      result[ii] = (type)mdouble[ii] + shift; \
    for ( ii = mxGetNumberOfElements(mvector) ; ii < number ; ii++ ) \
      result[ii] = init; \
    return result; \
  } 
mxReadVectorDef(mxReadIntVector,int,"%d ");
mxReadVectorDef(mxReadDoubleVector,double,"%g ");

#define mxWriteVectorDef(funcname, type, str) \
  mxArray *funcname(int mm,int nn,type *var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii; \
    result = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( ii = 0 ; ii < mm*nn ; ii++ ) \
      mdouble[ii] = var[ii] + shift; \
	if(del) \
    	free(var); \
    return result; \
  } 
mxWriteVectorDef(mxWriteIntVector, int, "%d ");
mxWriteVectorDef(mxWriteDoubleVector, double, "%g ");

#define mxReadMatrixDef(funcname,type) \
  type **funcname(const mxArray *marray,int mm,int nn,type shift,type init) { \
    double *mdouble; \
    int ii, jj, m1, n1; \
    type **result; \
    mdouble = mxGetPr(marray); \
    mm = max1(mm, m1 = mxGetM(marray)); \
    nn = max1(nn, n1 = mxGetN(marray)); \
    result = (type**) malloc(sizeof(type*)*mm); \
    for ( jj = 0 ; jj < mm ; jj++ ) { \
      result[jj] = (type*) malloc(sizeof(type)*nn); \
    } \
    for ( jj = 0 ; jj < m1 ; jj++ ) {\
      for ( ii = 0 ; ii < n1 ; ii++ ) \
        result[jj][ii] = (type)mdouble[ii*m1+jj] + shift; \
      for ( ii = n1 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    } \
    for ( jj = m1 ; jj < mm ; jj++ ) \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    return result; \
  }
mxReadMatrixDef(mxReadIntMatrix,int);
mxReadMatrixDef(mxReadDoubleMatrix,double);

#define mxWriteMatrixDef(funcname, type) \
  mxArray *funcname(int mm,int nn,int maxm,type **var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii, jj; \
    result  = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( jj = 0 ; jj < mm ; jj++) { \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        mdouble[jj+mm*ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del){ \
		for ( jj = mm ; jj < maxm ; jj ++ ) \
		  free(var[jj]); \
		free(var); \
	}\
	return result; \
  }
mxWriteMatrixDef(mxWriteIntMatrix, int);
mxWriteMatrixDef(mxWriteDoubleMatrix, double);

double mxReadScalar(const mxArray *mscalar) 
{
	return (*mxGetPr(mscalar));
}

mxArray *mxWriteScalar(double var) 
{ 
	mxArray *result;
	result = mxCreateDoubleMatrix(1,1,mxREAL);
	*mxGetPr(result) = var; 
	return result;
}

/*double randd()
{
	return (double)(rand() + 1) / (double)(RAND_MAX + 2);
}*/

double gammaln(double x)
{
	#define M_lnSqrt2PI 0.91893853320467274178
	static double gamma_series[] = {
		76.18009172947146,
		-86.50532032941677,
		24.01409824083091,
		-1.231739572450155,
		0.1208650973866179e-2,
		-0.5395239384953e-5
	};
	int i;
	double denom, x1, series;
	if(x < 0) return NAN;
	if(x == 0) return INFINITY;
	if(!finite(x)) return x;
	/* Lanczos method */
	denom = x+1;
	x1 = x + 5.5;
	series = 1.000000000190015;
	for(i = 0; i < 6; i++) {
		series += gamma_series[i] / denom;
		denom += 1.0;
	}
	return( M_lnSqrt2PI + (x+0.5)*log(x1) - x1 + log(series/x) );
}

double gamma(double x)
{
	if(x < 0) return NAN;
	if(x == 0) return INFINITY;
	return exp(gammaln(x));
}

double randgamma(double rr, double theta) 
{
	double aa, bb, cc, dd;
  	double uu, vv, ww, xx, yy, zz;

  	if ( rr <= 0.0 ) {
    	/* Not well defined, set to zero and skip. */
    	return 0.0;
  	} else if ( rr == 1.0 ) {
    	/* Exponential */
    	return - log(drand48()) / theta;
  	} else if ( rr < 1.0 ) {
    	/* Use Johnks generator */
    	cc = 1.0 / rr;
    	dd = 1.0 / (1.0-rr);
    	while (1) {
      		xx = pow(drand48(), cc);
      		yy = xx + pow(drand48(), dd);
      		if ( yy <= 1.0 ) {
        		return -log(drand48()) * xx / yy / theta;
      		}
    	}
  	} else { /* rr > 1.0 */
    	/* Use bests algorithm */
    	bb = rr - 1.0;
    	cc = 3.0 * rr - 0.75;
    	while (1) {
      		uu = drand48();
      		vv = drand48();
      		ww = uu * (1.0 - uu);
      		yy = sqrt(cc / ww) * (uu - 0.5);
      		xx = bb + yy;
      		if (xx >= 0) {
        		zz = 64.0 * ww * ww * ww * vv * vv;
        		if ( ( zz <= (1.0 - 2.0 * yy * yy / xx) ) ||
             		( log(zz) <= 2.0 * (bb * log(xx / bb) - yy) ) ) {
          			return xx / theta;
        		}
      		}
    	}
  	}
}

double randbeta(double aa, double bb) 
{
	aa = randgamma(aa, 1);
  	bb = randgamma(bb, 1);
  	return aa/(aa+bb);
}

void randdir(double *pi, double *alpha, int veclength, int skip)
{
	double *pi2, *piend;
	double sum;

	sum = 0.0;
	piend = pi + veclength*skip;
	for (pi2 = pi ; pi2 < piend ; pi2 += skip){
		sum += *pi2 = randgamma(*alpha, 1);
		alpha += skip;
	}
	for ( pi2 = pi ; pi2 < piend ; pi2 += skip) {
		*pi2 /= sum;
	}
}

int randmult(double* pi, int veclength)
{
	int i;
  	double sum = 0.0, mass;
	
	for ( i = 0 ; i < veclength ; i++ ){
    	sum += *(pi + i);
		assert(*(pi + i) >= 0);
	}
	mass = drand48();
	if(mass <= 0 || mass >= 1){
		printf("rand = %f\n", mass);
		assert(mass > 0 && mass < 1);
	}
  	mass *= sum;
	i = 0;
  	while (1) {
    	mass -= *(pi + i);
    	if ( mass <= 0.0 ) break;
    	i++;
  	}
	return i;
}


double Normal(double m, double s)
/* ========================================================================
* Returns a normal (Gaussian) distributed real number.
* NOTE: use s > 0.0
*
* Uses a very accurate approximation of the normal idf due to Odeh & Evans, 
* J. Applied Statistics, 1974, vol 23, pp 96-97.
* ========================================================================
*/
{ 
	const double p0 = 0.322232431088;     const double q0 = 0.099348462606;
	const double p1 = 1.0;                const double q1 = 0.588581570495;
	const double p2 = 0.342242088547;     const double q2 = 0.531103462366;
	const double p3 = 0.204231210245e-1;  const double q3 = 0.103537752850;
	const double p4 = 0.453642210148e-4;  const double q4 = 0.385607006340e-2;
	double u, t, p, q, z;
	
	u   = drand48();
	if (u < 0.5)
		t = sqrt(-2.0 * log(u));
	else
		t = sqrt(-2.0 * log(1.0 - u));
	p   = p0 + t * (p1 + t * (p2 + t * (p3 + t * p4)));
	q   = q0 + t * (q1 + t * (q2 + t * (q3 + t * q4)));
	if (u < 0.5)
		z = (p / q) - t;
	else
		z = t - (p / q);
	return (m + s * z);
}

double likelihood(double **mu, double *sum_mu, int **s, int **data, int *n, int nDP, double h, int V)
{
	int i, j, k, v;
	double smu, lik = 0;
	smu = h * V;
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			mu[k][v]--;
			sum_mu[k]--;
		}
	}
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			lik  += log((h + mu[k][v]) / (smu + sum_mu[k]));
			mu[k][v]++;
			sum_mu[k]++;
		}
	}
	return lik;
}

void calPerp(double* theta, int **data, double **mu, double *sum_mu, double h, int V, double *perp)
{
	int i, k, l, v, tag, s1, s2;
	double summ, smu;
	smu = h * V;
	perp[0] = perp[1] = 0;
	s1 = 0; s2 = 0;
	for(i = 0; i < nDP; i++){
		tag = 0;
		for(k = 0; k <= K_t; k++){
			if(nj[k] == 0){
				if(tag == 0){
					theta[k] = 0;
					for(l = 0; l < nS; l++){
						theta[k] += a * M[l] * (q[i][l] * (1 - q[1-i][l]) / pow(1 + u[i], 1 - a) 
								+ q[0][l] * q[1][l] / pow(1 + u[0] + u[1], 1 - a));
					}
					tag = 1;
				}else{
					theta[k] = 0;
				}
			}else{
				theta[k] = (nj[k] - a) * (q[i][K_id[k]] * (1 - q[1-i][K_id[k]]) / (1 + u[i])
						+ q[0][K_id[k]] * q[1][K_id[k]] / (1 + u[0] + u[1]));
			}
		}
		summ = 0;
		for(k = 0; k <= K_t; k++){
			summ += theta[k];
		}
		for(k = 0; k <= K_t; k++){
			theta[k] /= summ;
		}
		
		if(tag_tr[i] == 1){
			s1 += n_t[i];
			for(l = 0; l < n_t[i]; l++){
				v = data[i][l] - 1;
				summ = 0;
				for(k = 0; k <= K_t; k++){
					summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
				}
				perp[0] += log(summ);
			}
		}else{
			s2 += n_t_tot[i] - n_t[i];
			for(l = n_t[i]; l < n_t_tot[i]; l++){
				v = data[i][l] - 1;
				summ = 0;
				for(k = 0; k <= K_t; k++){
					summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
				}
				perp[1] += log(summ);
			}
		}
	}
	perp[0] = exp(-perp[0] / max1(s1, 1));
	perp[1] = exp(-perp[1] / max1(s2, 1));
}

void sampleM(int i,	double a, double a0, double b0)
{
	M[i] = randgamma(K_ct[i] + a0, q[0][i] * (1 - q[1][i]) * (pow(1 + u[0], a) - 1)
			+ q[1][i] * (1 - q[0][i]) * (pow(1 + u[1], a) - 1)
			+ q[0][i] * q[1][i] * (pow(1 + u[0] + u[1], a) - 1) + b0);
}

struct uPosterior{
	double sumzj;
	uPosterior(double sumzj)
	:sumzj(sumzj)
	{}
	double operator()(double x) const
	{
		int i;
		double expx = exp(x), val = n_t[u_idx] * x;
		val -= sumzj * expx;
		for(i = 0; i < nS; i++){
			val -= M[i] * (q[u_idx][i] * (1 - q[1-u_idx][i]) * pow(1 + expx, a) + q[0][i] * q[1][i] * pow(1 + expx + u[1-u_idx], a));
		}
		return val;
	}
};

void sampleU(int j)
{
	int k;
	double log_u = log(u[j]);
	
	u_idx = j;
	
	uPosterior upos(sumzj[j]);
	log_u = slice_sampler1d(upos, log_u, drand48, lv, rv, 0.0, 10, 1000);
		
	u[j] = exp(log_u);
}

struct qPosterior{
	double qa, qb;
	qPosterior(double qa, double qb)
	:qa(qa), qb(qb)
	{}
	double operator()(double x) const
	{
		int k;
		double val = (qa - 1) * log(x) + (qb - 1) * log(1 - x);
		for(k = 0; k <= K_t; k++){
			if(nj[k] == 0){
				continue;
			}
			if(K_id[k] == q_i){
				if(x < 1e-10){
					continue;
				}
				if(x > 0.99999){
					val -= u[q_j] * JJ[k];
				}else{
					val += log(1 - x + x * exp(-u[q_j] * JJ[k]));
				}
				mxAssert(finite(val), "value infinite!!!");
			}
		}
		val -= x * M[q_i] * (q[1-q_j][q_i] * (pow(1 + u[0] + u[1], a) - 1) + (1 - q[1-q_j][q_i]) * (pow(1 + u[q_j], a) - 1)
				- q[1-q_j][q_i] * (pow(1 + u[1-q_j], a) - 1));
		mxAssert(finite(val), "value infinite!!!");
		return val;
	}
};
void sampleQ(double q_a, double q_b)
{
	int i, j;
	
	qPosterior qpos(q_a, q_b);
	
	for(j = 0; j < nDP; j++){
		for(i = 0; i < nS; i++){
			q_i = i;
			q_j = j;
			
			q[j][i] = slice_sampler1d(qpos, q[j][i], drand48, qlv, qrv, 0.0, 10, 1000 * nDP);
		}
	}
}

struct qPosterior_1{
	double qa, qb;
	qPosterior_1(double qa, double qb)
	:qa(qa), qb(qb)
	{}
	double operator()(double x) const
	{
		int k;
		double val = (qa - 1) * log(x) + (qb - 1) * log(1 - x);
		val -= x * M[q_i] * (q[1-q_j][q_i] * (pow(1 + u[0] + u[1], a) - 1) + (1 - q[1-q_j][q_i]) * (pow(1 + u[q_j], a) - 1)
				- q[1-q_j][q_i] * (pow(1 + u[1-q_j], a) - 1));
		mxAssert(finite(val), "value infinite!!!");
		return val;
	}
};
void sampleQ_1(double q_a, double q_b)
{
	int i, j, k, iter, c1, c2;
	double qq;
	
	for(j = 0; j < nDP; j++){
		for(i = 0; i < nS; i++){
			q_i = i;
			q_j = j;

			for(iter = 0; iter < 2; iter++){
				c1 = 0; c2 = 0;
				for(k = 0; k <= K_t; k++){
					if(nj[k] == 0){
						continue;
					}
					if(K_id[k] == q_i){
						qq = q[j][i] * exp(-u[j] * JJ[k]);
						if(qq / (1 - q[j][i] + qq) > drand48()){
							c1++;
						}else{
							c2++;
						}
					}
				}
				qPosterior_1 qpos(c1 + q_a, c2 + q_b);
				q[j][i] = slice_sampler1d(qpos, q[j][i], drand48, qlv, qrv, 0.0, 10, 1000);
			}
		}
	}
}

/*
 * sample \sigma
*/
struct aPosterior{
	int sumk;
	double sumlogJ;
	aPosterior(int sumk, double sumlogJ)
	:sumk(sumk), sumlogJ(sumlogJ)
	{}
	double operator()(double x) const
	{
		int i, k;
		double acc;
		acc = sumk * (log(x) - gammaln(1 - x));mxAssert(finite(acc), "aa");
		acc -= x * sumlogJ;mxAssert(finite(acc), "bb");
		for(i = 0; i < nS; i++){
			acc -= M[i] * (q[0][i] * (1 - q[1][i]) * pow(1 + u[0], x) + q[1][i] * (1 - q[0][i]) * pow(1 + u[1], x)
					+ q[0][i] * q[1][i] * pow(1 + u[0] + u[1], x));
		}mxAssert(finite(acc), "cc");
		return acc;
	}
};
void samplea_slice()
{
	int k, sumk = 0;
	double sumlogJ = 0;
	for(k = 0; k < nS; k++){
		sumk += K_ct[k];
	}
	for(k = 0; k <= K_t; k++){
		if(nj[k] > 0){
			mxAssert(JJ[k] > 0, "jumps less than zero!!!");
			sumlogJ += log(JJ[k]);
		}
	}
	aPosterior apos(sumk, sumlogJ);
	a = slice_sampler1d(apos, a, drand48, 1.0e-10, 0.9999, 0.0, 10, 1000);
	gamma_a = gamma(1 - a);
}

/*
 * sample jumps with observations
*/
void sampleJ()
{
	int i, j, k, z1, z2;
	double sumzu, qq;
	for(k = 0; k <= K_t; k++){
		if(nj[k] > 0){
			sumzu = 0;
			for(j = 0; j < nDP; j++){
				sumzu += q[j][K_id[k]] * u[j];
			}
			JJ[k] = randgamma(nj[k] - a, 1 + sumzu);
		}else{
			JJ[k] = 0;
		}
	}
	sumzj[0] = 0; sumzj[1] = 0;
	for(i = 0; i < 3; i++){
		for(k = 0; k <= K_t; k++){
			if(nj[k] > 0){
				sumzu = 0;
				if(i == 2){
					z1 = 0; z2 = 0;
				}
				for(j = 0; j < nDP; j++){
					if(nj_t[j][k] > 0 || q[j][K_id[k]] == 1.0){
						if(i == 2){
							if(j == 0){
								z1 += 1;
							}else{
								z2 += 1;
							}
						}
						sumzu += u[j];
					}else{
						qq = q[j][K_id[k]] * exp(-u[j] * JJ[k]);
						if(drand48() < qq / (1 - q[j][K_id[k]] + qq)){
							if(i == 2){
								if(j == 0){
									z1 += 1;
								}else{
									z2 += 1;
								}
							}
							sumzu += u[j];
						}
					}
				}
				JJ[k] = randgamma(nj[k] - a, 1 + sumzu); mxAssert(JJ[k] > 0, "jumps less than zero!!!");
				if(i == 2){
					sumzj[0] += z1 * JJ[k];
					sumzj[1] += z2 * JJ[k];
				}
			}else{
				JJ[k] = 0;
			}
		}
	}
}

/******************
input: 	0: mu,	K * V						output: 0: mu
		1: sum_mu,	K								1: sum_mu
		2: s,	cell								2: nj
		3: M										3: nj_t
		4: q										4: n_ts
		5: n										5: Kid
		6: n_t										6: s
		7: n_ts										7: n
		8: nj										8: M
		9: nj_t										9: u
		10: K_c										10: Kct	
		11: K_ct									11: us
		12: K_id									12: L
		13: gamma									13: r
		14: burnin,	1 * 1							14: Kc
		15:	nsample, 1 * 1
		16: lag, 1 * 1
		17: calLik, 1 * 1
		18: data
		19:	a
		20:	gamma_a
		
******************/

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	const int npara = 14;
	clock_t t1;
	char* log_file;
	int i, j, ss, num_DP, v, k, k1, V, K, kk, iter, maxIter, burnin, nsample, lag, calLik, nLik;
	int **s, **data, ns;
	double **mu, *sum_mu, *prob, *c_tmp, mu0, v_sum, lik;
	double *lik_pr, *lik_te_pr, *comp_tmp, llik[2];
	int lik_iter = 0;
	mxArray *struc;
	const char* fieldnames[npara] = {"mu", "sum_mu", "nj", "nj_t", "n_ts", "Kid", "s", "n", "M", "u", "Kct", "Kc", "q", "a"};
	// read data and initialize	
	V = mxGetN(prhs[0]);
	K = mxGetNumberOfElements(prhs[8]); /* current maximal topics */
	nDP = mxGetNumberOfElements(prhs[6]);
	if(nDP != 2){
		printf("Applicable only for 2 times!!!\n");
		exit(0);
	}
	tag_tr = mxReadIntVector(prhs[22], nDP, 0, 0);
	nS = mxGetNumberOfElements(prhs[5]);
	mu0 = mxGetScalar(prhs[13]);
	burnin = mxGetScalar(prhs[14]);
	nsample = mxGetScalar(prhs[15]);
	lag = mxGetScalar(prhs[16]);
	calLik = (int)mxGetScalar(prhs[17]);
	maxIter = burnin + nsample * lag;

	mu = mxReadDoubleMatrix(prhs[0], K, V, 0, 0);
	sum_mu = mxReadDoubleVector(prhs[1], K, 0, 0);
	s = mxReadIntCellVector(prhs[2], 0);
	M = mxReadDoubleVector(prhs[3], nS, 0, 0);
	q = mxReadDoubleMatrix(prhs[4], nDP, nS, 0, 0);
	n = mxReadIntVector(prhs[5], nS, 0, 0);
	n_t = mxReadIntVector(prhs[6], nDP, 0, 0);
	n_t_tot = mxReadIntVector(prhs[21], nDP, 0, 0);
	n_ts = mxReadIntMatrix(prhs[7], nDP, nS, 0, 0);
	nj = mxReadIntVector(prhs[8], K, 0, 0);
	nj_t = mxReadIntMatrix(prhs[9], nDP, K, 0, 0);
	K_c = mxReadIntVector(prhs[10], nDP + 1, 0, 0);
	K_ct = mxReadIntVector(prhs[11], nS, 0, 0);
	K_id = mxReadIntVector(prhs[12], K, 0, 0);
	data = mxReadIntCellVector(prhs[18], 0);
	a = mxGetScalar(prhs[19]);
	gamma_a = mxGetScalar(prhs[20]);
	
	/*srand ( time(NULL) );*/
	output = (int)mxGetScalar(prhs[23]);
	if(output == 1){
		i = (mxGetM(prhs[24]) * mxGetN(prhs[24])) + 1;
		log_file = (char*)mxCalloc(i, sizeof(char)); 
		int status = mxGetString(prhs[24], log_file, i);
		mxAssert(status == 0, "read log file fail!!!");
		fid = fopen(log_file, "w");
	}
	
	/*
	 * file to store statistics
	 */
#ifdef RECORD
	FILE *fidstat = fopen("stat_TNGG_marg2.txt", "w");
	/******************/
#endif
	
	v_sum = mu0 * V;
	prob = (double*)malloc(sizeof(double) * K);
	JJ = (double*)malloc(sizeof(double) * K);
	sumzj = (double*)malloc(sizeof(double) * 2);
	c_tmp = (double*)malloc(sizeof(double) * nS);
	u = (double*)malloc(sizeof(double) * nDP); //mxReadDoubleVector(prhs[25], nDP, 0, 0);

	for(i = 0; i < nDP; i++){
		u[i] = 100;
	}
	//u[0] = 37.7; u[1] = 40.2;
	
	ns = 0;
	struc = mxCreateStructMatrix(1, nsample, npara, fieldnames);
	plhs[0] = struc;
	if(calLik == 0){
		nLik = 0;
	}else{
		nLik = (maxIter - 1) / calLik + 1;
	}
	plhs[1] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_pr = mxGetPr(plhs[1]);
	plhs[2] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_te_pr = mxGetPr(plhs[2]);
	plhs[3] = mxCreateDoubleMatrix(1, 1, mxREAL);

	/**** delete empty classes ***/
	K_t = K_c[nDP];
	while(K_t > 0 && sum_mu[K_t-1] == 0){
		K_t--;
	}
	if(K_t == 0){
		if(output == 0){
			printf("error: empty topics!\n");
		}else{
			fprintf(fid, "error: empty topics!\n");
		}
		exit(0);
	}
	k = 0;
	while(k < K_t - 1){
		if(sum_mu[k] == 0){
			assert(nj[k] == 0);
			for(j = 0; j < nDP; j++){
				assert(nj_t[j][k] == 0);
			}
			comp_tmp = mu[k];
			for(kk = k; kk < K_t - 1; kk++){
				mu[kk] = mu[kk + 1];
				sum_mu[kk] = sum_mu[kk + 1];
				K_id[kk] = K_id[kk + 1];
				nj[kk] = nj[kk + 1];
			}
			mu[K_t - 1] = comp_tmp;
			sum_mu[K_t - 1] = 0;
			K_id[K_t - 1] = -1;
			nj[K_t - 1] = 0;
			for(num_DP = 0; num_DP < nDP; num_DP++){
				for(kk = k; kk < K_t - 1; kk++){
					nj_t[num_DP][kk] = nj_t[num_DP][kk + 1];
				}
				nj_t[num_DP][K_t - 1] = 0;
				for(j = 0; j < n_t[num_DP]; j++){
					assert(s[num_DP][j] != k + 1);
					if(s[num_DP][j] > k + 1){
						s[num_DP][j]--;
					}
				}
			}
			K_t--;
		}else{
			assert(nj[k] > 0);
			k++;
		}
	}
	K_c[nDP] = K_t;

	for(iter = 0; iter < maxIter; iter++){
		if(calLik && iter % calLik == 0){
			//llik[0] = likelihood(mu, sum_mu, s, data, n_t, nDP, mu0, V); llik[1] = 0;
			calPerp(prob, data, mu, sum_mu, mu0, V, llik);
			lik_pr[lik_iter] = llik[0];
			lik_te_pr[lik_iter] = llik[1];
			if(output == 0){
				printf("In iteration %d, K = %d, lik_tr = %f, lik_te = %f, a = %f...\n", iter, K_c[nDP], lik_pr[lik_iter], lik_te_pr[lik_iter], a);
			}else{
				fprintf(fid, "In iteration %d, K = %d, lik_tr = %f, lik_te = %f, a = %f...\n", iter, K_c[nDP], lik_pr[lik_iter], lik_te_pr[lik_iter], a);
			}
			lik_iter++;
		}else{
			if(output == 0){
				printf("In iteration %d, K = %d...\n", iter, K_c[nDP]);
			}else{
				fprintf(fid, "In iteration %d, K = %d...\n", iter, K_c[nDP]);
			}
		}

		K_t = K_c[nDP];
		assert(K_t > 0);
		sampleJ();
		for(i = 0; i < nS; i++){
			sampleM(i,	a, 0.1, 0.1);
			/*printf("M[%d] = %f\n", i, M[i]);*/
		}
		for(i = 0; i < nDP; i++){
			sampleU(i);
		}//printf("finish sampling u..\n");
		//sampleQ(0.5, 0.5);
		//sampleQ_1(0.5, 0.5);
		//samplea_slice();
		
		for(num_DP = 0; num_DP < nDP; num_DP++){/*printf("dsds\n");*/
			int tag;
			
		    /* Update allocation variables, s */
			for(j = 0; j < n_t[num_DP]; j++){
				v = data[num_DP][j] - 1;
				k = s[num_DP][j] - 1;
				assert(q[num_DP][K_id[k]] > 0);
				mu[k][v]--;
				assert(mu[k][v] >= 0);
				sum_mu[k]--;
				assert(sum_mu[k] >= 0);
				nj_t[num_DP][k]--;
				assert(nj_t[num_DP][k] >= 0);
				if(nj_t[num_DP][k] == 0){
					K_c[num_DP]--;
					assert(K_c[num_DP] >= 0);
				}
				nj[k]--;
				assert(nj[k] >= 0);
				n[K_id[k]]--;
				n_ts[num_DP][K_id[k]]--;
				if(nj[k] == 0){
					K_c[nDP]--;
					K_ct[K_id[k]]--;
					assert(K_c[nDP] >= 0);
					K_id[k] = -1;
				}

				tag = 0;
				for(kk = 0; kk <= K_t; kk++){
					if(nj[kk] == 0){
						assert(nj_t[num_DP][kk] == 0);
						if(tag == 0){
							prob[kk] = 0;
							for(ss = 0; ss < nS; ss++){
								c_tmp[ss] = a * M[ss] * (q[num_DP][ss] * (1 - q[1-num_DP][ss]) / pow(1 + u[num_DP], 1 - a) 
										+ q[0][ss] * q[1][ss] / pow(1 + u[0] + u[1], 1 - a));
								prob[kk] += c_tmp[ss];
							}
							tag = 1;
						}else{
							prob[kk] = 0;
						}
					}else{
						assert(K_id[kk] >= 0);
						prob[kk] = (nj[kk] - a) * (q[num_DP][K_id[kk]] * (1 - q[1-num_DP][K_id[kk]]) / (1 + u[num_DP])
								+ q[0][K_id[kk]] * q[1][K_id[kk]] / (1 + u[0] + u[1]));
					}
					assert(mu[kk][v] >= 0);
					prob[kk] *= (mu[kk][v] + mu0) / (sum_mu[kk] + v_sum);
					assert(prob[kk] >= 0);
				}
		        kk = randmult(prob, K_t + 1);
				s[num_DP][j] = kk + 1;
				mu[kk][v]++;
				assert(mu[kk][v] > 0);
				sum_mu[kk]++;
				assert(sum_mu[kk] > 0);
				nj_t[num_DP][kk]++;
				assert(nj_t[num_DP][kk] > 0);
				nj[kk]++;
				if(nj_t[num_DP][kk] == 1){
					K_c[num_DP]++;
					assert(K_c[num_DP] > 0);
				}
				if(nj[kk] == 1){
					K_id[kk] = randmult(c_tmp, nS);
					K_c[nDP]++;
					K_ct[K_id[kk]]++;
					assert(K_c[nDP] > 0);
				}
				
				n[K_id[kk]]++;
				n_ts[num_DP][K_id[kk]]++;
				assert(q[num_DP][K_id[kk]] > 0);

				if(nj_t[num_DP][kk] == 1){
					assert(nj[kk] > 0);

					if(kk == K - 1){
						assert(K_t == K - 1);
						K *= 2;
						mu = realloDouble(mu, K_t + 1, K - K_t - 1, V);
						sum_mu = (double*)realloc(sum_mu, K * sizeof(double));
						nj = (int*)realloc(nj, K * sizeof(int));
						K_id = (int*)realloc(K_id, K * sizeof(int));
						for(k = kk + 1; k < K; k++){
							sum_mu[k] = 0;
							nj[k] = 0;
							K_id[k] = -1;
						}
						prob = (double*)realloc(prob, K * sizeof(double));
						JJ = (double*)realloc(JJ, K * sizeof(double));
						for(k = 0; k < nDP; k++){
							nj_t[k] = (int*)realloc(nj_t[k], K * sizeof(int));
							for(i = kk + 1; i < K; i++){
								nj_t[k][i] = 0;
							}
						}
					}
					if(kk == K_t){
						K_t++;
					}
				}
			}

		}

		/*** delete classes **/
		k = 0;
		while(k < K_t - 1){
			if(sum_mu[k] == 0){
				assert(nj[k] == 0);
				for(j = 0; j < nDP; j++){
					assert(nj_t[j][k] == 0);
				}
				comp_tmp = mu[k];
				for(kk = k; kk < K_t - 1; kk++){
					mu[kk] = mu[kk + 1];
					sum_mu[kk] = sum_mu[kk + 1];
					K_id[kk] = K_id[kk + 1];
					nj[kk] = nj[kk + 1];
					JJ[kk] = JJ[kk + 1];
				}
				mu[K_t - 1] = comp_tmp;
				sum_mu[K_t - 1] = 0;
				K_id[K_t - 1] = -1;
				nj[K_t - 1] = 0;
				JJ[K_t - 1] = 0;
				for(num_DP = 0; num_DP < nDP; num_DP++){
					for(kk = k; kk < K_t - 1; kk++){
						nj_t[num_DP][kk] = nj_t[num_DP][kk + 1];
					}
					nj_t[num_DP][K_t - 1] = 0;
					for(j = 0; j < n_t[num_DP]; j++){
						assert(s[num_DP][j] != k + 1);
						if(s[num_DP][j] > k + 1){
							s[num_DP][j]--;
						}
					}
				}
				K_t--;
			}else{
				k++;
			}
		}
		K_c[nDP] = K_t;
		
		/*
		 * store statistics to calculate ESS
		 */
		if(iter == maxIter - 1000){
			t1 = clock();
		}
#ifdef RECORD
		if(iter >= maxIter - 1000){
			llik[0] = likelihood(mu, sum_mu, s, data, n_t, nDP, mu0, V);
			fprintf(fidstat, "%f ", llik[0]);
			for(i = 0; i < nS; i++){
				fprintf(fidstat, "%f %d ", M[i], K_ct[i]);
			}
			for(i = 0; i < nDP; i++){
				fprintf(fidstat, "%f ", u[i]);
			}
			fprintf(fidstat, "%d\n", K_t);
		}
#endif

		/**** record output **********/
		if((iter + 1) > burnin && (iter + 1 - burnin)%lag == 0){
			int del;
			mxArray *MM;
			if(iter != maxIter - 1){
				del = 0;
			}else{
				del = 1;
			}

			mxSetField(struc, ns, "mu", mxWriteDoubleMatrix(K_c[nDP], V, K, mu, 0, del));
			mxSetField(struc, ns, "sum_mu", mxWriteDoubleVector(1, K_c[nDP], sum_mu, 0, del));
			mxSetField(struc, ns, "nj", mxWriteIntVector(1, K_c[nDP], nj, 0, del));
			mxSetField(struc, ns, "nj_t", mxWriteIntMatrix(nDP, K_c[nDP], nDP, nj_t, 0, del));
			mxSetField(struc, ns, "n_ts", mxWriteIntMatrix(nDP, nS, nDP, n_ts, 0, del));
			mxSetField(struc, ns, "Kid", mxWriteIntVector(1, K_c[nDP], K_id, 0, del));
			mxSetField(struc, ns, "s", mxWriteIntCellVector(nDP, n_t, s, 0, del));
			mxSetField(struc, ns, "n", mxWriteIntVector(1, nS, n, 0, del));
			mxSetField(struc, ns, "M", mxWriteDoubleVector(1, nS, M, 0, del));
			mxSetField(struc, ns, "u", mxWriteDoubleVector(1, nDP, u, 0, del));
			mxSetField(struc, ns, "Kct", mxWriteIntVector(1, nS, K_ct, 0, del));
			mxSetField(struc, ns, "Kc", mxWriteIntVector(1, nDP + 1, K_c, 0, del));
			mxSetField(struc, ns, "q", mxWriteDoubleMatrix(nDP, nS, nDP, q, 0, 0));
			mxSetField(struc, ns, "a", mxWriteScalar(a));
			ns++;
		}

	}

	t1 = clock() - t1;
	*mxGetPr(plhs[3]) = (double)t1/CLOCKS_PER_SEC;
	
	free(prob);
	for(j = 0; j < nDP; j++){
		free(data[j]);
		free(q[j]);
	}
	free(data);
	free(q);
	free(c_tmp);
	free(n_t);
	free(n_t_tot);
	free(JJ);
	free(sumzj);
	free(tag_tr);
#ifdef RECORD
	fclose(fidstat);
#endif
}
