import argparse
import copy
import datetime
import os
import pickle
import random
import socket
import sys
import time
from importlib.resources import files
from pathlib import Path
from typing import Any, Dict, List, Optional
import attridict
import numpy as np
import pandas as pd
import ray
import tenseal as ts
import torch
from fedgraph.data_process import data_loader
from fedgraph.gnn_models import GIN
from fedgraph.monitor_class import Monitor
from fedgraph.server_class import Server, Server_GC, Server_LP
from fedgraph.train_func import gc_avg_accuracy
from fedgraph.trainer_class import Trainer_GC, Trainer_General, Trainer_LP
from fedgraph.utils_gc import setup_server, setup_trainers
from fedgraph.utils_lp import (
check_data_files_existance,
get_global_user_item_mapping,
get_start_end_time,
to_next_day,
)
from fedgraph.utils_nc import get_1hop_feature_sum, save_all_trainers_data
try:
from .differential_privacy import Server_DP, Trainer_General_DP
DP_AVAILABLE = True
print("✓ Differential Privacy support loaded")
except ImportError:
DP_AVAILABLE = False
print("⚠️ Differential Privacy not available")
try:
from .low_rank import Server_LowRank, Trainer_General_LowRank
LOWRANK_AVAILABLE = True
except ImportError:
LOWRANK_AVAILABLE = False
[docs]
def run_fedgraph(args: attridict) -> None:
"""
Run the training process for the specified task.
This is the function for running different federated graph learning tasks,
including Node Classification (NC), Graph Classification (GC), and Link Prediction (LP)
in the following functions.
Parameters
----------
args : attridict
Configuration arguments that must include 'fedgraph_task' key with value
in ['NC', 'GC', 'LP'].
data: Any
Input data for the federated learning task. Format depends on the specific task and
will be explained in more detail below inside specific functions.
""" # Validate configuration for low-rank compression
if hasattr(args, "use_lowrank") and args.use_lowrank:
if args.fedgraph_task != "NC":
raise ValueError(
"Low-rank compression currently only supported for NC tasks"
)
if args.method != "FedAvg":
raise ValueError(
"Low-rank compression currently only supported for FedAvg method"
)
if args.use_encryption:
raise ValueError(
"Cannot use both encryption and low-rank compression simultaneously"
)
# Load data
if args.fedgraph_task != "NC" or not args.use_huggingface:
data = data_loader(args)
else:
data = None
if args.fedgraph_task == "NC":
if hasattr(args, "use_lowrank") and args.use_lowrank:
run_NC_lowrank(args, data)
else:
run_NC(args, data)
elif args.fedgraph_task == "GC":
run_GC(args, data)
elif args.fedgraph_task == "LP":
run_LP(args)
[docs]
def run_fedgraph_enhanced(args: attridict) -> None:
"""
Enhanced run function with support for HE, DP, and Low-Rank compression.
"""
# Validate mutually exclusive privacy options
privacy_options = [
getattr(args, "use_encryption", False),
getattr(args, "use_dp", False),
getattr(args, "use_lowrank", False),
]
privacy_count = sum(privacy_options)
if privacy_count > 1:
privacy_names = []
if getattr(args, "use_encryption", False):
privacy_names.append("Homomorphic Encryption")
if getattr(args, "use_dp", False):
privacy_names.append("Differential Privacy")
if getattr(args, "use_lowrank", False):
privacy_names.append("Low-Rank Compression")
raise ValueError(
f"Cannot use multiple privacy/compression methods simultaneously: {', '.join(privacy_names)}"
)
# Print selected method
if getattr(args, "use_encryption", False):
print("=== Using Homomorphic Encryption ===")
elif getattr(args, "use_dp", False):
print("=== Using Differential Privacy ===")
print(
f"DP parameters: ε={getattr(args, 'dp_epsilon', 1.0)}, δ={getattr(args, 'dp_delta', 1e-5)}"
)
elif getattr(args, "use_lowrank", False):
print("=== Using Low-Rank Compression ===")
else:
print("=== Using Standard FedGraph ===")
# Load data
if args.fedgraph_task != "NC" or not args.use_huggingface:
data = data_loader(args)
else:
data = None
# Route to appropriate implementation
if args.fedgraph_task == "NC":
if getattr(args, "use_dp", False):
run_NC_dp(args, data)
elif getattr(args, "use_lowrank", False):
run_NC_lowrank(args, data)
else:
run_NC(args, data) # Original with HE support
elif args.fedgraph_task == "GC":
run_GC(args, data)
elif args.fedgraph_task == "LP":
run_LP(args)
[docs]
def run_NC(args: attridict, data: Any = None) -> None:
"""
Train a Federated Graph Classification model using multiple trainers.
Implements FL for node classification tasks with support of homomorphic encryption.
Use configuration argument "use_encryption" to indicate the boolean flag for
homomorphic encryption or plaintext calculation of feature and/or gradient aggregation
during pre-training and training. Current algorithm that supports encryption includes
'FedAvg' and 'FedGCN'.
Parameters
----------
args: attridict
Configuration arguments
data: tuple
"""
monitor = Monitor(use_cluster=args.use_cluster)
monitor.init_time_start()
ray.init()
start_time = time.time()
torch.manual_seed(42)
pretrain_upload: float = 0.0
pretrain_download: float = 0.0
if args.num_hops == 0:
print("Changing method to FedAvg")
args.method = "FedAvg"
if not args.use_huggingface:
(
edge_index,
features,
labels,
idx_train,
idx_test,
class_num,
split_node_indexes,
communicate_node_global_indexes,
in_com_train_node_local_indexes,
in_com_test_node_local_indexes,
global_edge_indexes_clients,
) = data
if args.saveto_huggingface:
save_all_trainers_data(
split_node_indexes=split_node_indexes,
communicate_node_global_indexes=communicate_node_global_indexes,
global_edge_indexes_clients=global_edge_indexes_clients,
labels=labels,
features=features,
in_com_train_node_local_indexes=in_com_train_node_local_indexes,
in_com_test_node_local_indexes=in_com_test_node_local_indexes,
n_trainer=args.n_trainer,
args=args,
)
if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]:
args_hidden = 16
else:
args_hidden = 256
num_cpus_per_trainer = args.num_cpus_per_trainer
# specifying a target GPU
if args.gpu:
device = torch.device("cuda")
num_gpus_per_trainer = args.num_gpus_per_trainer
else:
device = torch.device("cpu")
num_gpus_per_trainer = 0
#######################################################################
# Define and Send Data to Trainers
# --------------------------------
# FedGraph first determines the resources for each trainer, then send
# the data to each remote trainer.
@ray.remote(
num_gpus=num_gpus_per_trainer,
num_cpus=num_cpus_per_trainer,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_General):
def __init__(self, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
args_obj = kwds.get("args", {})
self.use_encryption = (
getattr(args_obj, "use_encryption", False)
if hasattr(args_obj, "use_encryption")
else args_obj.get("use_encryption", False)
)
if self.use_encryption:
file_path = str(files("fedgraph").joinpath("he_context.pkl"))
with open(file_path, "rb") as f:
context_bytes = pickle.load(f)
self.he_context = ts.context_from(context_bytes)
print(f"Trainer {self.rank} loaded HE context")
def get_memory_usage(self):
"""Get current memory usage and local graph info"""
import psutil
process = psutil.Process()
memory_mb = process.memory_info().rss / (1024 * 1024)
num_nodes = (
len(self.local_node_index) if hasattr(self, "local_node_index") else 0
)
num_edges = (
self.adj.shape[1]
if hasattr(self, "adj") and len(self.adj.shape) > 1
else 0
)
return {
"trainer_id": getattr(self, "rank", "unknown"),
"memory_mb": memory_mb,
"num_nodes": num_nodes,
"num_edges": num_edges,
}
if args.use_huggingface:
trainers = [
Trainer.remote( # type: ignore
rank=i,
args_hidden=args_hidden,
device=device,
args=args,
)
for i in range(args.n_trainer)
]
else: # load from the server
trainers = [
Trainer.remote( # type: ignore
rank=i,
args_hidden=args_hidden,
# global_node_num=len(features),
# class_num=class_num,
device=device,
args=args,
local_node_index=split_node_indexes[i],
communicate_node_index=communicate_node_global_indexes[i],
adj=global_edge_indexes_clients[i],
train_labels=labels[communicate_node_global_indexes[i]][
in_com_train_node_local_indexes[i]
],
test_labels=labels[communicate_node_global_indexes[i]][
in_com_test_node_local_indexes[i]
],
features=features[split_node_indexes[i]],
idx_train=in_com_train_node_local_indexes[i],
idx_test=in_com_test_node_local_indexes[i],
)
for i in range(args.n_trainer)
]
# Retrieve data information from all trainers
trainer_information = [
ray.get(trainers[i].get_info.remote()) for i in range(len(trainers))
]
# Extract necessary details from trainer information
global_node_num = sum([info["features_num"] for info in trainer_information])
class_num = max([info["label_num"] for info in trainer_information])
feature_shape = trainer_information[0]["feature_shape"]
train_data_weights = [
info["len_in_com_train_node_local_indexes"] for info in trainer_information
]
test_data_weights = [
info["len_in_com_test_node_local_indexes"] for info in trainer_information
]
communicate_node_global_indexes = [
info["communicate_node_global_index"] for info in trainer_information
]
ray.get(
[
trainers[i].init_model.remote(global_node_num, class_num)
for i in range(len(trainers))
]
)
#######################################################################
# Define Server
# -------------
# Server class is defined for federated aggregation (e.g., FedAvg)
# without knowing the local trainer data
if args.use_huggingface:
server = Server(feature_shape, args_hidden, class_num, device, trainers, args)
else:
server = Server(
features.shape[1], args_hidden, class_num, device, trainers, args
)
# End initialization time tracking
server.broadcast_params(-1)
monitor.init_time_end()
pretrain_start = time.time()
monitor.pretrain_time_start()
if args.method != "FedAvg":
#######################################################################
# Pre-Train Communication of FedGCN
# ---------------------------------
# Clients send their local feature sum to the server, and the server
# aggregates all local feature sums and send the global feature sum
# of specific nodes back to each trainer.
if args.use_encryption:
print("Starting encrypted feature aggregation...")
encrypted_data = [
trainer.get_encrypted_local_feature_sum.remote()
for trainer in server.trainers
]
results = ray.get(encrypted_data)
encrypted_sums = [(r[0], r[1]) for r in results] # (encrypted_sum, shape)
encryption_times = [r[2] for r in results]
enc_sizes = [len(r[0]) for r in results] # size of encrypted data
# aggregate at server
(
aggregated_result,
aggregation_time,
) = server.aggregate_encrypted_feature_sums(encrypted_sums)
agg_size = len(aggregated_result[0])
load_feature_refs = [
trainer.load_encrypted_feature_aggregation.remote(aggregated_result)
for trainer in server.trainers
]
decryption_times = ray.get(load_feature_refs)
pretrain_time = time.time() - pretrain_start
pretrain_upload = sum(enc_sizes) / (1024 * 1024) # MB
pretrain_download = agg_size * len(server.trainers) / (1024 * 1024) # MB
pretrain_comm_cost = pretrain_upload + pretrain_download
# print performance metrics
print("\nPre-training Phase Metrics:")
print(f"Total Pre-training Time: {pretrain_time:.2f} seconds")
print(f"Pre-training Upload: {pretrain_upload:.2f} MB")
print(f"Pre-training Download: {pretrain_download:.2f} MB")
print(f"Total Pre-training Communication Cost: {pretrain_comm_cost:.2f} MB")
else:
pretrain_upload = 0
pretrain_download = 0
local_neighbor_feature_sums = [
trainer.get_local_feature_sum.remote() for trainer in server.trainers
]
# Record uploaded data sizes
upload_sizes = []
global_feature_sum = torch.zeros_like(features)
while True:
ready, left = ray.wait(
local_neighbor_feature_sums, num_returns=1, timeout=None
)
if ready:
for t in ready:
local_sum = ray.get(t)
global_feature_sum += local_sum
# Calculate size of uploaded data
upload_sizes.append(
local_sum.element_size() * local_sum.nelement()
)
local_neighbor_feature_sums = left
if not local_neighbor_feature_sums:
break
# Calculate total upload size
pretrain_upload = sum(upload_sizes) / (1024 * 1024) # MB
print("server aggregates all local neighbor feature sums")
# TODO: Verify that the aggregated global feature sum matches the true 1-hop feature sum for correctness checking.
# test if aggregation is correct
# if not args.use_huggingface and args.num_hops != 0:
# assert (
# global_feature_sum
# != get_1hop_feature_sum(features, edge_index, device)
# ).sum() == 0
# Calculate and record download sizes
download_sizes = []
for i in range(args.n_trainer):
communicate_nodes = (
communicate_node_global_indexes[i].clone().detach().to(device)
)
trainer_aggregation = global_feature_sum[communicate_nodes]
# Calculate download size for each trainer
download_sizes.append(
trainer_aggregation.element_size() * trainer_aggregation.nelement()
)
server.trainers[i].load_feature_aggregation.remote(trainer_aggregation)
# Calculate total download size
pretrain_download = sum(download_sizes) / (1024 * 1024) # MB
print("clients received feature aggregation from server")
[trainer.relabel_adj.remote() for trainer in server.trainers]
monitor.pretrain_time_end()
monitor.add_pretrain_comm_cost(
upload_mb=pretrain_upload,
download_mb=pretrain_download,
)
monitor.train_time_start()
#######################################################################
# Federated Training
# ------------------
# The server start training of all trainers and aggregate the parameters
# at every global round.
training_start = time.time()
# Time tracking variables for pure training and communication
total_pure_training_time = 0.0 # forward + gradient descent
total_communication_time = 0.0 # parameter aggregation
print("global_rounds", args.global_rounds)
global_acc_list = []
for i in range(args.global_rounds):
# Pure training phase - forward + gradient descent only
pure_training_start = time.time()
# Execute only training (forward + gradient descent)
train_refs = [trainer.train.remote(i) for trainer in server.trainers]
ray.get(train_refs)
pure_training_end = time.time()
round_training_time = pure_training_end - pure_training_start
total_pure_training_time += round_training_time
# Communication phase - parameter aggregation and broadcast
comm_start = time.time()
if args.use_encryption:
# Encrypted parameter aggregation
encrypted_params = [
trainer.get_encrypted_params.remote() for trainer in server.trainers
]
params_list = ray.get(encrypted_params)
# Server-side aggregation
aggregated_params, metadata, _ = server.aggregate_encrypted_params(
params_list
)
# Distribute aggregated parameters
decrypt_refs = [
trainer.load_encrypted_params.remote((aggregated_params, metadata), i)
for trainer in server.trainers
]
ray.get(decrypt_refs)
else:
# Regular parameter aggregation
# Get parameters from all trainers
params_refs = [trainer.get_params.remote() for trainer in server.trainers]
param_results = ray.get(params_refs)
# Aggregate parameters on server - avoid in-place operations
server.zero_params()
# Move model to CPU for aggregation
server.model = server.model.to("cpu")
# Aggregate parameters safely
for param_result in param_results:
for p, mp in zip(param_result, server.model.parameters()):
mp.data = mp.data + p.cpu()
# Move back to device and average
server.model = server.model.to(server.device)
# Average the parameters
with torch.no_grad():
for p in server.model.parameters():
p.data = p.data / len(server.trainers)
# Broadcast updated parameters to all trainers
server.broadcast_params(i)
comm_end = time.time()
round_comm_time = comm_end - comm_start
total_communication_time += round_comm_time
# Testing phase (not counted in training or communication time)
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
global_acc_list.append(average_test_accuracy)
print(f"Round {i+1}: Global Test Accuracy = {average_test_accuracy:.4f}")
print(
f"Round {i+1}: Training Time = {round_training_time:.2f}s, Communication Time = {round_comm_time:.2f}s"
)
model_size_mb = server.get_model_size() / (1024 * 1024)
monitor.add_train_comm_cost(
upload_mb=model_size_mb * args.n_trainer,
download_mb=model_size_mb * args.n_trainer,
)
monitor.train_time_end()
total_time = time.time() - training_start
# Print time breakdown
print(f"\n{'='*80}")
print("TIME BREAKDOWN (excluding initialization)")
print(f"{'='*80}")
print(
f"Total Pure Training Time (forward + gradient descent): {total_pure_training_time:.2f} seconds"
)
print(
f"Total Communication Time (parameter aggregation): {total_communication_time:.2f} seconds"
)
print(f"Total Training + Communication Time: {total_time:.2f} seconds")
print(f"Training Time Percentage: {(total_pure_training_time/total_time)*100:.1f}%")
print(
f"Communication Time Percentage: {(total_communication_time/total_time)*100:.1f}%"
)
print(
f"Average Training Time per Round: {total_pure_training_time/args.global_rounds:.2f} seconds"
)
print(
f"Average Communication Time per Round: {total_communication_time/args.global_rounds:.2f} seconds"
)
print(f"{'='*80}")
# Print for plotting use - now shows pure training time
print(
f"[Pure Training Time] Dataset: {args.dataset}, Batch Size: {args.batch_size}, Trainers: {args.n_trainer}, "
f"Hops: {args.num_hops}, IID Beta: {args.iid_beta} => Pure Training Time = {total_pure_training_time:.2f} seconds"
)
print(
f"[Communication Time] Dataset: {args.dataset}, Batch Size: {args.batch_size}, Trainers: {args.n_trainer}, "
f"Hops: {args.num_hops}, IID Beta: {args.iid_beta} => Communication Time = {total_communication_time:.2f} seconds"
)
if args.use_encryption:
if hasattr(server, "aggregation_stats") and server.aggregation_stats:
training_upload = sum(
[r["upload_size"] for r in server.aggregation_stats]
) / (
1024 * 1024
) # MB
training_download = sum(
[r["download_size"] for r in server.aggregation_stats]
) / (
1024 * 1024
) # MB
else:
training_upload = training_download = 0
training_comm_cost = training_upload + training_download
monitor.add_train_comm_cost(
upload_mb=training_upload,
download_mb=training_download,
)
print("\nTraining Phase Metrics:")
print(
f"Total Training Time: {total_pure_training_time:.2f} seconds"
) # Use pure training time
print(f"Training Upload: {training_upload:.2f} MB")
print(f"Training Download: {training_download:.2f} MB")
print(f"Total Training Communication Cost: {training_comm_cost:.2f} MB")
# Overall totals
total_exec_time = time.time() - start_time
total_upload = pretrain_upload + training_upload
total_download = pretrain_download + training_download
total_comm_cost = total_upload + total_download
print("\nOverall Totals:")
print(f"Total Execution Time: {total_exec_time:.2f} seconds")
print(f"Total Upload: {total_upload:.2f} MB")
print(f"Total Download: {total_download:.2f} MB")
print(f"Total Communication Cost: {total_comm_cost:.2f} MB")
print(f"Pre-training Time %: {(pretrain_time/total_exec_time)*100:.1f}%")
print(f"Training Time %: {(total_pure_training_time/total_exec_time)*100:.1f}%")
print(
f"Communication Time %: {(total_communication_time/total_exec_time)*100:.1f}%"
)
#######################################################################
# Summarize Experiment Results
# ----------------------------
# The server collects the local test loss and accuracy from all trainers
# then calculate the overall test loss and accuracy.
# train_data_weights = [len(i) for i in in_com_train_node_indexes]
# test_data_weights = [len(i) for i in in_com_test_node_indexes]
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_final_test_loss = np.average(
[row[0] for row in results], weights=test_data_weights, axis=0
)
average_final_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
print(f"average_final_test_loss, {average_final_test_loss}")
print(f"Average test accuracy, {average_final_test_accuracy}")
print("\n" + "=" * 80)
print("INDIVIDUAL TRAINER MEMORY USAGE")
print("=" * 80)
memory_stats_refs = [trainer.get_memory_usage.remote() for trainer in trainers]
memory_stats = ray.get(memory_stats_refs)
# Replace the existing memory statistics section with this:
print("\n" + "=" * 100)
print("TRAINER MEMORY vs LOCAL GRAPH SIZE")
print("=" * 100)
print(
f"{'Trainer':<8} {'Memory(MB)':<12} {'Nodes':<8} {'Edges':<8} {'Memory/Node':<12} {'Memory/Edge':<12}"
)
print("-" * 100)
memory_stats_refs = [trainer.get_memory_usage.remote() for trainer in trainers]
memory_stats = ray.get(memory_stats_refs)
total_memory = 0
total_nodes = 0
total_edges = 0
max_memory = 0
min_memory = float("inf")
max_trainer = 0
min_trainer = 0
for stats in memory_stats:
trainer_id = stats["trainer_id"]
memory_mb = stats["memory_mb"]
num_nodes = stats["num_nodes"]
num_edges = stats["num_edges"]
# Calculate memory per node and edge
memory_per_node = memory_mb / num_nodes if num_nodes > 0 else 0
memory_per_edge = memory_mb / num_edges if num_edges > 0 else 0
total_memory += memory_mb
total_nodes += num_nodes
total_edges += num_edges
if memory_mb > max_memory:
max_memory = memory_mb
max_trainer = trainer_id
if memory_mb < min_memory:
min_memory = memory_mb
min_trainer = trainer_id
print(
f"{trainer_id:<8} {memory_mb:<12.1f} {num_nodes:<8} {num_edges:<8} {memory_per_node:<12.3f} {memory_per_edge:<12.3f}"
)
avg_memory = total_memory / len(trainers)
avg_nodes = total_nodes / len(trainers)
avg_edges = total_edges / len(trainers)
print("=" * 100)
print(f"Total Memory Usage: {total_memory:.1f} MB ({total_memory/1024:.2f} GB)")
print(f"Total Nodes: {total_nodes}, Total Edges: {total_edges}")
print(f"Average Memory per Trainer: {avg_memory:.1f} MB")
print(f"Average Nodes per Trainer: {avg_nodes:.1f}")
print(f"Average Edges per Trainer: {avg_edges:.1f}")
print(f"Max Memory: {max_memory:.1f} MB (Trainer {max_trainer})")
print(f"Min Memory: {min_memory:.1f} MB (Trainer {min_trainer})")
print(f"Overall Memory/Node Ratio: {total_memory/total_nodes:.3f} MB/node")
print(f"Overall Memory/Edge Ratio: {total_memory/total_edges:.3f} MB/edge")
print("=" * 100)
if monitor is not None:
monitor.print_comm_cost()
# Calculate required metrics for CSV output
total_exec_time = time.time() - start_time
# Get model size - works in both cluster and local environments
model_size_mb = 0.0
total_params = 0
if hasattr(server, "get_model_size"):
model_size_mb = server.get_model_size() / (1024 * 1024)
elif len(trainers) > 0:
# Fallback: calculate from first trainer's model
trainer_info = (
ray.get(trainers[0].get_info.remote())
if hasattr(trainers[0], "get_info")
else {}
)
if "model_params" in trainer_info:
total_params = trainer_info["model_params"]
model_size_mb = (total_params * 4) / (1024 * 1024) # float32 = 4 bytes
# Get peak memory from existing memory_stats (already collected above)
peak_memory_mb = 0.0
if memory_stats:
peak_memory_mb = max([stats["memory_mb"] for stats in memory_stats])
# Calculate average round time
avg_round_time = (
total_pure_training_time / args.global_rounds if args.global_rounds > 0 else 0.0
)
# Get total communication cost from monitor (works in cluster)
total_comm_cost_mb = 0.0
if monitor:
total_comm_cost_mb = (
monitor.pretrain_theoretical_comm_MB + monitor.train_theoretical_comm_MB
)
# Print CSV format result - compatible with cluster logging
print(f"\n{'='*80}")
print("CSV FORMAT RESULT:")
print(
"DS,IID,BS,TotalTime[s],PureTrainingTime[s],CommTime[s],FinalAcc[%],CommCost[MB],PeakMem[MB],AvgRoundTime[s],ModelSize[MB],TotalParams"
)
print(
f"{args.dataset},{args.iid_beta},{args.batch_size},"
f"{total_exec_time:.1f},"
f"{total_pure_training_time:.1f},"
f"{total_communication_time:.1f},"
f"{average_final_test_accuracy:.2f},"
f"{total_comm_cost_mb:.1f},"
f"{peak_memory_mb:.1f},"
f"{avg_round_time:.3f},"
f"{model_size_mb:.3f},"
f"{total_params}"
)
print("=" * 80)
print(f"\n{'='*80}")
print(f"EXPERIMENT SUMMARY")
print(f"{'='*80}")
print(f"Dataset: {args.dataset}")
print(f"Method: {args.method}")
print(f"Trainers: {args.n_trainer}")
print(f"IID Beta: {args.iid_beta}")
print(f"Batch Size: {args.batch_size}")
print(f"Hops: {args.num_hops}")
print(f"Total Execution Time: {time.time() - start_time:.2f} seconds")
print(f"Pure Training Time: {total_pure_training_time:.2f} seconds")
print(f"Communication Time: {total_communication_time:.2f} seconds")
print(f"Pretrain Comm Cost: {pretrain_upload + pretrain_download:.2f} MB")
print(f"Training Comm Cost: {monitor.train_theoretical_comm_MB:.2f} MB")
if args.use_encryption:
print(f"Total Comm Cost: {total_comm_cost:.2f} MB")
print(f"{'='*80}\n")
ray.shutdown()
[docs]
def run_NC_dp(args: attridict, data: Any = None) -> None:
"""
Enhanced NC training with Differential Privacy support for FedGCN pre-training.
"""
monitor = Monitor(use_cluster=args.use_cluster)
monitor.init_time_start()
ray.init()
start_time = time.time()
torch.manual_seed(42)
pretrain_upload: float = 0.0
pretrain_download: float = 0.0
if args.num_hops == 0:
print("Changing method to FedAvg")
args.method = "FedAvg"
if not args.use_huggingface:
(
edge_index,
features,
labels,
idx_train,
idx_test,
class_num,
split_node_indexes,
communicate_node_global_indexes,
in_com_train_node_local_indexes,
in_com_test_node_local_indexes,
global_edge_indexes_clients,
) = data
if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]:
args_hidden = 16
else:
args_hidden = 256
num_cpus_per_trainer = args.num_cpus_per_trainer
if args.gpu:
device = torch.device("cuda")
num_gpus_per_trainer = args.num_gpus_per_trainer
else:
device = torch.device("cpu")
num_gpus_per_trainer = 0
# Define DP-enhanced trainer class
@ray.remote(
num_gpus=num_gpus_per_trainer,
num_cpus=num_cpus_per_trainer,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_General_DP):
def __init__(self, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
# Create trainers (same as original)
if args.use_huggingface:
trainers = [
Trainer.remote(
rank=i,
args_hidden=args_hidden,
device=device,
args=args,
)
for i in range(args.n_trainer)
]
else:
trainers = [
Trainer.remote(
rank=i,
args_hidden=args_hidden,
device=device,
args=args,
local_node_index=split_node_indexes[i],
communicate_node_index=communicate_node_global_indexes[i],
adj=global_edge_indexes_clients[i],
train_labels=labels[communicate_node_global_indexes[i]][
in_com_train_node_local_indexes[i]
],
test_labels=labels[communicate_node_global_indexes[i]][
in_com_test_node_local_indexes[i]
],
features=features[split_node_indexes[i]],
idx_train=in_com_train_node_local_indexes[i],
idx_test=in_com_test_node_local_indexes[i],
)
for i in range(args.n_trainer)
]
# Get trainer information
trainer_information = [
ray.get(trainers[i].get_info.remote()) for i in range(len(trainers))
]
global_node_num = sum([info["features_num"] for info in trainer_information])
class_num = max([info["label_num"] for info in trainer_information])
train_data_weights = [
info["len_in_com_train_node_local_indexes"] for info in trainer_information
]
test_data_weights = [
info["len_in_com_test_node_local_indexes"] for info in trainer_information
]
communicate_node_global_indexes = [
info["communicate_node_global_index"] for info in trainer_information
]
ray.get(
[
trainers[i].init_model.remote(global_node_num, class_num)
for i in range(len(trainers))
]
)
# Create DP-enhanced server
server = Server_DP(
features.shape[1], args_hidden, class_num, device, trainers, args
)
server.broadcast_params(-1)
monitor.init_time_end()
# DP-enhanced pre-training
pretrain_start = time.time()
monitor.pretrain_time_start()
if args.method != "FedAvg":
print("Starting DP-enhanced feature aggregation...")
# Get local feature sums with DP preprocessing
local_feature_data = [
trainer.get_dp_local_feature_sum.remote() for trainer in server.trainers
]
results = ray.get(local_feature_data)
local_feature_sums = [r[0] for r in results] # Extract tensors
computation_stats = [r[1] for r in results] # Extract stats
# Calculate upload sizes
upload_sizes = [
local_sum.element_size() * local_sum.nelement()
for local_sum in local_feature_sums
]
pretrain_upload = sum(upload_sizes) / (1024 * 1024) # MB
# DP aggregation at server
global_feature_sum, dp_stats = server.aggregate_dp_feature_sums(
local_feature_sums
)
# Print DP statistics
server.print_dp_stats(dp_stats)
# Distribute back to trainers
download_sizes = []
for i in range(args.n_trainer):
communicate_nodes = (
communicate_node_global_indexes[i].clone().detach().to(device)
)
trainer_aggregation = global_feature_sum[communicate_nodes]
download_sizes.append(
trainer_aggregation.element_size() * trainer_aggregation.nelement()
)
server.trainers[i].load_feature_aggregation.remote(trainer_aggregation)
pretrain_download = sum(download_sizes) / (1024 * 1024) # MB
[trainer.relabel_adj.remote() for trainer in server.trainers]
monitor.pretrain_time_end()
monitor.add_pretrain_comm_cost(
upload_mb=pretrain_upload,
download_mb=pretrain_download,
)
# Regular training phase (same as original)
monitor.train_time_start()
print("Starting federated training with DP-enhanced pre-training...")
global_acc_list = []
for i in range(args.global_rounds):
server.train(i)
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
global_acc_list.append(average_test_accuracy)
print(f"Round {i+1}: Global Test Accuracy = {average_test_accuracy:.4f}")
model_size_mb = server.get_model_size() / (1024 * 1024)
monitor.add_train_comm_cost(
upload_mb=model_size_mb * args.n_trainer,
download_mb=model_size_mb * args.n_trainer,
)
monitor.train_time_end()
# Final evaluation
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_final_test_loss = np.average(
[row[0] for row in results], weights=test_data_weights, axis=0
)
average_final_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
print(f"Final test loss: {average_final_test_loss:.4f}")
print(f"Final test accuracy: {average_final_test_accuracy:.4f}")
# Print final privacy budget
if args.use_dp:
server.privacy_accountant.print_privacy_budget()
if monitor is not None:
monitor.print_comm_cost()
ray.shutdown()
[docs]
def run_NC_lowrank(args: attridict, data: Any = None) -> None:
if not LOWRANK_AVAILABLE:
raise ImportError(
"Low-rank compression modules not available. Please implement the low-rank functionality in fedgraph.low_rank"
)
print("=== Running NC with Low-Rank Compression ===")
print(f"Low-rank method: {getattr(args, 'lowrank_method', 'fixed')}")
if hasattr(args, "lowrank_method"):
if args.lowrank_method == "fixed":
print(f"Fixed rank: {getattr(args, 'fixed_rank', 10)}")
elif args.lowrank_method == "adaptive":
print(
f"Target compression ratio: {getattr(args, 'compression_ratio', 2.0)}"
)
elif args.lowrank_method == "energy":
print(f"Energy threshold: {getattr(args, 'energy_threshold', 0.95)}")
monitor = Monitor(use_cluster=args.use_cluster)
monitor.init_time_start()
ray.init()
start_time = time.time()
torch.manual_seed(42)
if args.num_hops == 0:
print("Changing method to FedAvg")
args.method = "FedAvg"
if not args.use_huggingface:
(
edge_index,
features,
labels,
idx_train,
idx_test,
class_num,
split_node_indexes,
communicate_node_global_indexes,
in_com_train_node_local_indexes,
in_com_test_node_local_indexes,
global_edge_indexes_clients,
) = data
if args.saveto_huggingface:
save_all_trainers_data(
split_node_indexes=split_node_indexes,
communicate_node_global_indexes=communicate_node_global_indexes,
global_edge_indexes_clients=global_edge_indexes_clients,
labels=labels,
features=features,
in_com_train_node_local_indexes=in_com_train_node_local_indexes,
in_com_test_node_local_indexes=in_com_test_node_local_indexes,
n_trainer=args.n_trainer,
args=args,
)
# Model configuration
if args.dataset in ["simulate", "cora", "citeseer", "pubmed", "reddit"]:
args_hidden = 16
else:
args_hidden = 256
# Device configuration
num_cpus_per_trainer = args.num_cpus_per_trainer
if args.gpu:
device = torch.device("cuda")
num_gpus_per_trainer = args.num_gpus_per_trainer
else:
device = torch.device("cpu")
num_gpus_per_trainer = 0
@ray.remote(
num_gpus=num_gpus_per_trainer,
num_cpus=num_cpus_per_trainer,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_General_LowRank): # Use low-rank trainer instead
def __init__(self, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
# Create trainers
if args.use_huggingface:
trainers = [
Trainer.remote(
rank=i,
args_hidden=args_hidden,
device=device,
args=args,
)
for i in range(args.n_trainer)
]
else:
trainers = [
Trainer.remote(
rank=i,
args_hidden=args_hidden,
device=device,
args=args,
local_node_index=split_node_indexes[i],
communicate_node_index=communicate_node_global_indexes[i],
adj=global_edge_indexes_clients[i],
train_labels=labels[communicate_node_global_indexes[i]][
in_com_train_node_local_indexes[i]
],
test_labels=labels[communicate_node_global_indexes[i]][
in_com_test_node_local_indexes[i]
],
features=features[split_node_indexes[i]],
idx_train=in_com_train_node_local_indexes[i],
idx_test=in_com_test_node_local_indexes[i],
)
for i in range(args.n_trainer)
]
# Get trainer information
trainer_information = [
ray.get(trainers[i].get_info.remote()) for i in range(len(trainers))
]
global_node_num = sum([info["features_num"] for info in trainer_information])
class_num = max([info["label_num"] for info in trainer_information])
train_data_weights = [
info["len_in_com_train_node_local_indexes"] for info in trainer_information
]
test_data_weights = [
info["len_in_com_test_node_local_indexes"] for info in trainer_information
]
# Initialize models
ray.get(
[
trainers[i].init_model.remote(global_node_num, class_num)
for i in range(len(trainers))
]
)
server = Server_LowRank(
features.shape[1], args_hidden, class_num, device, trainers, args
)
# End initialization
server.broadcast_params(-1)
monitor.init_time_end()
monitor.pretrain_time_start()
monitor.pretrain_time_end()
monitor.train_time_start()
print("Starting federated training with low-rank compression...")
global_acc_list = []
for i in range(args.global_rounds):
server.train(i)
# Evaluation
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
global_acc_list.append(average_test_accuracy)
print(f"Round {i+1}: Global Test Accuracy = {average_test_accuracy:.4f}")
# Communication cost tracking (enhanced with compression-aware sizing)
model_size_mb = server.get_model_size() / (1024 * 1024)
monitor.add_train_comm_cost(
upload_mb=model_size_mb * args.n_trainer,
download_mb=model_size_mb * args.n_trainer,
)
if (i + 1) % 10 == 0 and hasattr(server, "print_compression_stats"):
server.print_compression_stats()
monitor.train_time_end()
# Final evaluation
results = [trainer.local_test.remote() for trainer in server.trainers]
results = np.array([ray.get(result) for result in results])
average_final_test_loss = np.average(
[row[0] for row in results], weights=test_data_weights, axis=0
)
average_final_test_accuracy = np.average(
[row[1] for row in results], weights=test_data_weights, axis=0
)
print(f"Final test loss: {average_final_test_loss:.4f}")
print(f"Final test accuracy: {average_final_test_accuracy:.4f}")
# Print final compression statistics
if hasattr(server, "print_compression_stats"):
server.print_compression_stats()
if monitor is not None:
monitor.print_comm_cost()
ray.shutdown()
[docs]
def run_GC(args: attridict, data: Any) -> None:
"""
Entrance of the training process for graph classification.
Supports multiple federated learning algorithms including FedAvg, FedProx, GCFL,
GCFL+, and GCFL+dWs. Implements client-server architecture with Ray for distributed
computing.
Parameters
----------
args: attridict
The configuration arguments.
data: Any
Dictionary mapping dataset names to their respective graph data including
dataloaders, number of node features, number of graph labels, and train size
base_model: Any
The base model on which the federated learning is based. It applies for both the server and the trainers (default: GIN).
"""
# transfer the config to argparse
#################### set seeds and devices ####################
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, "../fedgraph"))
sys.path.append(os.path.join(current_dir, "../../"))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
monitor = Monitor(use_cluster=args.use_cluster)
monitor.init_time_start()
base_model = GIN
num_cpus_per_trainer = args.num_cpus_per_trainer
# specifying a target GPU
if args.gpu:
print("using GPU")
args.device = torch.device("cuda")
num_gpus_per_trainer = args.num_gpus_per_trainer
else:
print("using CPU")
args.device = torch.device("cpu")
num_gpus_per_trainer = 0
#################### set output directory ####################
# outdir_base = os.path.join(args.outbase, f'seqLen{args.seq_length}')
if args.save_files:
outdir_base = args.outbase + "/" + f"{args.algorithm}"
outdir = os.path.join(outdir_base, f"oneDS-nonOverlap")
if args.algorithm in ["SelfTrain"]:
outdir = os.path.join(outdir, f"{args.dataset}")
elif args.algorithm in ["FedAvg", "FedProx"]:
outdir = os.path.join(outdir, f"{args.dataset}-{args.num_trainers}trainers")
elif args.algorithm in ["GCFL"]:
outdir = os.path.join(
outdir,
f"{args.dataset}-{args.num_trainers}trainers",
f"eps_{args.epsilon1}_{args.epsilon2}",
)
elif args.algorithm in ["GCFL+", "GCFL+dWs"]:
outdir = os.path.join(
outdir,
f"{args.dataset}-{args.num_trainers}trainers",
f"eps_{args.epsilon1}_{args.epsilon2}",
f"seqLen{args.seq_length}",
)
Path(outdir).mkdir(parents=True, exist_ok=True)
print(f"Output Path: {outdir}")
#################### save statistics of data on trainers ####################
# if args.save_files and df_stats:
# outdir_stats = os.path.join(outdir, f"stats_train_data.csv")
# df_stats.to_csv(outdir_stats)
# print(f"The statistics of the data are written to {outdir_stats}")
#################### setup server and trainers ####################
ray.init()
@ray.remote(
num_gpus=num_gpus_per_trainer,
num_cpus=num_cpus_per_trainer,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_GC):
def __init__(self, idx, splited_data, dataset_trainer_name, cmodel_gc, args): # type: ignore
print(f"inx: {idx}")
print(f"dataset_trainer_name: {dataset_trainer_name}")
"""acquire data"""
dataloaders, num_node_features, num_graph_labels, train_size = splited_data
print(f"dataloaders: {dataloaders}")
print(f"num_node_features: {num_node_features}")
print(f"num_graph_labels: {num_graph_labels}")
print(f"train_size: {train_size}")
"""build optimizer"""
optimizer = torch.optim.Adam(
params=filter(lambda p: p.requires_grad, cmodel_gc.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
)
super().__init__( # type: ignore
model=cmodel_gc,
trainer_id=idx,
trainer_name=dataset_trainer_name,
train_size=train_size,
dataloader=dataloaders,
optimizer=optimizer,
args=args,
)
trainers = [
Trainer.remote( # type: ignore
idx=idx,
splited_data=data[dataset_trainer_name],
dataset_trainer_name=dataset_trainer_name,
# "GIN model for GC",
cmodel_gc=base_model(
nfeat=data[dataset_trainer_name][1],
nhid=args.hidden,
nclass=data[dataset_trainer_name][2],
nlayer=args.nlayer,
dropout=args.dropout,
),
args=args,
)
for idx, dataset_trainer_name in enumerate(data.keys())
]
server = Server_GC(
base_model(nlayer=args.nlayer, nhid=args.hidden), args.device, args.use_cluster
)
# TODO: check and modify whether deepcopy should be added.
# trainers = copy.deepcopy(init_trainers)
# server = copy.deepcopy(init_server)
# End initialization time tracking after server setup is complete
monitor.init_time_end()
print("\nDone setting up devices.")
################ choose the algorithm to run ################
print(f"Running {args.algorithm} ...")
model_parameters = {
"SelfTrain": lambda: run_GC_selftrain(
trainers=trainers,
server=server,
local_epoch=args.local_epoch,
monitor=monitor,
),
"FedAvg": lambda: run_GC_Fed_algorithm(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
algorithm="FedAvg",
monitor=monitor,
),
"FedProx": lambda: run_GC_Fed_algorithm(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
algorithm="FedProx",
mu=args.mu,
),
"GCFL": lambda: run_GCFL_algorithm(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
EPS_1=args.epsilon1,
EPS_2=args.epsilon2,
algorithm_type="gcfl",
monitor=monitor,
),
"GCFL+": lambda: run_GCFL_algorithm(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
EPS_1=args.epsilon1,
EPS_2=args.epsilon2,
algorithm_type="gcfl_plus",
seq_length=args.seq_length,
standardize=args.standardize,
monitor=monitor,
),
"GCFL+dWs": lambda: run_GCFL_algorithm(
trainers=trainers,
server=server,
communication_rounds=args.num_rounds,
local_epoch=args.local_epoch,
EPS_1=args.epsilon1,
EPS_2=args.epsilon2,
algorithm_type="gcfl_plus_dWs",
seq_length=args.seq_length,
standardize=args.standardize,
monitor=monitor,
),
}
if args.algorithm in model_parameters:
output = model_parameters[args.algorithm]()
else:
raise ValueError(f"Unknown model: {args.algorithm}")
#################### save the output ####################
if args.save_files:
outdir_result = os.path.join(outdir, f"accuracy_seed{args.seed}.csv")
pd.DataFrame(output).to_csv(outdir_result)
print(f"The output has been written to file: {outdir_result}")
if monitor is not None:
monitor.print_comm_cost()
ray.shutdown()
# The following code is the implementation of different federated graph classification methods.
[docs]
def run_GC_selftrain(
trainers: list, server: Any, local_epoch: int, monitor: Optional[Monitor] = None
) -> dict:
"""
Run the training and testing process of self-training algorithm.
It only trains the model locally, and does not perform weights aggregation.
It is useful as a baseline comparison for federated methods.
Parameters
----------
trainers: list
List of trainers, each of which is a Trainer_GC object
server: Any
Server_GC object
local_epoch: int
Number of local epochs
Returns
-------
all_accs: dict
Dictionary with training and test accuracies for each trainer
"""
# all trainers are initialized with the same weights
if monitor is not None:
monitor.pretrain_time_start()
global_params_id = ray.put(server.W)
for trainer in trainers:
trainer.update_params.remote(global_params_id)
if monitor is not None:
monitor.pretrain_time_end()
all_accs = {}
acc_refs = []
if monitor is not None:
monitor.train_time_start()
for trainer in trainers:
trainer.local_train.remote(local_epoch=local_epoch)
acc_ref = trainer.local_test.remote()
acc_refs.append(acc_ref)
while True:
ready, left = ray.wait(acc_refs, num_returns=1, timeout=None)
if ready:
for t in ready:
_, acc, trainer_name, trainingaccs, valaccs = ray.get(t)
all_accs[trainer_name] = [
trainingaccs,
valaccs,
acc,
]
print(" > {} done.".format(trainer_name))
print(f"trainingaccs: {trainingaccs}, valaccs: {valaccs}, acc: {acc}")
acc_refs = left
if not acc_refs:
break
if monitor is not None:
model_size_mb = server.get_model_size() / (1024 * 1024)
monitor.add_train_comm_cost(
upload_mb=0, # No parameter upload in self-training
download_mb=model_size_mb * len(trainers),
)
monitor.train_time_end()
frame = pd.DataFrame(all_accs).T.iloc[:, [2]]
frame.columns = ["test_acc"]
print(frame)
# TODO: delete to make speed faster
print(f"Average test accuracy: {gc_avg_accuracy(frame, trainers)}")
if monitor is not None:
monitor.print_comm_cost()
return frame
[docs]
def run_GC_Fed_algorithm(
trainers: list,
server: Any,
communication_rounds: int,
local_epoch: int,
algorithm: str,
mu: float = 0.0,
sampling_frac: float = 1.0,
monitor: Optional[Monitor] = None,
) -> pd.DataFrame:
"""
Run the training and testing process of FedAvg or FedProx algorithm.
It trains the model locally, aggregates the weights to the server,
and downloads the global model within each communication round.
Parameters
----------
trainers: list
List of trainers, each of which is a Trainer_GC object
server: Any
Server_GC object
communication_rounds: int
Number of communication rounds
local_epoch: int
Number of local epochs
algorithm: str
Algorithm to run, either 'FedAvg' or 'FedProx'
mu: float, optional
Proximal term for FedProx (default is 0.0)
sampling_frac: float, optional
Fraction of trainers to sample (default is 1.0)
Returns
-------
frame: pd.DataFrame
Pandas dataframe with test accuracies
"""
if monitor is not None:
monitor.pretrain_time_start()
global_params_id = ray.put(server.W)
for trainer in trainers:
trainer.update_params.remote(global_params_id)
if monitor is not None:
monitor.pretrain_time_end()
if monitor is not None:
monitor.train_time_start()
for c_round in range(1, communication_rounds + 1):
if (c_round) % 10 == 0:
# print the current round every 10 rounds
print(f" > Training round {c_round} finished.")
if c_round == 1:
selected_trainers = trainers
else:
selected_trainers = server.random_sample_trainers(trainers, sampling_frac)
for trainer in selected_trainers:
if algorithm == "FedAvg":
trainer.local_train.remote(local_epoch=local_epoch)
elif algorithm == "FedProx":
trainer.local_train.remote(
local_epoch=local_epoch, train_option="prox", mu=mu
)
else:
raise ValueError(
"Invalid algorithm. Choose either 'FedAvg' or 'FedProx'."
)
server.aggregate_weights(selected_trainers)
if monitor is not None:
model_size_mb = server.get_model_size() / (1024 * 1024)
num_clients = len(selected_trainers)
monitor.add_train_comm_cost(
upload_mb=model_size_mb * num_clients,
download_mb=0,
)
ray.internal.free([global_params_id]) # Free the old weight memory
global_params_id = ray.put(server.W)
for trainer in selected_trainers:
trainer.update_params.remote(global_params_id)
if algorithm == "FedProx":
trainer.cache_weights.remote()
if monitor is not None:
# Download cost: server sends parameters to clients
monitor.add_train_comm_cost(
upload_mb=0,
download_mb=model_size_mb * num_clients,
)
if monitor is not None:
monitor.train_time_end()
# Test phase
frame = pd.DataFrame()
acc_refs = []
for trainer in trainers:
acc_ref = trainer.local_test.remote()
acc_refs.append(acc_ref)
while acc_refs:
ready, left = ray.wait(acc_refs, num_returns=1, timeout=None)
if ready:
for t in ready:
_, acc, trainer_name, trainingaccs, valaccs = ray.get(t)
frame.loc[trainer_name, "test_acc"] = acc
acc_refs = left
def highlight_max(s: pd.Series) -> list:
is_max = s == s.max()
return ["background-color: yellow" if v else "" for v in is_max]
fs = frame.style.apply(highlight_max).data
print(fs)
print(f"Average test accuracy: {gc_avg_accuracy(frame, trainers)}")
if monitor is not None:
monitor.print_comm_cost()
return frame
[docs]
def run_GCFL_algorithm(
trainers: list,
server: Any,
communication_rounds: int,
local_epoch: int,
EPS_1: float,
EPS_2: float,
algorithm_type: str,
seq_length: int = 0,
standardize: bool = True,
monitor: Optional[Monitor] = None,
) -> pd.DataFrame:
"""
Run the specified GCFL algorithm.
Parameters
----------
trainers: list
List of trainers, each of which is a Trainer_GC object
server: Any
Server_GC object
communication_rounds: int
Number of communication rounds
local_epoch: int
Number of local epochs
EPS_1: float
Threshold for mean update norm
EPS_2: float
Threshold for max update norm
algorithm_type: str
Type of algorithm ('gcfl', 'gcfl_plus', 'gcfl_plus_dWs')
seq_length: int, optional
The length of the gradient norm sequence, required for 'gcfl_plus' and 'gcfl_plus_dWs'
standardize: bool, optional
Whether to standardize the distance matrix, required for 'gcfl_plus' and 'gcfl_plus_dWs'
Returns
-------
frame: pandas.DataFrame
Pandas dataframe with test accuracies
"""
if algorithm_type not in ["gcfl", "gcfl_plus", "gcfl_plus_dWs"]:
raise ValueError(
"Invalid algorithm_type. Must be 'gcfl', 'gcfl_plus', or 'gcfl_plus_dWs'."
)
if monitor is not None:
monitor.pretrain_time_start()
cluster_indices = [np.arange(len(trainers)).astype("int")]
trainer_clusters = [[trainers[i] for i in idcs] for idcs in cluster_indices]
# Initialize clustering statistics tracking
from typing import Dict, List, Union
clustering_stats: Dict[str, Any] = {
"total_clustering_events": 0,
"similarity_computations": 0,
"dtw_computations": 0,
"model_cache_operations": 0,
"rounds_with_clustering": [],
"cluster_sizes_per_round": [],
}
global_params_id = ray.put(server.W)
if algorithm_type in ["gcfl_plus", "gcfl_plus_dWs"]:
seqs_grads: Dict[int, List[Any]] = {
ray.get(c.get_id.remote()): [] for c in trainers
}
# Perform update_params before communication rounds for GCFL+ and GCFL+ dWs
for trainer in trainers:
trainer.update_params.remote(global_params_id)
if monitor is not None:
monitor.pretrain_time_end()
acc_trainers: List[Any] = []
if monitor is not None:
monitor.train_time_start()
for c_round in range(1, communication_rounds + 1):
if (c_round) % 10 == 0:
print(f" > Training round {c_round} finished.")
round_upload_mb: float = 0.0
round_download_mb: float = 0.0
round_clustering_occurred = False
if c_round == 1:
# Perform update_params at the beginning of the first communication round
# ray.internal.free(
# [global_params_id]
# ) # Free the old weight memory in object store
global_params_id = ray.put(server.W)
for trainer in trainers:
trainer.update_params.remote(global_params_id)
# Initial parameter distribution cost
if monitor is not None:
model_size_mb = server.get_model_size() / (1024 * 1024)
round_download_mb += model_size_mb * len(trainers)
# Local training phase
reset_params_refs = []
participating_trainers = server.random_sample_trainers(trainers, frac=1.0)
for trainer in participating_trainers:
trainer.local_train.remote(local_epoch=local_epoch, train_option="gcfl")
reset_params_ref = trainer.reset_params.remote()
reset_params_refs.append(reset_params_ref)
ray.get(reset_params_refs)
# Add communication cost for reset_params operation (parameter retrieval after training)
if monitor is not None:
model_size_mb = server.get_model_size() / (1024 * 1024)
round_upload_mb += model_size_mb * len(participating_trainers)
# Gradient/weight change collection phase - get actual data sizes
for trainer in participating_trainers:
if algorithm_type == "gcfl_plus":
grad_norm = ray.get(trainer.get_conv_grads_norm.remote())
seqs_grads[ray.get(trainer.get_id.remote())].append(grad_norm)
# Gradient norm is typically a scalar (8 bytes for float64)
round_upload_mb += 8 / (1024 * 1024)
elif algorithm_type == "gcfl_plus_dWs":
dw_norm = ray.get(trainer.get_conv_dWs_norm.remote())
seqs_grads[ray.get(trainer.get_id.remote())].append(dw_norm)
# Weight change norm is typically a scalar (8 bytes for float64)
round_upload_mb += 8 / (1024 * 1024)
# Clustering decision phase - communication cost for update norm computations
cluster_indices_new = []
model_size_mb = server.get_model_size() / (1024 * 1024)
for idc in cluster_indices:
max_norm = server.compute_max_update_norm([trainers[i] for i in idc])
mean_norm = server.compute_mean_update_norm([trainers[i] for i in idc])
# Only add clustering-specific communication cost when clustering condition is met
if mean_norm < EPS_1 and max_norm > EPS_2 and len(idc) > 2 and c_round > 20:
# Record that clustering occurred in this round
round_clustering_occurred = True
clustering_stats["total_clustering_events"] = (
clustering_stats.get("total_clustering_events", 0) + 1
)
# marginal condition for gcfl, gcfl+, gcfl+dws
if algorithm_type == "gcfl" or all(
len(value) >= seq_length for value in seqs_grads.values()
):
# Record model cache operation
clustering_stats["model_cache_operations"] = (
clustering_stats.get("model_cache_operations", 0) + 1
)
# Cache model - full weight data uses actual model size
full_weight = ray.get(trainers[idc[0]].get_total_weight.remote())
server.cache_model(idc, full_weight, acc_trainers)
round_upload_mb += model_size_mb
if algorithm_type == "gcfl":
# Record similarity computation
clustering_stats["similarity_computations"] = (
clustering_stats.get("similarity_computations", 0) + 1
)
# Similarity computation - requires gradients from all trainers
similarity_matrix = server.compute_pairwise_similarities(
trainers
)
# Use actual model size for gradient transmission
round_upload_mb += model_size_mb * len(trainers)
c1, c2 = server.min_cut(similarity_matrix[idc][:, idc], idc)
cluster_indices_new += [c1, c2]
else: # gcfl+, gcfl+dws
# Record DTW computation
clustering_stats["dtw_computations"] = (
clustering_stats.get("dtw_computations", 0) + 1
)
# Sequence data: seq_length scalars per trainer
seq_data_size_bytes = (
seq_length * len(idc) * 8
) # 8 bytes per scalar
round_upload_mb += seq_data_size_bytes / (1024 * 1024)
tmp = [seqs_grads[id][-seq_length:] for id in idc]
dtw_distances = server.compute_pairwise_distances(
tmp, standardize
)
c1, c2 = server.min_cut(
np.max(dtw_distances) - dtw_distances, idc
)
cluster_indices_new += [c1, c2]
seqs_grads = {ray.get(c.get_id.remote()): [] for c in trainers}
else:
cluster_indices_new += [idc]
else:
cluster_indices_new += [idc]
# Record clustering statistics for this round
if round_clustering_occurred:
if isinstance(clustering_stats["rounds_with_clustering"], list):
clustering_stats["rounds_with_clustering"].append(c_round)
if isinstance(clustering_stats["cluster_sizes_per_round"], list):
clustering_stats["cluster_sizes_per_round"].append(len(cluster_indices_new))
cluster_indices = cluster_indices_new
trainer_clusters = [[trainers[i] for i in idcs] for idcs in cluster_indices]
# Cluster-wise aggregation phase - always happens but cost varies based on clustering
for cluster in trainer_clusters:
cluster_size = len(cluster)
# Use actual model size for parameter transmission
model_size_mb = server.get_model_size() / (1024 * 1024)
# Basic aggregation communication (always happens regardless of clustering)
# Each trainer uploads weights for aggregation
round_upload_mb += model_size_mb * cluster_size # Weight parameters only
# Training sizes are small and always needed
round_upload_mb += (4 * cluster_size) / (
1024 * 1024
) # Training sizes (int32)
# After aggregation, updated parameters are sent back to cluster
round_download_mb += model_size_mb * cluster_size
server.aggregate_clusterwise(trainer_clusters)
# Local testing phase - add communication cost for parameter retrieval during testing
acc_trainers = []
acc_trainers_refs = [trainer.local_test.remote() for trainer in trainers]
# Collect the model parameters as they become ready
while acc_trainers_refs:
ready, left = ray.wait(acc_trainers_refs, num_returns=1, timeout=None)
if ready:
for t in ready:
acc_trainers.append(ray.get(t)[1])
# Test result communication cost is negligible (single float value)
acc_trainers_refs = left
# Record communication cost for this round
if monitor is not None:
monitor.add_train_comm_cost(
upload_mb=round_upload_mb,
download_mb=round_download_mb,
)
# Print detailed clustering statistics
print("\n" + "=" * 50)
print("CLUSTERING STATISTICS")
print("=" * 50)
print(f"Algorithm: {algorithm_type}")
print(
f"Clustering Events: {clustering_stats['total_clustering_events']}/{communication_rounds}"
)
print(
f"Clustering Frequency: {clustering_stats['total_clustering_events']/communication_rounds:.1%}"
)
if clustering_stats["rounds_with_clustering"]:
print(f"Clustering Rounds: {clustering_stats['rounds_with_clustering']}")
print("=" * 50)
# Final model caching
for idc in cluster_indices:
server.cache_model(
idc, ray.get(trainers[idc[0]].get_total_weight.remote()), acc_trainers
)
if monitor is not None:
monitor.train_time_end()
# Build results
results = np.zeros([len(trainers), len(server.model_cache)])
for i, (idcs, W, accs) in enumerate(server.model_cache):
results[idcs, i] = np.array(accs)
frame = pd.DataFrame(
results,
columns=["FL Model"]
+ ["Model {}".format(i) for i in range(results.shape[1] - 1)],
index=[
"{}".format(ray.get(trainers[i].get_name.remote()))
for i in range(results.shape[0])
],
)
frame = pd.DataFrame(frame.max(axis=1))
frame.columns = ["test_acc"]
print(frame)
print(f"Average test accuracy: {gc_avg_accuracy(frame, trainers)}")
if monitor is not None:
monitor.print_comm_cost()
return frame
[docs]
def run_LP(args: Any) -> None:
"""
Implements various federated learning methods for link prediction tasks with support
for online learning and buffer mechanisms. Handles temporal aspects of link prediction
and cross-region user interactions.
Algorithm choices include ('STFL', 'StaticGNN', '4D-FED-GNN+', 'FedLink').
Parameters
----------
args: attridict
The configuration arguments.
"""
monitor = Monitor(use_cluster=args.use_cluster)
def setup_trainer_server(
country_codes: list,
user_id_mapping: Any,
item_id_mapping: Any,
meta_data: tuple,
hidden_channels: int = 64,
) -> tuple:
"""
Setup the trainer and server
Parameters
----------
country_codes: list
The list of country codes
user_id_mapping: Any
The user id mapping
item_id_mapping: Any
The item id mapping
meta_data: tuple
The meta data
hidden_channels: int, optional
The number of hidden channels
Returns
-------
(list, Server_LP): tuple
[0]: The list of clients
[1]: The server
"""
number_of_clients = len(country_codes)
number_of_users, number_of_items = len(user_id_mapping.keys()), len(
item_id_mapping.keys()
)
num_cpus_per_client = args.num_cpus_per_trainer
if args.gpu == True:
device = torch.device("cuda")
print("gpu detected")
num_gpus_per_client = args.num_gpus_per_trainer
else:
device = torch.device("cpu")
num_gpus_per_client = 0
print("gpu not detected")
@ray.remote(
num_gpus=num_gpus_per_client,
num_cpus=num_cpus_per_client,
scheduling_strategy="SPREAD",
)
class Trainer(Trainer_LP):
def __init__(self, *args, **kwargs): # type: ignore
super().__init__(*args, **kwargs)
print(
f"[Debug] Trainer running on node IP: {ray.util.get_node_ip_address()}"
)
clients = [
Trainer.remote( # type: ignore
i,
country_code=args.country_codes[i],
user_id_mapping=user_id_mapping,
item_id_mapping=item_id_mapping,
number_of_users=number_of_users,
number_of_items=number_of_items,
meta_data=meta_data,
dataset_path=args.dataset_path,
hidden_channels=args.hidden_channels,
)
for i in range(number_of_clients)
]
server = Server_LP( # the concrete information of users and items is not available in the server
number_of_users=number_of_users,
number_of_items=number_of_items,
meta_data=meta_data,
trainers=clients,
)
print(
f"[Debug] Server running on IP: {socket.gethostbyname(socket.gethostname())}"
)
return clients, server
method = args.method
use_buffer = args.use_buffer
buffer_size = args.buffer_size
online_learning = args.online_learning
global_rounds = args.global_rounds
local_steps = args.local_steps
hidden_channels = args.hidden_channels
record_results = args.record_results
country_codes = args.country_codes
current_dir = os.path.dirname(os.path.abspath(__file__))
ray.init()
monitor.init_time_start()
# Append paths relative to the current script's directory
sys.path.append(os.path.join(current_dir, "../fedgraph"))
sys.path.append(os.path.join(current_dir, "../../"))
dataset_path = args.dataset_path
global_file_path = os.path.join(dataset_path, "data_global.txt")
traveled_file_path = os.path.join(dataset_path, "traveled_users.txt")
# check the validity of the input
assert method in ["STFL", "StaticGNN", "4D-FED-GNN+", "FedLink"], "Invalid method."
assert all(
code in ["US", "BR", "ID", "TR", "JP"] for code in country_codes
), "The country codes should be in 'US', 'BR', 'ID', 'TR', 'JP'"
if use_buffer:
assert buffer_size > 0, "The buffer size should be greater than 0."
check_data_files_existance(country_codes, dataset_path)
# get global user and item mapping
user_id_mapping, item_id_mapping = get_global_user_item_mapping(
global_file_path=global_file_path
)
# set meta_data
meta_data = (
["user", "item"],
[("user", "select", "item"), ("item", "rev_select", "user")],
)
# repeat the training process
number_of_clients = len(country_codes) # each country is a client
clients, server = setup_trainer_server(
country_codes=country_codes,
user_id_mapping=user_id_mapping,
item_id_mapping=item_id_mapping,
meta_data=meta_data,
hidden_channels=hidden_channels,
)
server.monitor = monitor
# End initialization time tracking
monitor.init_time_end()
"""Broadcast the global model parameter to all clients"""
monitor.pretrain_time_start()
global_model_parameter = (
server.get_model_parameter()
) # fetch the global model parameter
# TODO: add memory optimization here by move ref to shared raylet
for i in range(number_of_clients):
clients[i].set_model_parameter.remote(
global_model_parameter
) # broadcast the global model parameter to all clients
"""Determine the start and end time of the conditional information"""
(
start_time,
end_time,
prediction_days,
start_time_float_format,
end_time_float_format,
) = get_start_end_time(online_learning=online_learning, method=method)
if record_results:
file_name = (
f"{method}_buffer_{use_buffer}_{buffer_size}_online_{online_learning}.txt"
)
result_writer = open(file_name, "a+")
time_writer = open("train_time_" + file_name, "a+")
else:
result_writer = None
time_writer = None
monitor.pretrain_time_end()
monitor.train_time_start()
# from 2012-04-03 to 2012-04-13
for day in range(prediction_days): # make predictions for each day
# get the train and test data for each client at the current time step
for i in range(number_of_clients):
clients[i].get_train_test_data_at_current_time_step.remote(
start_time_float_format,
end_time_float_format,
use_buffer=use_buffer,
buffer_size=buffer_size,
)
clients[i].calculate_traveled_user_edge_indices.remote(
file_path=traveled_file_path
)
if online_learning:
print(f"start training for day {day + 1}")
else:
print(f"start training")
for iteration in range(global_rounds):
# each client train on local graph
print(f"global rounds: {iteration}")
current_loss = LP_train_global_round(
server=server,
local_steps=local_steps,
use_buffer=use_buffer,
method=method,
online_learning=online_learning,
prediction_day=day,
curr_iteration=iteration,
global_rounds=global_rounds,
record_results=record_results,
result_writer=result_writer,
time_writer=time_writer,
)
if current_loss >= 0.01:
print("training is not complete")
# go to next day
(
start_time,
end_time,
start_time_float_format,
end_time_float_format,
) = to_next_day(start_time=start_time, end_time=end_time, method=method)
monitor.train_time_end()
if result_writer is not None and time_writer is not None:
result_writer.close()
time_writer.close()
if monitor is not None:
monitor.print_comm_cost()
print("The whole process has ended")
ray.shutdown()
[docs]
def LP_train_global_round(
server: Any,
local_steps: int,
use_buffer: bool,
method: str,
online_learning: bool,
prediction_day: int,
curr_iteration: int,
global_rounds: int,
record_results: bool = False,
result_writer: Any = None,
time_writer: Any = None,
) -> float:
"""
This function trains the clients for a global round, handles model aggregation,
updates the server model with the average of the client models, and and evaluates
performance metrics including AUC scores and hit rates.
Supports different training methods.
Parameters
----------
clients : list
List of client objects
server : Any
Server object
local_steps : int
Number of local steps
use_buffer : bool
Specifies whether to use buffer
method : str
Specifies the method
online_learning : bool
Specifies online learning
prediction_day : int
Prediction day
curr_iteration : int
Current iteration
global_rounds : int
Global rounds
record_results : bool, optional
Record model AUC and Running time
result_writer : Any, optional
File writer object
time_writer : Any, optional
File writer object
Returns
-------
current_loss : float
Loss of the model on the training data
"""
if record_results:
assert result_writer is not None and time_writer is not None
# local training
number_of_clients = len(server.clients)
print(f"Training in LP_train_global_round, number of clients: {number_of_clients}")
local_training_results = []
for client_id in range(number_of_clients):
# current_loss, train_finish_times
local_training_result_ref = server.clients[client_id].train.remote(
client_id=client_id, local_updates=local_steps, use_buffer=use_buffer
) # local training
local_training_results.append(local_training_result_ref)
while True:
ready, left = ray.wait(local_training_results, num_returns=1, timeout=None)
if ready:
for t in ready:
client_id, current_loss, train_finish_times = ray.get(t)
print(
f"clientId: {client_id} current_loss: {current_loss} train_finish_times: {train_finish_times}"
)
if record_results:
for train_finish_time in train_finish_times:
time_writer.write(
f"client {str(client_id)} train time {str(train_finish_time)}\n"
)
print(
f"client {str(client_id)} train time {str(train_finish_time)}\n"
)
local_training_results = left
if not local_training_results:
break
# aggregate the parameters and broadcast to the clients
gnn_only = True if method == "FedLink (OnlyAvgGNN)" else False
if method != "StaticGNN":
model_avg_parameter = server.fedavg(gnn_only)
server.set_model_parameter(model_avg_parameter, gnn_only)
for client_id in range(number_of_clients):
server.clients[client_id].set_model_parameter.remote(
model_avg_parameter, gnn_only
)
model_size_mb = 0.0
if hasattr(server, "get_model_size") and hasattr(server, "monitor"):
model_size_mb = server.get_model_size() / (1024 * 1024)
server.monitor.add_train_comm_cost(
upload_mb=model_size_mb * number_of_clients,
download_mb=model_size_mb * number_of_clients,
)
# ======== Add embedding size to theoretical train communication cost ========
if method in ["STFL", "FedLink", "4D-FED-GNN+"]:
number_of_users = server.number_of_users
number_of_items = server.number_of_items
embedding_dim = server.hidden_channels
float_size = 4 # float32
embedding_param_size_bytes = (
(number_of_users + number_of_items) * embedding_dim * float_size
)
embedding_param_size_MB = embedding_param_size_bytes / (1024 * 1024)
server.monitor.add_train_comm_cost(
upload_mb=embedding_param_size_MB * number_of_clients,
download_mb=embedding_param_size_MB * number_of_clients,
)
print(
f"//Log Theoretical Embedding Communication Cost Added (Train Phase): {embedding_param_size_MB * number_of_clients * 2:.2f} MB //end"
)
# test the model
test_results = [
server.clients[client_id].test.remote(server.clients[client_id], use_buffer)
for client_id in range(number_of_clients)
]
avg_auc, avg_hit_rate, avg_traveled_user_hit_rate = 0.0, 0.0, 0.0
# for client_id in range(number_of_clients):
# auc_score, hit_rate, traveled_user_hit_rate = server.clients[client_id].test(
# use_buffer=use_buffer
# ) # local testing
# avg_auc += auc_score
# avg_hit_rate += hit_rate
# avg_traveled_user_hit_rate += traveled_user_hit_rate
# print(
# f"Day {prediction_day} client {client_id} auc score: {auc_score} hit rate: {
# hit_rate} traveled user hit rate: {traveled_user_hit_rate}"
# )
# # write final test_auc
# if curr_iteration + 1 == global_rounds and record_results:
# result_writer.write(
# f"Day {prediction_day} client {client_id} final auc score: {auc_score} hit rate: {
# hit_rate} traveled user hit rate: {traveled_user_hit_rate}\n"
# )
while test_results:
ready, left = ray.wait(test_results, num_returns=1, timeout=None)
if ready:
for t in ready:
client_id, auc_score, hit_rate, traveled_user_hit_rate = ray.get(t)
avg_auc += auc_score
avg_hit_rate += hit_rate
avg_traveled_user_hit_rate += traveled_user_hit_rate
print(
f"Day {prediction_day} client {client_id} auc score: {auc_score} hit rate: {hit_rate} traveled user hit rate: {traveled_user_hit_rate}"
)
# write final test_auc
if curr_iteration + 1 == global_rounds and record_results:
result_writer.write(
f"Day {prediction_day} client {client_id} final auc score: {auc_score} hit rate: {hit_rate} traveled user hit rate: {traveled_user_hit_rate}\n"
)
print(
f"Day {prediction_day} client {client_id} final auc score: {auc_score} hit rate: {hit_rate} traveled user hit rate: {traveled_user_hit_rate}\n"
)
test_results = left
avg_auc /= number_of_clients
avg_hit_rate /= number_of_clients
if online_learning:
print(
f"Predict Day {prediction_day + 1} average auc score: {avg_auc} hit rate: {avg_hit_rate}"
)
else:
print(f"Predict Day 20 average auc score: {avg_auc} hit rate: {avg_hit_rate}")
return current_loss