Source code for deeprobust.graph.defense.pgd

from torch.optim.sgd import SGD
from torch.optim.optimizer import required
from torch.optim import Optimizer
import torch
import sklearn
import numpy as np
import scipy.sparse as sp

[docs]class PGD(Optimizer): """Proximal gradient descent. Parameters ---------- params : iterable iterable of parameters to optimize or dicts defining parameter groups proxs : iterable iterable of proximal operators alpha : iterable iterable of coefficients for proximal gradient descent lr : float learning rate momentum : float momentum factor (default: 0) weight_decay : float weight decay (L2 penalty) (default: 0) dampening : float dampening for momentum (default: 0) """ def __init__(self, params, proxs, alphas, lr=required, momentum=0, dampening=0, weight_decay=0): defaults = dict(lr=lr, momentum=0, dampening=0, weight_decay=0, nesterov=False) super(PGD, self).__init__(params, defaults) for group in self.param_groups: group.setdefault('proxs', proxs) group.setdefault('alphas', alphas) def __setstate__(self, state): super(PGD, self).__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False) group.setdefault('proxs', proxs) group.setdefault('alphas', alphas) def step(self, delta=0, closure=None): for group in self.param_groups: lr = group['lr'] weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] proxs = group['proxs'] alphas = group['alphas'] # apply the proximal operator to each parameter in a group for param in group['params']: for prox_operator, alpha in zip(proxs, alphas): # param.data.add_(lr, -param.grad.data) # param.data.add_(delta) param.data = prox_operator(param.data, alpha=alpha*lr)
[docs]class ProxOperators(): """Proximal Operators. """ def __init__(self): self.nuclear_norm = None
[docs] def prox_l1(self, data, alpha): """Proximal operator for l1 norm. """ data = torch.mul(torch.sign(data), torch.clamp(torch.abs(data)-alpha, min=0)) return data
[docs] def prox_nuclear(self, data, alpha): """Proximal operator for nuclear norm (trace norm). """ U, S, V = np.linalg.svd(data.cpu()) U, S, V = torch.FloatTensor(U).cuda(), torch.FloatTensor(S).cuda(), torch.FloatTensor(V).cuda() self.nuclear_norm = S.sum() # print("nuclear norm: %.4f" % self.nuclear_norm) diag_S = torch.diag(torch.clamp(S-alpha, min=0)) return torch.matmul(torch.matmul(U, diag_S), V)
def prox_nuclear_truncated_2(self, data, alpha, k=50): import tensorly as tl tl.set_backend('pytorch') U, S, V = tl.truncated_svd(data.cpu(), n_eigenvecs=k) U, S, V = torch.FloatTensor(U).cuda(), torch.FloatTensor(S).cuda(), torch.FloatTensor(V).cuda() self.nuclear_norm = S.sum() # print("nuclear norm: %.4f" % self.nuclear_norm) S = torch.clamp(S-alpha, min=0) # diag_S = torch.diag(torch.clamp(S-alpha, min=0)) # U = torch.spmm(U, diag_S) # V = torch.matmul(U, V) # make diag_S sparse matrix indices = torch.tensor((range(0, len(S)), range(0, len(S)))).cuda() values = S diag_S = torch.sparse.FloatTensor(indices, values, torch.Size((len(S), len(S)))) V = torch.spmm(diag_S, V) V = torch.matmul(U, V) return V def prox_nuclear_truncated(self, data, alpha, k=50): indices = torch.nonzero(data).t() values = data[indices[0], indices[1]] # modify this based on dimensionality data_sparse = sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy())) U, S, V = sp.linalg.svds(data_sparse, k=k) U, S, V = torch.FloatTensor(U).cuda(), torch.FloatTensor(S).cuda(), torch.FloatTensor(V).cuda() self.nuclear_norm = S.sum() diag_S = torch.diag(torch.clamp(S-alpha, min=0)) return torch.matmul(torch.matmul(U, diag_S), V) def prox_nuclear_cuda(self, data, alpha): U, S, V = torch.svd(data) # self.nuclear_norm = S.sum() # print(f"rank = {len(S.nonzero())}") self.nuclear_norm = S.sum() S = torch.clamp(S-alpha, min=0) indices = torch.tensor([range(0, U.shape[0]),range(0, U.shape[0])]).cuda() values = S diag_S = torch.sparse.FloatTensor(indices, values, torch.Size(U.shape)) # diag_S = torch.diag(torch.clamp(S-alpha, min=0)) # print(f"rank_after = {len(diag_S.nonzero())}") V = torch.spmm(diag_S, V.t_()) V = torch.matmul(U, V) return V
[docs]class SGD(Optimizer): def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super(SGD, self).__init__(params, defaults) def __setstate__(self, state): super(SGD, self).__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False)
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) return loss
prox_operators = ProxOperators()