Trainer Class
- class fedgraph.trainer_class.Trainer_GC(model: Any, trainer_id: int, trainer_name: str, train_size: int, dataloader: dict, optimizer: object, args: Any)[source]
A trainer class specified for graph classification tasks, which includes functionalities required for training GIN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.
- Parameters:
model (object) – The model to be trained, which is based on the GIN model.
trainer_id (int) – The ID of the trainer.
trainer_name (str) – The name of the trainer.
train_size (int) – The size of the training dataset.
dataLoader (dict) – The dataloaders for training, validation, and testing.
optimizer (object) – The optimizer for training.
args (Any) – The arguments for the training.
- train_stats
The training statistics of the model.
- Type:
Any
- cache_weights(*, _ray_trace_ctx=None) None [source]
Cache the weights of the model. The implementation is copying the model weights (W) to the cached weights (W_old).
- compute_mean_norm(total_size: int, keys: dict, *, _ray_trace_ctx=None) Tensor [source]
Compute the mean update norm (i.e., dW) for the trainer :returns: curr_dW :rtype: Tensor
- compute_update_norm(keys: dict, *, _ray_trace_ctx=None) float [source]
Compute the max update norm (i.e., dW) for the trainer
- local_test(test_option: str = 'basic', mu: float = 1, *, _ray_trace_ctx=None) tuple [source]
Final test of the model on the test dataset based on the test option.
- Parameters:
- Returns:
(test_loss, test_acc, trainer_name, trainingAccs, valAccs) – The average loss and accuracy, trainer’s name, trainer.train_stats[“trainingAccs”][-1], trainer.train_stats[“valAccs”][-1]
- Return type:
- local_train(local_epoch: int, train_option: str = 'basic', mu: float = 1, *, _ray_trace_ctx=None) None [source]
This function is a interface of the trainer class to train the model locally. It will call the train function specified for the training option, based on the args provided.
- Parameters:
local_epoch (int) – The number of local epochs
train_option (str, optional) – The training option. The possible values are ‘basic’, ‘prox’, and ‘gcfl’. The default is ‘basic’. ‘basic’ - self-train and FedAvg ‘prox’ - FedProx that includes the proximal term ‘gcfl’ - GCFL, GCFL+ and GCFL+dWs
mu (float, optional) – The proximal term. The default is 1.
- reset_params(*, _ray_trace_ctx=None) None [source]
Reset the weights of the model to the cached weights. The implementation is copying the cached weights (W_old) to the model weights (W).
- class fedgraph.trainer_class.Trainer_General(rank: int, args_hidden: int, device: device, args: Any, local_node_index: Tensor | None = None, communicate_node_index: Tensor | None = None, adj: Tensor | None = None, train_labels: Tensor | None = None, test_labels: Tensor | None = None, features: Tensor | None = None, idx_train: Tensor | None = None, idx_test: Tensor | None = None)[source]
A general trainer class for training GCN in a federated learning setup, which includes functionalities required for training GCN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.
- Parameters:
rank (int) – Unique identifier for the training instance (typically representing a trainer in federated learning).
local_node_index (torch.Tensor) – Indices of nodes local to this trainer.
communicate_node_index (torch.Tensor) – Indices of nodes that participate in communication during training.
adj (torch.Tensor) – The adjacency matrix representing the graph structure.
train_labels (torch.Tensor) – Labels of the training data.
test_labels (torch.Tensor) – Labels of the testing data.
features (torch.Tensor) – Node features for the entire graph.
idx_train (torch.Tensor) – Indices of training nodes.
idx_test (torch.Tensor) – Indices of test nodes.
args_hidden (int) – Number of hidden units in the GCN model.
global_node_num (int) – Total number of nodes in the global graph.
class_num (int) – Number of classes for classification.
device (torch.device) – The device (CPU or GPU) on which the model will be trained.
args (Any) – Additional arguments required for model initialization and training.
- get_all_loss_accuray(*, _ray_trace_ctx=None) list [source]
Returns all recorded training and testing losses and accuracies.
- Returns:
(list) – A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.
- Return type:
- get_local_feature_sum(*, _ray_trace_ctx=None) Tensor [source]
Computes the sum of features of all 1-hop neighbors for each node and normalizes the result.
- Returns:
normalized_sum – The normalized sum of features of 1-hop neighbors for each node
- Return type:
torch.Tensor
- get_local_feature_sum_og(*, _ray_trace_ctx=None) Tensor [source]
Computes the sum of features of all 1-hop neighbors for each node, used for plain text version.
- Returns:
one_hop_neighbor_feature_sum – The sum of features of 1-hop neighbors for each node
- Return type:
torch.Tensor
- get_params(*, _ray_trace_ctx=None) tuple [source]
Retrieves the current parameters of the model.
- Returns:
(tuple) – A tuple containing the current parameters of the model.
- Return type:
- get_rank(*, _ray_trace_ctx=None) int [source]
Returns the rank (trainer ID) of the trainer.
- Returns:
(int) – The rank (trainer ID) of this trainer instance.
- Return type:
- load_encrypted_params(encrypted_data: tuple, current_global_epoch: int, *, _ray_trace_ctx=None)[source]
Load encrypted parameters with rescaling
- load_feature_aggregation(feature_aggregation: Tensor, *, _ray_trace_ctx=None) None [source]
Loads the aggregated features into the trainer. Used for plain text version
- Parameters:
feature_aggregation (torch.Tensor) – The aggregated features to be loaded.
- local_test(*, _ray_trace_ctx=None) list [source]
Evaluates the model on the local test dataset.
- Returns:
(list) – A list containing the test loss and accuracy [local_test_loss, local_test_acc].
- Return type:
- relabel_adj(*, _ray_trace_ctx=None) None [source]
Relabels the adjacency matrix based on the communication node index.
- train(current_global_round: int, *, _ray_trace_ctx=None) None [source]
Performs local training for a specified number of iterations. This method updates the model using the loaded feature aggregation and the adjacency matrix.
- Parameters:
current_global_round (int) – The current global training round.
- class fedgraph.trainer_class.Trainer_LP(client_id: int, country_code: str, user_id_mapping: dict, item_id_mapping: dict, number_of_users: int, number_of_items: int, meta_data: tuple, dataset_path: str, hidden_channels: int = 64)[source]
A trainer class specified for graph link prediction tasks, which includes functionalities required for training GNN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.
- Parameters:
client_id (int) – The ID of the client.
country_code (str) – The country code of the client. Each client is associated with one country code.
user_id_mapping (dict) – The mapping of user IDs.
item_id_mapping (dict) – The mapping of item IDs.
number_of_users (int) – The number of users.
number_of_items (int) – The number of items.
meta_data (tuple) – The metadata of the dataset.
hidden_channels (int, optional) – The number of hidden channels in the GNN model. The default is 64.
- calculate_traveled_user_edge_indices(file_path: str, *, _ray_trace_ctx=None) None [source]
Calculate the indices of the edges of the traveled users.
- Parameters:
file_path (str) – The path to the file containing the traveled users.
- get_model_parameter(gnn_only: bool = False, *, _ray_trace_ctx=None) dict [source]
Get the model parameters.
- get_train_test_data_at_current_time_step(start_time_float_format: float, end_time_float_format: float, use_buffer: bool = False, buffer_size: int = 10, *, _ray_trace_ctx=None) None [source]
Get the training and testing data at the current time step.
- set_model_parameter(model_state_dict: dict, gnn_only: bool = False, *, _ray_trace_ctx=None) None [source]
Load the model parameters from the global server.
- test(clientId: int, use_buffer: bool = False, *, _ray_trace_ctx=None) tuple [source]
Test the model on the test data.