Trainer Class#
- class fedgraph.trainer_class.Trainer_General(rank: int, local_node_index: Tensor, communicate_node_index: Tensor, adj: Tensor, train_labels: Tensor, test_labels: Tensor, features: Tensor, idx_train: Tensor, idx_test: Tensor, args_hidden: int, global_node_num: int, class_num: int, device: device, args: Any)[source]#
A general trainer class for training GCN in a federated learning setup, which includes functionalities required for training GCN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation.
- Parameters:
rank (int) – Unique identifier for the training instance (typically representing a client in federated learning).
local_node_index (torch.Tensor) – Indices of nodes local to this trainer.
communicate_node_index (torch.Tensor) – Indices of nodes that participate in communication during training.
adj (torch.Tensor) – The adjacency matrix representing the graph structure.
train_labels (torch.Tensor) – Labels of the training data.
test_labels (torch.Tensor) – Labels of the testing data.
features (torch.Tensor) – Node features for the entire graph.
idx_train (torch.Tensor) – Indices of training nodes.
idx_test (torch.Tensor) – Indices of test nodes.
args_hidden (int) – Number of hidden units in the GCN model.
global_node_num (int) – Total number of nodes in the global graph.
class_num (int) – Number of classes for classification.
device (torch.device) – The device (CPU or GPU) on which the model will be trained.
args (Any) – Additional arguments required for model initialization and training.
- get_all_loss_accuray(*, _ray_trace_ctx=None) list [source]#
Returns all recorded training and testing losses and accuracies.
- Returns:
(list) – A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies.
- Return type:
- get_local_feature_sum(*, _ray_trace_ctx=None) Tensor [source]#
Computes the sum of features of all 1-hop neighbors for each node.
- Returns:
one_hop_neighbor_feature_sum – The sum of features of 1-hop neighbors for each node
- Return type:
- get_params(*, _ray_trace_ctx=None) tuple [source]#
Retrieves the current parameters of the model.
- Returns:
(tuple) – A tuple containing the current parameters of the model.
- Return type:
- get_rank(*, _ray_trace_ctx=None) int [source]#
Returns the rank (client ID) of the trainer.
- Returns:
(int) – The rank (client ID) of this trainer instance.
- Return type:
- load_feature_aggregation(feature_aggregation: Tensor, *, _ray_trace_ctx=None) None [source]#
Loads the aggregated features into the trainer.
- Parameters:
feature_aggregation (torch.Tensor) – The aggregated features to be loaded.
- local_test(*, _ray_trace_ctx=None) list [source]#
Evaluates the model on the local test dataset.
- Returns:
(list) – A list containing the test loss and accuracy [local_test_loss, local_test_acc].
- Return type:
- relabel_adj(*, _ray_trace_ctx=None) None [source]#
Relabels the adjacency matrix based on the communication node index.
- train(current_global_round: int, *, _ray_trace_ctx=None) None [source]#
Performs local training for a specified number of iterations. This method updates the model using the loaded feature aggregation and the adjacency matrix.
- Parameters:
current_global_round (int) – The current global training round.