Training Function

fedgraph.train_func.accuracy(output: Tensor, labels: Tensor) Tensor[source]

This function returns the accuracy of the output with respect to the ground truth given

Parameters:
  • output (torch.Tensor) – the output labels predicted by the model

  • labels (torch.Tensor) – ground truth labels

Returns:

(tensor) – Accuracy of the output with respect to the ground truth given

Return type:

torch.Tensor

fedgraph.train_func.gc_avg_accuracy(frame: DataFrame, trainers: list) float[source]

This function calculates the weighted average accuracy of the trainers in the frame.

Parameters:
  • frame (pd.DataFrame) – The frame containing the accuracies of the trainers

  • trainers (list) – List of trainer objects

Returns:

(float) – The average accuracy of the trainers in the frame

Return type:

float

fedgraph.train_func.test(model: Module, features: Tensor, adj: Tensor, test_labels: Tensor, idx_test: Tensor) tuple[source]

This function tests the model and calculates the loss and accuracy

Parameters:
  • model (torch.nn.Module) – Specific model passed

  • features (torch.Tensor) – Tensor representing the input features

  • adj (torch.Tensor) – Adjacency matrix

  • labels (torch.Tensor) – Contains the ground truth labels for the data.

  • idx_test (torch.Tensor) – Indices specifying the test data points

Returns:

  • loss_test.item() (float) – Loss of the model on the test data

  • acc_test.item() (float) – Accuracy of the model on the test data

fedgraph.train_func.train(epoch: int, model: Module, optimizer: Optimizer, features: Tensor, adj: Tensor, train_labels: Tensor, idx_train: Tensor) tuple[source]

Trains the model and calculates the loss and accuracy of the model on the training data, performs backpropagation, and updates the model parameters.

Parameters:
  • epoch (int) – Specifies the number of epoch on which the model is trained

  • model (torch.nn.Module) – Specific model to be trained

  • optimizer (optimizer) – Type of the optimizer used for updating the model parameters

  • features (torch.FloatTensor) – Tensor representing the input features

  • adj (torch_sparse.tensor.SparseTensor) – Adjacency matrix

  • train_labels (torch.LongTensor) – Contains the ground truth labels for the data.

  • idx_train (torch.LongTensor) – Indices specifying the test data points

Returns:

  • loss_train.item() (float) – Loss of the model on the training data

  • acc_train.item() (float) – Accuracy of the model on the training data