/*
 * seqTM.c
 *
 */

#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>

#include <gsl/gsl_vector.h>
#include <gsl/gsl_sf.h>
#include <gsl/gsl_math.h>
#include <unistd.h>

#include <scythestat/stat.h>
#include <scythestat/smath.h>

#include "NBMM.h"
#include "model.h"
#include "gibbs.h"
#include "util.h"
#include "params.h"

#define MAX_M 500
#define SAMPLES 30
#define LDAINIT 0
extern dpyp_params PARAMS;

double* vec;
int* t_num;

double* y_probs;

double* lglik;

/*
 * used in drawing normal
 */
gsl_matrix *work;
gsl_vector *res1;

Matrix<> *A;
Matrix<> *B;
Matrix<> *A_tot;
Matrix<> *B_tot;

int directory_exist(const char *dname)
{
    struct stat st;
    int ret;

    if (stat(dname,&st) != 0)
    {
        return 0;
    }

    ret = S_ISDIR(st.st_mode);

    if(!ret)
    {
        errno = ENOTDIR;
    }

    return ret;
}

/*
 * Train the model
 *
 */
static void sampling(Model* model, Cts* cts,
		Assignment* ass, Corpus* c, Corpus* c_te, vocabulary* v, char* root, int pred)
{
	int ite;
	int samples;
	int begin_k;
	clock_t t1, t2;
	double lg_lihood;
	char str[BUFSIZ];
	Corpus* c1 = NULL;
	Corpus* c2 = NULL;
	Assignment* ass_te = NULL;
	Cts* cts_te = NULL;
	//Estimator *est1 = NULL, *est2 = NULL, *est_te = NULL;

	ite = 1;
	samples = 0;

	/*********** test data *************/
	if(c_te){
		assert(pred == 0);
		c1 = c_te;
		cts_te = new_cts(model->K, model->v, c1);
		init_state_gibbs(model, cts_te, ass_te, c1, 1, 0, 0);
	}
	/*********************************/

	do{
		if(pred == 0){
			t1 = clock();
			if(ite < 1){
				stm_e_gibbs(model, cts, ass, c, 0);
			}else{
				stm_e_gibbs(model, cts, ass, c, 1);
			}
			t2 = clock() - t1;
			t1 = clock();

			/************ test data **********/
			if(c_te){
				held_out_gibbs(model, cts_te, ass_te, c1, cts);
			}
			/********************************/

			t2 += (clock() - t1);
			printf(">>> running time for %d-iteration: %lf seconds\n",  ite, (double)t2/CLOCKS_PER_SEC);

			if(ite % 3 == 0){

				lg_lihood = spyp_marg_lglihood1(c, cts, model, 1);
				printf("\n++++++++++++ SPYP train log-likelihood +++++++++++++++\n");
				printf(" ===> %lf\n\n", lg_lihood);

				/*if(c_te){
					//if(pred == 0){
					//	lg_lihood = spyp_marg_lglihood1(c2, cts_te, model, est_te->mu, 0);
					//	printf("\n++++++++++++ SPYP test perplexity +++++++++++++++\n");
					//	printf(" ===> %lf\n\n", gsl_sf_exp(-lg_lihood / c2->total));
					//}else{
					double correct = 0;
					Matrix<> predy1 = (*model->eta) * (*cts->n);
					for(int i = 0; i < predy1.cols(); i++){
						if(maxind(predy1(_, i)) == c->docs[i].yd){
							correct++;
						}
					}
					correct /= c->ndocs;
					printf("training classification accuracy: %f\n", correct);

					correct = 0;
					Matrix<> predy = (*model->eta) * (*cts_te->n);
					for(int i = 0; i < predy.cols(); i++){
						if(maxind(predy(_, i)) == c1->docs[i].yd){
							correct++;
						}
					}
					correct /= c1->ndocs;
					printf("testing classification accuracy: %f\n", correct);
				}

				//alpha_opt(cts, model, c->ndocs, model->K_I[0]);
				//alpha_opt_newton(cts, model, c->ndocs, model->K_I[0]);
				//sample_alpha(cts, model, c);*/
			}

		}else{
			held_out_gibbs(model, cts, ass, c, NULL);
		}

		ite++;
	}while(ite <= PARAMS.gibbs_max_iter);

	if(pred == 0){
		sprintf(str, "%s/SPYP_final", root);
		mkdir(str, S_IRUSR | S_IWUSR | S_IXUSR);
		save_model(model, cts, c->ndocs, str);

		begin_k = 0;
		sprintf(str, "%s/top_%d_words_final.txt", root, PARAMS.top_words);
		print_top_words(PARAMS.top_words, begin_k, begin_k + model->K-1, cts->m, cts, v, str);
		save_topic_assignmnet(c, ass, root);

		if(c_te){
			free_assignment(ass_te, c1);
			free_cts(cts_te, c1, model->K);
			if(c2){
				free_corpus(c1);
				free_corpus(c2);
			}
		}
	}else{
		/*double correct = 0;
		Matrix<> predy = (*model->eta) * (*cts->n);
		for(int i = 0; i < predy.cols(); i++){
			if(maxind(predy(_, i)) == c->docs[i].yd){
				correct++;
			}
		}
		correct /= c->ndocs;
		if(correct > 0.5){
			printf("classification accuracy: %f\n", correct);
		}else{
			printf("classification accuracy: %f\n", 1 - correct);
		}*/
	}
}

