Source code for fedgraph.server_class

import pickle
import random
import sys
import time
from importlib.resources import files
from typing import Any

import networkx as nx
import numpy as np
import ray
import tenseal as ts
import torch
from dtaidistance import dtw

from fedgraph.gnn_models import (
    GCN,
    GNN_LP,
    AggreGCN,
    AggreGCN_Arxiv,
    GCN_arxiv,
    SAGE_products,
)


[docs] class Server: """ This is a server class for federated learning which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- feature_dim : int The dimensionality of the feature vectors in the dataset. args_hidden : int The number of hidden units. class_num : int The number of classes for classification in the dataset. device : torch.device The device initialized for the server model. trainers : list[Trainer_General] A list of `Trainer_General` instances representing the trainers. args : Any Additional arguments required for initializing the server model and other configurations. Attributes ---------- model : [AggreGCN, GCN_arxiv, SAGE_products, GCN] The central GCN model that is trained in a federated manner. trainers : list[Trainer_General] The list of trainer instances. num_of_trainers : int The number of trainers. """ def __init__( self, feature_dim: int, args_hidden: int, class_num: int, device: torch.device, trainers: list, args: Any, ) -> None: self.args = args if self.args.num_hops >= 1: # Federated Methods if "ogbn-arxiv" in self.args.dataset: print("Running AggreGCN_Arxiv") self.model = AggreGCN_Arxiv( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=self.args.num_layers, ).to(device) else: self.model = AggreGCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=self.args.num_layers, ).to(device) else: # 0-hop FedAvg methods if "ogbn" in self.args.dataset: print("Running GCN_arxiv") self.model = GCN_arxiv( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=self.args.num_layers, ).to(device) elif self.args.dataset == "ogbn-products": self.model = SAGE_products( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=self.args.num_layers, ).to(device) else: self.model = GCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=self.args.num_layers, ).to(device) self.trainers = trainers self.num_of_trainers = len(trainers) self.use_encryption = args.use_encryption if args.use_encryption: file_path = str(files("fedgraph").joinpath("he_context.pkl")) with open(file_path, "rb") as f: context_bytes = pickle.load(f) self.he_context = ts.context_from(context_bytes) self.aggregation_stats = [] print("Loaded HE context with secret key.") self.device = device # self.broadcast_params(-1)
[docs] @torch.no_grad() def zero_params(self) -> None: """ Zeros out the parameters of the central model. """ for p in self.model.parameters(): p.zero_()
[docs] def prepare_params_for_encryption(self, params): processed_params = [] metadata = [] for param in params: param_min = param.min() param_max = param.max() param_range = param_max - param_min # handle division by 0 if param_range == 0: normalized = param - param_min else: normalized = (param - param_min) / param_range scaled = normalized * 1000 processed_params.append(scaled) metadata.append({"min": param_min, "range": param_range}) return processed_params, metadata
[docs] def aggregate_encrypted_feature_sums(self, encrypted_sums): aggregation_start = time.time() first_sum = ts.ckks_vector_from(self.he_context, encrypted_sums[0][0]) shape = encrypted_sums[0][1] for enc_sum, _ in encrypted_sums[1:]: next_sum = ts.ckks_vector_from(self.he_context, enc_sum) first_sum += next_sum return (first_sum.serialize(), shape), time.time() - aggregation_start
[docs] def aggregate_encrypted_params(self, encrypted_params_list): aggregation_start = time.time() first_params, metadata = encrypted_params_list[0] n_layers = len(first_params) # each layer aggregated_params = [] for layer_idx in range(n_layers): agg_layer = ts.ckks_vector_from( self.he_context, encrypted_params_list[0][0][layer_idx] ) for trainer_params, _ in encrypted_params_list[1:]: next_layer = ts.ckks_vector_from( self.he_context, trainer_params[layer_idx] ) agg_layer += next_layer # average agg_layer *= 1.0 / self.num_of_trainers aggregated_params.append(agg_layer.serialize()) aggregation_time = time.time() - aggregation_start return aggregated_params, metadata, aggregation_time
[docs] def get_encrypted_params(self): params = [p.data.cpu().detach() for p in self.model.parameters()] # normalize and scale processed_params, metadata = self.prepare_params_for_encryption(params) encrypted_params = [] for param in processed_params: param_list = param.flatten().tolist() encrypted = ts.ckks_vector(self.he_context, param_list).serialize() encrypted_params.append(encrypted) return encrypted_params, metadata
[docs] @torch.no_grad() def train( self, current_global_epoch: int, sampling_type: str = "random", sample_ratio: float = 1, ) -> None: """ Training round which performs aggregating parameters from sampled trainers (by index), updating the central model, and then broadcasting the updated parameters back to all trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ if self.use_encryption: if not hasattr(self, "aggregation_stats"): self.aggregation_stats = [] train_refs = [ trainer.train.remote(current_global_epoch) for trainer in self.trainers ] ray.get(train_refs) encryption_start = time.time() print("Starting encrypted parameter aggregation...") encrypted_params = [ trainer.get_encrypted_params.remote() for trainer in self.trainers ] # Wait for all trainers and collect parameters params_list = [] encryption_times = [] enc_sizes = [] while encrypted_params: ready, encrypted_params = ray.wait(encrypted_params) result = ray.get(ready[0]) params_list.append(result) enc_size = sum( len(p) for p in result[0] ) # Size of encrypted parameters enc_sizes.append(enc_size) encryption_time = time.time() - encryption_start # Aggregate parameters aggregated_params, metadata, agg_time = self.aggregate_encrypted_params( params_list ) print(f"Parameter aggregation completed in {agg_time:.4f}s") agg_size = sum(len(p) for p in aggregated_params) # Distribute back to trainers decryption_start = time.time() decrypt_refs = [ trainer.load_encrypted_params.remote( (aggregated_params, metadata), current_global_epoch ) for trainer in self.trainers ] decryption_times = ray.get(decrypt_refs) round_metrics = { "encryption_time": encryption_time, "decryption_times": decryption_times, "aggregation_time": agg_time, "upload_size": sum(enc_sizes), "download_size": agg_size * len(self.trainers), } self.aggregation_stats.append(round_metrics) else: # normal training logic # print( # f"Training round: {current_global_epoch}, sampling rate: {sample_ratio}" # ) assert 0 < sample_ratio <= 1, "Sample ratio must be between 0 and 1" num_samples = int(self.num_of_trainers * sample_ratio) if sampling_type == "random": selected_trainers_indices = random.sample( range(self.num_of_trainers), num_samples ) elif sampling_type == "uniform": selected_trainers_indices = [ ( i + int(self.num_of_trainers * sample_ratio) * current_global_epoch ) % self.num_of_trainers for i in range(num_samples) ] else: raise ValueError("sampling_type must be either 'random' or 'uniform'") for trainer_idx in selected_trainers_indices: self.trainers[trainer_idx].train.remote(current_global_epoch) params = [ self.trainers[trainer_idx].get_params.remote() for trainer_idx in selected_trainers_indices ] self.zero_params() self.model = self.model.to("cpu") while True: ready, left = ray.wait(params, num_returns=1, timeout=None) if ready: for t in ready: for p, mp in zip(ray.get(t), self.model.parameters()): mp.data += p.cpu() params = left if not params: break self.model = self.model.to(self.device) for p in self.model.parameters(): p /= num_samples self.broadcast_params(current_global_epoch)
[docs] def broadcast_params(self, current_global_epoch: int) -> None: """ Broadcasts the current parameters of the central model to all trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ for trainer in self.trainers: trainer.update_params.remote( tuple(self.model.parameters()), current_global_epoch ) # run in submit order
[docs] class Server_GC: """ This is a server class for federated graph classification which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- model: torch.nn.Module The base model that the federated learning is performed on. device: torch.device The device to run the model on. Attributes ---------- model: torch.nn.Module The base model for the server. W: dict Dictionary containing the model parameters. model_cache: list List of tuples, where each tuple contains the model parameters and the accuracies of the trainers. """ def __init__( self, model: torch.nn.Module, device: torch.device, use_cluster: bool ) -> None: self.model = model.to(device) self.W = {key: value for key, value in self.model.named_parameters()} self.model_cache: Any = [] self.use_cluster = use_cluster ########### Public functions ###########
[docs] def random_sample_trainers(self, all_trainers: list, frac: float) -> list: """ Randomly sample a fraction of trainers. Parameters ---------- all_trainers: list list of trainer objects frac: float fraction of trainers to be sampled Returns ------- (sampled_trainers): list list of trainer objects """ return random.sample(all_trainers, int(len(all_trainers) * frac))
[docs] def aggregate_weights(self, selected_trainers: list) -> None: """ Perform weighted aggregation among selected trainers. The weights are the number of training samples. Parameters ---------- selected_trainers: list list of trainer objects """ total_size = 0 size_refs = [trainer.get_train_size.remote() for trainer in selected_trainers] while size_refs: ready, left = ray.wait(size_refs, num_returns=1, timeout=None) if ready: for t in ready: total_size += ray.get(t) size_refs = left for k in self.W.keys(): # pass train_size, and weighted aggregate accumulate_list = [] acc_refs = [] for trainer in selected_trainers: acc_ref = trainer.calculate_weighted_weight.remote(k) acc_refs.append(acc_ref) while acc_refs: ready, left = ray.wait(acc_refs, num_returns=1, timeout=None) if ready: for t in ready: weighted_weight = ray.get(t) accumulate_list.append(weighted_weight) acc_refs = left accumulate = torch.stack(accumulate_list) self.W[k].data = torch.div(torch.sum(accumulate, dim=0), total_size).clone()
[docs] def compute_pairwise_similarities(self, trainers: list) -> np.ndarray: """ This function computes the pairwise cosine similarities between the gradients of the trainers. Parameters ---------- trainers: list list of trainer objects Returns ------- np.ndarray 2D np.ndarray of shape len(trainers) * len(trainers), which contains the pairwise cosine similarities """ trainer_dWs = [] for trainer in trainers: dW = {} for k in self.W.keys(): trainer_dW = ray.get(trainer.get_dW.remote()) dW[k] = trainer_dW[k] trainer_dWs.append(dW) return self.__pairwise_angles(trainer_dWs)
[docs] def compute_pairwise_distances( self, seqs: list, standardize: bool = False ) -> np.ndarray: """ This function computes the pairwise distances between the gradient norm sequences of the trainers. Parameters ---------- seqs: list list of 1D np.ndarray, where each 1D np.ndarray contains the gradient norm sequence of a trainer standardize: bool whether to standardize the distance matrix Returns ------- distances: np.ndarray 2D np.ndarray of shape len(seqs) * len(seqs), which contains the pairwise distances """ if standardize: # standardize to only focus on the trends seqs = np.array(seqs) seqs = seqs / np.std(seqs, axis=1).reshape(-1, 1) distances = dtw.distance_matrix(seqs) else: distances = dtw.distance_matrix(seqs) return distances
[docs] def min_cut(self, similarity: np.ndarray, idc: list) -> tuple: """ This function computes the minimum cut of the graph defined by the pairwise cosine similarities. Parameters ---------- similarity: np.ndarray 2D np.ndarray of shape len(trainers) * len(trainers), which contains the pairwise cosine similarities idc: list list of trainer indices Returns ------- (c1, c2): tuple tuple of two lists, where each list contains the indices of the trainers in a cluster """ g = nx.Graph() for i in range(len(similarity)): for j in range(len(similarity)): g.add_edge(i, j, weight=similarity[i][j]) _, partition = nx.stoer_wagner( g ) # using Stoer-Wagner algorithm to find the minimum cut c1 = np.array([idc[x] for x in partition[0]]) c2 = np.array([idc[x] for x in partition[1]]) return c1, c2
[docs] def aggregate_clusterwise(self, trainer_clusters: list) -> None: """ Perform weighted aggregation among the trainers in each cluster. The weights are the number of training samples. Parameters ---------- trainer_clusters: list list of cluster-specified trainer groups, where each group contains the trainer objects in a cluster """ ks = self.W.keys() for cluster in trainer_clusters: # cluster is a list of trainer objects weights_list = ray.get( [trainer.get_weights.remote(ks) for trainer in cluster] ) # Unpack the list of dictionaries into separate lists for targs, sours, and train_sizes targs = [weights["W"] for weights in weights_list] sours = [(weights["dW"], weights["train_size"]) for weights in weights_list] total_size = sum([weights["train_size"] for weights in weights_list]) # pass train_size, and weighted aggregate self.__reduce_add_average( targets=targs, sources=sours, total_size=total_size )
[docs] def compute_max_update_norm(self, cluster: list) -> float: """ Compute the maximum update norm (i.e., dW) among the trainers in the cluster. This function is used to determine whether the cluster is ready to be split. Parameters ---------- cluster: list list of trainer objects """ dw_refs = [] for trainer in cluster: dw_ref = trainer.compute_update_norm.remote(self.W.keys()) dw_refs.append(dw_ref) results = ray.get(dw_refs) max_dW = max(results) return max_dW
[docs] def compute_mean_update_norm(self, cluster: list) -> float: """ Compute the mean update norm (i.e., dW) among the trainers in the cluster. This function is used to determine whether the cluster is ready to be split. Parameters ---------- cluster: list list of trainer objects """ dw_refs = [] total_size = sum(ray.get([c.get_train_size.remote() for c in cluster])) for trainer in cluster: dw_ref = trainer.compute_mean_norm.remote(total_size, self.W.keys()) dw_refs.append(dw_ref) cluster_dWs = ray.get(dw_refs) return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()
[docs] def cache_model(self, idcs: list, params: dict, accuracies: list) -> None: """ Cache the model parameters and accuracies of the trainers. Parameters ---------- idcs: list list of trainer indices params: dict dictionary containing the model parameters of the trainers accuracies: list list of accuracies of the trainers """ self.model_cache += [ ( idcs, {name: params[name].data.clone() for name in params}, [accuracies[i] for i in idcs], ) ]
########### Private functions ########### def __pairwise_angles(self, sources: list) -> np.ndarray: """ Compute the pairwise cosine similarities between the gradients of the trainers into a 2D matrix. Parameters ---------- sources: list list of dictionaries, where each dictionary contains the gradients of a trainer Returns ------- np.ndarray 2D np.ndarray of shape len(sources) * len(sources), which contains the pairwise cosine similarities """ angles = torch.zeros([len(sources), len(sources)]) for i, source1 in enumerate(sources): for j, source2 in enumerate(sources): s1 = self.__flatten(source1) s2 = self.__flatten(source2) angles[i, j] = ( torch.true_divide( torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12) ) + 1 ) return angles.numpy() def __flatten(self, source: dict) -> torch.Tensor: """ Flatten the gradients of a trainer into a 1D tensor. Parameters ---------- source: dict dictionary containing the gradients of a trainer Returns ------- (flattend_gradients): torch.Tensor 1D tensor containing the flattened gradients """ return torch.cat([value.flatten() for value in source.values()]) def __reduce_add_average( self, targets: list, sources: list, total_size: int ) -> None: """ Perform weighted aggregation from the sources to the targets. The weights are the number of training samples. Parameters ---------- targets: list list of dictionaries, where each dictionary contains the model parameters of a trainer sources: list list of tuples, where each tuple contains the gradients and the number of training samples of a trainer total_size: int total number of training samples """ for target in targets: for name in target: weighted_stack = torch.stack( [torch.mul(source[0][name].data, source[1]) for source in sources] ) tmp = torch.div(torch.sum(weighted_stack, dim=0), total_size).clone() target[name].data += tmp
[docs] class Server_LP: """ This is a server class for federated graph link prediction which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- number_of_users: int The number of users in the dataset. number_of_items: int The number of items in the dataset. meta_data: dict Dictionary containing the meta data of the dataset. args_cuda: bool Whether to run the model on GPU. """ def __init__( self, number_of_users: int, number_of_items: int, meta_data: tuple, trainers: list, args_cuda: bool = False, ) -> None: self.global_model = GNN_LP( number_of_users, number_of_items, meta_data, hidden_channels=64 ) # create the base model self.global_model = self.global_model.cuda() if args_cuda else self.global_model self.clients = trainers
[docs] def fedavg(self, gnn_only: bool = False) -> dict: """ This function performs federated averaging on the model parameters of the clients. Parameters ---------- clients: list List of client objects gnn_only: bool, optional Whether to get only the GNN parameters Returns ------- model_avg_parameter: dict The averaged model parameters """ local_model_parameters = [ trainer.get_model_parameter.remote(gnn_only) for trainer in self.clients ] # Initialize an empty list to collect the results model_states = [] # Collect the model parameters as they become ready while local_model_parameters: ready, left = ray.wait(local_model_parameters, num_returns=1, timeout=None) if ready: for t in ready: model_states.append(ray.get(t)) local_model_parameters = left model_avg_parameter = self.__average_parameter(model_states) return model_avg_parameter
[docs] def set_model_parameter( self, model_state_dict: dict, gnn_only: bool = False ) -> None: """ Set the model parameters Parameters ---------- model_state_dict: dict The model parameters gnn_only: bool, optional Whether to set only the GNN parameters """ if gnn_only: self.global_model.gnn.load_state_dict(model_state_dict) else: self.global_model.load_state_dict(model_state_dict)
[docs] def get_model_parameter(self, gnn_only: bool = False) -> dict: """ Get the model parameters Parameters ---------- gnn_only: bool Whether to get only the GNN parameters Returns ------- dict The model parameters """ return ( self.global_model.gnn.state_dict() if gnn_only else self.global_model.state_dict() )
# Private functions def __average_parameter(self, states: list) -> dict: """ This function averages the model parameters of the clients. Parameters ---------- states: list List of model parameters Returns ------- global_state: dict The averaged model parameters """ global_state = dict() # Average all parameters for key in states[0]: global_state[key] = states[0][key] # for the first client for i in range(1, len(states)): global_state[key] += states[i][key] global_state[key] /= len(states) # average return global_state