/*
 * SNGGh: weights controlled by q, is a MNGG with fixed weights q, not really a SNGG
 */

#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <assert.h>
#include "mex.h"
#include "slice-sampler.h"

#include "arms.c"

//#define RECORD

/********************
 * M:		concentration parameters for sources NGG, 1 * nS
 * Mg:		concentration parameters for parents of each group, 1 * nG
 * q:		subsampling rates from source s to group t, nG * nS
 * n:		counts on sources, 1 * nS
 * n_t:		counts on groups, 1 * nG
 * n_ts:	counts from group t to source s, nG * nS
 * nj:		counts on sources for topics k, 1 * K
 * nj_t:	counts on groups for topics k, nG * K
 * K_c:		#topics in each group, last one is #topics in total, 1 * (nG + 1)
 * K_ct:	#topics in each source, 1 * nS
 * K_t:		max #topics considered so far, 1 * 1
 * nG:		#groups
 * nS:		#sources
 * K_id:	indicator of which sources this topics belongs to, 1 * K
 * G_id:	indicator of which group one document belongs to, 1 * nDP
 * nDP:		total number of documents, 1 * 1
 * ng:		#documents per group, 1 * nG
 * G_idx:	indexes of documents for each group, nG * ng
 * nw:		#words for each document, 1 * nDP
 * nd_t:	#tables for each document, 1 * nDP
 * nd_tk:	#tables for each document for each topic, K * nDP
 * nd_nk:	#words for each document for each topic, K * nDP
 * u:		auxiliary variables for each document, 1 * nDP
 * u_mid:	auxiliary variables for each group, 1 * nG
 */

double a, a_b, gamma_a, *M, *Mg, **q, *sumQU;
int *n, *n_t, **n_ts, *nj, **nj_t, *K_c, *K_ct, *K_id, K_t, nS, nDP, u_idx, q_i, q_j, *tag_tr, output;
int nG, *G_id, *ng, **G_idx, *nw, *nw_all, *nd_t, **nd_tk, **nd_nk;
double *u, *u_mid;
double M_, N_;
int K_;
double lv = 0.001, rv = 100;
FILE* fid;

/*
 * data loading functions, adapted from Teh's HDP code
*/
#define max1(x1,x2) ( (x1) < (x2) ? (x2) : (x1) )

#define mxReadCellVectorDef(funcname,type) \
  type **funcname(const mxArray *mcell, type shift) { \
    mxArray *mvector; \
    double *mdouble; \
    int ii, jj; \
    type **result; \
    result = (type**)malloc(sizeof(type*)*mxGetNumberOfElements(mcell));  \
    for ( jj = 0 ; jj < mxGetNumberOfElements(mcell) ; jj++ ) {  \
      mvector = mxGetCell(mcell,jj);  \
      mdouble = mxGetPr(mvector);  \
      result[jj] = (type*)malloc(sizeof(type)*mxGetNumberOfElements(mvector));  \
      for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ )  \
        result[jj][ii] = (type)mdouble[ii] + shift;  \
    } \
    return result; \
  }
mxReadCellVectorDef(mxReadIntCellVector,int);
mxReadCellVectorDef(mxReadDoubleCellVector,double);

#define mxWriteCellVectorDef(funcname, type) \
  mxArray *funcname(int numcell,int *numentry,type **var,type shift, int del) { \
    mxArray *result, *mvector; \
    double *mdouble; \
    int ii, jj; \
    result = mxCreateCellMatrix(1,numcell); \
    for ( jj = 0 ; jj < numcell ; jj++) { \
      mvector = mxCreateDoubleMatrix(1,numentry[jj],mxREAL); \
      mxSetCell(result,jj,mvector); \
      mdouble = mxGetPr(mvector); \
      for ( ii = 0 ; ii < numentry[jj] ; ii++ ) \
        mdouble[ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del) \
    	free(var); \
    return result; \
  }
mxWriteCellVectorDef(mxWriteIntCellVector, int);
mxWriteCellVectorDef(mxWriteDoubleCellVector, double);

#define reallo(funcname, type) \
	type** funcname(type** var, int old_len, int add_len, int dim) { \
		int i, j; \
		type** result = (type**)malloc((old_len + add_len) * sizeof(type*)); \
		for(i = 0; i < old_len + add_len; i++){ \
			if(i < old_len){ \
				result[i] = var[i]; \
				var[i] = NULL; \
			}else{ \
				result[i] = (type*)malloc(dim * sizeof(type)); \
				for(j = 0; j < dim; j++){ \
					result[i][j] = (type)0; \
				} \
			} \
		} \
		free(var); \
		return result; \
	}
reallo(realloInt, int);
reallo(realloDouble, double);

#define mxReadVectorDef(funcname,type,str) \
  type *funcname(const mxArray *mvector,int number,type shift,type init) { \
    double *mdouble; \
    type *result;  \
    int ii; \
    number = max1(number,mxGetNumberOfElements(mvector)); \
    result = (type*) malloc(sizeof(type)*number); \
    mdouble = mxGetPr(mvector); \
    for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ ) \
      result[ii] = (type)mdouble[ii] + shift; \
    for ( ii = mxGetNumberOfElements(mvector) ; ii < number ; ii++ ) \
      result[ii] = init; \
    return result; \
  } 
mxReadVectorDef(mxReadIntVector,int,"%d ");
mxReadVectorDef(mxReadDoubleVector,double,"%g ");

