Trainer Class

class fedgraph.trainer_class.Trainer_GC(model: Any, trainer_id: int, trainer_name: str, train_size: int, dataloader: dict, optimizer: object, args: Any)[source]

A trainer class specified for graph classification tasks, which includes functionalities required for training GIN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.

Parameters:
  • model (object) – The model to be trained, which is based on the GIN model.

  • trainer_id (int) – The ID of the trainer.

  • trainer_name (str) – The name of the trainer.

  • train_size (int) – The size of the training dataset.

  • dataLoader (dict) – The dataloaders for training, validation, and testing.

  • optimizer (object) – The optimizer for training.

  • args (Any) – The arguments for the training.

model

The model to be trained, which is based on the GIN model.

Type:

object

id

The ID of the trainer.

Type:

int

name

The name of the trainer.

Type:

str

train_size

The size of the training dataset.

Type:

int

dataloader

The dataloaders for training, validation, and testing.

Type:

dict

optimizer

The optimizer for training.

Type:

object

args

The arguments for the training.

Type:

object

W

The weights of the model.

Type:

dict

dW

The gradients of the model.

Type:

dict

W_old

The cached weights of the model.

Type:

dict

gconv_names

The names of the gconv layers.

Type:

list

train_stats

The training statistics of the model.

Type:

Any

weights_norm

The norm of the weights of the model.

Type:

float

grads_norm

The norm of the gradients of the model.

Type:

float

conv_grads_norm

The norm of the gradients of the gconv layers.

Type:

float

conv_weights_Norm

The norm of the weights of the gconv layers.

Type:

float

conv_dWs_norm

The norm of the gradients of the gconv layers.

Type:

float

cache_weights(*, _ray_trace_ctx=None) None[source]

Cache the weights of the model. The implementation is copying the model weights (W) to the cached weights (W_old).

calculate_weighted_weight(key: Any, *, _ray_trace_ctx=None) tensor[source]
compute_mean_norm(total_size: int, keys: dict, *, _ray_trace_ctx=None) Tensor[source]

Compute the mean update norm (i.e., dW) for the trainer :returns: curr_dW :rtype: Tensor

compute_update_norm(keys: dict, *, _ray_trace_ctx=None) float[source]

Compute the max update norm (i.e., dW) for the trainer

get_conv_dWs_norm(*, _ray_trace_ctx=None) Any[source]
get_conv_grads_norm(*, _ray_trace_ctx=None) Any[source]
get_dW(*, _ray_trace_ctx=None) Any[source]
get_id(*, _ray_trace_ctx=None) Any[source]
get_name(*, _ray_trace_ctx=None) str[source]
get_total_weight(*, _ray_trace_ctx=None) Any[source]
get_train_size(*, _ray_trace_ctx=None) int[source]
get_weights(ks: Any, *, _ray_trace_ctx=None) dict[str, Any][source]
local_test(test_option: str = 'basic', mu: float = 1, *, _ray_trace_ctx=None) tuple[source]

Final test of the model on the test dataset based on the test option.

Parameters:
  • test_option (str, optional) – The test option. The possible values are ‘basic’ and ‘prox’. The default is ‘basic’. ‘basic’ - self-train, FedAvg, GCFL, GCFL+ and GCFL+dWs ‘prox’ - FedProx that includes the proximal term

  • mu (float, optional) – The proximal term. The default is 1.

Returns:

(test_loss, test_acc, trainer_name, trainingAccs, valAccs) – The average loss and accuracy, trainer’s name, trainer.train_stats[“trainingAccs”][-1], trainer.train_stats[“valAccs”][-1]

Return type:

tuple(float, float, string, float, float)

local_train(local_epoch: int, train_option: str = 'basic', mu: float = 1, *, _ray_trace_ctx=None) None[source]

This function is a interface of the trainer class to train the model locally. It will call the train function specified for the training option, based on the args provided.

Parameters:
  • local_epoch (int) – The number of local epochs

  • train_option (str, optional) – The training option. The possible values are ‘basic’, ‘prox’, and ‘gcfl’. The default is ‘basic’. ‘basic’ - self-train and FedAvg ‘prox’ - FedProx that includes the proximal term ‘gcfl’ - GCFL, GCFL+ and GCFL+dWs

  • mu (float, optional) – The proximal term. The default is 1.

