/*
 * seqTM.c
 *
 */

#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>

#include <gsl/gsl_vector.h>
#include <gsl/gsl_sf.h>
#include <gsl/gsl_math.h>
#include <unistd.h>
#include <gsl/gsl_randist.h>

#include "seqTM.h"
#include "model.h"
#include "gibbs.h"
#include "util.h"
#include "params.h"

extern dpyp_params PARAMS;

double* vector;
int* docs_id;
int* rdocs_id;
double *v_t;
double *exp_term;
double **Nkw_ave;
double *Nk_ave;

extern gsl_rng* glob_r;

int directory_exist(const char *dname)
{
    struct stat st;
    int ret;

    if (stat(dname,&st) != 0)
    {
        return 0;
    }

    ret = S_ISDIR(st.st_mode);

    if(!ret)
    {
        errno = ENOTDIR;
    }

    return ret;
}

static void update_phi(Model *model)
{
	for(int k = 0; k < model->K; k++){
		double s = 0;
		double mm = vget(model->theta, 0);
		for(int v = 1; v < model->v; v++){
			int idx = k * model->v + v;
			if(vget(model->theta, idx) > mm){
				mm = vget(model->theta, idx);
			}
		}
		for(int v = 0; v < model->v; v++){
			int idx = k * model->v + v;
			s += exp(vget(model->theta, idx) - mm);
		}
		for(int v = 0; v < model->v; v++){
			int idx = k * model->v + v;
			mset(model->phi, k, v, exp(vget(model->theta, idx) - mm) / s);
		}
	}
}

/*
 * Train the model
 *
 */
static void sampling(Model* model, Cts* cts,
		Assignment* ass, Corpus* c, Corpus* c_te,
		vocabulary* v, char* root, int pred)
{
	int i, ite;
	int samples;
	int begin_k;
	clock_t t1, t2;
	double lg_lihood;
	char str[BUFSIZ];
	Corpus* c1 = NULL;
	Corpus* c2 = NULL;
	Assignment* ass_te = NULL;
	Cts* cts_te = NULL;

	ite = 1;
	samples = 0;

	// save train, test perplexity
	FILE *fpt_tr, *fpt_te;
	sprintf(str, "%s/train.txt", root);
	fpt_tr = fopen(str, "w");
	if(!fpt_tr){
		printf("Cannot open file %s\n", str);
		exit(0);
	}
	sprintf(str, "%s/test.txt", root);
	fpt_te = fopen(str, "w");
	if(!fpt_te){
		printf("Cannot open file %s\n", str);
		exit(0);
	}

	/*********** test data *************/
	if(c_te){
		if(pred == 0){
			c1 = (Corpus*)malloc(sizeof(Corpus));
			c2 = (Corpus*)malloc(sizeof(Corpus));
			split_corpus(c_te, c1, c2);
		}else{
			c1 = c_te;
		}
		ass_te = new_assignment(c1);
		cts_te = new_cts(model->K, model->v, c1);
		init_state_gibbs(model, cts_te, ass_te, c1, 1, 0);
	}
	/*********************************/

	do{
		update_phi(model);

		t1 = clock();
		stm_e_gibbs(model, cts, ass, c);
		t2 = clock() - t1;
		t1 = clock();

		/************ test data **********/
		if(c_te){
			if(pred == 0){
				held_out_gibbs(model, cts_te, ass_te, c1);
			}
		}
		/********************************/

		t2 += (clock() - t1);
		printf(">>> running time for %d-iteration: %lf seconds\n",  ite, (double)t2/CLOCKS_PER_SEC);

		if(ite > PARAMS.burn_in && ite % PARAMS.sampling_lag == 0){
			samples++;

		}
		if(ite % 3 == 0){
			lg_lihood = spyp_marg_lglihood(c, cts, model);
			printf("\n++++++++++++ SPYP train perplexity +++++++++++++++\n");
			printf(" ===> %lf\n\n", lg_lihood);

			// save train, test perplexity
			fprintf(fpt_tr, "%lf\n", lg_lihood);

			if(c_te){
				if(pred == 0){
					lg_lihood = spyp_marg_lglihood(c2, cts_te, model);
					printf("\n++++++++++++ SPYP test perplexity +++++++++++++++\n");
					printf(" ===> %lf\n\n", lg_lihood);

					// save train, test perplexity
					fprintf(fpt_te, "%lf\n", lg_lihood);

				}
			}

			//sample_alpha(cts, model, c);

		}

		ite++;
	}while(ite <= PARAMS.gibbs_max_iter);

	sprintf(str, "%s/SPYP_final", root);
	mkdir(str, S_IRUSR | S_IWUSR | S_IXUSR);
	save_model(model, model->phi, str);

	begin_k = 0;
	sprintf(str, "%s/top_%d_words_final.txt", root, PARAMS.top_words);
	print_top_words(PARAMS.top_words, begin_k, begin_k + model->K-1, model->phi, cts, v, str);
	save_topic_assignmnet(c, ass, root);

	if(c_te){
		free_assignment(ass_te, c1);
		free_cts(cts_te, c1, model->K);
		if(pred == 0){
			free_corpus(c1);
			free_corpus(c2);
		}
	}

	// save train, test perplexity
	fclose(fpt_tr);
	fclose(fpt_te);
}

