Source code for deeprobust.image.defense.TherEncoding

"""
This is an implementation of Thermometer Encoding.

References
----------
.. [1] Buckman, Jacob, Aurko Roy, Colin Raffel, and Ian Goodfellow. "Thermometer encoding: One hot way to resist adversarial examples." In International Conference on Learning Representations. 2018.
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
from torchvision import datasets, transforms
from deeprobust.image.netmodels.CNN import Net

import logging

## TODO
# class ther_attack(pgd_attack):
#     """
#     PGD attacks in response to thermometer encoding models
#     """
## TODO
# def adv_train():
#     """
#     adversarial training for thermomoter encoding
#     """

[docs]def train(model, device, train_loader, optimizer, epoch): """training process. Parameters ---------- model : model device : device train_loader : training data loader optimizer : optimizer epoch : epoch """ logger.info('trainging') model.train() correct = 0 bs = train_loader.batch_size for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() data, target = data.to(device), target.to(device) encoding = Thermometer(data, LEVELS) encoding = encoding.permute(0, 2, 3, 1, 4) encoding = torch.flatten(encoding, start_dim = 3) encoding = encoding.permute(0, 3, 1, 2) #print(encoding.size()) #ipdb.set_trace() output = model(encoding) loss = F.nll_loss(output, target) loss.backward() optimizer.step() pred = output.argmax(dim = 1, keepdim = True) correct += pred.eq(target.view_as(pred)).sum().item() #print(pred,target) #print every 10 if batch_idx % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy:{:.2f}%'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item(), 100 * correct/(10*bs))) correct = 0 a = input()
def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) encoding = Thermometer(data, LEVELS) encoding = encoding.permute(0, 2, 3, 1, 4) encoding = torch.flatten(encoding, start_dim=3) encoding = encoding.permute(0, 3, 1, 2) # print clean accuracy output = model(encoding) test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Clean loss: {:.3f}, Clean Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
[docs]def Thermometer(x, levels, flattened = False): """ Output ------ Thermometer Encoding of the input. """ onehot = one_hot(x, levels) thermometer = one_hot_to_thermometer(onehot, levels) return thermometer
[docs]def one_hot(x, levels): """ Output ------ One hot Encoding of the input. """ batch_size, channel, H, W = x.size() x = x.unsqueeze_(4) x = torch.ceil(x * (LEVELS-1)).long() onehot = torch.zeros(batch_size, channel, H, W, levels).float().to('cuda').scatter_(4, x, 1) #print(onehot) return onehot
[docs]def one_hot_to_thermometer(x, levels, flattened = False): """ Convert One hot Encoding to Thermometer Encoding. """ if flattened: pass #TODO: check how to flatten thermometer = torch.cumsum(x , dim = 4) if flattened: pass return thermometer
if __name__ =='__main__': logger = logging.getLogger('Thermometer Encoding') handler = logging.StreamHandler() # Handler for the logger handler.setFormatter(logging.Formatter('%(asctime)s')) logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info('Start attack.') torch.manual_seed(100) device = torch.device("cuda") #ipdb.set_trace() logger.info('Load trainset.') train_loader = torch.utils.data.DataLoader( datasets.MNIST('deeprobust/image/data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])), batch_size=100, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('deeprobust/image/data', train=False, transform=transforms.Compose([transforms.ToTensor()])), batch_size=1000, shuffle=True) #ipdb.set_trace() #TODO: change the channel according to the dataset. LEVELS = 10 channel = 1 model = Net(in_channel1 = channel * LEVELS, out_channel1= 32 * LEVELS, out_channel2= 64 * LEVELS).to(device) optimizer = optim.SGD(model.parameters(), lr = 0.0001, momentum = 0.2) logger.info('Load model.') save_model = True for epoch in range(1, 50 + 1): ## 5 batches print('Running epoch ', epoch) train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader) if (save_model): torch.save(model.state_dict(), "deeprobust/image/save_models/thermometer_encoding.pt")