Source code for fedgraph.server_class

import random
from typing import Any

import networkx as nx
import numpy as np
import ray
import torch
from dtaidistance import dtw

from fedgraph.gnn_models import GCN, GNN_LP, AggreGCN, GCN_arxiv, SAGE_products


[docs] class Server: """ This is a server class for federated learning which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- feature_dim : int The dimensionality of the feature vectors in the dataset. args_hidden : int The number of hidden units. class_num : int The number of classes for classification in the dataset. device : torch.device The device initialized for the server model. trainers : list[Trainer_General] A list of `Trainer_General` instances representing the trainers. args : Any Additional arguments required for initializing the server model and other configurations. Attributes ---------- model : [AggreGCN, GCN_arxiv, SAGE_products, GCN] The central GCN model that is trained in a federated manner. trainers : list[Trainer_General] The list of trainer instances. num_of_trainers : int The number of trainers. """ def __init__( self, feature_dim: int, args_hidden: int, class_num: int, device: torch.device, trainers: list, args: Any, ) -> None: # server model on cpu if args.num_hops >= 1 and args.fedtype == "fedgcn": self.model = AggreGCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) else: if args.dataset == "ogbn-arxiv": self.model = GCN_arxiv( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) elif args.dataset == "ogbn-products": self.model = SAGE_products( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) else: self.model = GCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) self.trainers = trainers self.num_of_trainers = len(trainers) self.broadcast_params(-1)
[docs] @torch.no_grad() def zero_params(self) -> None: """ Zeros out the parameters of the central model. """ for p in self.model.parameters(): p.zero_()
[docs] @torch.no_grad() def train(self, current_global_epoch: int) -> None: """ Training round which perform aggregating parameters from trainers, updating the central model, and then broadcasting the updated parameters back to the trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ for trainer in self.trainers: trainer.train.remote(current_global_epoch) params = [trainer.get_params.remote() for trainer in self.trainers] self.zero_params() while True: ready, left = ray.wait(params, num_returns=1, timeout=None) if ready: for t in ready: for p, mp in zip(ray.get(t), self.model.parameters()): mp.data += p.cpu() params = left if not params: break for p in self.model.parameters(): p /= self.num_of_trainers self.broadcast_params(current_global_epoch)
[docs] def broadcast_params(self, current_global_epoch: int) -> None: """ Broadcasts the current parameters of the central model to all trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ for trainer in self.trainers: trainer.update_params.remote( tuple(self.model.parameters()), current_global_epoch ) # run in submit order
[docs] class Server_GC: """ This is a server class for federated graph classification which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- model: torch.nn.Module The base model that the federated learning is performed on. device: torch.device The device to run the model on. Attributes ---------- model: torch.nn.Module The base model for the server. W: dict Dictionary containing the model parameters. model_cache: list List of tuples, where each tuple contains the model parameters and the accuracies of the trainers. """ def __init__(self, model: torch.nn.Module, device: torch.device) -> None: self.model = model.to(device) self.W = {key: value for key, value in self.model.named_parameters()} self.model_cache: Any = [] ########### Public functions ###########
[docs] def random_sample_trainers(self, all_trainers: list, frac: float) -> list: """ Randomly sample a fraction of trainers. Parameters ---------- all_trainers: list list of trainer objects frac: float fraction of trainers to be sampled Returns ------- (sampled_trainers): list list of trainer objects """ return random.sample(all_trainers, int(len(all_trainers) * frac))
[docs] def aggregate_weights(self, selected_trainers: list) -> None: """ Perform weighted aggregation among selected trainers. The weights are the number of training samples. Parameters ---------- selected_trainers: list list of trainer objects """ total_size = 0 for trainer in selected_trainers: total_size += trainer.train_size for k in self.W.keys(): # pass train_size, and weighted aggregate accumulate = torch.stack( [ torch.mul(trainer.W[k].data, trainer.train_size) for trainer in selected_trainers ] ) self.W[k].data = torch.div(torch.sum(accumulate, dim=0), total_size).clone()
[docs] def compute_pairwise_similarities(self, trainers: list) -> np.ndarray: """ This function computes the pairwise cosine similarities between the gradients of the trainers. Parameters ---------- trainers: list list of trainer objects Returns ------- np.ndarray 2D np.ndarray of shape len(trainers) * len(trainers), which contains the pairwise cosine similarities """ trainer_dWs = [] for trainer in trainers: dW = {} for k in self.W.keys(): dW[k] = trainer.dW[k] trainer_dWs.append(dW) return self.__pairwise_angles(trainer_dWs)
[docs] def compute_pairwise_distances( self, seqs: list, standardize: bool = False ) -> np.ndarray: """ This function computes the pairwise distances between the gradient norm sequences of the trainers. Parameters ---------- seqs: list list of 1D np.ndarray, where each 1D np.ndarray contains the gradient norm sequence of a trainer standardize: bool whether to standardize the distance matrix Returns ------- distances: np.ndarray 2D np.ndarray of shape len(seqs) * len(seqs), which contains the pairwise distances """ if standardize: # standardize to only focus on the trends seqs = np.array(seqs) seqs = seqs / np.std(seqs, axis=1).reshape(-1, 1) distances = dtw.distance_matrix(seqs) else: distances = dtw.distance_matrix(seqs) return distances
[docs] def min_cut(self, similarity: np.ndarray, idc: list) -> tuple: """ This function computes the minimum cut of the graph defined by the pairwise cosine similarities. Parameters ---------- similarity: np.ndarray 2D np.ndarray of shape len(trainers) * len(trainers), which contains the pairwise cosine similarities idc: list list of trainer indices Returns ------- (c1, c2): tuple tuple of two lists, where each list contains the indices of the trainers in a cluster """ g = nx.Graph() for i in range(len(similarity)): for j in range(len(similarity)): g.add_edge(i, j, weight=similarity[i][j]) _, partition = nx.stoer_wagner( g ) # using Stoer-Wagner algorithm to find the minimum cut c1 = np.array([idc[x] for x in partition[0]]) c2 = np.array([idc[x] for x in partition[1]]) return c1, c2
[docs] def aggregate_clusterwise(self, trainer_clusters: list) -> None: """ Perform weighted aggregation among the trainers in each cluster. The weights are the number of training samples. Parameters ---------- trainer_clusters: list list of cluster-specified trainer groups, where each group contains the trainer objects in a cluster """ for cluster in trainer_clusters: # cluster is a list of trainer objects targs, sours = [], [] total_size = 0 for trainer in cluster: W = {} dW = {} for k in self.W.keys(): W[k] = trainer.W[k] dW[k] = trainer.dW[k] targs.append(W) sours.append((dW, trainer.train_size)) total_size += trainer.train_size # pass train_size, and weighted aggregate self.__reduce_add_average( targets=targs, sources=sours, total_size=total_size )
[docs] def compute_max_update_norm(self, cluster: list) -> float: """ Compute the maximum update norm (i.e., dW) among the trainers in the cluster. This function is used to determine whether the cluster is ready to be split. Parameters ---------- cluster: list list of trainer objects """ max_dW = -np.inf for trainer in cluster: dW = {} for k in self.W.keys(): dW[k] = trainer.dW[k] curr_dW = torch.norm(self.__flatten(dW)).item() max_dW = max(max_dW, curr_dW) return max_dW
[docs] def compute_mean_update_norm(self, cluster: list) -> float: """ Compute the mean update norm (i.e., dW) among the trainers in the cluster. This function is used to determine whether the cluster is ready to be split. Parameters ---------- cluster: list list of trainer objects """ cluster_dWs = [] for trainer in cluster: dW = {} for k in self.W.keys(): # dW[k] = trainer.dW[k] dW[k] = ( trainer.dW[k] * trainer.train_size / sum([c.train_size for c in cluster]) ) cluster_dWs.append(self.__flatten(dW)) return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()
[docs] def cache_model(self, idcs: list, params: dict, accuracies: list) -> None: """ Cache the model parameters and accuracies of the trainers. Parameters ---------- idcs: list list of trainer indices params: dict dictionary containing the model parameters of the trainers accuracies: list list of accuracies of the trainers """ self.model_cache += [ ( idcs, {name: params[name].data.clone() for name in params}, [accuracies[i] for i in idcs], ) ]
########### Private functions ########### def __pairwise_angles(self, sources: list) -> np.ndarray: """ Compute the pairwise cosine similarities between the gradients of the trainers into a 2D matrix. Parameters ---------- sources: list list of dictionaries, where each dictionary contains the gradients of a trainer Returns ------- np.ndarray 2D np.ndarray of shape len(sources) * len(sources), which contains the pairwise cosine similarities """ angles = torch.zeros([len(sources), len(sources)]) for i, source1 in enumerate(sources): for j, source2 in enumerate(sources): s1 = self.__flatten(source1) s2 = self.__flatten(source2) angles[i, j] = ( torch.true_divide( torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12) ) + 1 ) return angles.numpy() def __flatten(self, source: dict) -> torch.Tensor: """ Flatten the gradients of a trainer into a 1D tensor. Parameters ---------- source: dict dictionary containing the gradients of a trainer Returns ------- (flattend_gradients): torch.Tensor 1D tensor containing the flattened gradients """ return torch.cat([value.flatten() for value in source.values()]) def __reduce_add_average( self, targets: list, sources: list, total_size: int ) -> None: """ Perform weighted aggregation from the sources to the targets. The weights are the number of training samples. Parameters ---------- targets: list list of dictionaries, where each dictionary contains the model parameters of a trainer sources: list list of tuples, where each tuple contains the gradients and the number of training samples of a trainer total_size: int total number of training samples """ for target in targets: for name in target: weighted_stack = torch.stack( [torch.mul(source[0][name].data, source[1]) for source in sources] ) tmp = torch.div(torch.sum(weighted_stack, dim=0), total_size).clone() target[name].data += tmp
[docs] class Server_LP: """ This is a server class for federated graph link prediction which is responsible for aggregating model parameters from different trainers, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- number_of_users: int The number of users in the dataset. number_of_items: int The number of items in the dataset. meta_data: dict Dictionary containing the meta data of the dataset. args_cuda: bool Whether to run the model on GPU. """ def __init__( self, number_of_users: int, number_of_items: int, meta_data: tuple, args_cuda: bool = False, ) -> None: self.global_model = GNN_LP( number_of_users, number_of_items, meta_data, hidden_channels=64 ) # create the base model self.global_model = self.global_model.cuda() if args_cuda else self.global_model
[docs] def fedavg(self, clients: list, gnn_only: bool = False) -> dict: """ This function performs federated averaging on the model parameters of the clients. Parameters ---------- clients: list List of client objects gnn_only: bool, optional Whether to get only the GNN parameters Returns ------- model_avg_parameter: dict The averaged model parameters """ model_states = [] for i in range(len(clients)): local_model_parameter = clients[i].get_model_parameter(gnn_only) model_states.append(local_model_parameter) model_avg_parameter = self.__average_parameter(model_states) return model_avg_parameter
[docs] def set_model_parameter( self, model_state_dict: dict, gnn_only: bool = False ) -> None: """ Set the model parameters Parameters ---------- model_state_dict: dict The model parameters gnn_only: bool, optional Whether to set only the GNN parameters """ if gnn_only: self.global_model.gnn.load_state_dict(model_state_dict) else: self.global_model.load_state_dict(model_state_dict)
[docs] def get_model_parameter(self, gnn_only: bool = False) -> dict: """ Get the model parameters Parameters ---------- gnn_only: bool Whether to get only the GNN parameters Returns ------- dict The model parameters """ return ( self.global_model.gnn.state_dict() if gnn_only else self.global_model.state_dict() )
# Private functions def __average_parameter(self, states: list) -> dict: """ This function averages the model parameters of the clients. Parameters ---------- states: list List of model parameters Returns ------- global_state: dict The averaged model parameters """ global_state = dict() # Average all parameters for key in states[0]: global_state[key] = states[0][key] # for the first client for i in range(1, len(states)): global_state[key] += states[i][key] global_state[key] /= len(states) # average return global_state