reset_params(*, _ray_trace_ctx=None) None[source]

Reset the weights of the model to the cached weights. The implementation is copying the cached weights (W_old) to the model weights (W).

set_stats_norms(train_stats: Any, is_gcfl: bool = False, *, _ray_trace_ctx=None) None[source]

Set the norms of the weights and gradients of the model, as well as the statistics of the training.

Parameters:
  • train_stats (dict) – The training statistics of the model.

  • is_gcfl (bool, optional) – Whether the training is for GCFL. The default is False.

update_params(server_params: Any, *, _ray_trace_ctx=None) None[source]

Update the model parameters by downloading the global model weights from the server.

Parameters:

server (Server_GC) – The server object that contains the global model weights.

class fedgraph.trainer_class.Trainer_General(rank: int, args_hidden: int, device: device, args: Any, local_node_index: Tensor | None = None, communicate_node_index: Tensor | None = None, adj: Tensor | None = None, train_labels: Tensor | None = None, test_labels: Tensor | None = None, features: Tensor | None = None, idx_train: Tensor | None = None, idx_test: Tensor | None = None)[source]

A general trainer class for training GCN in a federated learning setup, which includes functionalities required for training GCN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.

Parameters:
  • rank (int) – Unique identifier for the training instance (typically representing a trainer in federated learning).

  • local_node_index (torch.Tensor) – Indices of nodes local to this trainer.

  • communicate_node_index (torch.Tensor) – Indices of nodes that participate in communication during training.

  • adj (torch.Tensor) – The adjacency matrix representing the graph structure.

  • train_labels (torch.Tensor) – Labels of the training data.

  • test_labels (torch.Tensor) – Labels of the testing data.

  • features (torch.Tensor) – Node features for the entire graph.

  • idx_train (torch.Tensor) – Indices of training nodes.

  • idx_test (torch.Tensor) – Indices of test nodes.

  • args_hidden (int) – Number of hidden units in the GCN model.

  • global_node_num (int) – Total number of nodes in the global graph.

  • class_num (int) – Number of classes for classification.

  • device (torch.device) – The device (CPU or GPU) on which the model will be trained.

  • args (Any) – Additional arguments required for model initialization and training.

decrypt_feature_sum(encrypted_sum, shape, *, _ray_trace_ctx=None)[source]
encrypt_feature_sum(feature_sum, *, _ray_trace_ctx=None)[source]
get_all_loss_accuray(*, _ray_trace_ctx=None) list[source]

Returns all recorded training and testing losses and accuracies.

Returns:

(list) – A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.

Return type:

list

get_encrypted_local_feature_sum(*, _ray_trace_ctx=None)[source]
get_encrypted_params(*, _ray_trace_ctx=None)[source]

Get encrypted parameters with proper scaling

get_info(*, _ray_trace_ctx=None)[source]
get_local_feature_sum(*, _ray_trace_ctx=None) Tensor[source]

Computes the sum of features of all 1-hop neighbors for each node and normalizes the result.

Returns:

normalized_sum – The normalized sum of features of 1-hop neighbors for each node

Return type:

torch.Tensor

get_local_feature_sum_og(*, _ray_trace_ctx=None) Tensor[source]

Computes the sum of features of all 1-hop neighbors for each node, used for plain text version.

Returns:

one_hop_neighbor_feature_sum – The sum of features of 1-hop neighbors for each node

Return type:

torch.Tensor

get_params(*, _ray_trace_ctx=None) tuple[source]

Retrieves the current parameters of the model.

Returns:

(tuple) – A tuple containing the current parameters of the model.

Return type:

tuple

get_rank(*, _ray_trace_ctx=None) int[source]

Returns the rank (trainer ID) of the trainer.

Returns:

(int) – The rank (trainer ID) of this trainer instance.

Return type:

int

init_model(global_node_num, class_num, *, _ray_trace_ctx=None)[source]
load_encrypted_feature_aggregation(encrypted_data, *, _ray_trace_ctx=None)[source]
load_encrypted_params(encrypted_data: tuple, current_global_epoch: int, *, _ray_trace_ctx=None)[source]

Load encrypted parameters with rescaling

load_feature_aggregation(feature_aggregation: Tensor, *, _ray_trace_ctx=None) None[source]

Loads the aggregated features into the trainer. Used for plain text version