#define mxWriteVectorDef(funcname, type, str) \
  mxArray *funcname(int mm,int nn,type *var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii; \
    result = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( ii = 0 ; ii < mm*nn ; ii++ ) \
      mdouble[ii] = var[ii] + shift; \
	if(del) \
    	free(var); \
    return result; \
  } 
mxWriteVectorDef(mxWriteIntVector, int, "%d ");
mxWriteVectorDef(mxWriteDoubleVector, double, "%g ");

#define mxReadMatrixDef(funcname,type) \
  type **funcname(const mxArray *marray,int mm,int nn,type shift,type init) { \
    double *mdouble; \
    int ii, jj, m1, n1; \
    type **result; \
    mdouble = mxGetPr(marray); \
    mm = max1(mm, m1 = mxGetM(marray)); \
    nn = max1(nn, n1 = mxGetN(marray)); \
    result = (type**) malloc(sizeof(type*)*mm); \
    for ( jj = 0 ; jj < mm ; jj++ ) { \
      result[jj] = (type*) malloc(sizeof(type)*nn); \
    } \
    for ( jj = 0 ; jj < m1 ; jj++ ) {\
      for ( ii = 0 ; ii < n1 ; ii++ ) \
        result[jj][ii] = (type)mdouble[ii*m1+jj] + shift; \
      for ( ii = n1 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    } \
    for ( jj = m1 ; jj < mm ; jj++ ) \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    return result; \
  }
mxReadMatrixDef(mxReadIntMatrix,int);
mxReadMatrixDef(mxReadDoubleMatrix,double);

#define mxWriteMatrixDef(funcname, type) \
  mxArray *funcname(int mm,int nn,int maxm,type **var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii, jj; \
    result  = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( jj = 0 ; jj < mm ; jj++) { \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        mdouble[jj+mm*ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del){ \
		for ( jj = mm ; jj < maxm ; jj ++ ) \
		  free(var[jj]); \
		free(var); \
	}\
	return result; \
  }
mxWriteMatrixDef(mxWriteIntMatrix, int);
mxWriteMatrixDef(mxWriteDoubleMatrix, double);

double mxReadScalar(const mxArray *mscalar) 
{
	return (*mxGetPr(mscalar));
}

mxArray *mxWriteScalar(double var) 
{ 
	mxArray *result;
	result = mxCreateDoubleMatrix(1,1,mxREAL);
	*mxGetPr(result) = var; 
	return result;
}

/*double randd()
{
	return (double)(rand() + 1) / (double)(RAND_MAX + 2);
}*/

double randgamma(double rr, double theta) 
{
	double aa, bb, cc, dd;
  	double uu, vv, ww, xx, yy, zz;

  	if ( rr <= 0.0 ) {
    	/* Not well defined, set to zero and skip. */
    	return 0.0;
  	} else if ( rr == 1.0 ) {
    	/* Exponential */
    	return - log(drand48()) / theta;
  	} else if ( rr < 1.0 ) {
    	/* Use Johnks generator */
    	cc = 1.0 / rr;
    	dd = 1.0 / (1.0-rr);
    	while (1) {
      		xx = pow(drand48(), cc);
      		yy = xx + pow(drand48(), dd);
      		if ( yy <= 1.0 ) {
        		return -log(drand48()) * xx / yy / theta;
      		}
    	}
  	} else { /* rr > 1.0 */
    	/* Use bests algorithm */
    	bb = rr - 1.0;
    	cc = 3.0 * rr - 0.75;
    	while (1) {
      		uu = drand48();
      		vv = drand48();
      		ww = uu * (1.0 - uu);
      		yy = sqrt(cc / ww) * (uu - 0.5);
      		xx = bb + yy;
      		if (xx >= 0) {
        		zz = 64.0 * ww * ww * ww * vv * vv;
        		if ( ( zz <= (1.0 - 2.0 * yy * yy / xx) ) ||
             		( log(zz) <= 2.0 * (bb * log(xx / bb) - yy) ) ) {
          			return xx / theta;
        		}
      		}
    	}
  	}
}

double randbeta(double aa, double bb) 
{
	aa = randgamma(aa, 1);
  	bb = randgamma(bb, 1);
  	return aa/(aa+bb);
}

void randdir(double *pi, double *alpha, int veclength, int skip)
{
	double *pi2, *piend;
	double sum;

	sum = 0.0;
	piend = pi + veclength*skip;
	for (pi2 = pi ; pi2 < piend ; pi2 += skip){
		sum += *pi2 = randgamma(*alpha, 1);
		alpha += skip;
	}
	for ( pi2 = pi ; pi2 < piend ; pi2 += skip) {
		*pi2 /= sum;
	}
}

int randmult(double* pi, int veclength)
{
	int i;
  	double sum = 0.0, mass;
	
	for ( i = 0 ; i < veclength ; i++ ){
    	sum += *(pi + i);
		mxAssert(*(pi + i) >= 0, "element less than zero!!!");
	}
	mass = drand48();
	if(mass <= 0 || mass >= 1){
		printf("rand = %f\n", mass);
		mxAssert(mass > 0 && mass < 1, "error in random number!!!");
	}
  	mass *= sum;
	i = 0;
  	while (1) {
    	mass -= *(pi + i);
    	if ( mass <= 0.0 ) break;
    	i++;
  	}
	return i;
}


double Normal(double m, double s)
/* ========================================================================
* Returns a normal (Gaussian) distributed real number.
* NOTE: use s > 0.0
*
* Uses a very accurate approximation of the normal idf due to Odeh & Evans, 
* J. Applied Statistics, 1974, vol 23, pp 96-97.
* ========================================================================
*/
{ 
	const double p0 = 0.322232431088;     const double q0 = 0.099348462606;
	const double p1 = 1.0;                const double q1 = 0.588581570495;
	const double p2 = 0.342242088547;     const double q2 = 0.531103462366;
	const double p3 = 0.204231210245e-1;  const double q3 = 0.103537752850;
	const double p4 = 0.453642210148e-4;  const double q4 = 0.385607006340e-2;
	double u, t, p, q, z;
	
	u   = drand48();
	if (u < 0.5)
		t = sqrt(-2.0 * log(u));
	else
		t = sqrt(-2.0 * log(1.0 - u));
	p   = p0 + t * (p1 + t * (p2 + t * (p3 + t * p4)));
	q   = q0 + t * (q1 + t * (q2 + t * (q3 + t * q4)));
	if (u < 0.5)
		z = (p / q) - t;
	else
		z = t - (p / q);
	return (m + s * z);
}

double gammaln(double x)
{
	#define M_lnSqrt2PI 0.91893853320467274178
	static double gamma_series[] = {
		76.18009172947146,
		-86.50532032941677,
		24.01409824083091,
		-1.231739572450155,
		0.1208650973866179e-2,
		-0.5395239384953e-5
	};
	int i;
	double denom, x1, series;
	mxAssert(x > 0, "argument less than zero!!!");
	if(x < 0)
		return NAN;
	if(x == 0)
		return INFINITY;
	/* Lanczos method */
	denom = x + 1;
	x1 = x + 5.5;
	series = 1.000000000190015;
	for(i = 0; i < 6; i++) {
		series += gamma_series[i] / denom;
		denom += 1.0;
	}
	return( M_lnSqrt2PI + (x + 0.5) * log(x1) - x1 + log(series / x));
}

double gamma(double x)
{
	double rr;
	mxAssert(finite(x) && x > 0, "x less than zero!!!");
	rr = exp(gammaln(x));
	mxAssert(finite(rr), "result infinite!!!");
	return rr;
}

double likelihood(double **mu, double *sum_mu, int **s, int **data, int *n, int nDP, double h, int V)
{
	int i, j, k, v;
	double smu, lik = 0;
	smu = h * V;
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			mu[k][v]--;
			sum_mu[k]--;
		}
	}
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			lik  += log((h + mu[k][v]) / (smu + sum_mu[k]));
			mu[k][v]++;
			sum_mu[k]++;
		}
	}
	return lik;
}

