Note
Go to the end to download the full example code.
Federated Link Prediction Example
Federated Link Prediction with STFL on the Link Prediction dataset.
(Time estimate: 3 minutes)
Load libraries
import os
import attridict
from fedgraph.federated_methods import run_fedgraph
Specify the Link Prediction configuration
BASE_DIR = os.path.dirname(os.path.abspath("."))
DATASET_PATH = os.path.join(
BASE_DIR, "data", "LPDataset"
) # Could be modified based on the user needs
config = {
"fedgraph_task": "LP",
# method = ["STFL", "StaticGNN", "4D-FED-GNN+", "FedLink"]
"method": "STFL",
# Dataset configuration
# country_codes = ['US', 'BR', 'ID', 'TR', 'JP']
"country_codes": ["JP"],
"dataset_path": DATASET_PATH,
# Setup configuration
"device": "cpu",
"use_buffer": False,
"buffer_size": 300000,
"online_learning": False,
"seed": 10,
# Model parameters
"global_rounds": 8,
"local_steps": 3,
"hidden_channels": 64,
# Output configuration
"record_results": False,
# System configuration
"gpu": False,
"num_cpus_per_trainer": 1,
"num_gpus_per_trainer": 0,
"use_cluster": False, # whether use kubernetes for scalability or not
"distribution_type": "average", # the node number distribution among clients
"batch_size": -1, # -1 is full batch
}
Run fedgraph method
config = attridict(config)
run_fedgraph(config)
/home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset not exists, creating directory
Downloading traveled_users from https://drive.google.com/uc?id=1RUsyGrsz4hmY3OA3b-oqyh5yqlks02-p...
Downloading...
From: https://drive.google.com/uc?id=1RUsyGrsz4hmY3OA3b-oqyh5yqlks02-p
To: /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/traveled_users.txt
0%| | 0.00/552k [00:00<?, ?B/s]
100%|██████████| 552k/552k [00:00<00:00, 11.4MB/s]
Downloaded traveled_users to /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/traveled_users.txt
Downloading data_global from https://drive.google.com/uc?id=1CnBlVXqCbfjSswagTci5D7nAqO7laU_J...
Downloading...
From (original): https://drive.google.com/uc?id=1CnBlVXqCbfjSswagTci5D7nAqO7laU_J
From (redirected): https://drive.google.com/uc?id=1CnBlVXqCbfjSswagTci5D7nAqO7laU_J&confirm=t&uuid=c02dfd72-126f-449c-b744-09538a11dc23
To: /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/data_global.txt
0%| | 0.00/278M [00:00<?, ?B/s]
0%| | 524k/278M [00:00<00:59, 4.69MB/s]
4%|▎ | 9.96M/278M [00:00<00:04, 54.3MB/s]
6%|▌ | 17.3M/278M [00:00<00:04, 56.1MB/s]
12%|█▏ | 34.6M/278M [00:00<00:02, 97.9MB/s]
16%|█▌ | 45.1M/278M [00:00<00:02, 86.3MB/s]
22%|██▏ | 61.3M/278M [00:00<00:02, 108MB/s]
26%|██▌ | 72.9M/278M [00:00<00:02, 102MB/s]
32%|███▏ | 90.2M/278M [00:00<00:01, 122MB/s]
37%|███▋ | 103M/278M [00:01<00:01, 123MB/s]
42%|████▏ | 116M/278M [00:01<00:01, 109MB/s]
46%|████▋ | 129M/278M [00:01<00:01, 113MB/s]
53%|█████▎ | 146M/278M [00:01<00:01, 129MB/s]
59%|█████▊ | 163M/278M [00:01<00:00, 139MB/s]
65%|██████▍ | 180M/278M [00:01<00:00, 145MB/s]
70%|███████ | 195M/278M [00:01<00:00, 141MB/s]
76%|███████▋ | 212M/278M [00:01<00:00, 150MB/s]
83%|████████▎ | 230M/278M [00:01<00:00, 154MB/s]
89%|████████▉ | 247M/278M [00:02<00:00, 160MB/s]
95%|█████████▌| 264M/278M [00:02<00:00, 162MB/s]
100%|██████████| 278M/278M [00:02<00:00, 127MB/s]
Downloaded data_global to /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/data_global.txt
Downloading data_JP from https://drive.google.com/uc?id=1IPBW4dRYk52x8TahfBqFOh3GdxoYafJ2...
Downloading...
From: https://drive.google.com/uc?id=1IPBW4dRYk52x8TahfBqFOh3GdxoYafJ2
To: /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/data_JP.txt
0%| | 0.00/28.7M [00:00<?, ?B/s]
2%|▏ | 524k/28.7M [00:00<00:06, 4.37MB/s]
29%|██▉ | 8.39M/28.7M [00:00<00:00, 44.4MB/s]
67%|██████▋ | 19.4M/28.7M [00:00<00:00, 54.3MB/s]
100%|██████████| 28.7M/28.7M [00:00<00:00, 62.5MB/s]
Downloaded data_JP to /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/data_JP.txt
gpu not detected
start training
global rounds: 0
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) checking code and file path: JP,/home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset
(Trainer pid=1051) printing in getdata, path: /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset
(Trainer pid=1051) Loading data in /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/stable/data/LPDataset/data_JP.txt
(Trainer pid=1051) Device: 'cpu'
(Trainer pid=1051) loading train_data and test_data
(Trainer pid=1051) client 0 local steps 0 loss 0.7621 train time 6.3499
(Trainer pid=1051) client 0 local steps 1 loss 0.6572 train time 6.2429
clientId: 0 current_loss: 0.5709986090660095 train_finish_times: [6.349920749664307, 6.242884635925293, 6.108320713043213]
(Trainer pid=1051) client 0 local steps 2 loss 0.5710 train time 6.1083
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.7278755307197571 hit rate: 0.8611308336257935 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.7278755307197571 hit rate: 0.8611308336257935 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.7278755307197571 hit rate: 0.8611308336257935
global rounds: 1
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.7279
(Trainer pid=1051) Test Hit Rate at 2: 0.8611
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.4961 train time 6.2524
(Trainer pid=1051) client 0 local steps 1 loss 0.4308 train time 6.0702
clientId: 0 current_loss: 0.37478575110435486 train_finish_times: [6.252429962158203, 6.070228099822998, 6.064333200454712]
(Trainer pid=1051) client 0 local steps 2 loss 0.3748 train time 6.0643
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.808786153793335 hit rate: 0.9076176881790161 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.808786153793335 hit rate: 0.9076176881790161 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.808786153793335 hit rate: 0.9076176881790161
global rounds: 2
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8088
(Trainer pid=1051) Test Hit Rate at 2: 0.9076
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.3275 train time 6.1146
(Trainer pid=1051) client 0 local steps 1 loss 0.2882 train time 6.1155
clientId: 0 current_loss: 0.25669851899147034 train_finish_times: [6.1146180629730225, 6.115481615066528, 6.106277227401733]
(Trainer pid=1051) client 0 local steps 2 loss 0.2567 train time 6.1063
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8187956213951111 hit rate: 0.9138869047164917 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8187956213951111 hit rate: 0.9138869047164917 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8187956213951111 hit rate: 0.9138869047164917
global rounds: 3
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8188
(Trainer pid=1051) Test Hit Rate at 2: 0.9139
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.2326 train time 6.2146
(Trainer pid=1051) client 0 local steps 1 loss 0.2150 train time 6.0485
clientId: 0 current_loss: 0.2025657594203949 train_finish_times: [6.214552879333496, 6.048479080200195, 6.112559795379639]
(Trainer pid=1051) client 0 local steps 2 loss 0.2026 train time 6.1126
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8214533925056458 hit rate: 0.914951503276825 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8214533925056458 hit rate: 0.914951503276825 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8214533925056458 hit rate: 0.914951503276825
global rounds: 4
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8215
(Trainer pid=1051) Test Hit Rate at 2: 0.9150
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.1937 train time 6.1005
(Trainer pid=1051) client 0 local steps 1 loss 0.1871 train time 6.1238
clientId: 0 current_loss: 0.18176108598709106 train_finish_times: [6.100545883178711, 6.12384819984436, 6.105283975601196]
(Trainer pid=1051) client 0 local steps 2 loss 0.1818 train time 6.1053
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8232846856117249 hit rate: 0.9168440699577332 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8232846856117249 hit rate: 0.9168440699577332 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8232846856117249 hit rate: 0.9168440699577332
global rounds: 5
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8233
(Trainer pid=1051) Test Hit Rate at 2: 0.9168
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.1770 train time 6.1955
(Trainer pid=1051) client 0 local steps 1 loss 0.1726 train time 6.1975
clientId: 0 current_loss: 0.16818910837173462 train_finish_times: [6.195451259613037, 6.197492361068726, 6.185129165649414]
(Trainer pid=1051) client 0 local steps 2 loss 0.1682 train time 6.1851
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8249557018280029 hit rate: 0.9183818101882935 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8249557018280029 hit rate: 0.9183818101882935 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8249557018280029 hit rate: 0.9183818101882935
global rounds: 6
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8250
(Trainer pid=1051) Test Hit Rate at 2: 0.9184
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.1639 train time 6.1090
(Trainer pid=1051) client 0 local steps 1 loss 0.1597 train time 6.1722
clientId: 0 current_loss: 0.15554769337177277 train_finish_times: [6.109033584594727, 6.172162771224976, 6.105194807052612]
(Trainer pid=1051) client 0 local steps 2 loss 0.1555 train time 6.1052
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8264950513839722 hit rate: 0.918618381023407 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8264950513839722 hit rate: 0.918618381023407 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8264950513839722 hit rate: 0.918618381023407
global rounds: 7
Training in LP_train_global_round, number of clients: 1
(Trainer pid=1051) Test AUC: 0.8265
(Trainer pid=1051) Test Hit Rate at 2: 0.9186
(Trainer pid=1051) Test Traveled User Hit Rate at 2: 1.0000
(Trainer pid=1051) client 0 local steps 0 loss 0.1516 train time 6.2038
(Trainer pid=1051) client 0 local steps 1 loss 0.1479 train time 6.0647
clientId: 0 current_loss: 0.14428174495697021 train_finish_times: [6.203776121139526, 6.064650058746338, 6.149558067321777]
(Trainer pid=1051) client 0 local steps 2 loss 0.1443 train time 6.1496
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) auc score: 0.8278848528862 hit rate: 0.9189732670783997 traveled user hit rate: 1.0
Day 0 client Actor(run_LP.<locals>.setup_trainer_server.<locals>.Trainer, 8dfffc5d4b541b2dbcc5c18601000000) final auc score: 0.8278848528862 hit rate: 0.9189732670783997 traveled user hit rate: 1.0
Predict Day 20 average auc score: 0.8278848528862 hit rate: 0.9189732670783997
training is not complete
The whole process has ended
Total running time of the script: (3 minutes 40.697 seconds)