dpl=FALSE
#dpl=TRUE
source("notes-funs.R")
######################################################################
if(1) {cat("### init h2o\n")
library(h2o)
h2o.init(nthreads=-1,max_mem_size="10g")
}
######################################################################
if(1) {cat("### read in mnist data\n")

train60D = read.csv("mnist-train.csv")
train60D$C785 = as.factor(train60D$C785)
train60 = as.h2o(train60D,"train60")

testD = read.csv("mnist-test.csv")
testD$C785 = as.factor(testD$C785)
test = as.h2o(testD,"test")

x=1:784;y=785

print(ls())
print(h2o.ls())
}
######################################################################
if(1) {cat("### plot digits\n")
source("Visualize.R")

if(dpl) pdf(file="plot-digits.pdf",height=10,width=12)
par(mfrow=c(3,3))
for(i in 1:9) {
   temp = train60[i,x]
   plot_mnist_images(temp)
   title(main=train60[i,y])
}
if(dpl) dev.off()


temp=h2o.group_by(train60,by="C785",nrow("C785"))
if(dpl) printfl(as.data.frame(temp),dpl,"y-counts.rtxt")

}
######################################################################
if(1) {cat("### split train60 into train/val\n")
set.seed(99)
parts = h2o.splitFrame(train60,1.0/6.0)
valid = parts[[1]]
train = parts[[2]]
rm(parts)

cat("is train S4: ", isS4(train),"\n")
print(train[1:4,1:5])
print(attr(train,"id"))
h2o.ls()
}
######################################################################
if(1) {cat("### default Random Forests\n")
fp = file.path("./files","mRFdef")
if(file.exists(fp)) {
   mRFdef = h2o.loadModel(fp)
} else {
   mRFdef = h2o.randomForest(x,y,train,
               model_id="mRFdef",
              validation_frame=valid)
   h2o.saveModel(mRFdef,path="./files")
}
cat("is model S4:",isS4(mRFdef),"\n")
cat("model id: ",mRFdef@model_id,"\n")
convRFdef = h2o.confusionMatrix(mRFdef,valid=TRUE)
printfl(convRFdef,dpl,"defRF-conf.rtxt")
}
######################################################################
if(1) {cat("### default dnn\n")
fp = file.path("./files","mDNNdef")
if(file.exists(fp)) {
   mDNNdef = h2o.loadModel(fp)
} else {
   mDNNdef = h2o.deeplearning(x,y,train,
            model_id="mDNNdef",
           validation_frame=valid)
   h2o.saveModel(mDNNdef,path="./files")
}
cat("model id: ",mDNNdef@model_id,"\n")
convDNNdef = h2o.confusionMatrix(mDNNdef,valid=TRUE)
printfl(convDNNdef,dpl,"defDNN-conf.rtxt")
perfDNNdef = h2o.performance(mDNNdef,valid=TRUE)
print(perfDNNdef@metrics$hit_ratio_table$hit_ratio)
print(perfDNNdef@metrics$mean_per_class_error)
}
######################################################################
if(1) {cat("### RF tuned\n")
listModels = list()
modelNames = list.files(file.path("./files/"),pattern="Grid_DRF_*")
if(length(modelNames)!=0) {
   numModels = 0
   for (modelName in modelNames) {
      numModels = numModels + 1
      listModels[[numModels]] = h2o.loadModel(path = file.path("./files/", modelName))
   }
} else { #this takes a long time
   gRF = h2o.grid("randomForest",
      hyper_params=list(
           ntrees=c(100,500),
           mtries=c(28,50),
           min_rows=c(2,5)),
     x=x,y=y,training_frame=train,validation_frame=valid)

   listModels = lapply(gRF@model_ids, function(id) h2o.getModel(id))
   for(m in listModels) h2o.saveModel(m,path="./files")
}
}
######################################################################
if(1) {cat("### see performance of RF tuned\n")
numModels=length(listModels)
mrate = rep(0,numModels)
for(i in 1:numModels) {
   print(h2o.confusionMatrix(listModels[[i]],valid=TRUE))
   mrate[i] = h2o.performance(listModels[[i]],valid=TRUE)@metrics$mean_per_class_error
}

if(dpl) pdf(file="mrate-rftuned.pdf",height=10,width=12)
plot(mrate,pch=16,col="red",cex.axis=1.5,cex.lab=1.5,type="b",cex=1.5,lty=3)
if(dpl) dev.off()

bestRF = listModels[[which.min(mrate)]]
cat("bestRF has:\n")
cat("ntrees,mtries,min_rows: ", bestRF@parameters$ntrees,bestRF@parameters$mtries, 
                                                       bestRF@parameters$min_rows,"\n") 

for(i in 1:length(listModels)) {
   cat("ntrees,mtries,min_rows: ", listModels[[i]]@parameters$ntrees,
                     listModels[[i]]@parameters$mtries, listModels[[i]]@parameters$min_rows,"\n") 
}

hyper_params=list(
           ntrees=c(100,500),
           mtries=c(28,50),
           min_rows=c(2,5))
print(expand.grid(hyper_params))

print(h2o.confusionMatrix(bestRF,valid=TRUE))
}
######################################################################
if(1) {cat("### DNN tuning\n")
#network structure
hidden_opt = list(c(200,200),
                    c(300,300))

#activation
activation_opt = c("TanhWithDropout", "RectifierWithDropout")

#input
#input_dropout_ratio_opt = c(0.2, 0)

#hidden
hidden_dropout_ratios_opt = list(c(.1,.1),c(.5,.5))

#l1 regularization
l1_opt = c(1e-4, 1e-2)

#l2 max
max_w2_opt = c(3.4028235e+38,50)

hyper_params = list(hidden = hidden_opt,
                     activation = activation_opt,
                     hidden_dropout_ratios = hidden_dropout_ratios_opt,
                     l1 = l1_opt,
                     max_w2 = max_w2_opt
                     )

#hyper_params = list(hidden = hidden_opt)
#hyper_params = list(max_w2=max_w2_opt)
#hyper_params = list(activation = activation_opt,
#                     hidden_dropout_ratios = hidden_dropout_ratios_opt)

##################################################
## do DNN grid
listModels = list()
modelNames = list.files(file.path("./files/"),pattern="Grid_DeepLearning_*")
if(length(modelNames)!=0) {
   numModels = 0
   for (modelName in modelNames) {
      numModels = numModels + 1
      listModels[[numModels]] = h2o.loadModel(path = file.path("./files/", modelName))
   }
} else { #this takes a long time
   gDNN = h2o.grid("deeplearning",
                    hyper_params=hyper_params,
                    x=x,y=y,training_frame=train,validation_frame=valid,
                    epochs=200)

   listModels = lapply(gDNN@model_ids, function(id) h2o.getModel(id))
   for(m in listModels) h2o.saveModel(m,path="./files",force=TRUE)
}

numModels=length(listModels)
mratednn = rep(0,numModels)
for(i in 1:numModels) {
   print(h2o.confusionMatrix(listModels[[i]],valid=TRUE))
   mratednn[i] = h2o.performance(listModels[[i]],valid=TRUE)@metrics$mean_per_class_error
}

par(mfrow=c(1,1))
rgy = range(c(mratednn,mrate))
plot(c(1,numModels),rgy,type="n",xlab="grid run number",ylab="miss-class",cex.lab=1.5,
            cex.axis=1.5)
points(1:numModels,mratednn,col="blue",pch=15,cex=2,type="b",lty=2)
nmRF = length(mrate)
points(1:nmRF,mrate,col="red",pch=16,cex=2,type="b",lty=3)
legend("topleft",legend=c("DNN","RF"),col=c("blue","red"),lty=c(2,3),lwd=c(3,3),pch=c(15,16))
# dev.copy2pdf(file="compare-grids.pdf",height=8,width=12)


bestDNN = listModels[[which.min(mratednn)]]
print(h2o.confusionMatrix(bestDNN,valid=TRUE))


if(0) {
mDNNtemp = h2o.deeplearning(x,y,train,
            hidden=c(100,100),
            activation="TanhWithDropout",
            input_dropout_ratio=.2,
            hidden_dropout_ratios=c(.5,.2),
            l1=1e-3,
            max_w2=10,
            epochs=200,
            model_id="mDNNtemp",
           validation_frame=valid)
}
}
######################################################################
### best nn on (train,val) -> pred on test
trainval = h2o.rbind(train,valid)

fp = file.path("./files","mDNNfinal")
if(file.exists(fp)) {
   mDNNfinal = h2o.loadModel(fp)
} else {
   mDNNfinal = h2o.deeplearning(x,y,trainval,
            hidden=c(200,200),
            activation="TanhWithDropout",
            hidden_dropout_ratios=c(.1,.1),
            l1=1e-4,
            epochs=200,
            model_id="mDNNfinal",
           validation_frame=test)

   h2o.saveModel(mDNNfinal,path="./files")
}

print(h2o.confusionMatrix(mDNNfinal,valid=TRUE))

######################################################################
if(dpl) rm(list=ls())