void CalTrTeLik(double* theta, int **data, double **mu, double *sum_mu, double**beta, double h, int V, double *lik)
{
	int i, k, l, g, ll, v;
	double summ, smu, s1, s2, tmp_M;
	lik[0] = lik[1] = 0;
	s1 = 0; s2 = 0;
	smu = h * V;
	
	for(g = 0; g < nG; g++){
		for(i = 0; i < ng[g]; i++){
			tmp_M = Mg[g] * a_b * pow(1 + u[G_idx[g][i]], a_b);
			for(k = 0; k < K_t; k++){
				theta[k] = (max1(nd_nk[k][G_idx[g][i]] - a_b, 0) + tmp_M * beta[g][k]);
			}
		
			summ = 0;
			for(k = 0; k < K_t; k++){
				summ += theta[k];
			}
			for(k = 0; k < K_t; k++){
				theta[k] /= summ;
			}
			
			if(tag_tr[G_idx[g][i]] == 1){
				s1 += nw[G_idx[g][i]];
				for(l = 0; l < nw[G_idx[g][i]]; l++){
					v = data[G_idx[g][i]][l] - 1;
					summ = 0;
					for(k = 0; k < K_t; k++){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}
					mxAssert(summ > 0, "sum less than zero!!!");
					lik[0] += log(summ);
				}
			}else{
				s2 += nw_all[G_idx[g][i]] - nw[G_idx[g][i]];
				for(l = nw[G_idx[g][i]]; l < nw_all[G_idx[g][i]]; l++){
					v = data[G_idx[g][i]][l] - 1;
					summ = 0;
					for(k = 0; k < K_t; k++){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}
					mxAssert(summ > 0, "sum less than zero!!!");
					lik[1] += log(summ);
				}
			}
		}
	}
	lik[0] = exp(-lik[0] / s1);
	lik[1] = exp(-lik[1] / max1(s2, 1));
}

/*
 * sample tables using CRP
*/
int randnumtable(double alpha, double a, int numdata)
{
  	int ii, numtable;

  	if (numdata == 0){
    	numtable = 0;
  	}else{
    	numtable = 1;
    	for ( ii = 1 ; ii < numdata ; ii++ ) {
      		if ( drand48() < alpha / (ii+alpha-a*numtable) ) 
				numtable++;
    	}
  	}
	return numtable;
}

/*
 * sample M on top NGG
*/
void sampleM(int i,	double a, double a0, double b0)
{
	M[i] = randgamma(K_ct[i] + a0, pow(1 + sumQU[i], a) + b0 - 1);
}

void sampleM_ave(double a, double a0, double b0)
{
	int i;
	double theta = 0;
	for(i = 0; i < nS; i++){
		theta += pow(1 + sumQU[i], a) - 1;
	}
	M[0] = randgamma(K_c[nG] + 1 + a0, theta + b0);
	for(i = 1; i < nS; i++){
		M[i] = M[0];
	}
}

