"""
CSE 455/555 Homework 2 --- Work with Support Vector Machines
Jason Corso (jcorso@acm.org)

You are to use the prpy and examples code already provided as a base for this 
functionality.

Make sure they are in your python path before you execute this file.

Refer to the assignment description for full details.  There are XXX places in 
this file in which you are required to make changes and add code.  They are 
listed below.

Matrix notation is from cvxopt document (discussed below), relations to 
equations in Burges tutorial are given in parentheses.
A Ln. 139 -- Define the elements of the matrix P (Burges Eq. 16)
B Ln. 146 -- Define the elements of vector q (Burges Eq. 16)
C Ln. 150 -- Define the two sets of constraints (Burges p. 9)
D Ln. 169 -- Compute the explicit weight vector (Burges Eq. 14)
E Ln. 176 -- Compute the bias
F Ln. 182 -- Compute the margin

***THIS CODE WILL NOT RUN UNTIL AT LEAST A, B and C are done***

Once the changes are complete, you should be able to just run the file and this 
will perform the full set of steps required in the homework assignment.  This 
is what we'll do...

"""

import numpy as np
import scipy as sp
import scipy.io as sio
import matplotlib.pyplot as plt
import PIL
import os
import cvxopt as cvx    # these are new to this assignment, for convex optimization
import cvxopt.solvers

import prpy as pr


def debugPlot_SVM(X,Y,SV,w,b=None,h=None,m=None,fname=None):
    ''' Debug two-class plotting and draw the linear discriminant. 
    
        (X,Y) are the data set
        Z is the normalized data set

        w is the weight vector, b (if exists) is the bias (a'x + b is the classifier)
        m is the margin (if exists)

        h is the figure handle to draw into

        fname is a filename at which to save the figure (and not display it to the screen)
    '''

    if h is None:
        h = plt.figure()
    else:
        plt.figure(h.number)

    plt.clf()

    plt.axis([-10,10,-10,10])
    plt.grid(True)

    if SV is not None:
        for i in range(Y.shape[0]):
            if SV[i]:
                plt.plot(X[i,0],X[i,1],'go',hold=True,markersize=14.0, \
                        markerfacecolor=[1,1,1],markeredgecolor='g',markeredgewidth=1.5)

    for i in range(Y.shape[0]):
        if Y[i] == -1:
            plt.plot(X[i,0],X[i,1],pr.datatools.kDraw[0],markersize=8.0,hold=True)
        else:
            plt.plot(X[i,0],X[i,1],pr.datatools.kDraw[1],markersize=8.0,hold=True)

    pr.lindisc.plotWVector(w,b,m)

    plt.draw()
    if fname is not None:
        plt.savefig(fname)
    else:
        plt.show()

    return h



def linearSVM(X,Y,fname=None):
    '''
    Train a linear SVM on the data set (X,Y).

     X is n by d 
     Y is n by 1 and is -1 or +1 for classes
      Assume (linear) separability.
     fname is the name of the filename to save the figure displaying the learned linearSVM

     Returns w,b,m: the discriminant direction, bias, and margin
        wx+b > 0 are useful for classifying
    '''

    n = len(Y)
    d = X.shape[1]

    # CVXOPT provides our quadratic programming routine (see homework 
    # description
    #  for getting it; it is not bundled with EPD).  From: 
    #  http://abel.ee.ucla.edu/cvxopt/userguide/#
    #  cvxopt.solvers.qp(P, q[, G, h[, A, b[, solver[, initvals]]]])
    #  
    #   minimize      1/2 x'Px + q'x
    #   subject to    Gx <= h  AND   Ax = b
    #
    # Variables below follow this notation from the documentation wherever 
    # possible.
    # BUT note this b is different than the bias term in the SVM, this is just 
    # to keep the same notation as in the cvxopt doc
    #
    # Please note that CVXOPT using its own "matrix" class.  But, we can build 
    # numpy arrays
    # and then directly convert them to matrix, as in the case of P below.
    #  http://abel.ee.ucla.edu/cvxopt/userguide/matrices.html#the-numpy-array-interface
    #
    # Note that we'll consider only relatively small problems, as this code is 
    # inefficient.
    #    
    # Note for clarity, in the language of the Burgess tutorial paper, we are 
    # optimizing the dual function L_D, but the cvxopt program seeks to 
    # minimize it.  So, we multiply the objective function by -1.
    #
    # Note, there is no slack variable in our formulation

    # Build P first
    P = np.zeros((n,n))

    ###################################
    ###  A.  Fill In Here (define P)
    ###################################

    # increases numerical stability of P
    P = P + 1e-10*np.eye(n)
    
    ###################################
    ###  B.  Fill In Here (define q)
    ###################################

    ###################################
    ###  C.  Fill In Here (define G,h,A,b as the 2 SVM QP constraints (no slack,dual-form))
    ###################################


    # set up for cvx solver
    P = cvx.matrix(P)
    q = cvx.matrix(q)
    G = cvx.matrix(G)
    h = cvx.matrix(h)
    A = cvx.matrix(A).T   # be careful here, with this matrix shape
    b = cvx.matrix(b)

    sol = cvx.solvers.qp(P,q,G,h,A,b)
    alpha = np.asarray(sol['x'])
    # this alpha is the alpha in the Burges solution

    # compute w 
    w = np.zeros(d)
    ###################################
    ###  D.  Fill In Here (solve for w, weight vector)
    ###################################

    # compute b 
    tol = 1e-7   # you should consider anything < tol to effectively be 0
    b = 0
    ###################################
    ###  E.  Fill In Here (solve for b, bias, as in w'*x+b)
    ###################################

    # compute the margin 
    m = 0
    ###################################
    ###  F.  Fill In Here (solve for m, the margin)
    ###################################

    SV = alpha>tol
    print w
    print b
    print m
    debugPlot_SVM(X,Y,SV=SV,w=w,b=b,m=m,fname=fname)

    return w,b,m



if __name__ == '__main__':
    print 'Running hw2_svm.py as a main script; will proceed for training and testing.'
    print 'This code assumes you are running the data in the hw2 directory with the hw2 data files available in this current directory.'

    # Load the first data set from disk
    Z = np.load('./hw2_data1.npz')
    X = Z['X']
    Y = Z['Y']

    # Train the linear SVM on it
    (w,b,m) = linearSVM(X,Y,'./figout1.pdf')

    # Load the second data set from disk
    Z = np.load('./hw2_data2.npz')
    X = Z['X']
    Y = Z['Y']

    # Train the linear SVM on it
    (w,b,m) = linearSVM(X,Y,'./figout2.pdf')


    # Load the second data set from disk
    Z = np.load('./hw2_data3.npz')
    X = Z['X']
    Y = Z['Y']

    # Train the linear SVM on it
    (w,b,m) = linearSVM(X,Y,'./figout3.pdf')



