Federated Graph Classification Example

Federated Graph Classification with GCFL+dWs on the MUTAG dataset.

(Time estimate: 3 minutes)

Load libraries

import attridict

from fedgraph.federated_methods import run_fedgraph

Specify the Graph Classification configuration

config = {
    "fedgraph_task": "GC",
    # General configuration
    # algorithm options: "SelfTrain", "FedAvg", "FedProx", "GCFL", "GCFL+", "GCFL+dWs"
    "algorithm": "GCFL+dWs",
    # Dataset configuration
    "dataset": "MUTAG",
    "is_multiple_dataset": False,
    "datapath": "./data",
    "convert_x": False,
    "overlap": False,
    # Setup configuration
    "device": "cpu",
    "seed": 10,
    "seed_split_data": 42,
    # Model parameters
    "num_trainers": 2,
    "num_rounds": 200,  # Used by "FedAvg" and "GCFL" (not used in "SelfTrain")
    "local_epoch": 1,  # Used by "FedAvg" and "GCFL"
    # Specific for "SelfTrain" (used instead of "num_rounds" and "local_epoch")
    "local_epoch_selftrain": 200,
    "lr": 0.001,
    "weight_decay": 0.0005,
    "nlayer": 3,  # Number of model layers
    "hidden": 64,  # Hidden layer dimension
    "dropout": 0.5,  # Dropout rate
    "batch_size": 128,
    "gpu": False,
    "num_cpus_per_trainer": 1,
    "num_gpus_per_trainer": 0,
    # FedProx specific parameter
    "mu": 0.01,  # Regularization parameter, only used in "FedProx"
    # GCFL specific parameters
    "standardize": False,  # Used only in "GCFL", "GCFL+", "GCFL+dWs"
    "seq_length": 5,  # Sequence length, only used in "GCFL", "GCFL+", "GCFL+dWs"
    "epsilon1": 0.05,  # Privacy epsilon1, specific to "GCFL", "GCFL+", "GCFL+dWs"
    "epsilon2": 0.1,  # Privacy epsilon2, specific to "GCFL", "GCFL+", "GCFL+dWs"
    # Output configuration
    "outbase": "./outputs",
    "save_files": False,
    # Scalability and Cluster Configuration
    "use_cluster": False,  # Use Kubernetes for scalability if True
}

Run fedgraph method

config = attridict(config)
run_fedgraph(config)
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Processing...
Done!
Dataset name:  MUTAG  Total number of graphs:  188
using CPU

Done setting up devices.
Running GCFL+dWs ...
(Trainer pid=761) inx: 0
(Trainer pid=761) dataset_trainer_name: 0-MUTAG
(Trainer pid=761) dataloaders: {'train': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c11ba60>, 'val': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c19a770>, 'test': <torch_geometric.loader.dataloader.DataLoader object at 0x7f771c1aaa70>}
(Trainer pid=761) num_node_features: 7
(Trainer pid=761) num_graph_labels: 2
(Trainer pid=761) train_size: 74
  > Training round 10 finished.
  > Training round 20 finished.
  > Training round 30 finished.
  > Training round 40 finished.
  > Training round 50 finished.
  > Training round 60 finished.
  > Training round 70 finished.
  > Training round 80 finished.
  > Training round 90 finished.
  > Training round 100 finished.
  > Training round 110 finished.
  > Training round 120 finished.
  > Training round 130 finished.
  > Training round 140 finished.
  > Training round 150 finished.
  > Training round 160 finished.
  > Training round 170 finished.
  > Training round 180 finished.
  > Training round 190 finished.
  > Training round 200 finished.
         test_acc
0-MUTAG       0.7
1-MUTAG       0.6
Average test accuracy: 0.6493333333333333
(Trainer pid=760) inx: 1
(Trainer pid=760) dataset_trainer_name: 1-MUTAG
(Trainer pid=760) dataloaders: {'train': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962baf7ca0>, 'val': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962b9888b0>, 'test': <torch_geometric.loader.dataloader.DataLoader object at 0x7f962b9891b0>}
(Trainer pid=760) num_node_features: 7
(Trainer pid=760) num_graph_labels: 2
(Trainer pid=760) train_size: 76

Total running time of the script: (0 minutes 24.123 seconds)

Gallery generated by Sphinx-Gallery