.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/intro_LP.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_intro_LP.py: Federated Link Prediction Example ================ In this tutorial, you will learn the basic workflow of Federated Link Prediction with a runnable example. This tutorial assumes that you have basic familiarity with PyTorch and PyTorch Geometric (PyG). (Time estimate: 20 minutes) .. GENERATED FROM PYTHON SOURCE LINES 11-31 .. code-block:: Python import argparse import copy import os import random import sys from pathlib import Path import attridict import numpy as np import torch import yaml sys.path.append("../fedgraph") sys.path.append("../../") from fedgraph.federated_methods import LP_train_global_round from fedgraph.server_class import Server_LP from fedgraph.trainer_class import Trainer_LP from fedgraph.utils_lp import * .. GENERATED FROM PYTHON SOURCE LINES 32-38 Load configuration and check arguments ------------ Here we load the configuration file for the experiment. The configuration file contains the parameters for the experiment. The algorithm and dataset (represented by the country code) are specified by the user here. We also specify some prechecks to ensure the validity of the arguments. .. GENERATED FROM PYTHON SOURCE LINES 38-55 .. code-block:: Python config_file = "configs/config_LP.yaml" with open(config_file, "r") as file: args = attridict(yaml.safe_load(file)) print(args) global_file_path = os.path.join(args.dataset_path, "data_global.txt") traveled_file_path = os.path.join(args.dataset_path, "traveled_users.txt") assert args.method in ["STFL", "StaticGNN", "4D-FED-GNN+", "FedLink"], "Invalid method." assert all( code in ["US", "BR", "ID", "TR", "JP"] for code in args.country_codes ), "The country codes should be in 'US', 'BR', 'ID', 'TR', 'JP'" if args.use_buffer: assert args.buffer_size > 0, "The buffer size should be greater than 0." .. rst-class:: sphx-glr-script-out .. code-block:: none {'method': 'FedLink', 'country_codes': ['JP'], 'dataset_path': 'data/LPDataset', 'global_file_path': 'data/LPDataset/data_five_countries.txt', 'traveled_file_path': 'data/LPDataset/traveled_users.txt', 'device': 'cuda', 'use_buffer': False, 'buffer_size': 300000, 'online_learning': False, 'seed': 10, 'global_rounds': 20, 'local_steps': 3, 'repeat_time': 10, 'hidden_channels': 64, 'record_results': False} .. GENERATED FROM PYTHON SOURCE LINES 56-62 Generate data ------------ Here we generate the data for the experiment. If the data is already generated, we load the data from the file. Otherwise, we download the data from the website and generate the data. We also create the mappings and meta_data for the data. .. GENERATED FROM PYTHON SOURCE LINES 62-78 .. code-block:: Python check_data_files_existance(args.country_codes, args.dataset_path) ( user_id_mapping, item_id_mapping, ) = get_global_user_item_mapping( # get global user and item mapping global_file_path=global_file_path ) meta_data = ( ["user", "item"], [("user", "select", "item"), ("item", "rev_select", "user")], ) # set meta_data .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading traveled_users from https://drive.google.com/uc?id=1RUsyGrsz4hmY3OA3b-oqyh5yqlks02-p... Downloading... From: https://drive.google.com/uc?id=1RUsyGrsz4hmY3OA3b-oqyh5yqlks02-p To: /home/docs/checkouts/readthedocs.org/user_builds/fedgraph/checkouts/latest/docs/examples/data/LPDataset/traveled_users.txt 0%| | 0.00/552k [00:00= 0.01: print("training is not complete") # go to next day ( start_time, end_time, start_time_float_format, end_time_float_format, ) = to_next_day(start_time=start_time, end_time=end_time, method=args.method) # delete the train and test data of each client client_id = number_of_clients - 1 if not args.use_buffer: del clients[client_id].train_data else: del clients[client_id].global_train_data del clients[client_id].test_data if result_writer is not None and time_writer is not None: result_writer.close() time_writer.close() .. rst-class:: sphx-glr-script-out .. code-block:: none loading train_data and test_data start training 0 client 0 local update 0 loss 0.8362 train time 9.2834 client 0 local update 1 loss 0.6853 train time 8.8130 client 0 local update 2 loss 0.5779 train time 8.7460 Test AUC: 0.7354 Test Hit Rate at 2: 0.8592 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.7353901267051697 hit rate: 0.8592382073402405 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.7353901267051697 hit rate: 0.8592382073402405 1 client 0 local update 0 loss 0.4993 train time 8.7960 client 0 local update 1 loss 0.4378 train time 8.7492 client 0 local update 2 loss 0.3854 train time 8.7488 Test AUC: 0.8016 Test Hit Rate at 2: 0.8987 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8015587329864502 hit rate: 0.898746132850647 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8015587329864502 hit rate: 0.898746132850647 2 client 0 local update 0 loss 0.3393 train time 8.7772 client 0 local update 1 loss 0.2998 train time 8.8556 client 0 local update 2 loss 0.2689 train time 8.7808 Test AUC: 0.8153 Test Hit Rate at 2: 0.9064 Test Traveled User Hit Rate at 2: 1.0000 Day 0 client 0 auc score: 0.8152586817741394 hit rate: 0.9064348340034485 traveled user hit rate: 1.0 Predict Day 20 average auc score: 0.8152586817741394 hit rate: 0.9064348340034485 3 client 0 local update 0 loss 0.2467 train time 8.7660 client 0 local update 1 loss 0.2315 train time 8.7531 client 0 local update 2 loss 0.2207 train time 8.7660 Test AUC: 0.8183 Test Hit Rate at 2: 0.9077 Test Traveled User Hit Rate at 2: 1.0000 Day 0 client 0 auc score: 0.8182907700538635 hit rate: 0.9077360033988953 traveled user hit rate: 1.0 Predict Day 20 average auc score: 0.8182907700538635 hit rate: 0.9077360033988953 4 client 0 local update 0 loss 0.2115 train time 8.7582 client 0 local update 1 loss 0.2026 train time 8.7586 client 0 local update 2 loss 0.1938 train time 8.7626 Test AUC: 0.8206 Test Hit Rate at 2: 0.9082 Test Traveled User Hit Rate at 2: 1.0000 Day 0 client 0 auc score: 0.8205775022506714 hit rate: 0.9082091450691223 traveled user hit rate: 1.0 Predict Day 20 average auc score: 0.8205775022506714 hit rate: 0.9082091450691223 5 client 0 local update 0 loss 0.1858 train time 8.7598 client 0 local update 1 loss 0.1796 train time 8.7653 client 0 local update 2 loss 0.1755 train time 8.7564 Test AUC: 0.8228 Test Hit Rate at 2: 0.9107 Test Traveled User Hit Rate at 2: 1.0000 Day 0 client 0 auc score: 0.822797417640686 hit rate: 0.9106931686401367 traveled user hit rate: 1.0 Predict Day 20 average auc score: 0.822797417640686 hit rate: 0.9106931686401367 6 client 0 local update 0 loss 0.1724 train time 8.7678 client 0 local update 1 loss 0.1688 train time 8.7840 client 0 local update 2 loss 0.1639 train time 8.7483 Test AUC: 0.8245 Test Hit Rate at 2: 0.9131 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8245153427124023 hit rate: 0.9130589365959167 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8245153427124023 hit rate: 0.9130589365959167 7 client 0 local update 0 loss 0.1588 train time 8.7553 client 0 local update 1 loss 0.1548 train time 8.7690 client 0 local update 2 loss 0.1520 train time 8.7415 Test AUC: 0.8257 Test Hit Rate at 2: 0.9141 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8257431983947754 hit rate: 0.9141234755516052 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8257431983947754 hit rate: 0.9141234755516052 8 client 0 local update 0 loss 0.1496 train time 8.7609 client 0 local update 1 loss 0.1468 train time 8.7721 client 0 local update 2 loss 0.1434 train time 8.7582 Test AUC: 0.8270 Test Hit Rate at 2: 0.9148 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8269596099853516 hit rate: 0.9148331880569458 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8269596099853516 hit rate: 0.9148331880569458 9 client 0 local update 0 loss 0.1399 train time 8.7718 client 0 local update 1 loss 0.1368 train time 8.7785 client 0 local update 2 loss 0.1343 train time 8.7508 Test AUC: 0.8281 Test Hit Rate at 2: 0.9153 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8281118869781494 hit rate: 0.9153063893318176 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8281118869781494 hit rate: 0.9153063893318176 10 client 0 local update 0 loss 0.1318 train time 8.7986 client 0 local update 1 loss 0.1287 train time 8.7691 client 0 local update 2 loss 0.1253 train time 8.7561 Test AUC: 0.8291 Test Hit Rate at 2: 0.9161 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8291079998016357 hit rate: 0.9161343574523926 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8291079998016357 hit rate: 0.9161343574523926 11 client 0 local update 0 loss 0.1221 train time 8.7978 client 0 local update 1 loss 0.1193 train time 8.7577 client 0 local update 2 loss 0.1167 train time 8.7848 Test AUC: 0.8300 Test Hit Rate at 2: 0.9157 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8300279378890991 hit rate: 0.9156612157821655 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8300279378890991 hit rate: 0.9156612157821655 12 client 0 local update 0 loss 0.1140 train time 8.7992 client 0 local update 1 loss 0.1112 train time 8.7513 client 0 local update 2 loss 0.1084 train time 8.7827 Test AUC: 0.8309 Test Hit Rate at 2: 0.9159 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.830906867980957 hit rate: 0.915897786617279 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.830906867980957 hit rate: 0.915897786617279 13 client 0 local update 0 loss 0.1059 train time 8.7906 client 0 local update 1 loss 0.1036 train time 8.7490 client 0 local update 2 loss 0.1014 train time 8.7772 Test AUC: 0.8316 Test Hit Rate at 2: 0.9164 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8316266536712646 hit rate: 0.9163709282875061 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8316266536712646 hit rate: 0.9163709282875061 14 client 0 local update 0 loss 0.0991 train time 8.7816 client 0 local update 1 loss 0.0968 train time 8.7995 client 0 local update 2 loss 0.0946 train time 8.8105 Test AUC: 0.8322 Test Hit Rate at 2: 0.9160 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8322358131408691 hit rate: 0.9160161018371582 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8322358131408691 hit rate: 0.9160161018371582 15 client 0 local update 0 loss 0.0926 train time 8.7885 client 0 local update 1 loss 0.0907 train time 8.8183 client 0 local update 2 loss 0.0888 train time 8.7826 Test AUC: 0.8328 Test Hit Rate at 2: 0.9168 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8327579498291016 hit rate: 0.9168440699577332 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8327579498291016 hit rate: 0.9168440699577332 16 client 0 local update 0 loss 0.0868 train time 8.8050 client 0 local update 1 loss 0.0849 train time 8.8081 client 0 local update 2 loss 0.0831 train time 8.7780 Test AUC: 0.8332 Test Hit Rate at 2: 0.9178 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8331815004348755 hit rate: 0.917790412902832 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8331815004348755 hit rate: 0.917790412902832 17 client 0 local update 0 loss 0.0814 train time 8.7943 client 0 local update 1 loss 0.0797 train time 8.7707 client 0 local update 2 loss 0.0780 train time 8.7316 Test AUC: 0.8336 Test Hit Rate at 2: 0.9177 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8335594534873962 hit rate: 0.9176720976829529 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8335594534873962 hit rate: 0.9176720976829529 18 client 0 local update 0 loss 0.0763 train time 8.7535 client 0 local update 1 loss 0.0747 train time 8.7700 client 0 local update 2 loss 0.0732 train time 8.7396 Test AUC: 0.8339 Test Hit Rate at 2: 0.9173 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8339011669158936 hit rate: 0.917317271232605 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8339011669158936 hit rate: 0.917317271232605 19 client 0 local update 0 loss 0.0717 train time 8.7922 client 0 local update 1 loss 0.0702 train time 8.7827 client 0 local update 2 loss 0.0687 train time 8.7516 Test AUC: 0.8341 Test Hit Rate at 2: 0.9173 Test Traveled User Hit Rate at 2: 0.8571 Day 0 client 0 auc score: 0.8341457843780518 hit rate: 0.917317271232605 traveled user hit rate: 0.8571428656578064 Predict Day 20 average auc score: 0.8341457843780518 hit rate: 0.917317271232605 training is not complete .. rst-class:: sphx-glr-timing **Total running time of the script:** (10 minutes 32.792 seconds) .. _sphx_glr_download_tutorials_intro_LP.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: intro_LP.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: intro_LP.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_