Source code for fedgraph.server_class

from typing import Any

import ray
import torch

from fedgraph.gnn_models import GCN, AggreGCN, GCN_arxiv, SAGE_products
from fedgraph.trainer_class import Trainer_General


[docs] class Server: """ This is a server class for federated learning which is responsible for aggregating model parameters from different clients, updating the central model, and then broadcasting the updated model parameters back to the trainers. Parameters ---------- feature_dim : int The dimensionality of the feature vectors in the dataset. args_hidden : int The number of hidden units. class_num : int The number of classes for classification in the dataset. device : torch.device The device initialized for the server model. trainers : list[Trainer_General] A list of `Trainer_General` instances representing the trainers. args : Any Additional arguments required for initializing the server model and other configurations. Attributes ---------- model : [AggreGCN, GCN_arxiv, SAGE_products, GCN] The central GCN model that is trained in a federated manner. trainers : list[Trainer_General] The list of trainer instances. num_of_trainers : int The number of trainers. """ def __init__( self, feature_dim: int, args_hidden: int, class_num: int, device: torch.device, trainers: list, args: Any, ) -> None: # server model on cpu if args.num_hops >= 1 and args.fedtype == "fedgcn": self.model = AggreGCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) else: if args.dataset == "ogbn-arxiv": self.model = GCN_arxiv( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) elif args.dataset == "ogbn-products": self.model = SAGE_products( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) else: self.model = GCN( nfeat=feature_dim, nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ) self.trainers = trainers self.num_of_trainers = len(trainers) self.broadcast_params(-1)
[docs] @torch.no_grad() def zero_params(self) -> None: """ Zeros out the parameters of the central model. """ for p in self.model.parameters(): p.zero_()
[docs] @torch.no_grad() def train(self, current_global_epoch: int) -> None: """ Training round which perform aggregating parameters from trainers, updating the central model, and then broadcasting the updated parameters back to the trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ for trainer in self.trainers: trainer.train.remote(current_global_epoch) params = [trainer.get_params.remote() for trainer in self.trainers] self.zero_params() while True: ready, left = ray.wait(params, num_returns=1, timeout=None) if ready: for t in ready: for p, mp in zip(ray.get(t), self.model.parameters()): mp.data += p.cpu() params = left if not params: break for p in self.model.parameters(): p /= self.num_of_trainers self.broadcast_params(current_global_epoch)
[docs] def broadcast_params(self, current_global_epoch: int) -> None: """ Broadcasts the current parameters of the central model to all trainers. Parameters ---------- current_global_epoch : int The current global epoch number during the federated learning process. """ for trainer in self.trainers: trainer.update_params.remote( tuple(self.model.parameters()), current_global_epoch ) # run in submit order