/*
 * sample u for each group/time
*/
static double uterms(double x, void* data)
{
	int i;
	double val = n_t[u_idx] * x;
	for(i = 0; i < nS; i++){
		if(q[u_idx][i] > 0){
			val -= M[i] * pow(1 + sumQU[i] + q[u_idx][i] * exp(x), a_b);
			val -= (n[i] - a_b * K_ct[i]) * log(1 + sumQU[i] + q[u_idx][i] * exp(x));
		}
	}
	return val;
}

void sampleU(int j)
{
	int i;
	double log_u = log(u_mid[j]);
	
	u_idx = j;
	for(i = 0; i < nS; i++){
		if(q[j][i] > 0){
			sumQU[i] -= q[j][i] * u_mid[j]; mxAssert(sumQU[i] >= 0, "sumQU less than zero!!!");
			if(sumQU[i] < 0){
				sumQU[i] = 0;
			}
		}
	}
	arms_simple (8, &lv, &rv, uterms, NULL, 0, &log_u, &log_u);
	u_mid[j] = exp(log_u);
	for(i = 0; i < nS; i++){
		if(q[j][i] > 0){
			sumQU[i] += q[j][i] * u_mid[j];
		}
	}
}

/*
 * u term likelihood in the second level
*/
static double uuterms(double x, void* data)
{
	double val;

	val = N_ * x;
	val -= (N_ - K_ * a) * log(exp(x) + 1);
	val -= M_ * pow(exp(x) + 1, a);

	return val;
}

/*
 * a realization of \mu_t
*/
void updateBeta(double **beta)
{
	int g, kk, ss;
	double summ;
	for(g = 0; g < nG; g++){
		summ = 0;
		for(kk = 0; kk < K_t; kk++){
			if(nj[kk] == 0){
				mxAssert(nj_t[g][kk] == 0, "conflict!!!");
				if(kk == K_t - 1){
					beta[g][kk] = 0;
					for(ss = 0; ss < nS; ss++){
						if(q[g][ss] > 0){
							beta[g][kk] += q[g][ss] * M[ss] / pow(1 + sumQU[ss], 1 - a);
						}
					}
					beta[g][kk] *= a;
				}else{
					beta[g][kk] = 0;
				}
			}else{
				mxAssert(K_id[kk] >= 0, "no source jump!!!");
				if(q[g][K_id[kk]] > 0){
					beta[g][kk] = max1((nj[kk] - a), 0) * q[g][K_id[kk]] / (1 + sumQU[K_id[kk]]);
				}else{
					beta[g][kk] = 0;
				}
			}
			summ += beta[g][kk];
		}
		for(kk = 0; kk < K_t; kk++){
			beta[g][kk] /= summ;
		}
	}
}

/*
 * sample \sigma
*/
struct aPosterior{
	int sumk;
	double sumlogj;
	aPosterior(int sumk, double sumlogj)
	:sumk(sumk), sumlogj(sumlogj)
	{}
	double operator()(double x) const
	{
		int i, k;
		double acc;
		acc = sumk * (log(x) - gammaln(1 - x));
		acc -= x * sumlogj;
		for(i = 0; i < nS; i++){
			acc -= M[i] * pow(1 + sumQU[i], x);
		}
		return acc;
	}
};
void samplea_slice()
{
	int i, k, iter, sumk;
	double sumlogj;
	
	sumk = 0;
	for(i = 0; i < nS; i++){
		sumk += K_ct[i];
	}
	for(iter = 0; iter < 2; iter++){
		sumlogj = 0;
		for(k = 0; k < K_t; k++){
			if(nj[k] > 0){
				sumlogj += log(randgamma(nj[k] - a, 1 + sumQU[K_id[k]]));
			}
		}
		aPosterior apos(sumk, sumlogj);
		a = slice_sampler1d(apos, a, drand48, 1.0e-10, 0.9999, 0.0, 10, 1000);
	}
	gamma_a = gamma(1 - a);
}

