"""
Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019.
http://pengcui.thumedialab.com/papers/RGCN.pdf
Author's Tensorflow implemention:
https://github.com/thumanlab/nrlweb/tree/master/static/assets/download
"""
import torch.nn.functional as F
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch.distributions.multivariate_normal import MultivariateNormal
from deeprobust.graph import utils
import torch.optim as optim
from copy import deepcopy
# TODO sparse implementation
[docs]class GGCL_F(Module):
"""Graph Gaussian Convolution Layer (GGCL) when the input is feature"""
def __init__(self, in_features, out_features, dropout=0.6):
super(GGCL_F, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.dropout = dropout
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.weight_miu)
torch.nn.init.xavier_uniform_(self.weight_sigma)
def forward(self, features, adj_norm1, adj_norm2, gamma=1):
features = F.dropout(features, self.dropout, training=self.training)
self.miu = F.elu(torch.mm(features, self.weight_miu))
self.sigma = F.relu(torch.mm(features, self.weight_sigma))
# torch.mm(previous_sigma, self.weight_sigma)
Att = torch.exp(-gamma * self.sigma)
miu_out = adj_norm1 @ (self.miu * Att)
sigma_out = adj_norm2 @ (self.sigma * Att * Att)
return miu_out, sigma_out
[docs]class GGCL_D(Module):
"""Graph Gaussian Convolution Layer (GGCL) when the input is distribution"""
def __init__(self, in_features, out_features, dropout):
super(GGCL_D, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.dropout = dropout
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
# self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.weight_miu)
torch.nn.init.xavier_uniform_(self.weight_sigma)
def forward(self, miu, sigma, adj_norm1, adj_norm2, gamma=1):
miu = F.dropout(miu, self.dropout, training=self.training)
sigma = F.dropout(sigma, self.dropout, training=self.training)
miu = F.elu(miu @ self.weight_miu)
sigma = F.relu(sigma @ self.weight_sigma)
Att = torch.exp(-gamma * sigma)
mean_out = adj_norm1 @ (miu * Att)
sigma_out = adj_norm2 @ (sigma * Att * Att)
return mean_out, sigma_out
[docs]class GaussianConvolution(Module):
"""[Deprecated] Alternative gaussion convolution layer.
"""
def __init__(self, in_features, out_features):
super(GaussianConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features))
self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features))
# self.sigma = Parameter(torch.FloatTensor(out_features))
# self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
# TODO
torch.nn.init.xavier_uniform_(self.weight_miu)
torch.nn.init.xavier_uniform_(self.weight_sigma)
def forward(self, previous_miu, previous_sigma, adj_norm1=None, adj_norm2=None, gamma=1):
if adj_norm1 is None and adj_norm2 is None:
return torch.mm(previous_miu, self.weight_miu), \
torch.mm(previous_miu, self.weight_miu)
# torch.mm(previous_sigma, self.weight_sigma)
Att = torch.exp(-gamma * previous_sigma)
M = adj_norm1 @ (previous_miu * Att) @ self.weight_miu
Sigma = adj_norm2 @ (previous_sigma * Att * Att) @ self.weight_sigma
return M, Sigma
# M = torch.mm(torch.mm(adj, previous_miu * A), self.weight_miu)
# Sigma = torch.mm(torch.mm(adj, previous_sigma * A * A), self.weight_sigma)
# TODO sparse implemention
# support = torch.mm(input, self.weight)
# output = torch.spmm(adj, support)
# return output + self.bias
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
[docs]class RGCN(Module):
"""Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019.
Parameters
----------
nnodes : int
number of nodes in the input grpah
nfeat : int
size of input feature dimension
nhid : int
number of hidden units
nclass : int
size of output dimension
gamma : float
hyper-parameter for RGCN. See more details in the paper.
beta1 : float
hyper-parameter for RGCN. See more details in the paper.
beta2 : float
hyper-parameter for RGCN. See more details in the paper.
lr : float
learning rate for GCN
dropout : float
dropout rate for GCN
device: str
'cpu' or 'cuda'.
"""
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
super(RGCN, self).__init__()
self.device = device
# adj_norm = normalize(adj)
# first turn original features to distribution
self.lr = lr
self.gamma = gamma
self.beta1 = beta1
self.beta2 = beta2
self.nclass = nclass
self.nhid = nhid // 2
# self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout)
# self.gc2 = GaussianConvolution(nhid, nclass, dropout)
self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout)
self.gc2 = GGCL_D(nhid, nclass, dropout=dropout)
self.dropout = dropout
# self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
torch.diag_embed(torch.ones(nnodes, self.nclass)))
self.adj_norm1, self.adj_norm2 = None, None
self.features, self.labels = None, None
def forward(self):
features = self.features
miu, sigma = self.gc1(features, self.adj_norm1, self.adj_norm2, self.gamma)
miu, sigma = self.gc2(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma)
output = miu + self.gaussian.sample().to(self.device) * torch.sqrt(sigma + 1e-8)
return F.log_softmax(output, dim=1)
[docs] def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, verbose=True, **kwargs):
"""Train RGCN.
Parameters
----------
features :
node features
adj :
the adjacency matrix. The format could be torch.tensor or scipy matrix
labels :
node labels
idx_train :
node training indices
idx_val :
node validation indices. If not given (None), GCN training process will not adpot early stopping
train_iters : int
number of training epochs
verbose : bool
whether to show verbose logs
Examples
--------
We can first load dataset and then train RGCN.
>>> from deeprobust.graph.data import PrePtbDataset, Dataset
>>> from deeprobust.graph.defense import RGCN
>>> # load clean graph data
>>> data = Dataset(root='/tmp/', name='cora', seed=15)
>>> adj, features, labels = data.adj, data.features, data.labels
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
>>> # load perturbed graph data
>>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora')
>>> perturbed_adj = perturbed_data.adj
>>> # train defense model
>>> model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1],
nclass=labels.max()+1, nhid=32, device='cpu')
>>> model.fit(features, perturbed_adj, labels, idx_train, idx_val,
train_iters=200, verbose=True)
>>> model.test(idx_test)
"""
adj, features, labels = utils.to_tensor(adj.todense(), features.todense(), labels, device=self.device)
self.features, self.labels = features, labels
self.adj_norm1 = self._normalize_adj(adj, power=-1/2)
self.adj_norm2 = self._normalize_adj(adj, power=-1)
print('=== training rgcn model ===')
self._initialize()
if idx_val is None:
self._train_without_val(labels, idx_train, train_iters, verbose)
else:
self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
def _train_without_val(self, labels, idx_train, train_iters, verbose=True):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
self.train()
for i in range(train_iters):
optimizer.zero_grad()
output = self.forward()
loss_train = self._loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward()
self.output = output
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
best_loss_val = 100
best_acc_val = 0
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward()
loss_train = self._loss(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward()
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
if best_loss_val > loss_val:
best_loss_val = loss_val
self.output = output
if acc_val > best_acc_val:
best_acc_val = acc_val
self.output = output
print('=== picking the best model according to the performance on validation ===')
[docs] def test(self, idx_test):
"""Evaluate the peformance on test set
"""
self.eval()
# output = self.forward()
output = self.output
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
[docs] def predict(self):
"""
Returns
-------
torch.FloatTensor
output (log probabilities) of RGCN
"""
self.eval()
return self.forward()
def _loss(self, input, labels):
loss = F.nll_loss(input, labels)
miu1 = self.gc1.miu
sigma1 = self.gc1.sigma
kl_loss = 0.5 * (miu1.pow(2) + sigma1 - torch.log(1e-8 + sigma1)).mean(1)
kl_loss = kl_loss.sum()
norm2 = torch.norm(self.gc1.weight_miu, 2).pow(2) + \
torch.norm(self.gc1.weight_sigma, 2).pow(2)
# print(f'gcn_loss: {loss.item()}, kl_loss: {self.beta1 * kl_loss.item()}, norm2: {self.beta2 * norm2.item()}')
return loss + self.beta1 * kl_loss + self.beta2 * norm2
def _initialize(self):
self.gc1.reset_parameters()
self.gc2.reset_parameters()
def _normalize_adj(self, adj, power=-1/2):
"""Row-normalize sparse matrix"""
A = adj + torch.eye(len(adj)).to(self.device)
D_power = (A.sum(1)).pow(power)
D_power[torch.isinf(D_power)] = 0.
D_power = torch.diag(D_power)
return D_power @ A @ D_power
if __name__ == "__main__":
from deeprobust.graph.data import PrePtbDataset, Dataset
# load clean graph data
dataset_str = 'pubmed'
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
# load perturbed graph data
perturbed_data = PrePtbDataset(root='/tmp/', name=dataset_str)
perturbed_adj = perturbed_data.adj
# train defense model
model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1],
nclass=labels.max()+1, nhid=32, device='cuda').to('cuda')
model.fit(features, perturbed_adj, labels, idx_train, idx_val,
train_iters=200, verbose=True)
model.test(idx_test)
prediction_1 = model.predict()
print(prediction_1)
# prediction_2 = model.predict(features, perturbed_adj)
# assert (prediction_1 != prediction_2).sum() == 0