Skip to content

Commit

Permalink
fix bugs on the loss function in fga.py and ig_attack.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ChandlerBang committed Jul 29, 2020
1 parent f4cc99c commit fb047ac
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 57 deletions.
11 changes: 5 additions & 6 deletions deeprobust/graph/targeted_attack/fga.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FGA(BaseAttack):
>>> target_node = 0
>>> model = FGA(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device='cpu').to('cpu')
>>> # Attack
>>> model.attack(features, adj, labels, idx_train, target_node, n_perturbations=5)
>>> model.attack(features, adj, labels, target_node, n_perturbations=5)
>>> modified_adj = model.modified_adj
"""
Expand All @@ -67,7 +67,7 @@ def __init__(self, model, nnodes, feature_shape=None, attack_structure=True, att
self.feature_changes = Parameter(torch.FloatTensor(feature_shape))
self.feature_changes.data.fill_(0)

def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_perturbations, **kwargs):
def attack(self, ori_features, ori_adj, labels, target_node, n_perturbations, **kwargs):
"""Generate perturbations on the input graph.
Parameters
Expand All @@ -78,8 +78,6 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur
Original (unperturbed) node feature matrix
labels :
node labels
idx_train :
node training indices
target_node : int
target node index to be attacked
n_perturbations : int
Expand All @@ -93,15 +91,16 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur

self.surrogate.eval()
print('number of pertubations: %s' % n_perturbations)
pseudo_labels = self.surrogate.predict().detach().argmax(1)

for i in range(n_perturbations):
modified_row = modified_adj[target_node] + self.adj_changes
modified_adj[target_node] = modified_row
adj_norm = utils.normalize_adj_tensor(modified_adj)

if self.attack_structure:
output = self.surrogate(modified_features, adj_norm)
loss = F.nll_loss(output[idx_train], labels[idx_train])
# acc_train = accuracy(output[idx_train], labels[idx_train])
loss = F.nll_loss(output[[target_node]], pseudo_labels[[target_node]])
grad = torch.autograd.grad(loss, self.adj_changes, retain_graph=True)[0]
grad = grad * (-2*modified_row + 1)
grad[target_node] = 0
Expand Down
29 changes: 15 additions & 14 deletions deeprobust/graph/targeted_attack/ig_attack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective
https://arxiv.org/pdf/1906.04214.pdf
Tensorflow Implementation:
https://github.com/KaidiXu/GCN_ADV_Train
Adversarial Examples on Graph Data: Deep Insights into Attack and Defense
https://arxiv.org/pdf/1903.01610.pdf
"""

import torch
Expand Down Expand Up @@ -57,7 +55,7 @@ class IGAttack(BaseAttack):
>>> target_node = 0
>>> model = IGAttack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device='cpu').to('cpu')
>>> # Attack
>>> model.attack(features, adj, labels, idx_train, target_node, n_perturbations=5, steps=10)
>>> model.attack(features, adj, labels, target_node, n_perturbations=5, steps=10)
>>> modified_adj = model.modified_adj
>>> modified_features = model.modified_features
Expand All @@ -73,7 +71,7 @@ def __init__(self, model, nnodes=None, feature_shape=None, attack_structure=True
self.modified_features = None
self.target_node = None

def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_perturbations, steps=10, **kwargs):
def attack(self, ori_features, ori_adj, labels, target_node, n_perturbations, steps=10, **kwargs):
"""Generate perturbations on the input graph.
Parameters
Expand All @@ -84,8 +82,6 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur
Original (unperturbed) adjacency matrix
labels :
node labels
idx_train :
node training indices
target_node : int
target node index to be attacked
n_perturbations : int
Expand All @@ -98,6 +94,8 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur
self.surrogate.eval()
self.target_node = target_node

self.pseudo_labels = self.surrogate.predict().detach().argmax(1)

modified_adj = ori_adj.todense()
modified_features = ori_features.todense()
adj, features, labels = utils.to_tensor(modified_adj, modified_features, labels, device=self.device)
Expand All @@ -106,9 +104,9 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur
s_e = np.zeros(adj.shape[1])
s_f = np.zeros(features.shape[1])
if self.attack_structure:
s_e = self.calc_importance_edge(features, adj_norm, labels, idx_train, steps)
s_e = self.calc_importance_edge(features, adj_norm, labels, steps)
if self.attack_features:
s_f = self.calc_importance_feature(features, adj_norm, labels, idx_train, steps)
s_f = self.calc_importance_feature(features, adj_norm, labels, steps)

