from typing import Any
import ray
import torch
from fedgraph.gnn_models import GCN, AggreGCN, GCN_arxiv, SAGE_products
from fedgraph.trainer_class import Trainer_General
[docs]
class Server:
"""
This is a server class for federated learning which is responsible for aggregating model parameters
from different clients, updating the central model, and then broadcasting the updated model parameters
back to the trainers.
Parameters
----------
feature_dim : int
The dimensionality of the feature vectors in the dataset.
args_hidden : int
The number of hidden units.
class_num : int
The number of classes for classification in the dataset.
device : torch.device
The device initialized for the server model.
trainers : list[Trainer_General]
A list of `Trainer_General` instances representing the trainers.
args : Any
Additional arguments required for initializing the server model and other configurations.
Attributes
----------
model : [AggreGCN, GCN_arxiv, SAGE_products, GCN]
The central GCN model that is trained in a federated manner.
trainers : list[Trainer_General]
The list of trainer instances.
num_of_trainers : int
The number of trainers.
"""
def __init__(
self,
feature_dim: int,
args_hidden: int,
class_num: int,
device: torch.device,
trainers: list,
args: Any,
) -> None:
# server model on cpu
if args.num_hops >= 1 and args.fedtype == "fedgcn":
self.model = AggreGCN(
nfeat=feature_dim,
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
)
else:
if args.dataset == "ogbn-arxiv":
self.model = GCN_arxiv(
nfeat=feature_dim,
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
)
elif args.dataset == "ogbn-products":
self.model = SAGE_products(
nfeat=feature_dim,
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
)
else:
self.model = GCN(
nfeat=feature_dim,
nhid=args_hidden,
nclass=class_num,
dropout=0.5,
NumLayers=args.num_layers,
)
self.trainers = trainers
self.num_of_trainers = len(trainers)
self.broadcast_params(-1)
[docs]
@torch.no_grad()
def zero_params(self) -> None:
"""
Zeros out the parameters of the central model.
"""
for p in self.model.parameters():
p.zero_()
[docs]
@torch.no_grad()
def train(self, current_global_epoch: int) -> None:
"""
Training round which perform aggregating parameters from trainers, updating the central model,
and then broadcasting the updated parameters back to the trainers.
Parameters
----------
current_global_epoch : int
The current global epoch number during the federated learning process.
"""
for trainer in self.trainers:
trainer.train.remote(current_global_epoch)
params = [trainer.get_params.remote() for trainer in self.trainers]
self.zero_params()
while True:
ready, left = ray.wait(params, num_returns=1, timeout=None)
if ready:
for t in ready:
for p, mp in zip(ray.get(t), self.model.parameters()):
mp.data += p.cpu()
params = left
if not params:
break
for p in self.model.parameters():
p /= self.num_of_trainers
self.broadcast_params(current_global_epoch)
[docs]
def broadcast_params(self, current_global_epoch: int) -> None:
"""
Broadcasts the current parameters of the central model to all trainers.
Parameters
----------
current_global_epoch : int
The current global epoch number during the federated learning process.
"""
for trainer in self.trainers:
trainer.update_params.remote(
tuple(self.model.parameters()), current_global_epoch
) # run in submit order