Training Function#

fedgraph.train_func.accuracy(output: Tensor, labels: Tensor) Tensor[source]#

This function returns the accuracy of the output with respect to the ground truth given

Parameters:
Returns:

(tensor) – Accuracy of the output with respect to the ground truth given

Return type:

torch.Tensor

fedgraph.train_func.test(model: Module, features: Tensor, adj: Tensor, test_labels: Tensor, idx_test: Tensor) tuple[source]#

This function tests the model and calculates the loss and accuracy

Parameters:
Returns:

  • loss_test.item() (float) – Loss of the model on the test data

  • acc_test.item() (float) – Accuracy of the model on the test data

fedgraph.train_func.train(epoch: int, model: Module, optimizer: Optimizer, features: Tensor, adj: Tensor, train_labels: Tensor, idx_train: Tensor) tuple[source]#

Trains the model and calculates the loss and accuracy of the model on the training data, performs backpropagation, and updates the model parameters.

Parameters:
  • epoch (int) – Specifies the number of epoch on which the model is trained

  • model (torch.nn.Module) – Specific model to be trained

  • optimizer (optimizer) – Type of the optimizer used for updating the model parameters

  • features (torch.FloatTensor) – Tensor representing the input features

  • adj (torch_sparse.tensor.SparseTensor) – Adjacency matrix

  • train_labels (torch.LongTensor) – Contains the ground truth labels for the data.

  • idx_train (torch.LongTensor) – Indices specifying the test data points

Returns:

  • loss_train.item() (float) – Loss of the model on the training data

  • acc_train.item() (float) – Accuracy of the model on the training data