Trainer Class#

class fedgraph.trainer_class.Trainer_General(rank: int, local_node_index: Tensor, communicate_node_index: Tensor, adj: Tensor, train_labels: Tensor, test_labels: Tensor, features: Tensor, idx_train: Tensor, idx_test: Tensor, args_hidden: int, global_node_num: int, class_num: int, device: device, args: Any)[source]#

A general trainer class for training GCN in a federated learning setup, which includes functionalities required for training GCN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.

Parameters:
  • rank (int) – Unique identifier for the training instance (typically representing a client in federated learning).

  • local_node_index (torch.Tensor) – Indices of nodes local to this trainer.

  • communicate_node_index (torch.Tensor) – Indices of nodes that participate in communication during training.

  • adj (torch.Tensor) – The adjacency matrix representing the graph structure.

  • train_labels (torch.Tensor) – Labels of the training data.

  • test_labels (torch.Tensor) – Labels of the testing data.

  • features (torch.Tensor) – Node features for the entire graph.

  • idx_train (torch.Tensor) – Indices of training nodes.

  • idx_test (torch.Tensor) – Indices of test nodes.

  • args_hidden (int) – Number of hidden units in the GCN model.

  • global_node_num (int) – Total number of nodes in the global graph.

  • class_num (int) – Number of classes for classification.

  • device (torch.device) – The device (CPU or GPU) on which the model will be trained.

  • args (Any) – Additional arguments required for model initialization and training.

get_all_loss_accuray(*, _ray_trace_ctx=None) list[source]#

Returns all recorded training and testing losses and accuracies.

Returns:

(list) – A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.

Return type:

list

get_local_feature_sum(*, _ray_trace_ctx=None) Tensor[source]#

Computes the sum of features of all 1-hop neighbors for each node.

Returns:

one_hop_neighbor_feature_sum – The sum of features of 1-hop neighbors for each node

Return type:

torch.Tensor

get_params(*, _ray_trace_ctx=None) tuple[source]#

Retrieves the current parameters of the model.

Returns:

(tuple) – A tuple containing the current parameters of the model.

Return type:

tuple

get_rank(*, _ray_trace_ctx=None) int[source]#

Returns the rank (client ID) of the trainer.

Returns:

(int) – The rank (client ID) of this trainer instance.

Return type:

int

load_feature_aggregation(feature_aggregation: Tensor, *, _ray_trace_ctx=None) None[source]#

Loads the aggregated features into the trainer.

Parameters:

feature_aggregation (torch.Tensor) – The aggregated features to be loaded.

local_test(*, _ray_trace_ctx=None) list[source]#

Evaluates the model on the local test dataset.

Returns:

(list) – A list containing the test loss and accuracy [local_test_loss, local_test_acc].

Return type:

list

relabel_adj(*, _ray_trace_ctx=None) None[source]#

Relabels the adjacency matrix based on the communication node index.

train(current_global_round: int, *, _ray_trace_ctx=None) None[source]#

Performs local training for a specified number of iterations. This method updates the model using the loaded feature aggregation and the adjacency matrix.

Parameters:

current_global_round (int) – The current global training round.

update_params(params: tuple, current_global_epoch: int, *, _ray_trace_ctx=None) None[source]#

Updates the model parameters with global parameters received from the server.

Parameters:
  • params (tuple) – A tuple containing the global parameters from the server.

  • current_global_epoch (int) – The current global epoch number.