void estimate(Corpus* c, Corpus* c_te, vocabulary* v, char* root, char* dir_p, int init, char* z_root)
{
	Model* model;
	Cts* cts;
	Assignment* ass;

	initial_rng();

	model = random_init(PARAMS.K_i, v->size, PARAMS.alpha, PARAMS.gamma,
			PARAMS.ALPHA, PARAMS.ETA, PARAMS.C, PARAMS.NUMLF, PARAMS.BATCHSIZE);

	cts = new_cts(model->K, model->v, c);
	if(init){
		ass = read_topic_assignmnet(c, z_root);
	}else{
		ass = new_assignment(c);
	}

	printf(">>>>>> Begin Gibbs sampling iteration for SPYP Model ......\n");
	if(init)
		init_state_gibbs(model, cts, ass, c, 0, 1);
	else
		init_state_gibbs(model, cts, ass, c, 0, 0);

	vector = (double*)malloc(sizeof(double) * 5 * model->K);

	docs_id = (int*)malloc(sizeof(int) * c->ndocs);
	rdocs_id = (int*)malloc(sizeof(int) * model->BATCHSIZE);
	for(int d = 0; d < c->ndocs; d++){
		docs_id[d] = d;
	}
	v_t = (double*)malloc(sizeof(double) * model->v * model->K);
	exp_term = (double*)malloc(sizeof(double) * model->v * model->K);
	Nkw_ave = (double**)malloc(sizeof(double*) * model->K);
	Nk_ave = (double*)malloc(sizeof(double) * model->K);
	for(int k = 0; k < model->K; k++){
		Nkw_ave[k] = (double*)malloc(sizeof(double) * model->v);
	}
	for(int d = 0; d < model->v * model->K; d++){
		v_t[d] = gsl_ran_gaussian(glob_r, sqrt(model->ETA));
		assert(gsl_finite(v_t[d]));
	}

	sampling(model, cts, ass, c, c_te, v, root, 0);
	printf(">>>>>> End Gibbs sampling iteration \n");

	//free(lglik);
	free(Nk_ave);
	for(int k = 0; k < model->K; k++){
		free(Nkw_ave[k]);
	}
	free(Nkw_ave);
	free_rng();
	free_all(c, model, cts, ass, v);
	free(vector);
	free(docs_id);
	free(rdocs_id);
	free(v_t);
	free(exp_term);
	if(c_te){
		free_corpus(c_te);
	}
}

double spyp_marg_lglihood(Corpus* c, Cts* cts, Model* model)
{
	double lik = 0;
	int ns = 0;
	for(int d = 0; d < c->ndocs; d++){
		double s = 0;
		for(int k = 0; k < model->K; k++){
			vector[k] = cts->n[d][k] + model->alpha;assert(vector[k] > 0);
			s += vector[k];
		}
		for(int k = 0; k < model->K; k++){
			vector[k] /= s;
		}
		ns += c->docs[d].total;
		for(int l = 0; l < c->docs[d].total; l++){
			double val = 0;
			int w = c->docs[d].words[l];
			for(int k = 0; k < model->K; k++){
				val += vector[k] * mget(model->phi, k, w);assert(val > 0);
			}
			lik += gsl_sf_log(val);
		}
	}
	return exp(-lik / ns);
}