/******************
input: 	0: mu,	K * V						output: 0: mu
		1: sum_mu,	K								1: sum_mu
		2: s,	cell								2: nj
		3: M										3: nj_t
		4: q										4: n_ts
		5: n										5: Kid
		6: n_t										6: s
		7: n_ts										7: n
		8: nj										8: M
		9: nj_t										9: u
		10: K_c										10: Kct	
		11: K_ct									11: us
		12: K_id									12: L
		13: G_id									13: r
		14: ng										14: Kc
		15: nd_nk
		16: nd_tk
		17: nw
		18: gamma									
		19: burnin,	1 * 1
		20:	nsample, 1 * 1
		21: lag, 1 * 1
		22: calLik, 1 * 1
		23: data
		24:	a
		25:	gamma_a
		
******************/
/*[samples, lik] = sample_zt_DHNGG_ave(mu, sum_mu, s, M, q, n, n_t, ...
    n_ts, nj, nj_t, K_c, K_ct, K_id, G_id, ng, nd_nk, nd_tk, nw, mu0, ...
    burnin, numbofits, every, dolik, data, a, gamma(1 - a), nw_all, tag_tr);*/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	const int npara = 20;
	char* log_file;
	int i, j, gg, ss, num_DP, v, k, k1, V, K, kk, iter, maxIter, burnin, nsample, lag, calLik, nLik;
	int **s, **data, ns;
	double **mu, **beta, *sum_mu, *prob, *c_tmp, tmp_M, u_tmp, mu0, v_sum;
	double *lik_pr, *lik_te_pr, *comp_tmp, llik[2];
	int lik_iter = 0;
	mxArray *struc;
	const char* fieldnames[npara] = {"a", "a_b", "mu", "sum_mu", "nj", "nj_t", "n_ts", "Kid", "s", "n", 
			"n_t", "nd_t", "nd_tk", "nd_nk", "M", "Mg", "u", "u_mid", "Kct", "Kc"};
	// read data and initialize	
	V = mxGetN(prhs[0]);
	K = mxGetNumberOfElements(prhs[8]); /* current maximal topics */
	nG = mxGetNumberOfElements(prhs[6]);
	nDP = mxGetNumberOfElements(prhs[13]);
	nS = mxGetNumberOfElements(prhs[5]);
	mu0 = mxGetScalar(prhs[18]);
	burnin = mxGetScalar(prhs[19]);
	nsample = mxGetScalar(prhs[20]);
	lag = mxGetScalar(prhs[21]);
	calLik = (int)mxGetScalar(prhs[22]);
	maxIter = burnin + nsample * lag;

	mu = mxReadDoubleMatrix(prhs[0], K, V, 0, 0);
	sum_mu = mxReadDoubleVector(prhs[1], K, 0, 0);
	s = mxReadIntCellVector(prhs[2], 0);
	M = mxReadDoubleVector(prhs[3], nS, 0, 0);
	q = mxReadDoubleMatrix(prhs[4], nG, nS, 0, 0);
	n = mxReadIntVector(prhs[5], nS, 0, 0);
	n_t = mxReadIntVector(prhs[6], nG, 0, 0);
	n_ts = mxReadIntMatrix(prhs[7], nG, nS, 0, 0);
	nj = mxReadIntVector(prhs[8], K, 0, 0);
	nj_t = mxReadIntMatrix(prhs[9], nG, K, 0, 0);
	K_c = mxReadIntVector(prhs[10], nG + 1, 0, 0);
	K_ct = mxReadIntVector(prhs[11], nS, 0, 0);
	K_id = mxReadIntVector(prhs[12], K, 0, 0);
	G_id = mxReadIntVector(prhs[13], nDP, 0, 0);
	ng = mxReadIntVector(prhs[14], nG, 0, 0);
	nd_nk = mxReadIntMatrix(prhs[15], K, nDP, 0, 0);
	nd_tk = mxReadIntMatrix(prhs[16], K, nDP, 0, 0);
	nw = mxReadIntVector(prhs[17], nDP, 0, 0);
	nw_all = mxReadIntVector(prhs[26], nDP, 0, 0);
	tag_tr = mxReadIntVector(prhs[27], nDP, 0, 0);
	data = mxReadIntCellVector(prhs[23], 0);
	a = mxGetScalar(prhs[24]);
	a_b = a;
	gamma_a = mxGetScalar(prhs[25]);
	
	output = (int)mxGetScalar(prhs[28]);
	
	/*srand ( time(NULL) );*/
	if(output == 1){
		i = (mxGetM(prhs[29]) * mxGetN(prhs[29])) + 1;
		log_file = (char*)mxCalloc(i, sizeof(char)); 
		int status = mxGetString(prhs[29], log_file, i);
		mxAssert(status == 0, "read log file fail!!!");
		fid = fopen(log_file, "w");
	}
	
	/*
	 * file to store statistics
	 */
#ifdef RECORD
	FILE *fidstat = fopen("stat_SNGGh.txt", "w");
	/******************/
