import logging
import random
import time
from io import BytesIO
from typing import Any, Dict, List, Union
logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
import numpy as np
import ray
import tenseal as ts
import torch
import torch.nn.functional as F
import torch_geometric
from huggingface_hub import hf_hub_download
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torchmetrics.functional.retrieval import retrieval_auroc
from torchmetrics.retrieval import RetrievalHitRate
from fedgraph.gnn_models import (
GCN,
GIN,
GNN_LP,
AggreGCN,
AggreGCN_Arxiv,
GCN_arxiv,
SAGE_products,
)
from fedgraph.train_func import test, train
from fedgraph.utils_lp import (
check_data_files_existance,
get_data,
get_data_loaders_per_time_step,
get_global_user_item_mapping,
)
from fedgraph.utils_nc import get_1hop_feature_sum
[docs]
def load_trainer_data_from_hugging_face(trainer_id, args):
repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}"
def download_and_load_tensor(file_name):
file_path = hf_hub_download(
repo_id=repo_name, repo_type="dataset", filename=file_name
)
with open(file_path, "rb") as f:
buffer = BytesIO(f.read())
tensor = torch.load(buffer)
print(f"Loaded {file_name}, size: {tensor.size()}")
return tensor
print(f"Loading client data {trainer_id}")
local_node_index = download_and_load_tensor("local_node_index.pt")
communicate_node_global_index = download_and_load_tensor(
"communicate_node_index.pt"
)
global_edge_index_client = download_and_load_tensor("adj.pt")
train_labels = download_and_load_tensor("train_labels.pt")
test_labels = download_and_load_tensor("test_labels.pt")
features = download_and_load_tensor("features.pt")
in_com_train_node_local_indexes = download_and_load_tensor("idx_train.pt")
in_com_test_node_local_indexes = download_and_load_tensor("idx_test.pt")
return (
local_node_index,
communicate_node_global_index,
global_edge_index_client,
train_labels,
test_labels,
features,
in_com_train_node_local_indexes,
in_com_test_node_local_indexes,
)
[docs]
class Trainer_General:
"""
A general trainer class for training GCN in a federated learning setup, which includes functionalities
required for training GCN models on a subset of a distributed dataset, handling local training and testing,
parameter updates, and feature aggregation.
Parameters
----------
rank : int
Unique identifier for the training instance (typically representing a trainer in federated learning).
local_node_index : torch.Tensor
Indices of nodes local to this trainer.
communicate_node_index : torch.Tensor
Indices of nodes that participate in communication during training.
adj : torch.Tensor
The adjacency matrix representing the graph structure.
train_labels : torch.Tensor
Labels of the training data.
test_labels : torch.Tensor
Labels of the testing data.
features : torch.Tensor
Node features for the entire graph.
idx_train : torch.Tensor
Indices of training nodes.
idx_test : torch.Tensor
Indices of test nodes.
args_hidden : int
Number of hidden units in the GCN model.
global_node_num : int
Total number of nodes in the global graph.
class_num : int
Number of classes for classification.
device : torch.device
The device (CPU or GPU) on which the model will be trained.
args : Any
Additional arguments required for model initialization and training.
"""
def __init__(
self,
rank: int,
# local_node_index: torch.Tensor,
# communicate_node_index: torch.Tensor,
# adj: torch.Tensor,
# train_labels: torch.Tensor,
# test_labels: torch.Tensor,
# features: torch.Tensor,
# idx_train: torch.Tensor,
# idx_test: torch.Tensor,
args_hidden: int,
# global_node_num: int,
# class_num: int,
device: torch.device,
args: Any,
local_node_index: torch.Tensor = None,
communicate_node_index: torch.Tensor = None,
adj: torch.Tensor = None,
train_labels: torch.Tensor = None,
test_labels: torch.Tensor = None,
features: torch.Tensor = None,
idx_train: torch.Tensor = None,
idx_test: torch.Tensor = None,
):
# from gnn_models import GCN_Graph_Classification
torch.manual_seed(rank)
if (
local_node_index is None
or communicate_node_index is None
or adj is None
or train_labels is None
or test_labels is None
or features is None
or idx_train is None
or idx_test is None
):
(
local_node_index,
communicate_node_index,
adj,
train_labels,
test_labels,
features,
idx_train,
idx_test,
) = load_trainer_data_from_hugging_face(rank, args)
self.rank = rank # rank = trainer ID
self.device = device
self.criterion = torch.nn.CrossEntropyLoss()
self.train_losses: list = []
self.train_accs: list = []
self.test_losses: list = []
self.test_accs: list = []
self.local_node_index = local_node_index.to(device)
self.communicate_node_index = communicate_node_index.to(device)
self.adj = adj.to(device)
self.train_labels = train_labels.to(device)
self.test_labels = test_labels.to(device)
self.features = features.to(device)
self.idx_train = idx_train.to(device)
self.idx_test = idx_test.to(device)
self.local_step = args.local_step
self.args_hidden = args_hidden
# self.global_node_num = global_node_num
# self.class_num = class_num
self.args = args
self.model = None
self.feature_aggregation = None
if self.args.method == "FedAvg":
# print("Loading feature as the feature aggregation for fedavg method")
self.feature_aggregation = self.features
[docs]
def get_info(self):
# assert self.train_labels.numel() > 0, "train_labels is empty"
# assert self.test_labels.numel() > 0, "test_labels is empty"
return {
"features_num": len(self.features),
"label_num": max(
self.train_labels.max().item(), self.test_labels.max().item()
)
+ 1,
"feature_shape": self.features.shape[1],
"len_in_com_train_node_local_indexes": len(self.idx_train),
"len_in_com_test_node_local_indexes": len(self.idx_test),
"communicate_node_global_index": self.communicate_node_index,
}
[docs]
def init_model(self, global_node_num, class_num):
self.global_node_num = global_node_num
self.class_num = class_num
self.feature_shape = None
self.scale_factor = 1e3
self.param_history = []
# seems that new trainer process will not inherit sys.path from parent, need to reimport!
if self.args.num_hops >= 1:
if self.args.dataset == "ogbn-arxiv":
print("running AggreGCN_Arxiv")
self.model = AggreGCN_Arxiv(
nfeat=self.features.shape[1],
nhid=self.args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=self.args.num_layers,
).to(self.device)
else:
self.model = AggreGCN(
nfeat=self.features.shape[1],
nhid=self.args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=self.args.num_layers,
).to(self.device)
else:
if "ogbn" in self.args.dataset: # all ogbn large datasets
print("Running GCN_arxiv")
self.model = GCN_arxiv(
nfeat=self.features.shape[1],
nhid=self.args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=self.args.num_layers,
).to(self.device)
elif self.args.dataset == "ogbn-products": # ogbn not coming here
self.model = SAGE_products(
nfeat=self.features.shape[1],
nhid=self.args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=self.args.num_layers,
).to(self.device)
else: # small datasets
self.model = GCN(
nfeat=self.features.shape[1],
nhid=self.args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=self.args.num_layers,
).to(self.device)
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=self.args.learning_rate, weight_decay=5e-4
)
[docs]
@torch.no_grad()
def update_params(self, params: tuple, current_global_epoch: int) -> None:
"""
Updates the model parameters with global parameters received from the server.
Parameters
----------
params : tuple
A tuple containing the global parameters from the server.
current_global_epoch : int
The current global epoch number.
"""
# load global parameter from global server
self.model.to("cpu")
for (
p,
mp,
) in zip(params, self.model.parameters()):
mp.data = p
self.model.to(self.device)
[docs]
def verify_param_ranges(self, params, stage="pre-encryption"):
"""Verify parameter ranges and print statistics"""
stats = []
for i, p in enumerate(params):
if isinstance(p, torch.Tensor):
p = p.detach().cpu()
stats.append(
{
"layer": i,
"min": float(p.min()),
"max": float(p.max()),
"mean": float(p.mean()),
"std": float(p.std()),
}
)
print(f"{stage} Layer {i} stats:")
print(f"Range: [{stats[-1]['min']:.6f}, {stats[-1]['max']:.6f}]")
print(f"Mean: {stats[-1]['mean']:.6f}")
print(f"Std: {stats[-1]['std']:.6f}")
return stats
[docs]
def get_local_feature_sum(self) -> torch.Tensor:
"""
Computes the sum of features of all 1-hop neighbors for each node and normalizes the result.
Returns
-------
normalized_sum : torch.Tensor
The normalized sum of features of 1-hop neighbors for each node
"""
# Create a large matrix with known local node features
new_feature_for_trainer = torch.zeros(
self.global_node_num, self.features.shape[1]
).to(self.device)
new_feature_for_trainer[self.local_node_index] = self.features
# Sum of features of all 1-hop nodes for each node
one_hop_neighbor_feature_sum = get_1hop_feature_sum(
new_feature_for_trainer, self.adj, self.device
)
if self.args.use_encryption:
print(
f"Trainer {self.rank} - Original feature sum (first 10 and last 10 elements): "
f"{one_hop_neighbor_feature_sum.flatten()[:10].tolist()} ... {one_hop_neighbor_feature_sum.flatten()[-10:].tolist()}"
)
return one_hop_neighbor_feature_sum
[docs]
def get_local_feature_sum_og(self) -> torch.Tensor:
"""
Computes the sum of features of all 1-hop neighbors for each node, used for plain text version.
Returns
-------
one_hop_neighbor_feature_sum : torch.Tensor
The sum of features of 1-hop neighbors for each node
"""
computation_start = time.time()
new_feature_for_trainer = torch.zeros(
self.global_node_num, self.features.shape[1]
).to(self.device)
new_feature_for_trainer[self.local_node_index] = self.features
one_hop_neighbor_feature_sum = get_1hop_feature_sum(
new_feature_for_trainer, self.adj, self.device
)
computation_time = time.time() - computation_start
data_size = (
one_hop_neighbor_feature_sum.element_size()
* one_hop_neighbor_feature_sum.nelement()
)
print(f"Trainer {self.rank} - Computation time: {computation_time:.4f} seconds")
print(f"Trainer {self.rank} - Data size: {data_size / 1024:.2f} KB")
print(f"Trainer {self.rank} - Feature sum statistics:")
print(f"Shape: {one_hop_neighbor_feature_sum.shape}")
print(f"Mean: {one_hop_neighbor_feature_sum.mean().item():.6f}")
print(f"Std: {one_hop_neighbor_feature_sum.std().item():.6f}")
print(f"Min: {one_hop_neighbor_feature_sum.min().item():.6f}")
print(f"Max: {one_hop_neighbor_feature_sum.max().item():.6f}")
print(f"Non-zeros: {(one_hop_neighbor_feature_sum != 0).sum().item()}")
return one_hop_neighbor_feature_sum, computation_time, data_size
[docs]
def load_feature_aggregation(self, feature_aggregation: torch.Tensor) -> None:
"""
Loads the aggregated features into the trainer. Used for plain text version
Parameters
----------
feature_aggregation : torch.Tensor
The aggregated features to be loaded.
"""
# load_start = time.time()
self.feature_aggregation = feature_aggregation.float()
# load_time = time.time() - load_start
# data_size = (
# self.feature_aggregation.element_size()
# * self.feature_aggregation.nelement()
# )
# print(f"Trainer {self.rank} - Load time: {load_time:.4f} seconds")
# print(f"Trainer {self.rank} - Data size: {data_size / 1024:.2f} KB")
# return load_time
[docs]
def encrypt_feature_sum(self, feature_sum):
feature_sum = self.get_local_feature_sum()
# does not scale
flattened_sum = feature_sum.flatten()
enc_sum = ts.ckks_vector(self.he_context, flattened_sum.tolist()).serialize()
return enc_sum, feature_sum.shape
[docs]
def decrypt_feature_sum(self, encrypted_sum, shape):
decrypted_rows = [
ts.ckks_vector_from(self.he_context, enc_row).decrypt()
for enc_row in encrypted_sum
]
decrypted_array = np.array(decrypted_rows)
return torch.from_numpy(decrypted_array).float().reshape(shape)
[docs]
def get_encrypted_local_feature_sum(self):
# Same feature sum computation as original
new_feature_for_trainer = torch.zeros(
self.global_node_num, self.features.shape[1]
).to(self.device)
new_feature_for_trainer[self.local_node_index] = self.features
feature_sum = get_1hop_feature_sum(
new_feature_for_trainer, self.adj, self.device
)
# Encrypt the feature sum
encryption_start = time.time()
flattened = feature_sum.flatten().tolist()
encrypted = ts.ckks_vector(self.he_context, flattened).serialize()
encryption_time = time.time() - encryption_start
return encrypted, feature_sum.shape, encryption_time
[docs]
def load_encrypted_feature_aggregation(self, encrypted_data):
encrypted_sum, shape = encrypted_data
decryption_start = time.time()
decrypted = ts.ckks_vector_from(self.he_context, encrypted_sum).decrypt()
# reshape and store
self.feature_aggregation = torch.tensor(decrypted).reshape(shape)[
self.communicate_node_index
]
return time.time() - decryption_start
[docs]
def get_encrypted_params(self):
"""Get encrypted parameters with proper scaling"""
params_list = []
metadata = []
for param in self.model.parameters():
param_data = param.cpu().detach()
# scale
max_abs_val = torch.max(torch.abs(param_data))
scale = 1e3 if max_abs_val < 1e-3 else 1e2
scaled_params = (param_data * scale).flatten().tolist()
encrypted = ts.ckks_vector(self.he_context, scaled_params).serialize()
params_list.append(encrypted)
metadata.append({"shape": param_data.shape, "scale": scale})
return params_list, metadata
[docs]
def load_encrypted_params(self, encrypted_data: tuple, current_global_epoch: int):
"""Load encrypted parameters with rescaling"""
params_list, metadata = encrypted_data
self.model.to("cpu")
# load each layer's parameters
for param, enc_param, meta in zip(
self.model.parameters(), params_list, metadata
):
decrypted = ts.ckks_vector_from(self.he_context, enc_param).decrypt()
param_data = torch.tensor(decrypted).reshape(meta["shape"])
param_data = param_data / meta["scale"] # Reverse scaling
param.data.copy_(param_data)
self.model.to(self.device)
return True
[docs]
def use_fedavg_feature(self) -> None:
self.feature_aggregation
[docs]
def relabel_adj(self) -> None:
"""
Relabels the adjacency matrix based on the communication node index.
"""
# print(f"Max value in adj: {self.adj.max()}")
# print(
# f"Max value in communicate_node_index: {self.communicate_node_index.max()}"
# )
# distinct_values = torch.unique(self.adj.flatten())
# print(f"Number of distinct values in adj: {distinct_values.numel()}")
# print(f"distinct local: {len(self.local_node_index)}")
# print(f"distinct communic: {len(self.communicate_node_index)}")
# time.sleep(30)
_, self.adj, __, ___ = torch_geometric.utils.k_hop_subgraph(
self.communicate_node_index, 0, self.adj, relabel_nodes=True
)
# print(f"Max value in adj: {self.adj.max()}")
# print(
# f"Max value in communicate_node_index: {self.communicate_node_index.max()}"
# )
# distinct_values = torch.unique(self.adj.flatten())
# print(f"Number of distinct values in adj: {distinct_values.numel()}")
# print(f"distinct communic: {len(self.communicate_node_index)}")
[docs]
def train(self, current_global_round: int) -> None:
"""
Performs local training for a specified number of iterations. This method
updates the model using the loaded feature aggregation and the adjacency matrix.
Parameters
----------
current_global_round : int
The current global training round.
"""
# clean cache
torch.cuda.empty_cache()
assert self.model is not None
self.model.to(self.device)
if self.feature_aggregation is None:
raise ValueError(
"feature_aggregation has not been set. Ensure pre-training communication is completed."
)
self.feature_aggregation = self.feature_aggregation.to(self.device)
if hasattr(self.args, "batch_size") and self.args.batch_size > 0:
# batch preparation
train_mask = torch.zeros(
self.feature_aggregation.size(0), dtype=torch.bool
).to(self.device)
train_mask[self.idx_train] = True
node_labels = torch.full(
(self.feature_aggregation.size(0),), -1, dtype=torch.long
).to(self.device)
mask_indices = train_mask.nonzero(as_tuple=True)[0].to(self.device)
node_labels[train_mask] = self.train_labels[: len(mask_indices)]
data = Data(
x=self.feature_aggregation,
edge_index=self.adj,
train_mask=train_mask,
y=node_labels,
)
for iteration in range(self.local_step):
self.model.train()
if hasattr(self.args, "batch_size") and self.args.batch_size > 0:
# print(f"Training with batch size {self.args.batch_size}")
loader = NeighborLoader(
data,
num_neighbors=[-1] * self.args.num_layers,
batch_size=2048,
input_nodes=self.idx_train,
shuffle=False,
num_workers=0,
)
batch_iter = iter(loader)
batch = next(batch_iter, None)
while batch is not None:
batch_feature_aggregation = batch.x
batch_adj_matrix = batch.edge_index
# print(f"Batch Feature Aggregation (Node Features): {batch_feature_aggregation.size()}")
# print(f"Batch Adjacency Matrix (Edge Index): {batch_adj_matrix}")
# print(f"Training Labels (Filtered by train_mask): {batch.y[batch.train_mask]}")
# print(f"Train Mask: {batch.train_mask}")
loss_train, acc_train = train(
iteration,
self.model,
self.optimizer,
batch_feature_aggregation,
batch_adj_matrix,
batch.y[batch.train_mask],
batch.train_mask,
)
# print(f"acc_train: {acc_train}")
self.train_losses.append(loss_train)
self.train_accs.append(acc_train)
batch = next(batch_iter, None)
else:
# print("Training with full batch")
# print(f"feature_aggregation size: {self.feature_aggregation.size()}")
# print(f"adj shape: {self.adj.size()}")
# print(f"Max value in adj: {self.adj.max()}")
# print(f"Max value in communicate_node_index: {self.communicate_node_index.max()}")
# Assuming class_num is the number of classes
train_labels = self.train_labels
class_num = self.class_num
assert (
train_labels.min() >= 0
), f"train_labels contains negative values: {train_labels.min()}"
assert (
train_labels.max() < class_num
), f"train_labels contains a value out of range: {train_labels.max()} (number of classes: {class_num})"
# time.sleep(30)
loss_train, acc_train = train(
iteration,
self.model,
self.optimizer,
self.feature_aggregation,
self.adj,
self.train_labels,
self.idx_train,
)
self.train_losses.append(loss_train)
self.train_accs.append(acc_train)
# print(f"acc_train: {acc_train}")
loss_test, acc_test = self.local_test()
self.test_losses.append(loss_test)
self.test_accs.append(acc_test)
# print(f"current round: {current_global_round}, acc_test: {acc_test}")
[docs]
def local_test(self) -> list:
"""
Evaluates the model on the local test dataset.
Returns
-------
(list) : list
A list containing the test loss and accuracy [local_test_loss, local_test_acc].
"""
local_test_loss, local_test_acc = test(
self.model,
self.feature_aggregation,
self.adj,
self.test_labels,
self.idx_test,
)
return [local_test_loss, local_test_acc]
[docs]
def get_params(self) -> tuple:
"""
Retrieves the current parameters of the model.
Returns
-------
(tuple) : tuple
A tuple containing the current parameters of the model.
"""
self.optimizer.zero_grad(set_to_none=True)
return tuple(self.model.parameters())
[docs]
def get_all_loss_accuray(self) -> list:
"""
Returns all recorded training and testing losses and accuracies.
Returns
-------
(list) : list
A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.
"""
return [
np.array(self.train_losses),
np.array(self.train_accs),
np.array(self.test_losses),
np.array(self.test_accs),
]
[docs]
def get_rank(self) -> int:
"""
Returns the rank (trainer ID) of the trainer.
Returns
-------
(int) : int
The rank (trainer ID) of this trainer instance.
"""
return self.rank
[docs]
class Trainer_GC:
"""
A trainer class specified for graph classification tasks, which includes functionalities required
for training GIN models on a subset of a distributed dataset, handling local training and testing,
parameter updates, and feature aggregation.
Parameters
----------
model: object
The model to be trained, which is based on the GIN model.
trainer_id: int
The ID of the trainer.
trainer_name: str
The name of the trainer.
train_size: int
The size of the training dataset.
dataLoader: dict
The dataloaders for training, validation, and testing.
optimizer: object
The optimizer for training.
args: Any
The arguments for the training.
Attributes
----------
model: object
The model to be trained, which is based on the GIN model.
id: int
The ID of the trainer.
name: str
The name of the trainer.
train_size: int
The size of the training dataset.
dataloader: dict
The dataloaders for training, validation, and testing.
optimizer: object
The optimizer for training.
args: object
The arguments for the training.
W: dict
The weights of the model.
dW: dict
The gradients of the model.
W_old: dict
The cached weights of the model.
gconv_names: list
The names of the gconv layers.
train_stats: Any
The training statistics of the model.
weights_norm: float
The norm of the weights of the model.
grads_norm: float
The norm of the gradients of the model.
conv_grads_norm: float
The norm of the gradients of the gconv layers.
conv_weights_Norm: float
The norm of the weights of the gconv layers.
conv_dWs_norm: float
The norm of the gradients of the gconv layers.
"""
def __init__(
self,
model: Any,
trainer_id: int,
trainer_name: str,
train_size: int,
dataloader: dict,
optimizer: object,
args: Any,
) -> None:
self.model = model.to(args.device)
self.id = trainer_id
self.name = trainer_name
self.train_size = train_size
self.dataloader = dataloader
self.optimizer = optimizer
self.args = args
self.W = {key: value for key, value in self.model.named_parameters()}
self.dW = {
key: torch.zeros_like(value) for key, value in self.model.named_parameters()
}
self.W_old = {
key: value.data.clone() for key, value in self.model.named_parameters()
}
self.gconv_names: Any = None
self.train_stats: dict[str, list[Any]] = {
"trainingAccs": [],
"valAccs": [],
"trainingLosses": [],
"valLosses": [],
"testAccs": [],
"testLosses": [],
}
self.weights_norm = 0.0
self.grads_norm = 0.0
self.conv_grads_norm = 0.0
self.conv_weights_norm = 0.0
self.conv_dWs_norm = 0.0
########### Public functions ###########
[docs]
def update_params(self, server_params: Any) -> None:
"""
Update the model parameters by downloading the global model weights from the server.
Parameters
----------
server: Server_GC
The server object that contains the global model weights.
"""
self.gconv_names = server_params.keys() # gconv layers
for k in server_params:
self.W[k].data = server_params[k].data.clone()
[docs]
def reset_params(self) -> None:
"""
Reset the weights of the model to the cached weights.
The implementation is copying the cached weights (W_old) to the model weights (W).
"""
self.__copy_weights(target=self.W, source=self.W_old, keys=self.gconv_names)
[docs]
def cache_weights(self) -> None:
"""
Cache the weights of the model.
The implementation is copying the model weights (W) to the cached weights (W_old).
"""
for name in self.W.keys():
self.W_old[name].data = self.W[name].data.clone()
[docs]
def compute_update_norm(self, keys: dict) -> float:
"""
Compute the max update norm (i.e., dW) for the trainer
"""
dW = {}
for k in keys:
dW[k] = self.dW[k]
curr_dW = torch.norm(
torch.cat([value.flatten() for value in dW.values()])
).item()
return curr_dW
[docs]
def compute_mean_norm(self, total_size: int, keys: dict) -> torch.Tensor:
"""
Compute the mean update norm (i.e., dW) for the trainer
Returns
-------
curr_dW: Tensor
"""
dW = {}
for k in keys:
dW[k] = self.dW[k] * self.train_size / total_size
curr_dW = torch.cat([value.flatten() for value in dW.values()])
return curr_dW
[docs]
def set_stats_norms(self, train_stats: Any, is_gcfl: bool = False) -> None:
"""
Set the norms of the weights and gradients of the model, as well as the statistics of the training.
Parameters
----------
train_stats: dict
The training statistics of the model.
is_gcfl: bool, optional
Whether the training is for GCFL. The default is False.
"""
self.train_stats = train_stats
self.weights_norm = torch.norm(self.__flatten(self.W)).item()
if self.gconv_names is not None:
weights_conv = {key: self.W[key] for key in self.gconv_names}
self.conv_weights_norm = torch.norm(self.__flatten(weights_conv)).item()
grads_conv = {key: self.W[key].grad for key in self.gconv_names}
self.conv_grads_norm = torch.norm(self.__flatten(grads_conv)).item()
grads = {key: value.grad for key, value in self.W.items()}
self.grads_norm = torch.norm(self.__flatten(grads)).item()
if is_gcfl and self.gconv_names is not None:
dWs_conv = {key: self.dW[key] for key in self.gconv_names}
self.conv_dWs_norm = torch.norm(self.__flatten(dWs_conv)).item()
[docs]
def local_train(
self, local_epoch: int, train_option: str = "basic", mu: float = 1
) -> None:
"""
This function is a interface of the trainer class to train the model locally.
It will call the train function specified for the training option, based on the args provided.
Parameters
----------
local_epoch: int
The number of local epochs
train_option: str, optional
The training option. The possible values are 'basic', 'prox', and 'gcfl'. The default is 'basic'.
'basic' - self-train and FedAvg
'prox' - FedProx that includes the proximal term
'gcfl' - GCFL, GCFL+ and GCFL+dWs
mu: float, optional
The proximal term. The default is 1.
"""
assert train_option in ["basic", "prox", "gcfl"], "Invalid training option."
if train_option == "gcfl":
self.__copy_weights(target=self.W_old, source=self.W, keys=self.gconv_names)
if train_option in ["basic", "prox"]:
train_stats = self.__train(
model=self.model,
dataloaders=self.dataloader,
optimizer=self.optimizer,
local_epoch=local_epoch,
device=self.args.device,
)
elif train_option == "gcfl":
train_stats = self.__train(
model=self.model,
dataloaders=self.dataloader,
optimizer=self.optimizer,
local_epoch=local_epoch,
device=self.args.device,
prox=True,
gconv_names=self.gconv_names,
Ws=self.W,
Wt=self.W_old,
mu=mu,
)
if train_option == "gcfl":
self.__subtract_weights(
target=self.dW, minuend=self.W, subtrahend=self.W_old
)
self.set_stats_norms(train_stats)
[docs]
def local_test(self, test_option: str = "basic", mu: float = 1) -> tuple:
"""
Final test of the model on the test dataset based on the test option.
Parameters
----------
test_option: str, optional
The test option. The possible values are 'basic' and 'prox'. The default is 'basic'.
'basic' - self-train, FedAvg, GCFL, GCFL+ and GCFL+dWs
'prox' - FedProx that includes the proximal term
mu: float, optional
The proximal term. The default is 1.
Returns
-------
(test_loss, test_acc, trainer_name, trainingAccs, valAccs): tuple(float, float, string, float, float)
The average loss and accuracy, trainer's name, trainer.train_stats["trainingAccs"][-1], trainer.train_stats["valAccs"][-1]
"""
assert test_option in ["basic", "prox"], "Invalid test option."
if test_option == "basic":
return self.__eval(
model=self.model,
test_loader=self.dataloader["test"],
device=self.args.device,
)
elif test_option == "prox":
return self.__eval(
model=self.model,
test_loader=self.dataloader["test"],
device=self.args.device,
prox=True,
gconv_names=self.gconv_names,
mu=mu,
Wt=self.W_old,
)
else:
raise ValueError("Invalid test option.")
[docs]
def get_train_size(self) -> int:
return self.train_size
[docs]
def get_weights(self, ks: Any) -> dict[str, Any]:
data: dict[str, Any] = {}
W = {}
dW = {}
for k in ks:
W[k], dW[k] = self.W[k], self.dW[k]
data["W"] = W
data["dW"] = dW
data["train_size"] = self.train_size
return data
[docs]
def get_total_weight(self) -> Any:
return self.W
[docs]
def get_dW(self) -> Any:
return self.dW
[docs]
def get_name(self) -> str:
return self.name
[docs]
def get_id(self) -> Any:
return self.id
[docs]
def get_conv_grads_norm(self) -> Any:
return self.conv_grads_norm
[docs]
def get_conv_dWs_norm(self) -> Any:
return self.conv_dWs_norm
########### Private functions ###########
def __train(
self,
model: Any,
dataloaders: dict,
optimizer: Any,
local_epoch: int,
device: str,
prox: bool = False,
gconv_names: Any = None,
Ws: Any = None,
Wt: Any = None,
mu: float = 0,
) -> dict:
"""
Train the model on the local dataset.
Parameters
----------
model: object
The model to be trained
dataloaders: dict
The dataloaders for training, validation, and testing
optimizer: Any
The optimizer for training
local_epoch: int
The number of local epochs
device: str
The device to run the training
prox: bool, optional
Whether to add the proximal term. The default is False.
gconv_names: Any, optional
The names of the gconv layers. The default is None.
Ws: Any, optional
The weights of the model. The default is None.
Wt: Any, optional
The target weights. The default is None.
mu: float, optional
The proximal term. The default is 0.
Returns
-------
(results): dict
The training statistics
Note
----
If prox is True, the function will add the proximal term to the loss function.
Make sure to provide the required arguments `gconv_names`, `Ws`, `Wt`, and `mu` for the proximal term.
"""
if prox:
assert (
(gconv_names is not None)
and (Ws is not None)
and (Wt is not None)
and (mu != 0)
), "Please provide the required arguments for the proximal term."
losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = (
[],
[],
[],
[],
[],
[],
)
if prox:
convGradsNorm = []
train_loader, val_loader, test_loader = (
dataloaders["train"],
dataloaders["val"],
dataloaders["test"],
)
for _ in range(local_epoch):
model.train()
loss_train, acc_train, num_graphs = 0.0, 0.0, 0
for _, batch in enumerate(train_loader):
batch.to(device)
optimizer.zero_grad()
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
loss += (
mu / 2.0 * self.__prox_term(model, gconv_names, Wt) if prox else 0.0
) # add the proximal term if required
loss.backward()
optimizer.step()
loss_train += loss.item() * batch.num_graphs
acc_train += pred.max(dim=1)[1].eq(label).sum().item()
num_graphs += batch.num_graphs
loss_train /= num_graphs # get the average loss per graph
acc_train /= num_graphs # get the average average per graph
loss_val, acc_val, _, _, _ = self.__eval(model, val_loader, device)
loss_test, acc_test, _, _, _ = self.__eval(model, test_loader, device)
losses_train.append(loss_train)
accs_train.append(acc_train)
losses_val.append(loss_val)
accs_val.append(acc_val)
losses_test.append(loss_test)
accs_test.append(acc_test)
if prox:
convGradsNorm.append(self.__calc_grads_norm(gconv_names, Ws))
# record the losses and accuracies for each epoch
res_dict = {
"trainingLosses": losses_train,
"trainingAccs": accs_train,
"valLosses": losses_val,
"valAccs": accs_val,
"testLosses": losses_test,
"testAccs": accs_test,
}
if prox:
res_dict["convGradsNorm"] = convGradsNorm
return res_dict
def __eval(
self,
model: GIN,
test_loader: Any,
device: str,
prox: bool = False,
gconv_names: Any = None,
mu: float = 0,
Wt: Any = None,
) -> tuple:
"""
Validate and test the model on the local dataset.
Parameters
----------
model: GIN
The model to be tested
test_loader: Any
The dataloader for testing
device: str
The device to run the testing
prox: bool, optional
Whether to add the proximal term. The default is False.
gconv_names: Any, optional
The names of the gconv layers. The default is None.
mu: float, optional
The proximal term. The default is None.
Wt: Any, optional
The target weights. The default is None.
Returns
-------
(test_loss, test_acc, trainer_name, trainingAccs, valAccs): tuple(float, float, string, float, float)
The average loss and accuracy, trainer's name, trainer.train_stats["trainingAccs"][-1], trainer.train_stats["valAccs"][-1]
Note
----
If prox is True, the function will add the proximal term to the loss function.
Make sure to provide the required arguments `gconv_names`, `Ws`, `Wt`, and `mu` for the proximal term.
"""
if prox:
assert (
(gconv_names is not None) and (mu is not None) and (Wt != 0)
), "Please provide the required arguments for the proximal term."
model.eval()
total_loss, total_acc, num_graphs = 0.0, 0.0, 0
for batch in test_loader:
batch.to(device)
with torch.no_grad():
pred = model(batch)
label = batch.y
loss = model.loss(pred, label)
loss += (
mu / 2.0 * self.__prox_term(model, gconv_names, Wt) if prox else 0.0
)
total_loss += loss.item() * batch.num_graphs
total_acc += pred.max(dim=1)[1].eq(label).sum().item()
num_graphs += batch.num_graphs
current_training_acc = -1
current_val_acc = -1
if self.train_stats["trainingAccs"]:
current_training_acc = self.train_stats["trainingAccs"][-1]
if self.train_stats["valAccs"]:
current_val_acc = self.train_stats["valAccs"][-1]
return (
total_loss / num_graphs,
total_acc / num_graphs,
self.name,
current_training_acc, # if no data then return -1 for 1st train round
current_val_acc, # if no data then return -1 for 1st train round
)
def __prox_term(self, model: Any, gconv_names: Any, Wt: Any) -> torch.tensor:
"""
Compute the proximal term.
Parameters
----------
model: Any
The model to be trained
gconv_names: Any
The names of the gconv layers
Wt: Any
The target weights
Returns
-------
prox: torch.tensor
The proximal term
"""
prox = torch.tensor(0.0, requires_grad=True)
for name, param in model.named_parameters():
# only add the prox term for sharing layers (gConv)
if name in gconv_names:
prox = prox + torch.norm(param - Wt[name]).pow(
2
) # force the weights to be close to the old weights
return prox
def __calc_grads_norm(self, gconv_names: Any, Ws: Any) -> float:
"""
Calculate the norm of the gradients of the gconv layers.
Parameters
----------
model: Any
The model to be trained
gconv_names: Any
The names of the gconv layers
Wt: Any
The target weights
Returns
-------
convGradsNorm: float
The norm of the gradients of the gconv layers
"""
grads_conv = {k: Ws[k].grad for k in gconv_names}
convGradsNorm = torch.norm(self.__flatten(grads_conv)).item()
return convGradsNorm
def __copy_weights(
self, target: dict, source: dict, keys: Union[list, None]
) -> None:
"""
Copy the source weights to the target weights.
Parameters
----------
target: dict
The target weights
source: dict
The source weights
keys: list, optional
The keys to be copied. The default is None.
"""
if keys is not None:
for name in keys:
target[name].data = source[name].data.clone()
def __subtract_weights(self, target: dict, minuend: dict, subtrahend: dict) -> None:
"""
Subtract the subtrahend from the minuend and store the result in the target.
Parameters
----------
target: dict
The target weights
minuend: dict
The minuend
subtrahend: dict
The subtrahend
"""
for name in target:
target[name].data = (
minuend[name].data.clone() - subtrahend[name].data.clone()
)
def __flatten(self, w: dict) -> torch.tensor:
"""
Flatten the gradients of a trainer into a 1D tensor.
Parameters
----------
w: dict
The gradients of a trainer
"""
return torch.cat([v.flatten() for v in w.values()])
[docs]
def calculate_weighted_weight(self, key: Any) -> torch.tensor:
weighted_weight = torch.mul(self.W[key].data, self.train_size)
return weighted_weight
[docs]
class Trainer_LP:
"""
A trainer class specified for graph link prediction tasks, which includes functionalities required
for training GNN models on a subset of a distributed dataset, handling local training and testing,
parameter updates, and feature aggregation.
Parameters
----------
client_id : int
The ID of the client.
country_code : str
The country code of the client. Each client is associated with one country code.
user_id_mapping : dict
The mapping of user IDs.
item_id_mapping : dict
The mapping of item IDs.
number_of_users : int
The number of users.
number_of_items : int
The number of items.
meta_data : tuple
The metadata of the dataset.
hidden_channels : int, optional
The number of hidden channels in the GNN model. The default is 64.
"""
def __init__(
self,
client_id: int,
country_code: str,
user_id_mapping: dict,
item_id_mapping: dict,
number_of_users: int,
number_of_items: int,
meta_data: tuple,
dataset_path: str,
hidden_channels: int = 64,
):
self.client_id = client_id
self.country_code = country_code
print(f"checking code and file path: {country_code},{dataset_path}")
file_path = dataset_path
country_codes: List[str] = [self.country_code]
check_data_files_existance(country_codes, file_path)
# global user_id and item_id
self.data = get_data(
self.country_code, user_id_mapping, item_id_mapping, file_path
)
self.model = GNN_LP(
number_of_users, number_of_items, meta_data, hidden_channels
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: '{self.device}'")
self.model = self.model.to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
[docs]
def get_train_test_data_at_current_time_step(
self,
start_time_float_format: float,
end_time_float_format: float,
use_buffer: bool = False,
buffer_size: int = 10,
) -> None:
"""
Get the training and testing data at the current time step.
Parameters
----------
start_time_float_format : float
The start time in float format.
end_time_float_format : float
The end time in float format.
use_buffer : bool, optional
Whether to use the buffer. The default is False.
buffer_size : int, optional
The size of the buffer. The default is 10.
"""
print("loading buffer_train_data_list") if use_buffer else print(
"loading train_data and test_data"
)
load_res = get_data_loaders_per_time_step(
self.data,
start_time_float_format,
end_time_float_format,
use_buffer,
buffer_size,
)
if use_buffer:
(
self.global_train_data,
self.test_data,
self.buffer_train_data_list,
) = load_res
else:
self.train_data, self.test_data = load_res
[docs]
def train(
self, client_id: int, local_updates: int, use_buffer: bool = False
) -> tuple:
"""
Perform local training for a specified number of iterations.
Parameters
----------
local_updates : int
The number of local updates.
use_buffer : bool, optional
Whether to use the buffer. The default is False.
Returns
-------
(loss, train_finish_times) : tuple
[0] The loss of the model
[1] The time taken for each local update
"""
train_finish_times = []
if use_buffer:
probabilities = [1 / len(self.buffer_train_data_list)] * len(
self.buffer_train_data_list
)
for i in range(local_updates):
if use_buffer:
train_data = random.choices(
self.buffer_train_data_list, weights=probabilities, k=1
)[0].to(self.device)
else:
train_data = self.train_data.to(self.device)
start_train_time = time.time()
self.optimizer.zero_grad()
pred = self.model(train_data)
ground_truth = train_data["user", "select", "item"].edge_label
loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
loss.backward()
self.optimizer.step()
train_finish_time = time.time() - start_train_time
train_finish_times.append(train_finish_time)
print(
f"client {self.client_id} local steps {i} loss {loss:.4f} train time {train_finish_time:.4f}"
)
return client_id, loss, train_finish_times
[docs]
def test(self, clientId: int, use_buffer: bool = False) -> tuple:
"""
Test the model on the test data.
Parameters
----------
use_buffer : bool, optional
Whether to use the buffer. The default is False.
Returns
-------
(auc, hit_rate_at_2, traveled_user_hit_rate_at_2) : tuple
[0] The AUC score
[1] The hit rate at 2
[2] The hit rate at 2 for traveled users
"""
preds, ground_truths = [], []
self.test_data.to(self.device)
with torch.no_grad():
if not use_buffer:
self.train_data.to(self.device)
preds.append(self.model.pred(self.train_data, self.test_data))
else:
self.global_train_data.to(self.device)
preds.append(self.model.pred(self.global_train_data, self.test_data))
ground_truths.append(self.test_data["user", "select", "item"].edge_label)
pred = torch.cat(preds, dim=0)
ground_truth = torch.cat(ground_truths, dim=0)
auc = retrieval_auroc(pred, ground_truth)
hit_rate_evaluator = RetrievalHitRate(top_k=2)
hit_rate_at_2 = hit_rate_evaluator(
pred,
ground_truth,
indexes=self.test_data["user", "select", "item"].edge_label_index[0],
)
traveled_user_hit_rate_at_2 = hit_rate_evaluator(
pred[self.traveled_user_edge_indices],
ground_truth[self.traveled_user_edge_indices],
indexes=self.test_data["user", "select", "item"].edge_label_index[0][
self.traveled_user_edge_indices
],
)
print(f"Test AUC: {auc:.4f}")
print(f"Test Hit Rate at 2: {hit_rate_at_2:.4f}")
print(f"Test Traveled User Hit Rate at 2: {traveled_user_hit_rate_at_2:.4f}")
return clientId, auc, hit_rate_at_2, traveled_user_hit_rate_at_2
[docs]
def calculate_traveled_user_edge_indices(self, file_path: str) -> None:
"""
Calculate the indices of the edges of the traveled users.
Parameters
----------
file_path : str
The path to the file containing the traveled users.
"""
with open(file_path, "r") as a:
traveled_users = torch.tensor(
[int(line.split("\t")[0]) for line in a]
) # read the user IDs of the traveled users
mask = torch.isin(
self.test_data["user", "select", "item"].edge_label_index[0], traveled_users
) # mark the indices of the edges of the traveled users as True or False
self.traveled_user_edge_indices = torch.where(mask)[
0
] # get the indices of the edges of the traveled users
[docs]
def set_model_parameter(
self, model_state_dict: dict, gnn_only: bool = False
) -> None:
"""
Load the model parameters from the global server.
Parameters
----------
model_state_dict : dict
The model parameters to be loaded.
gnn_only : bool, optional
Whether to load only the GNN parameters. The default is False.
"""
if gnn_only:
self.model.gnn.load_state_dict(model_state_dict)
else:
self.model.load_state_dict(model_state_dict)
[docs]
def get_model_parameter(self, gnn_only: bool = False) -> dict:
"""
Get the model parameters.
Parameters
----------
gnn_only : bool, optional
Whether to get only the GNN parameters. The default is False.
Returns
-------
dict
The model parameters.
"""
if gnn_only:
return self.model.gnn.state_dict()
else:
return self.model.state_dict()