Parameters:

feature_aggregation (torch.Tensor) – The aggregated features to be loaded.

local_test(*, _ray_trace_ctx=None) list[source]

Evaluates the model on the local test dataset.

Returns:

(list) – A list containing the test loss and accuracy [local_test_loss, local_test_acc].

Return type:

list

relabel_adj(*, _ray_trace_ctx=None) None[source]

Relabels the adjacency matrix based on the communication node index.

train(current_global_round: int, *, _ray_trace_ctx=None) None[source]

Performs local training for a specified number of iterations. This method updates the model using the loaded feature aggregation and the adjacency matrix.

Parameters:

current_global_round (int) – The current global training round.

update_params(params: tuple, current_global_epoch: int, *, _ray_trace_ctx=None) None[source]

Updates the model parameters with global parameters received from the server.

Parameters:
  • params (tuple) – A tuple containing the global parameters from the server.

  • current_global_epoch (int) – The current global epoch number.

use_fedavg_feature(*, _ray_trace_ctx=None) None[source]
verify_param_ranges(params, stage='pre-encryption', *, _ray_trace_ctx=None)[source]

Verify parameter ranges and print statistics

class fedgraph.trainer_class.Trainer_LP(client_id: int, country_code: str, user_id_mapping: dict, item_id_mapping: dict, number_of_users: int, number_of_items: int, meta_data: tuple, dataset_path: str, hidden_channels: int = 64)[source]

A trainer class specified for graph link prediction tasks, which includes functionalities required for training GNN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.

Parameters:
  • client_id (int) – The ID of the client.

  • country_code (str) – The country code of the client. Each client is associated with one country code.

  • user_id_mapping (dict) – The mapping of user IDs.

  • item_id_mapping (dict) – The mapping of item IDs.

  • number_of_users (int) – The number of users.

  • number_of_items (int) – The number of items.

  • meta_data (tuple) – The metadata of the dataset.

  • hidden_channels (int, optional) – The number of hidden channels in the GNN model. The default is 64.

calculate_traveled_user_edge_indices(file_path: str, *, _ray_trace_ctx=None) None[source]

Calculate the indices of the edges of the traveled users.

Parameters:

file_path (str) – The path to the file containing the traveled users.

get_model_parameter(gnn_only: bool = False, *, _ray_trace_ctx=None) dict[source]

Get the model parameters.

Parameters:

gnn_only (bool, optional) – Whether to get only the GNN parameters. The default is False.

Returns:

The model parameters.

Return type:

dict

get_train_test_data_at_current_time_step(start_time_float_format: float, end_time_float_format: float, use_buffer: bool = False, buffer_size: int = 10, *, _ray_trace_ctx=None) None[source]

Get the training and testing data at the current time step.

Parameters:
  • start_time_float_format (float) – The start time in float format.

  • end_time_float_format (float) – The end time in float format.

  • use_buffer (bool, optional) – Whether to use the buffer. The default is False.

  • buffer_size (int, optional) – The size of the buffer. The default is 10.

set_model_parameter(model_state_dict: dict, gnn_only: bool = False, *, _ray_trace_ctx=None) None[source]

Load the model parameters from the global server.

Parameters:
  • model_state_dict (dict) – The model parameters to be loaded.

  • gnn_only (bool, optional) – Whether to load only the GNN parameters. The default is False.

test(clientId: int, use_buffer: bool = False, *, _ray_trace_ctx=None) tuple[source]

Test the model on the test data.

Parameters:

use_buffer (bool, optional) – Whether to use the buffer. The default is False.

Returns:

(auc, hit_rate_at_2, traveled_user_hit_rate_at_2) – [0] The AUC score [1] The hit rate at 2 [2] The hit rate at 2 for traveled users

Return type:

tuple

train(client_id: int, local_updates: int, use_buffer: bool = False, *, _ray_trace_ctx=None) tuple[source]

Perform local training for a specified number of iterations.

Parameters:
  • local_updates (int) – The number of local updates.

  • use_buffer (bool, optional) – Whether to use the buffer. The default is False.

Returns:

(loss, train_finish_times) – [0] The loss of the model [1] The time taken for each local update

Return type:

tuple

fedgraph.trainer_class.load_trainer_data_from_hugging_face(trainer_id, args)[source]