#endif
	
	v_sum = mu0 * V;
	prob = (double*)malloc(sizeof(double) * K);
	Mg = (double*)malloc(sizeof(double) * nG);
	c_tmp = (double*)malloc(sizeof(double) * nS);
	sumQU = (double*)malloc(sizeof(double) * nS);
	u = (double*)malloc(sizeof(double) * nDP);
	u_mid = (double*)malloc(sizeof(double) * nG);
	G_idx = (int**)malloc(sizeof(int*) * nG);
	beta = (double**)malloc(sizeof(double*) * nG);
	nd_t = (int*)malloc(sizeof(int) * nDP);
	
	for(i = 0; i < nDP; i++){
		u[i] = 100;
		nd_t[i] = 0;
		for(k = 0; k < K; k++){
			nd_t[i] += nd_tk[k][i];
		}
	}
	for(i = 0; i < nG; i++){
		u_mid[i] = 100;
		Mg[i] = 10;
		G_idx[i] = (int*)malloc(sizeof(int) * ng[i]);
		k = 0;
		for(j = 0; j < nDP; j++){
			if(G_id[j] == i){
				G_idx[i][k] = j;
				k++;
			}
		}mxAssert(k == ng[i], "#customer not agree!!!");
		beta[i] = (double*)malloc(sizeof(double) * K);
	}
	for(i = 0; i < nS; i++){
		sumQU[i] = 0;
		for(j = 0; j < nG; j++){
			if(q[j][i] > 0){
				sumQU[i] += q[j][i] * u_mid[j];
			}
		}
	}
	
	ns = 0;
	struc = mxCreateStructMatrix(1, nsample, npara, fieldnames);
	plhs[0] = struc;
	if(calLik == 0){
		nLik = 0;
	}else{
		nLik = (maxIter - 1) / calLik + 1;
	}
	plhs[1] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_pr = mxGetPr(plhs[1]);
	plhs[2] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_te_pr = mxGetPr(plhs[2]);

	/**** delete empty classes ***/
	K_t = K_c[nG];
	while(K_t > 0 && sum_mu[K_t-1] == 0){
		K_t--;
	}
	if(K_t == 0){
		if(output == 0){
			printf("error: empty topics!\n");
		}else{
			fprintf(fid, "error: empty topics!\n");
		}
		exit(0);
	}
	k = 0;
	while(k < K_t - 1){
		if(sum_mu[k] == 0){
			mxAssert(nj[k] == 0, "#topics not equal to zero!!!");
			for(j = 0; j < nG; j++){
				mxAssert(nj_t[j][k] == 0, "non zero topics!!!");
			}
			comp_tmp = mu[k];
			for(kk = k; kk < K_t - 1; kk++){
				mu[kk] = mu[kk + 1];
				sum_mu[kk] = sum_mu[kk + 1];
				K_id[kk] = K_id[kk + 1];
				nj[kk] = nj[kk + 1];
			}
			mu[K_t - 1] = comp_tmp;
			sum_mu[K_t - 1] = 0;
			K_id[K_t - 1] = -1;
			nj[K_t - 1] = 0;
			for(num_DP = 0; num_DP < nG; num_DP++){
				for(kk = k; kk < K_t - 1; kk++){
					nj_t[num_DP][kk] = nj_t[num_DP][kk + 1];
				}
				nj_t[num_DP][K_t - 1] = 0;
			}
			for(num_DP = 0; num_DP < nDP; num_DP++){
				for(kk = k; kk < K_t - 1; kk++){
					nd_tk[kk][num_DP] = nd_tk[kk + 1][num_DP];
					nd_nk[kk][num_DP] = nd_nk[kk + 1][num_DP];
				}
				nd_tk[K_t - 1][num_DP] = 0;
				nd_nk[K_t - 1][num_DP] = 0;
				for(j = 0; j < nw[num_DP]; j++){
					mxAssert(s[num_DP][j] != k + 1, "data in empty topic!!!");
					if(s[num_DP][j] > k + 1){
						s[num_DP][j]--;
					}
				}
			}
			K_t--;
		}else{
			mxAssert(nj[k] > 0, "empty topic!!!");
			k++;
		}
	}
	K_c[nG] = K_t;

	K_t++;
	updateBeta(beta);
	for(iter = 0; iter < maxIter; iter++){
		// correct errors in sumQU
		if((iter + 1) % 300 == 0){
			for(i = 0; i < nS; i++){
				sumQU[i] = 0;
				for(j = 0; j < nG; j++){
					if(q[j][i] > 0){
						sumQU[i] += q[j][i] * u_mid[j];
					}
				}
			}
		}
		
		if(calLik && iter % calLik == 0){
			/*lik = likelihood(mu, sum_mu, s, data, nw, nDP, mu0, V);*/
			CalTrTeLik(prob, data, mu, sum_mu, beta, mu0, V, llik);
			lik_pr[lik_iter] = llik[0];
			lik_te_pr[lik_iter] = llik[1];
			lik_iter++;
			if(output == 0){
				printf("In iteration %d, K = %d, lik_tr = %f, lik_te = %f...\n", iter, K_c[nG], llik[0], llik[1]);
			}else{
				fprintf(fid, "In iteration %d, K = %d, lik_tr = %f, lik_te = %f...\n", iter, K_c[nG], llik[0], llik[1]);
			}
		}else{
			if(output == 0){
				printf("In iteration %d, K = %d...\n", iter, K_c[nG]);
			}else{
				fprintf(fid, "In iteration %d, K = %d...\n", iter, K_c[nG]);
			}
		}
		
		mxAssert(K_t > 0, "#topic less than zero!!!");
		for(i = 0; i < nS; i++){
			sampleM(i,	a, 5, 0.1);
			/*if(output == 0){
				printf("M[%d] = %f\n", i, M[i]);
			}else{
				fprintf(fid, "M[%d] = %f\n", i, M[i]);
			}*/
		}
		/*sampleM_ave(a, 0.1, 0.1);*/
		for(i = 0; i < nG; i++){
			sampleU(i);
		}printf("finish sampling u..\n");
		
		samplea_slice();
		
		// sample a realization of \mu_t
		updateBeta(beta);printf("finish updating beta..\n");
		
		for(num_DP = 0; num_DP < nDP; num_DP++){/*printf("dsds\n");*/
			int tag, g;
			g = G_id[num_DP];
			
			M_ = Mg[g];
			N_ = nw[num_DP];
			K_ = nd_t[num_DP];
			u_tmp = log(u[num_DP]);
			arms_simple (8, &lv, &rv, uuterms, NULL, 0, &u_tmp, &u_tmp);
			u[num_DP] = exp(u_tmp);
			
		    /* Update allocation variables, s */
			tmp_M = Mg[g] * a_b * pow(1 + u[num_DP], a_b);
			for(j = 0; j < nw[num_DP]; j++){
				v = data[num_DP][j] - 1;
				k = s[num_DP][j] - 1;
				mxAssert(q[g][K_id[k]] > 0, "region weight less than zero!!!");
				mu[k][v]--;
				mxAssert(mu[k][v] >= 0, "empty topic!!!");
				sum_mu[k]--;
				mxAssert(sum_mu[k] >= 0, "empty topic");
				nd_nk[k][num_DP]--;
				mxAssert(nd_nk[k][num_DP] >= 0, "empty topic!!!");
				
				for(kk = 0; kk < K_t; kk++){
					if(K_id[kk] >= 0 && q[g][K_id[kk]] > 0){
						prob[kk] = (max1(nd_nk[kk][num_DP] - a_b, 0) + tmp_M * beta[g][kk]);
					}else{
						prob[kk] = 0;
					}
					if(prob[kk]){
						prob[kk] *= (mu[kk][v] + mu0) / (sum_mu[kk] + v_sum);
					}
				}
		        kk = randmult(prob, K_t);
				s[num_DP][j] = kk + 1;
				mu[kk][v]++;
				mxAssert(mu[kk][v] > 0, "empty topic!!!");
				sum_mu[kk]++;
				mxAssert(sum_mu[kk] > 0, "empty topic!!!");
				nd_nk[kk][num_DP]++;
				mxAssert(nd_nk[kk][num_DP] > 0, "empty topic");
				
				if(nd_nk[kk][num_DP] == 1){
					K_++;
					u_tmp = log(u[num_DP]);
					arms_simple (8, &lv, &rv, uuterms, NULL, 0, &u_tmp, &u_tmp);
					u[num_DP] = exp(u_tmp);
					tmp_M = Mg[g] * a_b * pow(1 + u[num_DP], a_b);
					
					if(sum_mu[kk] == 1){
						for(ss = 0; ss < nS; ss++){
							if(q[g][ss] > 0){
								c_tmp[ss] = q[g][ss] * M[ss] / pow(1 + sumQU[ss], 1 - a_b);
							}else{
								c_tmp[ss] = 0;
							}
						}
						K_id[kk] = randmult(c_tmp, nS);
					}

					if(kk == K_t - 1){
						double b_tmp;
						if(kk == K - 1){
							K *= 2;
							mu = realloDouble(mu, K_t, K - K_t, V);
							sum_mu = (double*)realloc(sum_mu, K * sizeof(double));
							nj = (int*)realloc(nj, K * sizeof(int));
							nd_tk = realloInt(nd_tk, K_t, K - K_t, nDP);
							nd_nk = realloInt(nd_nk, K_t, K - K_t, nDP);
							K_id = (int*)realloc(K_id, K * sizeof(int));
							for(k = kk + 1; k < K; k++){
								sum_mu[k] = 0;
								nj[k] = 0;
								K_id[k] = -1;
							}
							prob = (double*)realloc(prob, K * sizeof(double));
							for(k = 0; k < nG; k++){
								nj_t[k] = (int*)realloc(nj_t[k], K * sizeof(int));
								for(i = kk + 1; i < K; i++){
									nj_t[k][i] = 0;
								}
							}
							for(gg = 0; gg < nG; gg++){
								beta[gg] = (double*)realloc(beta[gg], K * sizeof(double));
							}
						}
						/*
						 * reassign probability to the new topic approximately, should really resample
						 */
						for(gg = 0; gg < nG; gg++){
							b_tmp = beta[gg][K_t - 1];
							beta[gg][K_t - 1] *= ((1 - a_b) / (Mg[gg] * a_b * pow((1 + u_mid[gg]), a_b) + 1 - a_b));
							beta[gg][K_t] = b_tmp - beta[gg][K_t - 1];
						}
						K_t++;
					}
				}
			}
			/*sample tables*/
		    for(k = 0; k < K_t; k++){
		    	if(K_id[k] >= 0){
					kk = nd_tk[k][num_DP];
					n_t[g] -= kk;
					mxAssert(n_t[g] >= 0, "empty table!!!");
					nj_t[g][k] -= kk;
					mxAssert(nj_t[g][k] >= 0, "empty table!!!");
					nd_t[num_DP] -= kk;
					mxAssert(nd_t[num_DP] >= 0, "empty table!!!");
					nj[k] -= kk;
					mxAssert(nj[k] >= 0, "empty table!!!");
					kk = nd_tk[k][num_DP] = randnumtable(tmp_M * beta[g][k], a_b, nd_nk[k][num_DP]);
					n_t[g] += kk;
					nj_t[g][k] += kk;
					nd_t[num_DP] += kk;
					nj[k] += kk;
		    	}
		    }
		}
		
		/* adjust n, n_ts */
		for(ss = 0; ss < nS; ss++){
			n[ss] = 0;
			for(gg = 0; gg < nG; gg++){
				n_ts[gg][ss] = 0;
			}
		}
		for(k = 0; k < K_t; k++){
			if(K_id[k] >= 0){
				n[K_id[k]] += nj[k];
			}
		}
		for(i = 0; i < nDP; i++){
		    for(k = 0; k< K_t; k++){
		    	if(K_id[k] >= 0){
		    		n_ts[G_id[i]][K_id[k]] += nd_tk[k][i];
		    	}
		    }
		}

		/*sample M*/
		for(gg = 0; gg < nG; gg++){
			double K_tot, theta_tot;
			K_tot = 0;
			theta_tot = 0;
			for(num_DP = 0; num_DP < ng[gg]; num_DP++){
				K_tot += nd_t[G_idx[gg][num_DP]];
				theta_tot += pow(1 + u[G_idx[gg][num_DP]], a_b) - 1;
			}mxAssert(K_tot >= 0 && theta_tot >= 0, "invalid values!!!");
			theta_tot = randgamma(K_tot + 0.1, theta_tot + 0.5);
			Mg[gg] = theta_tot;
		}
		
		/*adjust K_c, K_ct*/
		for(gg = 0; gg < nG; gg++){
			K_c[gg] = 0;
			for(k = 0; k < K_t; k++){
				if(K_id[k] >= 0 && nj_t[gg][k] > 0){
					K_c[gg]++;
				}
			}
		}
		for(gg = 0; gg < nS; gg++){
			K_ct[gg] = 0;
		}
		K_c[nG] = 0;
		for(k = 0; k < K_t; k++){
			if(K_id[k] >= 0){
				K_c[nG]++;
				K_ct[K_id[k]]++;
			}
		}
		
		/*** delete classes **/
		k = 0;
		while(k < K_t - 1){
			if(sum_mu[k] == 0){
				mxAssert(nj[k] == 0, "empty topic!!!");
				for(j = 0; j < nDP; j++){
					mxAssert(nd_nk[k][j] == 0, "empty topic!!!");
				}
				comp_tmp = mu[k];
				for(kk = k; kk < K_t - 1; kk++){
					mu[kk] = mu[kk + 1];
					sum_mu[kk] = sum_mu[kk + 1];
					K_id[kk] = K_id[kk + 1];
					nj[kk] = nj[kk + 1];
				}
				mu[K_t - 1] = comp_tmp;
				sum_mu[K_t - 1] = 0;
				K_id[K_t - 1] = -1;
				nj[K_t - 1] = 0;
				for(num_DP = 0; num_DP < nG; num_DP++){
					for(kk = k; kk < K_t - 1; kk++){
						nj_t[num_DP][kk] = nj_t[num_DP][kk + 1];
					}
					nj_t[num_DP][K_t - 1] = 0;
				}
				for(num_DP = 0; num_DP < nDP; num_DP++){
					for(kk = k; kk < K_t - 1; kk++){
						nd_tk[kk][num_DP] = nd_tk[kk + 1][num_DP];
						nd_nk[kk][num_DP] = nd_nk[kk + 1][num_DP];
					}
					nd_tk[K_t - 1][num_DP] = 0;
					nd_nk[K_t - 1][num_DP] = 0;
				}
				for(num_DP = 0; num_DP < nDP; num_DP++){
					for(j = 0; j < nw[num_DP]; j++){
						mxAssert(s[num_DP][j] != k + 1, "words in empty topic!!!");
						if(s[num_DP][j] > k + 1){
							s[num_DP][j]--;
						}
					}
				}
				K_t--;
			}else{
				k++;
			}
		}
		
		/*
		 * store statistics to calculate ESS
		 */
#ifdef RECORD
		if(iter >= maxIter - 1000){
			llik[0] = likelihood(mu, sum_mu, s, data, nw, nDP, mu0, V);
			fprintf(fidstat, "%f ", llik[0]);
			for(i = 0; i < nS; i++){
				fprintf(fidstat, "%f %f %d", M[i], sumQU[i], K_ct[i]);
			}
			fprintf(fidstat, "\n");
		}
#endif

		/**** record output **********/
		if((iter + 1) > burnin && (iter + 1 - burnin)%lag == 0){
			int del;
			mxArray *MM;
			if(iter != maxIter - 1){
				del = 0;
			}else{
				del = 1;
			}

			mxSetField(struc, ns, "a", mxWriteScalar(a));
			mxSetField(struc, ns, "a_b", mxWriteScalar(a_b));
			mxSetField(struc, ns, "mu", mxWriteDoubleMatrix(K_c[nG], V, K, mu, 0, del));
			mxSetField(struc, ns, "sum_mu", mxWriteDoubleVector(1, K_c[nG], sum_mu, 0, del));
			mxSetField(struc, ns, "nj", mxWriteIntVector(1, K_c[nG], nj, 0, del));
			mxSetField(struc, ns, "nj_t", mxWriteIntMatrix(nG, K_c[nG], nG, nj_t, 0, del));
			mxSetField(struc, ns, "n_ts", mxWriteIntMatrix(nG, nS, nG, n_ts, 0, del));
			mxSetField(struc, ns, "Kid", mxWriteIntVector(1, K_c[nG], K_id, 0, del));
			mxSetField(struc, ns, "s", mxWriteIntCellVector(nDP, nw, s, 0, del));
			mxSetField(struc, ns, "n", mxWriteIntVector(1, nS, n, 0, del));
			mxSetField(struc, ns, "n_t", mxWriteIntVector(1, nG, n_t, 0, del));
			mxSetField(struc, ns, "nd_t", mxWriteIntVector(1, nDP, nd_t, 0, del));
			mxSetField(struc, ns, "nd_tk", mxWriteIntMatrix(K_c[nG], nDP, K, nd_tk, 0, del));
			mxSetField(struc, ns, "nd_nk", mxWriteIntMatrix(K_c[nG], nDP, K, nd_nk, 0, del));
			mxSetField(struc, ns, "M", mxWriteDoubleVector(1, nS, M, 0, del));
			mxSetField(struc, ns, "Mg", mxWriteDoubleVector(1, nG, Mg, 0, del));
			mxSetField(struc, ns, "u", mxWriteDoubleVector(1, nDP, u, 0, del));
			mxSetField(struc, ns, "u_mid", mxWriteDoubleVector(1, nG, u_mid, 0, del));
			mxSetField(struc, ns, "Kct", mxWriteIntVector(1, nS, K_ct, 0, del));
			mxSetField(struc, ns, "Kc", mxWriteIntVector(1, nG + 1, K_c, 0, del));
			ns++;
		}

	}

	free(prob);
	for(j = 0; j < nDP; j++){
		free(data[j]);
	}
	for(j = 0; j < nG; j++){
		free(G_idx[j]);
		free(beta[j]);
		free(q[j]);
	}
	free(G_idx);
	free(beta);
	free(q);
	free(data);
	free(c_tmp);
	free(sumQU);
	free(G_id);
	free(ng);
	free(nw);
	free(nw_all);
	free(tag_tr);
	if(output == 1){
		mxFree(log_file);
		fclose(fid);
	}
#ifdef RECORD
	fclose(fidstat);
#endif
}
