Note
Go to the end to download the full example code
FedGraph Example#
In this tutorial, you will learn the basic workflow of FedGraph with a runnable example. This tutorial assumes that you have basic familiarity with PyTorch and PyTorch Geometric (PyG).
(Time estimate: 15 minutes)
import argparse
from typing import Any
import numpy as np
import ray
import torch
ray.init()
from fedgraph.data_process import load_data
from fedgraph.server_class import Server
from fedgraph.trainer_class import Trainer_General
from fedgraph.utils import (
get_1hop_feature_sum,
get_in_comm_indexes,
label_dirichlet_partition,
)
np.random.seed(42)
torch.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", default="cora", type=str)
parser.add_argument("-f", "--fedtype", default="fedgcn", type=str)
parser.add_argument("-c", "--global_rounds", default=100, type=int)
parser.add_argument("-i", "--local_step", default=3, type=int)
parser.add_argument("-lr", "--learning_rate", default=0.5, type=float)
parser.add_argument("-n", "--n_trainer", default=2, type=int)
parser.add_argument("-nl", "--num_layers", default=2, type=int)
parser.add_argument("-nhop", "--num_hops", default=2, type=int)
parser.add_argument("-g", "--gpu", action="store_true") # if -g, use gpu
parser.add_argument("-iid_b", "--iid_beta", default=10000, type=float)
parser.add_argument("-l", "--logdir", default="./runs", type=str)
args = parser.parse_args()
Data Loading#
FedGraph use torch_geometric.data.Data
to handle the data. Here, we
use Cora, a PyG built-in dataset, as an example. To load your own
dataset into FedGraph, you can simply load your data
into “features, adj, labels, idx_train, idx_val, idx_test”.
Or you can create dataset in PyG. Please refer to creating your own datasets
tutorial in PyG.
features, adj, labels, idx_train, idx_val, idx_test = load_data(args.dataset)
class_num = labels.max().item() + 1
if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]:
args_hidden = 16
else:
args_hidden = 256
row, col, edge_attr = adj.coo()
edge_index = torch.stack([row, col], dim=0)
num_cpus_per_client = 1
# specifying a target GPU
if args.gpu:
device = torch.device("cuda")
edge_index = edge_index.to("cuda:0")
num_gpus_per_client = 1
else:
device = torch.device("cpu")
num_gpus_per_client = 0
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Split Graph for Federated Learning#
FedGraph currents has two partition methods: label_dirichlet_partition and community_partition_non_iid to split the large graph into multiple trainers
split_node_indexes = label_dirichlet_partition(
labels, len(labels), class_num, args.n_trainer, beta=args.iid_beta
)
for i in range(args.n_trainer):
split_node_indexes[i] = np.array(split_node_indexes[i])
split_node_indexes[i].sort()
split_node_indexes[i] = torch.tensor(split_node_indexes[i])
(
communicate_node_indexes,
in_com_train_node_indexes,
in_com_test_node_indexes,
edge_indexes_clients,
) = get_in_comm_indexes(
edge_index,
split_node_indexes,
args.n_trainer,
args.num_hops,
idx_train,
idx_test,
)
Define and Send Data to Trainers#
FedGraph first determines the resources for each trainer, then send the data to each remote trainer.
@ray.remote(
num_gpus=num_gpus_per_client,
num_cpus=num_cpus_per_client,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_General):
def __init__(self, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
trainers = [
Trainer.remote( # type: ignore
rank=i,
local_node_index=split_node_indexes[i],
communicate_node_index=communicate_node_indexes[i],
adj=edge_indexes_clients[i],
train_labels=labels[communicate_node_indexes[i]][in_com_train_node_indexes[i]],
test_labels=labels[communicate_node_indexes[i]][in_com_test_node_indexes[i]],
features=features[split_node_indexes[i]],
idx_train=in_com_train_node_indexes[i],
idx_test=in_com_test_node_indexes[i],
args_hidden=args_hidden,
global_node_num=len(features),
class_num=class_num,
device=device,
args=args,
)
for i in range(args.n_trainer)
]
Define Server#
Server class is defined for federated aggregation (e.g., FedAvg) without knowing the local trainer data
Pre-Train Communication of FedGCN#
Clients send their local feature sum to the server, and the server aggregates all local feature sums and send the global feature sum of specific nodes back to each client.
local_neighbor_feature_sums = [
trainer.get_local_feature_sum.remote() for trainer in server.trainers
]
global_feature_sum = torch.zeros_like(features)
while True:
ready, left = ray.wait(local_neighbor_feature_sums, num_returns=1, timeout=None)
if ready:
for t in ready:
global_feature_sum += ray.get(t)
local_neighbor_feature_sums = left
if not local_neighbor_feature_sums:
break
print("server aggregates all local neighbor feature sums")
# test if aggregation is correct
if args.num_hops != 0:
assert (global_feature_sum != get_1hop_feature_sum(features, edge_index)).sum() == 0
for i in range(args.n_trainer):
server.trainers[i].load_feature_aggregation.remote(
global_feature_sum[communicate_node_indexes[i]]
)
print("clients received feature aggregation from server")
[trainer.relabel_adj.remote() for trainer in server.trainers]
server aggregates all local neighbor feature sums
clients received feature aggregation from server
[ObjectRef(8849b62d89cb30f920a56446c0f420423265c6300100000001000000), ObjectRef(80e22aed7718a1258fc3a191e873299ea42790f90100000001000000)]
Federated Training#
The server start training of all clients and aggregate the parameters at every global round.
print("global_rounds", args.global_rounds)
for i in range(args.global_rounds):
server.train(i)
global_rounds 100
Summarize Experiment Results#
The server collects the local test loss and accuracy from all clients then calculate the overall test loss and accuracy.
train_data_weights = [len(i) for i in in_com_train_node_indexes]
test_data_weights = [len(i) for i in in_com_test_node_indexes]
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_final_test_loss = np.average(
[row[0] for row in results], weights=test_data_weights, axis=0
)
average_final_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
print(average_final_test_loss, average_final_test_accuracy)
ray.shutdown()
0.9552325581908226 0.789
Total running time of the script: (0 minutes 22.001 seconds)