Source code for fedgraph.federated_methods

from typing import Any

import attridict
import numpy as np
import ray
import torch

from fedgraph.server_class import Server
from fedgraph.trainer_class import Trainer_General
from fedgraph.utils import get_1hop_feature_sum


[docs] def FedGCN_Train(args: attridict, data: tuple) -> None: """ Train a FedGCN model. Parameters ---------- args data """ ray.init() ( edge_index, features, labels, idx_train, idx_test, class_num, split_node_indexes, communicate_node_indexes, in_com_train_node_indexes, in_com_test_node_indexes, edge_indexes_clients, ) = data if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]: args_hidden = 16 else: args_hidden = 256 num_cpus_per_client = 1 # specifying a target GPU if args.gpu: device = torch.device("cuda") num_gpus_per_client = 1 else: device = torch.device("cpu") num_gpus_per_client = 0 ####################################################################### # Define and Send Data to Trainers # -------------------------------- # FedGraph first determines the resources for each trainer, then send # the data to each remote trainer. @ray.remote( num_gpus=num_gpus_per_client, num_cpus=num_cpus_per_client, scheduling_strategy="SPREAD", ) class Trainer(Trainer_General): def __init__(self, *args: Any, **kwds: Any): super().__init__(*args, **kwds) trainers = [ Trainer.remote( # type: ignore rank=i, local_node_index=split_node_indexes[i], communicate_node_index=communicate_node_indexes[i], adj=edge_indexes_clients[i], train_labels=labels[communicate_node_indexes[i]][ in_com_train_node_indexes[i] ], test_labels=labels[communicate_node_indexes[i]][ in_com_test_node_indexes[i] ], features=features[split_node_indexes[i]], idx_train=in_com_train_node_indexes[i], idx_test=in_com_test_node_indexes[i], args_hidden=args_hidden, global_node_num=len(features), class_num=class_num, device=device, args=args, ) for i in range(args.n_trainer) ] ####################################################################### # Define Server # ------------- # Server class is defined for federated aggregation (e.g., FedAvg) # without knowing the local trainer data server = Server(features.shape[1], args_hidden, class_num, device, trainers, args) ####################################################################### # Pre-Train Communication of FedGCN # --------------------------------- # Clients send their local feature sum to the server, and the server # aggregates all local feature sums and send the global feature sum # of specific nodes back to each client. local_neighbor_feature_sums = [ trainer.get_local_feature_sum.remote() for trainer in server.trainers ] global_feature_sum = torch.zeros_like(features) while True: ready, left = ray.wait(local_neighbor_feature_sums, num_returns=1, timeout=None) if ready: for t in ready: global_feature_sum += ray.get(t) local_neighbor_feature_sums = left if not local_neighbor_feature_sums: break print("server aggregates all local neighbor feature sums") # test if aggregation is correct if args.num_hops != 0: assert ( global_feature_sum != get_1hop_feature_sum(features, edge_index) ).sum() == 0 for i in range(args.n_trainer): server.trainers[i].load_feature_aggregation.remote( global_feature_sum[communicate_node_indexes[i]] ) print("clients received feature aggregation from server") [trainer.relabel_adj.remote() for trainer in server.trainers] ####################################################################### # Federated Training # ------------------ # The server start training of all clients and aggregate the parameters # at every global round. print("global_rounds", args.global_rounds) for i in range(args.global_rounds): server.train(i) ####################################################################### # Summarize Experiment Results # ---------------------------- # The server collects the local test loss and accuracy from all clients # then calculate the overall test loss and accuracy. train_data_weights = [len(i) for i in in_com_train_node_indexes] test_data_weights = [len(i) for i in in_com_test_node_indexes] results = [trainer.local_test.remote() for trainer in server.trainers] results = np.array([ray.get(result) for result in results]) average_final_test_loss = np.average( [row[0] for row in results], weights=test_data_weights, axis=0 ) average_final_test_accuracy = np.average( [row[1] for row in results], weights=test_data_weights, axis=0 ) print(f"average_final_test_loss, {average_final_test_loss}") print(f"average_final_test_accuracy, {average_final_test_accuracy}") ray.shutdown()