Note
Go to the end to download the full example code.
Federated Graph Classification Example
Federated Graph Classification with GCFL+dWs on the MUTAG dataset.
(Time estimate: 3 minutes)
Load libraries
import attridict
from fedgraph.federated_methods import run_fedgraph
Specify the Graph Classification configuration
config = {
"fedgraph_task": "GC",
# General configuration
# algorithm options: "SelfTrain", "FedAvg", "FedProx", "GCFL", "GCFL+", "GCFL+dWs"
"algorithm": "GCFL+dWs",
# Dataset configuration
"dataset": "MUTAG",
"is_multiple_dataset": False,
"datapath": "./data",
"convert_x": False,
"overlap": False,
# Setup configuration
"device": "cpu",
"seed": 10,
"seed_split_data": 42,
# Model parameters
"num_trainers": 2,
"num_rounds": 200, # Used by "FedAvg" and "GCFL" (not used in "SelfTrain")
"local_epoch": 1, # Used by "FedAvg" and "GCFL"
# Specific for "SelfTrain" (used instead of "num_rounds" and "local_epoch")
"local_epoch_selftrain": 200,
"lr": 0.001,
"weight_decay": 0.0005,
"nlayer": 3, # Number of model layers
"hidden": 64, # Hidden layer dimension
"dropout": 0.5, # Dropout rate
"batch_size": 128,
"gpu": False,
"num_cpus_per_trainer": 1,
"num_gpus_per_trainer": 0,
# FedProx specific parameter
"mu": 0.01, # Regularization parameter, only used in "FedProx"
# GCFL specific parameters
"standardize": False, # Used only in "GCFL", "GCFL+", "GCFL+dWs"
"seq_length": 5, # Sequence length, only used in "GCFL", "GCFL+", "GCFL+dWs"
"epsilon1": 0.05, # Privacy epsilon1, specific to "GCFL", "GCFL+", "GCFL+dWs"
"epsilon2": 0.1, # Privacy epsilon2, specific to "GCFL", "GCFL+", "GCFL+dWs"
# Output configuration
"outbase": "./outputs",
"save_files": False,
# Scalability and Cluster Configuration
"use_cluster": False, # Use Kubernetes for scalability if True
}
Run fedgraph method
config = attridict(config)
run_fedgraph(config)
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Processing...
Done!
Dataset name: MUTAG Total number of graphs: 188
using CPU
Done setting up devices.
Running GCFL+dWs ...
(Trainer pid=761) inx: 0
(Trainer pid=761) dataset_trainer_name: 0-MUTAG
(Trainer pid=761) dataloaders: {'train': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c11ba60>, 'val': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c19a770>, 'test': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c1aaa70>}
(Trainer pid=761) num_node_features: 7
(Trainer pid=761) num_graph_labels: 2
(Trainer pid=761) train_size: 74
> Training round 10 finished.
> Training round 20 finished.
> Training round 30 finished.
> Training round 40 finished.
> Training round 50 finished.
> Training round 60 finished.
> Training round 70 finished.
> Training round 80 finished.
> Training round 90 finished.
> Training round 100 finished.
> Training round 110 finished.
> Training round 120 finished.
> Training round 130 finished.
> Training round 140 finished.
> Training round 150 finished.
> Training round 160 finished.
> Training round 170 finished.
> Training round 180 finished.
> Training round 190 finished.
> Training round 200 finished.
test_acc
0-MUTAG 0.7
1-MUTAG 0.6
Average test accuracy: 0.6493333333333333
(Trainer pid=760) inx: 1
(Trainer pid=760) dataset_trainer_name: 1-MUTAG
(Trainer pid=760) dataloaders: {'train': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962baf7ca0>, 'val': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962b9888b0>, 'test': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962b9891b0>}
(Trainer pid=760) num_node_features: 7
(Trainer pid=760) num_graph_labels: 2
(Trainer pid=760) train_size: 76
Total running time of the script: (0 minutes 24.123 seconds)