"""
This function help to train model of different archtecture easily. Select model archtecture and training data, then output corresponding model.
"""
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F #233
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
[docs]def train(model, data, device, maxepoch, data_path = './', save_per_epoch = 10, seed = 100):
"""train.
Parameters
----------
model :
model(option:'CNN', 'ResNet18', 'ResNet34', 'ResNet50', 'densenet', 'vgg11', 'vgg13', 'vgg16', 'vgg19')
data :
data(option:'MNIST','CIFAR10')
device :
device(option:'cpu', 'cuda')
maxepoch :
training epoch
data_path :
data path(default = './')
save_per_epoch :
save_per_epoch(default = 10)
seed :
seed
Examples
--------
>>>import deeprobust.image.netmodels.train_model as trainmodel
>>>trainmodel.train('CNN', 'MNIST', 'cuda', 20)
"""
torch.manual_seed(seed)
train_loader, test_loader = feed_dataset(data, data_path)
if (model == 'CNN'):
import deeprobust.image.netmodels.CNN as MODEL
#from deeprobust.image.netmodels.CNN import Net
train_net = MODEL.Net().to(device)
elif (model == 'ResNet18'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet18().to(device)
elif (model == 'ResNet34'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet34().to(device)
elif (model == 'ResNet50'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet50().to(device)
elif (model == 'densenet'):
import deeprobust.image.netmodels.densenet as MODEL
train_net = MODEL.densenet_cifar().to(device)
elif (model == 'vgg11'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG11').to(device)
elif (model == 'vgg13'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG13').to(device)
elif (model == 'vgg16'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG16').to(device)
elif (model == 'vgg19'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG19').to(device)
optimizer = optim.SGD(train_net.parameters(), lr= 0.1, momentum=0.5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
save_model = True
for epoch in range(1, maxepoch + 1): ## 5 batches
print(epoch)
MODEL.train(train_net, device, train_loader, optimizer, epoch)
MODEL.test(train_net, device, test_loader)
if (save_model and (epoch % (save_per_epoch) == 0 or epoch == maxepoch)):
if os.path.isdir('./trained_models/'):
print('Save model.')
torch.save(train_net.state_dict(), './trained_models/'+ data + "_" + model + "_epoch_" + str(epoch) + ".pt")
else:
os.mkdir('./trained_models/')
print('Make directory and save model.')
torch.save(train_net.state_dict(), './trained_models/'+ data + "_" + model + "_epoch_" + str(epoch) + ".pt")
scheduler.step()
def feed_dataset(data, data_dict):
if(data == 'CIFAR10'):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(data_dict, train=True, download = True,
transform=transform_train),
batch_size= 128, shuffle=True) #, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(data_dict, train=False, download = True,
transform=transform_val),
batch_size= 1000, shuffle=True) #, **kwargs)
elif(data == 'MNIST'):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dict, train=True, download = True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=128,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dict, train=False, download = True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=1000,
shuffle=True)
elif(data == 'ImageNet'):
pass
return train_loader, test_loader