for t in (range(n_perturbations)):
s_e_max = np.argmax(s_e)
Expand All @@ -127,7 +125,7 @@ def attack(self, ori_features, ori_adj, labels, idx_train, target_node, n_pertur
self.check_adj(modified_adj)


def calc_importance_edge(self, features, adj_norm, labels, idx_train, steps):
def calc_importance_edge(self, features, adj_norm, labels, steps):
"""Calculate integrated gradient for edges
"""
baseline_add = adj_norm.clone()
Expand All @@ -147,7 +145,8 @@ def calc_importance_edge(self, features, adj_norm, labels, idx_train, steps):

for new_adj in scaled_inputs:
output = self.surrogate(features, new_adj)
loss = F.nll_loss(output[idx_train], labels[idx_train])
loss = F.nll_loss(output[[self.target_node]],
self.pseudo_labels[[self.target_node]])
adj_grad = torch.autograd.grad(loss, adj_norm)[0]
adj_grad = adj_grad[i][j]
_sum += adj_grad
Expand All @@ -165,7 +164,7 @@ def calc_importance_edge(self, features, adj_norm, labels, idx_train, steps):
integrated_grad_list = (-2 * adj[self.target_node] + 1) * integrated_grad_list
return integrated_grad_list

def calc_importance_feature(self, features, adj_norm, labels, idx_train, steps):
def calc_importance_feature(self, features, adj_norm, labels, steps):
"""Calculate integrated gradient for features
"""
baseline_add = features.clone()
Expand All @@ -185,7 +184,9 @@ def calc_importance_feature(self, features, adj_norm, labels, idx_train, steps):

for new_features in scaled_inputs:
output = self.surrogate(new_features, adj_norm)
loss = F.nll_loss(output[idx_train], labels[idx_train])
loss = F.nll_loss(output[[self.target_node]],
self.pseudo_labels[[self.target_node]])

feature_grad = torch.autograd.grad(loss, features)[0]
feature_grad = feature_grad[i][j]
_sum += feature_grad
Expand Down
113 changes: 83 additions & 30 deletions examples/graph/test_fga.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@

# Setup Surrogate model
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)
nhid=16, device=device)

surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train)
surrogate.fit(features, adj, labels, idx_train, idx_val)

# Setup Attack Model
target_node = 0
model = FGA(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device=device)
model = FGA(surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)

def main():
Expand All @@ -51,7 +51,7 @@ def main():
degrees = adj.sum(0).A1
n_perturbations = int(degrees[u]) # How many perturbations to perform. Default: Degree of the node

model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
model.attack(features, adj, labels, target_node, n_perturbations)

print('=== testing GCN on original(clean) graph ===')
test(adj, features, target_node)
Expand Down Expand Up @@ -82,22 +82,41 @@ def test(adj, features, target_node):

return acc_test.item()

def select_nodes():
'''
selecting nodes as reported in nettack paper:
(i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
(ii) the 10 nodes with lowest margin (but still correctly classified) and
(iii) 20 more nodes randomly
'''

def single_test(adj, features, target_node):
''' test on GCN (poisoning attack)'''
gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)

gcn = gcn.to(device)

gcn.fit(features, adj, labels, idx_train)

gcn.eval()
output = gcn.predict()
probs = torch.exp(output[[target_node]])
acc_test = accuracy(output[[target_node]], labels[target_node])
# print("Test set results:", "accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()

def select_nodes(target_gcn=None):
'''
selecting nodes as reported in nettack paper:
(i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
(ii) the 10 nodes with lowest margin (but still correctly classified) and
(iii) 20 more nodes randomly
'''

if target_gcn is None:
target_gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
target_gcn = target_gcn.to(device)
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
target_gcn.eval()
output = target_gcn.predict()

margin_dict = {}
for idx in idx_test:
Expand All @@ -113,42 +132,76 @@ def select_nodes():

return high + low + other

def multi_test():
# attack first 50 nodes in idx_test
def multi_test_poison():
# test on 40 nodes on poisoining attack
cnt = 0
degrees = adj.sum(0).A1
node_list = select_nodes()
num = len(node_list)
print('=== Attacking %s nodes respectively ===' % num)
print('=== [Poisoning] Attacking %s nodes respectively ===' % num)
for target_node in tqdm(node_list):
n_perturbations = int(degrees[target_node])
model = FGA(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device=device)
model = FGA(surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)
model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
acc = single_test(model.modified_adj, features, target_node)
model.attack(features, adj, labels, target_node, n_perturbations)
modified_adj = model.modified_adj
acc = single_test(modified_adj, features, target_node)
if acc == 0:
cnt += 1
print('misclassification rate : %s' % (cnt/num))

def single_test(adj, features, target_node):
''' test on GCN (poisoning attack)'''
gcn = GCN(nfeat=features.shape[1],
def single_test(adj, features, target_node, gcn=None):
if gcn is None:
# test on GCN (poisoning attack)
gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)

gcn = gcn.to(device)

gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
gcn.eval()
output = gcn.predict()
else:
# test on GCN (evasion attack)
output = gcn.predict(features, adj)
probs = torch.exp(output[[target_node]])

# acc_test = accuracy(output[[target_node]], labels[target_node])
acc_test = (output.argmax(1)[target_node] == labels[target_node])
return acc_test.item()

def multi_test_evasion():
# test on 40 nodes on evasion attack
target_gcn = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)

gcn = gcn.to(device)
target_gcn = target_gcn.to(device)

gcn.fit(features, adj, labels, idx_train)
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)

gcn.eval()
output = gcn.predict()
probs = torch.exp(output[[target_node]])
acc_test = accuracy(output[[target_node]], labels[target_node])
# print("Test set results:", "accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
cnt = 0
degrees = adj.sum(0).A1
node_list = select_nodes(target_gcn)
num = len(node_list)

print('=== [Evasion] Attacking %s nodes respectively ===' % num)
for target_node in tqdm(node_list):
n_perturbations = int(degrees[target_node])
model = FGA(surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)
model.attack(features, adj, labels, target_node, n_perturbations)
modified_adj = model.modified_adj

acc = single_test(modified_adj, features, target_node, gcn=target_gcn)
if acc == 0:
cnt += 1
print('misclassification rate : %s' % (cnt/num))

if __name__ == '__main__':
main()
multi_test()
multi_test_evasion()
multi_test_poison()
Loading

0 comments on commit fb047ac

Please sign in to comment.