''' ab_hmdb51_splits.py Jason Corso Train and test an SVM on the HMDB51 data. Uses the 3 splits provided by the HMDB creators Will produce the result statistic that we reported in the paper. The processed HMDB51 data is available at http://www.cse.buffalo.edu/~jcorso/r/actionbank MAKE sure that ../code is in your PYTHONPATH, i.e., export PYTHONPATH=../code before running this script ---- Information on the splits There are totally 153 files in this folder, [action]_test_split[1-3].txt corresponding to three splits reported in the paper. The format of each file is [video_name] [id] The video is included in the training set if id is 1 The video is included in the testing set if id is 2 The video is not included for training/testing if id is 0 There should be 70 videos with id 1 , 30 videos with id 2 in each txt file. ---- The following three videos are corrupt and we do not use them (as of 30 May 2012) pour/How_to_pour_beer_pour_u_nm_np1_fr_goo_0.avi pour/How_to_pour_beer__eh__pour_u_nm_np1_fr_goo_0.avi talk/jonhs_netfreemovies_holygrail_talk_h_nm_np1_fr_med_6.avi ''' import argparse import glob import gzip import numpy as np import os import os.path import random as rnd import scipy.io as sio import multiprocessing as mp from actionbank import * import ab_svm def loadsplit(classes,path,splitnumber): trainfiles = [] testfiles = [] for ci,c in enumerate(classes): fp = open(os.path.join(path,"%s_test_split%d.txt"%(c,splitnumber))) L = fp.readlines() for l in L: (name,op) = l.strip().split() # see note above: these are corrupt videos if ((c == "pour") and (name == "How_to_pour_beer_pour_u_nm_np1_fr_goo_0.avi")): print "not adding %s,%s" % (c,name) continue if ((c == "pour") and (name == "How_to_pour_beer__eh__pour_u_nm_np1_fr_goo_0.avi")): print "not adding %s,%s" % (c,name) continue if ((c == "talk") and (name == "jonhs_netfreemovies_holygrail_talk_h_nm_np1_fr_med_6.avi")): print "not adding %s,%s" % (c,name) continue if op == '1': trainfiles.append([os.path.join(c,name),ci]) elif op == '2': testfiles.append([os.path.join(c,name),ci]) fp.close() return trainfiles,testfiles if __name__ == '__main__': parser = argparse.ArgumentParser(description="Script to perform 10-fold cross-validation on the HMDB51 data set using the included SVM code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("root", help="path to the directory containing the action bank processed hmdb51 files structured as in root/class/class00_banked.npy.gz for each class") parser.add_argument("splits", help="path to the directory containing the HMDB51 splits files (153 of them)") args = parser.parse_args() vlen = 0 classes = os.listdir(args.root) if (len(classes) != 51): print "error: found %d classes, but there should be 51"%(len(cdir)) accs = np.zeros(3) for splitnumber in range(1,4): print "working on split %d"%splitnumber trainfiles,testfiles = loadsplit(classes,args.splits,splitnumber) print "have %d training files" % len(trainfiles) print "have %d testing files" % len(testfiles) if not vlen: fp = gzip.open(os.path.join(args.root,'%s%s'%(trainfiles[0][0],banked_suffix)),"rb") vlen = len(np.load(fp)) fp.close() print "vector length is %d"%vlen Dtrain = np.zeros( (len(trainfiles),vlen), np.uint8 ) Ytrain = np.ones ( (len(trainfiles) )) * -1000 for fi,f in enumerate(trainfiles): #print f fp = gzip.open(os.path.join(args.root,'%s%s'%(f[0],banked_suffix)),"rb") Dtrain[fi][:] = np.load(fp) fp.close() Ytrain[fi] = f[1] Dtest = np.zeros( (len(testfiles),vlen), np.uint8 ) Ytest = np.ones ( (len(testfiles) )) * -1000 for fi,f in enumerate(testfiles): #print f fp = gzip.open(os.path.join(args.root,'%s%s'%(f[0],banked_suffix)),"rb") Dtest[fi][:] = np.load(fp) fp.close() Ytest[fi] = f[1] print Dtrain.shape print Ytrain.shape print Dtest.shape print Ytest.shape res=ab_svm.SVMLinear(Dtrain,np.int32(Ytrain),Dtest,threads=mp.cpu_count()-1,useLibLinear=True,useL1R=False) tp=np.sum(res==Ytest) print 'Accuracy is %.1f%%' % ((np.float64(tp)/Dtest.shape[0])*100) accs[splitnumber-1] = ((np.float64(tp)/Dtest.shape[0])*100) print 'Mean accuracy is %f'%(accs.mean())