void estimate(Corpus* c, Corpus* c_te, vocabulary* v, char* root, char* dir_p, int init, char* z_root, int pred)
{
	Model* model;
	Cts* cts;
	Assignment* ass;

	initial_rng();

	if(pred == 0){
		model = random_init(PARAMS.K_i, v->size, PARAMS.c, PARAMS.C, PARAMS.ell,
				PARAMS.L, PARAMS.nu, PARAMS.alpha, PARAMS.gamma);
	}else{
		char str[BUFSIZ];
		sprintf(str, "%s/SPYP_final", root);
		model = read_model(str);
	}

	cts = new_cts(model->K, model->v, c);
	if(init){
		ass = read_topic_assignmnet(c, z_root);
	}else{
		ass = new_assignment(c);
	}
	printf(">>>>>> Begin Gibbs sampling iteration for SPYP Model ......\n");
	if(pred == 0){
		if(init)
			init_state_gibbs(model, cts, ass, c, 0, 1, 1);
		else
			init_state_gibbs(model, cts, ass, c, 0, 0, 1);
	}else{
		init_state_gibbs(model, cts, ass, c, 0, 0, 0);
	}

	vec = (double*)malloc(sizeof(double) * model->K);

	//lglik = (double*)malloc(sizeof(double) * model->I);
	work = gsl_matrix_alloc(model->K + 1, model->K + 1);
	res1 = gsl_vector_alloc(model->K + 1);
	A = new Matrix<>(model->K + 1, 1, true, 0);
	B = new Matrix<>(model->K + 1, model->K + 1, true, 0);
	A_tot = new Matrix<>(model->K + 1, 1, true, 0);
	B_tot = new Matrix<>(model->K + 1, model->K + 1, true, 0);
	y_probs = (double*)malloc(sizeof(double) * (model->maxL + model->C));

	sampling(model, cts, ass, c, c_te, v, root, pred);
	printf(">>>>>> End Gibbs sampling iteration \n");

	//free(lglik);
	free_rng();
	free_all(c, model, cts, ass, v);
	free(vec);
	if(c_te){
		free_corpus(c_te);
	}

	gsl_matrix_free(work);
	gsl_vector_free(res1);
	delete A;
	delete B;
	delete A_tot;
	delete B_tot;
	free(y_probs);
}

double spyp_marg_lglihood1(Corpus* c, Cts* cts, Model* model, int train)
{
	int k, d, l, w;
	double lglik_ave = 0;
	double tmp;

	for(d = 0; d < c->ndocs; d++){
		for(l = 0; l < c->docs[d].total; l++){
			double val = 0;
			w = c->docs[d].words[l];
			for(k = 0; k < model->K; k++){
				tmp = (cts->m[k][w] + model->gamma) / (cts->M[k] + model->gamma * model->v)
						* (*cts->n)(k, d) / (cts->N[d] + model->alpha * model->K);
				val += tmp;
			}
			if(val > 0){
				lglik_ave += gsl_sf_log(val);
			}
		}
	}
	return lglik_ave /= c->total;
}

