from typing import Any
import numpy as np
import ray
import torch
import torch_geometric
from fedgraph.gnn_models import GCN, AggreGCN, GCN_arxiv, SAGE_products
from fedgraph.train_func import test, train
from fedgraph.utils import get_1hop_feature_sum
[docs]
class Trainer_General:
"""
A general trainer class for training GCN in a federated learning setup, which includes functionalities
required for training GCN models on a subset of a distributed dataset, handling local training and testing,
parameter updates, and feature aggregation.
Parameters
----------
rank : int
Unique identifier for the training instance (typically representing a client in federated learning).
local_node_index : torch.Tensor
Indices of nodes local to this trainer.
communicate_node_index : torch.Tensor
Indices of nodes that participate in communication during training.
adj : torch.Tensor
The adjacency matrix representing the graph structure.
train_labels : torch.Tensor
Labels of the training data.
test_labels : torch.Tensor
Labels of the testing data.
features : torch.Tensor
Node features for the entire graph.
idx_train : torch.Tensor
Indices of training nodes.
idx_test : torch.Tensor
Indices of test nodes.
args_hidden : int
Number of hidden units in the GCN model.
global_node_num : int
Total number of nodes in the global graph.
class_num : int
Number of classes for classification.
device : torch.device
The device (CPU or GPU) on which the model will be trained.
args : Any
Additional arguments required for model initialization and training.
"""
def __init__(
self,
rank: int,
local_node_index: torch.Tensor,
communicate_node_index: torch.Tensor,
adj: torch.Tensor,
train_labels: torch.Tensor,
test_labels: torch.Tensor,
features: torch.Tensor,
idx_train: torch.Tensor,
idx_test: torch.Tensor,
args_hidden: int,
global_node_num: int,
class_num: int,
device: torch.device,
args: Any,
):
# from gnn_models import GCN_Graph_Classification
torch.manual_seed(rank)
# seems that new trainer process will not inherit sys.path from parent, need to reimport!
if args.num_hops >= 1 and args.fedtype == "fedgcn":
self.model = AggreGCN(
nfeat=features.shape[1],
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
).to(device)
else:
if args.dataset == "ogbn-arxiv":
self.model = GCN_arxiv(
nfeat=features.shape[1],
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
).to(device)
elif args.dataset == "ogbn-products":
self.model = SAGE_products(
nfeat=features.shape[1],
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
).to(device)
else:
self.model = GCN(
nfeat=features.shape[1],
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
).to(device)
self.rank = rank # rank = client ID
self.device = device
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=args.learning_rate, weight_decay=5e-4
)
self.criterion = torch.nn.CrossEntropyLoss()
self.train_losses: list = []
self.train_accs: list = []
self.test_losses: list = []
self.test_accs: list = []
self.local_node_index = local_node_index.to(device)
self.communicate_node_index = communicate_node_index.to(device)
self.adj = adj.to(device)
self.train_labels = train_labels.to(device)
self.test_labels = test_labels.to(device)
self.features = features.to(device)
self.idx_train = idx_train.to(device)
self.idx_test = idx_test.to(device)
self.local_step = args.local_step
self.global_node_num = global_node_num
self.class_num = class_num
[docs]
@torch.no_grad()
def update_params(self, params: tuple, current_global_epoch: int) -> None:
"""
Updates the model parameters with global parameters received from the server.
Parameters
----------
params : tuple
A tuple containing the global parameters from the server.
current_global_epoch : int
The current global epoch number.
"""
# load global parameter from global server
self.model.to("cpu")
for (
p,
mp,
) in zip(params, self.model.parameters()):
mp.data = p
self.model.to(self.device)
[docs]
def get_local_feature_sum(self) -> torch.Tensor:
"""
Computes the sum of features of all 1-hop neighbors for each node.
Returns
-------
one_hop_neighbor_feature_sum : torch.Tensor
The sum of features of 1-hop neighbors for each node
"""
# create a large matrix with known local node features
new_feature_for_client = torch.zeros(
self.global_node_num, self.features.shape[1]
)
new_feature_for_client[self.local_node_index] = self.features
# sum of features of all 1-hop nodes for each node
one_hop_neighbor_feature_sum = get_1hop_feature_sum(
new_feature_for_client, self.adj
)
return one_hop_neighbor_feature_sum
[docs]
def load_feature_aggregation(self, feature_aggregation: torch.Tensor) -> None:
"""
Loads the aggregated features into the trainer.
Parameters
----------
feature_aggregation : torch.Tensor
The aggregated features to be loaded.
"""
self.feature_aggregation = feature_aggregation
[docs]
def relabel_adj(self) -> None:
"""
Relabels the adjacency matrix based on the communication node index.
"""
_, self.adj, __, ___ = torch_geometric.utils.k_hop_subgraph(
self.communicate_node_index, 0, self.adj, relabel_nodes=True
)
[docs]
def train(self, current_global_round: int) -> None:
"""
Performs local training for a specified number of iterations. This method
updates the model using the loaded feature aggregation and the adjacency matrix.
Parameters
----------
current_global_round : int
The current global training round.
"""
# clean cache
torch.cuda.empty_cache()
for iteration in range(self.local_step):
self.model.train()
loss_train, acc_train = train(
iteration,
self.model,
self.optimizer,
self.feature_aggregation,
self.adj,
self.train_labels,
self.idx_train,
)
self.train_losses.append(loss_train)
self.train_accs.append(acc_train)
loss_test, acc_test = self.local_test()
self.test_losses.append(loss_test)
self.test_accs.append(acc_test)
[docs]
def local_test(self) -> list:
"""
Evaluates the model on the local test dataset.
Returns
-------
(list) : list
A list containing the test loss and accuracy [local_test_loss, local_test_acc].
"""
local_test_loss, local_test_acc = test(
self.model,
self.feature_aggregation,
self.adj,
self.test_labels,
self.idx_test,
)
return [local_test_loss, local_test_acc]
[docs]
def get_params(self) -> tuple:
"""
Retrieves the current parameters of the model.
Returns
-------
(tuple) : tuple
A tuple containing the current parameters of the model.
"""
self.optimizer.zero_grad(set_to_none=True)
return tuple(self.model.parameters())
[docs]
def get_all_loss_accuray(self) -> list:
"""
Returns all recorded training and testing losses and accuracies.
Returns
-------
(list) : list
A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.
"""
return [
np.array(self.train_losses),
np.array(self.train_accs),
np.array(self.test_losses),
np.array(self.test_accs),
]
[docs]
def get_rank(self) -> int:
"""
Returns the rank (client ID) of the trainer.
Returns
-------
(int) : int
The rank (client ID) of this trainer instance.
"""
return self.rank