"""
Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
"""
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from deeprobust.graph import utils
from copy import deepcopy
from torch_geometric.nn import ChebConv
[docs]class ChebNet(nn.Module):
""" 2 Layer ChebNet based on pytorch geometric.
Parameters
----------
nfeat : int
size of input feature dimension
nhid : int
number of hidden units
nclass : int
size of output dimension
num_hops: int
number of hops in ChebConv
dropout : float
dropout rate for ChebNet
lr : float
learning rate for ChebNet
weight_decay : float
weight decay coefficient (l2 normalization) for GCN.
When `with_relu` is True, `weight_decay` will be set to 0.
with_bias: bool
whether to include bias term in ChebNet weights.
device: str
'cpu' or 'cuda'.
Examples
--------
We can first load dataset and then train ChebNet.
>>> from deeprobust.graph.data import Dataset
>>> from deeprobust.graph.defense import ChebNet
>>> data = Dataset(root='/tmp/', name='cora')
>>> adj, features, labels = data.adj, data.features, data.labels
>>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
>>> cheby = ChebNet(nfeat=features.shape[1],
nhid=16, num_hops=3,
nclass=labels.max().item() + 1,
dropout=0.5, device='cpu')
>>> cheby = cheby.to('cpu')
>>> pyg_data = Dpr2Pyg(data) # convert deeprobust dataset to pyg dataset
>>> cheby.fit(pyg_data, patience=10, verbose=True) # train with earlystopping
"""
def __init__(self, nfeat, nhid, nclass, num_hops=3, dropout=0.5, lr=0.01,
weight_decay=5e-4, with_bias=True, device=None):
super(ChebNet, self).__init__()
assert device is not None, "Please specify 'device'!"
self.device = device
self.conv1 = ChebConv(
nfeat,
nhid,
K=num_hops,
bias=with_bias)
self.conv2 = ChebConv(
nhid,
nclass,
K=num_hops,
bias=with_bias)
self.dropout = dropout
self.weight_decay = weight_decay
self.lr = lr
self.output = None
self.best_model = None
self.best_output = None
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
[docs] def initialize(self):
"""Initialize parameters of ChebNet.
"""
self.conv1.reset_parameters()
self.conv2.reset_parameters()
[docs] def fit(self, pyg_data, train_iters=200, initialize=True, verbose=False, patience=500, **kwargs):
"""Train the ChebNet model, when idx_val is not None, pick the best model
according to the validation loss.
Parameters
----------
pyg_data :
pytorch geometric dataset object
train_iters : int
number of training epochs
initialize : bool
whether to initialize parameters before training
verbose : bool
whether to show verbose logs
patience : int
patience for early stopping, only valid when `idx_val` is given
"""
self.device = self.conv1.weight.device
if initialize:
self.initialize()
self.data = pyg_data[0].to(self.device)
# By default, it is trained with early stopping on validation
self.train_with_early_stopping(train_iters, patience, verbose)
[docs] def train_with_early_stopping(self, train_iters, patience, verbose):
"""early stopping based on the validation loss
"""
if verbose:
print('=== training ChebNet model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
labels = self.data.y
train_mask, val_mask = self.data.train_mask, self.data.val_mask
early_stopping = patience
best_loss_val = 100
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(self.data)
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward(self.data)
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
if best_loss_val > loss_val:
best_loss_val = loss_val
self.output = output
weights = deepcopy(self.state_dict())
patience = early_stopping
else:
patience -= 1
if i > early_stopping and patience <= 0:
break
if verbose:
print('=== early stopping at {0}, loss_val = {1} ==='.format(i, best_loss_val) )
self.load_state_dict(weights)
[docs] def test(self):
"""Evaluate ChebNet performance on test set.
Parameters
----------
idx_test :
node testing indices
"""
self.eval()
test_mask = self.data.test_mask
labels = self.data.y
output = self.forward(self.data)
# output = self.output
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
[docs] def predict(self):
"""
Returns
-------
torch.FloatTensor
output (log probabilities) of ChebNet
"""
self.eval()
return self.forward(self.data)
if __name__ == "__main__":
from deeprobust.graph.data import Dataset, Dpr2Pyg
# from deeprobust.graph.defense import ChebNet
data = Dataset(root='/tmp/', name='cora')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
cheby = ChebNet(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device='cpu')
cheby = cheby.to('cpu')
pyg_data = Dpr2Pyg(data)
cheby.fit(pyg_data, verbose=True) # train with earlystopping
cheby.test()
print(cheby.predict())