# for pytorchimporttorchimporttorchvision# for GPU checkingimportpsutilimporthumanizeimportosimportGPUtilasGPUimportmatplotlib.pyplotasplt# for training NNimportnumpyasnpimportcalendarimporttimefromdatetimeimportdatetime
Settings
# set a random seed for the reproducibilitytorch.manual_seed(0);
# GPU checkGPUs=GPU.getGPUs()# XXX: only one GPU on Colab and isn’t guaranteedgpu=GPUs[0]defprintm():process=psutil.Process(os.getpid())print("Gen RAM Free: "+humanize.naturalsize(psutil.virtual_memory().available)," | Proc size: "+humanize.naturalsize(process.memory_info().rss))print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree,gpu.memoryUsed,gpu.memoryUtil*100,gpu.memoryTotal))printm()
# split train data into train and valid dataN_train=len(train_data_origin)N_train_new=int(N_train*0.8)N_valid_new=N_train-N_train_newN_train_new,N_valid_new
# Define the lossLossfunction=torch.nn.CrossEntropyLoss()Lossfunction.to(device)# Define the optimizerOptimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
Define the training process
deftrain_nn(nn,train_loader,valid_loader,optimizer,lossfunction,Epoch=10**10):# dir for save temporary filesifnotos.path.exists('./temp'):os.mkdir('./temp')# create an unique ID for saving temp file, avoiding file overwriting while multiple trainingtraining_ID=ts=int(calendar.timegm(time.gmtime()))print(f'The ID for this training is {training_ID}.')# initialize best valid loss for saving the best modelbest_valid_loss=10**10# arrays to save training processtrain_losses=[]valid_losses=[]train_accs=[]valid_accs=[]# to count the epoch without any improvement, for early stopnot_imporved=0# trainingforepochinrange(Epoch):# timerstart_time=datetime.now()# some temp variables to calculate the loss and acc from mini-batch to batchnum_of_mini_batch=[]loss_of_mini_batch=[]acc_of_mini_batch=[]forX_train,y_trainintrain_loader:# reshape 2D images into 1D, also transfer to deviceX_train=X_train.view(X_train.shape[0],-1).to(device)y_train=y_train.to(device)# forward propagationprediction_train=nn(X_train)# calculate losstrain_loss_mini_batch=lossfunction(prediction_train,y_train)# calculate predicted class of input datayhat_train=torch.argmax(prediction_train.data,1)# calculate how many predictions are correcttrain_correct=torch.sum(yhat_train==y_train.data)# calculate accuracy of predictiontrain_acc_mini_batch=train_correct/y_train.numel()# update parameters in modeloptimizer.zero_grad()train_loss_mini_batch.backward()optimizer.step()# loss and acc from mini-batchnum_of_mini_batch.append(X_train.shape[0])loss_of_mini_batch.append(train_loss_mini_batch.item())acc_of_mini_batch.append(train_acc_mini_batch.item())# convert and record loss/acc from mini-batch to batchtrain_loss=np.average(loss_of_mini_batch,weights=num_of_mini_batch)train_losses.append(train_loss)train_acc=np.average(acc_of_mini_batch,weights=num_of_mini_batch)train_accs.append(train_acc)# similar as training, calculate loss and accuracy on valid datanum_of_mini_batch=[]loss_of_mini_batch=[]acc_of_mini_batch=[]withtorch.no_grad():forX_valid,y_validinvalid_loader:X_valid=X_valid.view(X_valid.shape[0],-1).to(device)y_valid=y_valid.to(device)prediction_valid=nn(X_valid)valid_loss_mini_batch=lossfunction(prediction_valid,y_valid).datayhat_valid=torch.argmax(prediction_valid.data,1)valid_correct=torch.sum(yhat_valid==y_valid.data)valid_acc_mini_batch=valid_correct/y_valid.numel()num_of_mini_batch.append(X_valid.shape[0])loss_of_mini_batch.append(valid_loss_mini_batch.item())acc_of_mini_batch.append(valid_acc_mini_batch.item())valid_loss=np.average(loss_of_mini_batch,weights=num_of_mini_batch)valid_losses.append(valid_loss)valid_acc=np.average(acc_of_mini_batch,weights=num_of_mini_batch)valid_accs.append(valid_acc)# if valid loss in current epoch is better than previous one, save this modelifvalid_loss<best_valid_loss:best_valid_loss=valid_losstorch.save(nn,f'./temp/NN_temp_{training_ID}')random_state=torch.random.get_rng_state()torch.save(random_state,f'./temp/NN_temp_random_state_{training_ID}')not_imporved=0# if not, that means in this epoch, the model is not improved.else:not_imporved+=1ifnot_imporved>5:print('Early stop.')break# timerend_time=datetime.now()# print information about current epochifepoch%1==0:print(f'| Epoch: {epoch:-5d} | Train acc: {train_acc:-.5f} | Train loss: {train_loss:-.5e} | Valid acc: {valid_acc:-.5f} | Valid loss: {valid_loss:-.5e} | run time: {end_time-start_time} |')print('Finished.')returntorch.load(f'./temp/NN_temp_{training_ID}'),train_losses,valid_losses,train_accs,valid_accs
plt.plot(train_loss,label='train loss');plt.plot(valid_loss,label='valid loss');plt.xlabel('epoch')plt.xlabel('loss')plt.title('Loss of training')plt.legend();
plt.plot(train_acc,label='train acc');plt.plot(valid_acc,label='valid acc');plt.xlabel('epoch')plt.xlabel('accuracy')plt.title('Acc of training')plt.legend();