Source code for fedgraph.trainer_class

import argparse
import random
import time
from typing import Any, Union

import numpy as np
import ray
import torch
import torch.nn.functional as F
import torch_geometric
from torchmetrics.functional.retrieval import retrieval_auroc
from torchmetrics.retrieval import RetrievalHitRate

from fedgraph.gnn_models import GCN, GIN, GNN_LP, AggreGCN, GCN_arxiv, SAGE_products
from fedgraph.train_func import test, train
from fedgraph.utils_lp import (
    get_data,
    get_data_loaders_per_time_step,
    get_global_user_item_mapping,
)
from fedgraph.utils_nc import get_1hop_feature_sum


[docs] class Trainer_General: """ A general trainer class for training GCN in a federated learning setup, which includes functionalities required for training GCN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation. Parameters ---------- rank : int Unique identifier for the training instance (typically representing a trainer in federated learning). local_node_index : torch.Tensor Indices of nodes local to this trainer. communicate_node_index : torch.Tensor Indices of nodes that participate in communication during training. adj : torch.Tensor The adjacency matrix representing the graph structure. train_labels : torch.Tensor Labels of the training data. test_labels : torch.Tensor Labels of the testing data. features : torch.Tensor Node features for the entire graph. idx_train : torch.Tensor Indices of training nodes. idx_test : torch.Tensor Indices of test nodes. args_hidden : int Number of hidden units in the GCN model. global_node_num : int Total number of nodes in the global graph. class_num : int Number of classes for classification. device : torch.device The device (CPU or GPU) on which the model will be trained. args : Any Additional arguments required for model initialization and training. """ def __init__( self, rank: int, local_node_index: torch.Tensor, communicate_node_index: torch.Tensor, adj: torch.Tensor, train_labels: torch.Tensor, test_labels: torch.Tensor, features: torch.Tensor, idx_train: torch.Tensor, idx_test: torch.Tensor, args_hidden: int, global_node_num: int, class_num: int, device: torch.device, args: Any, ): # from gnn_models import GCN_Graph_Classification torch.manual_seed(rank) # seems that new trainer process will not inherit sys.path from parent, need to reimport! if args.num_hops >= 1 and args.fedtype == "fedgcn": self.model = AggreGCN( nfeat=features.shape[1], nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ).to(device) else: if args.dataset == "ogbn-arxiv": self.model = GCN_arxiv( nfeat=features.shape[1], nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ).to(device) elif args.dataset == "ogbn-products": self.model = SAGE_products( nfeat=features.shape[1], nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ).to(device) else: self.model = GCN( nfeat=features.shape[1], nhid=args_hidden, nclass=class_num, dropout=0.5, NumLayers=args.num_layers, ).to(device) self.rank = rank # rank = trainer ID self.device = device self.optimizer = torch.optim.SGD( self.model.parameters(), lr=args.learning_rate, weight_decay=5e-4 ) self.criterion = torch.nn.CrossEntropyLoss() self.train_losses: list = [] self.train_accs: list = [] self.test_losses: list = [] self.test_accs: list = [] self.local_node_index = local_node_index.to(device) self.communicate_node_index = communicate_node_index.to(device) self.adj = adj.to(device) self.train_labels = train_labels.to(device) self.test_labels = test_labels.to(device) self.features = features.to(device) self.idx_train = idx_train.to(device) self.idx_test = idx_test.to(device) self.local_step = args.local_step self.global_node_num = global_node_num self.class_num = class_num
[docs] @torch.no_grad() def update_params(self, params: tuple, current_global_epoch: int) -> None: """ Updates the model parameters with global parameters received from the server. Parameters ---------- params : tuple A tuple containing the global parameters from the server. current_global_epoch : int The current global epoch number. """ # load global parameter from global server self.model.to("cpu") for ( p, mp, ) in zip(params, self.model.parameters()): mp.data = p self.model.to(self.device)
[docs] def get_local_feature_sum(self) -> torch.Tensor: """ Computes the sum of features of all 1-hop neighbors for each node. Returns ------- one_hop_neighbor_feature_sum : torch.Tensor The sum of features of 1-hop neighbors for each node """ # create a large matrix with known local node features new_feature_for_trainer = torch.zeros( self.global_node_num, self.features.shape[1] ) new_feature_for_trainer[self.local_node_index] = self.features # sum of features of all 1-hop nodes for each node one_hop_neighbor_feature_sum = get_1hop_feature_sum( new_feature_for_trainer, self.adj ) return one_hop_neighbor_feature_sum
[docs] def load_feature_aggregation(self, feature_aggregation: torch.Tensor) -> None: """ Loads the aggregated features into the trainer. Parameters ---------- feature_aggregation : torch.Tensor The aggregated features to be loaded. """ self.feature_aggregation = feature_aggregation
[docs] def relabel_adj(self) -> None: """ Relabels the adjacency matrix based on the communication node index. """ _, self.adj, __, ___ = torch_geometric.utils.k_hop_subgraph( self.communicate_node_index, 0, self.adj, relabel_nodes=True )
[docs] def train(self, current_global_round: int) -> None: """ Performs local training for a specified number of iterations. This method updates the model using the loaded feature aggregation and the adjacency matrix. Parameters ---------- current_global_round : int The current global training round. """ # clean cache torch.cuda.empty_cache() for iteration in range(self.local_step): self.model.train() loss_train, acc_train = train( iteration, self.model, self.optimizer, self.feature_aggregation, self.adj, self.train_labels, self.idx_train, ) self.train_losses.append(loss_train) self.train_accs.append(acc_train) loss_test, acc_test = self.local_test() self.test_losses.append(loss_test) self.test_accs.append(acc_test)
[docs] def local_test(self) -> list: """ Evaluates the model on the local test dataset. Returns ------- (list) : list A list containing the test loss and accuracy [local_test_loss, local_test_acc]. """ local_test_loss, local_test_acc = test( self.model, self.feature_aggregation, self.adj, self.test_labels, self.idx_test, ) return [local_test_loss, local_test_acc]
[docs] def get_params(self) -> tuple: """ Retrieves the current parameters of the model. Returns ------- (tuple) : tuple A tuple containing the current parameters of the model. """ self.optimizer.zero_grad(set_to_none=True) return tuple(self.model.parameters())
[docs] def get_all_loss_accuray(self) -> list: """ Returns all recorded training and testing losses and accuracies. Returns ------- (list) : list A list containing arrays of training losses, training accuracies, testing losses, and testing accuracies. """ return [ np.array(self.train_losses), np.array(self.train_accs), np.array(self.test_losses), np.array(self.test_accs), ]
[docs] def get_rank(self) -> int: """ Returns the rank (trainer ID) of the trainer. Returns ------- (int) : int The rank (trainer ID) of this trainer instance. """ return self.rank
[docs] class Trainer_GC: """ A trainer class specified for graph classification tasks, which includes functionalities required for training GIN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation. Parameters ---------- model: object The model to be trained, which is based on the GIN model. trainer_id: int The ID of the trainer. trainer_name: str The name of the trainer. train_size: int The size of the training dataset. dataLoader: dict The dataloaders for training, validation, and testing. optimizer: object The optimizer for training. args: Any The arguments for the training. Attributes ---------- model: object The model to be trained, which is based on the GIN model. id: int The ID of the trainer. name: str The name of the trainer. train_size: int The size of the training dataset. dataloader: dict The dataloaders for training, validation, and testing. optimizer: object The optimizer for training. args: object The arguments for the training. W: dict The weights of the model. dW: dict The gradients of the model. W_old: dict The cached weights of the model. gconv_names: list The names of the gconv layers. train_stats: Any The training statistics of the model. weights_norm: float The norm of the weights of the model. grads_norm: float The norm of the gradients of the model. conv_grads_norm: float The norm of the gradients of the gconv layers. conv_weights_Norm: float The norm of the weights of the gconv layers. conv_dWs_norm: float The norm of the gradients of the gconv layers. """ def __init__( self, model: Any, trainer_id: int, trainer_name: str, train_size: int, dataloader: dict, optimizer: object, args: Any, ) -> None: self.model = model.to(args.device) self.id = trainer_id self.name = trainer_name self.train_size = train_size self.dataloader = dataloader self.optimizer = optimizer self.args = args self.W = {key: value for key, value in self.model.named_parameters()} self.dW = { key: torch.zeros_like(value) for key, value in self.model.named_parameters() } self.W_old = { key: value.data.clone() for key, value in self.model.named_parameters() } self.gconv_names: Any = None self.train_stats = ([0], [0], [0], [0]) self.weights_norm = 0.0 self.grads_norm = 0.0 self.conv_grads_norm = 0.0 self.conv_weights_norm = 0.0 self.conv_dWs_norm = 0.0 ########### Public functions ###########
[docs] def update_params(self, server: Any) -> None: """ Update the model parameters by downloading the global model weights from the server. Parameters ---------- server: Server_GC The server object that contains the global model weights. """ self.gconv_names = server.W.keys() # gconv layers for k in server.W: self.W[k].data = server.W[k].data.clone()
[docs] def reset_params(self) -> None: """ Reset the weights of the model to the cached weights. The implementation is copying the cached weights (W_old) to the model weights (W). """ self.__copy_weights(target=self.W, source=self.W_old, keys=self.gconv_names)
[docs] def cache_weights(self) -> None: """ Cache the weights of the model. The implementation is copying the model weights (W) to the cached weights (W_old). """ for name in self.W.keys(): self.W_old[name].data = self.W[name].data.clone()
[docs] def set_stats_norms(self, train_stats: Any, is_gcfl: bool = False) -> None: """ Set the norms of the weights and gradients of the model, as well as the statistics of the training. Parameters ---------- train_stats: dict The training statistics of the model. is_gcfl: bool, optional Whether the training is for GCFL. The default is False. """ self.train_stats = train_stats self.weights_norm = torch.norm(self.__flatten(self.W)).item() if self.gconv_names is not None: weights_conv = {key: self.W[key] for key in self.gconv_names} self.conv_weights_norm = torch.norm(self.__flatten(weights_conv)).item() grads_conv = {key: self.W[key].grad for key in self.gconv_names} self.conv_grads_norm = torch.norm(self.__flatten(grads_conv)).item() grads = {key: value.grad for key, value in self.W.items()} self.grads_norm = torch.norm(self.__flatten(grads)).item() if is_gcfl and self.gconv_names is not None: dWs_conv = {key: self.dW[key] for key in self.gconv_names} self.conv_dWs_norm = torch.norm(self.__flatten(dWs_conv)).item()
[docs] def local_train( self, local_epoch: int, train_option: str = "basic", mu: float = 1 ) -> None: """ This function is a interface of the trainer class to train the model locally. It will call the train function specified for the training option, based on the args provided. Parameters ---------- local_epoch: int The number of local epochs train_option: str, optional The training option. The possible values are 'basic', 'prox', and 'gcfl'. The default is 'basic'. 'basic' - self-train and FedAvg 'prox' - FedProx that includes the proximal term 'gcfl' - GCFL, GCFL+ and GCFL+dWs mu: float, optional The proximal term. The default is 1. """ assert train_option in ["basic", "prox", "gcfl"], "Invalid training option." if train_option == "gcfl": self.__copy_weights(target=self.W_old, source=self.W, keys=self.gconv_names) if train_option in ["basic", "prox"]: train_stats = self.__train( model=self.model, dataloaders=self.dataloader, optimizer=self.optimizer, local_epoch=local_epoch, device=self.args.device, ) elif train_option == "gcfl": train_stats = self.__train( model=self.model, dataloaders=self.dataloader, optimizer=self.optimizer, local_epoch=local_epoch, device=self.args.device, prox=True, gconv_names=self.gconv_names, Ws=self.W, Wt=self.W_old, mu=mu, ) if train_option == "gcfl": self.__subtract_weights( target=self.dW, minuend=self.W, subtrahend=self.W_old ) self.set_stats_norms(train_stats)
[docs] def local_test(self, test_option: str = "basic", mu: float = 1) -> tuple: """ Final test of the model on the test dataset based on the test option. Parameters ---------- test_option: str, optional The test option. The possible values are 'basic' and 'prox'. The default is 'basic'. 'basic' - self-train, FedAvg, GCFL, GCFL+ and GCFL+dWs 'prox' - FedProx that includes the proximal term mu: float, optional The proximal term. The default is 1. Returns ------- (test_loss, test_acc): tuple(float, float) The average loss and accuracy """ assert test_option in ["basic", "prox"], "Invalid test option." if test_option == "basic": return self.__eval( model=self.model, test_loader=self.dataloader["test"], device=self.args.device, ) elif test_option == "prox": return self.__eval( model=self.model, test_loader=self.dataloader["test"], device=self.args.device, prox=True, gconv_names=self.gconv_names, mu=mu, Wt=self.W_old, ) else: raise ValueError("Invalid test option.")
########### Private functions ########### def __train( self, model: Any, dataloaders: dict, optimizer: Any, local_epoch: int, device: str, prox: bool = False, gconv_names: Any = None, Ws: Any = None, Wt: Any = None, mu: float = 0, ) -> dict: """ Train the model on the local dataset. Parameters ---------- model: object The model to be trained dataloaders: dict The dataloaders for training, validation, and testing optimizer: Any The optimizer for training local_epoch: int The number of local epochs device: str The device to run the training prox: bool, optional Whether to add the proximal term. The default is False. gconv_names: Any, optional The names of the gconv layers. The default is None. Ws: Any, optional The weights of the model. The default is None. Wt: Any, optional The target weights. The default is None. mu: float, optional The proximal term. The default is 0. Returns ------- (results): dict The training statistics Note ---- If prox is True, the function will add the proximal term to the loss function. Make sure to provide the required arguments `gconv_names`, `Ws`, `Wt`, and `mu` for the proximal term. """ if prox: assert ( (gconv_names is not None) and (Ws is not None) and (Wt is not None) and (mu != 0) ), "Please provide the required arguments for the proximal term." losses_train, accs_train, losses_val, accs_val, losses_test, accs_test = ( [], [], [], [], [], [], ) if prox: convGradsNorm = [] train_loader, val_loader, test_loader = ( dataloaders["train"], dataloaders["val"], dataloaders["test"], ) for _ in range(local_epoch): model.train() loss_train, acc_train, num_graphs = 0.0, 0.0, 0 for _, batch in enumerate(train_loader): batch.to(device) optimizer.zero_grad() pred = model(batch) label = batch.y loss = model.loss(pred, label) loss += ( mu / 2.0 * self.__prox_term(model, gconv_names, Wt) if prox else 0.0 ) # add the proximal term if required loss.backward() optimizer.step() loss_train += loss.item() * batch.num_graphs acc_train += pred.max(dim=1)[1].eq(label).sum().item() num_graphs += batch.num_graphs loss_train /= num_graphs # get the average loss per graph acc_train /= num_graphs # get the average average per graph loss_val, acc_val = self.__eval(model, val_loader, device) loss_test, acc_test = self.__eval(model, test_loader, device) losses_train.append(loss_train) accs_train.append(acc_train) losses_val.append(loss_val) accs_val.append(acc_val) losses_test.append(loss_test) accs_test.append(acc_test) if prox: convGradsNorm.append(self.__calc_grads_norm(gconv_names, Ws)) # record the losses and accuracies for each epoch res_dict = { "trainingLosses": losses_train, "trainingAccs": accs_train, "valLosses": losses_val, "valAccs": accs_val, "testLosses": losses_test, "testAccs": accs_test, } if prox: res_dict["convGradsNorm"] = convGradsNorm return res_dict def __eval( self, model: GIN, test_loader: Any, device: str, prox: bool = False, gconv_names: Any = None, mu: float = 0, Wt: Any = None, ) -> tuple: """ Validate and test the model on the local dataset. Parameters ---------- model: GIN The model to be tested test_loader: Any The dataloader for testing device: str The device to run the testing prox: bool, optional Whether to add the proximal term. The default is False. gconv_names: Any, optional The names of the gconv layers. The default is None. mu: float, optional The proximal term. The default is None. Wt: Any, optional The target weights. The default is None. Returns ------- (test_loss, test_acc): tuple(float, float) The average loss and accuracy Note ---- If prox is True, the function will add the proximal term to the loss function. Make sure to provide the required arguments `gconv_names`, `Ws`, `Wt`, and `mu` for the proximal term. """ if prox: assert ( (gconv_names is not None) and (mu is not None) and (Wt != 0) ), "Please provide the required arguments for the proximal term." model.eval() total_loss, total_acc, num_graphs = 0.0, 0.0, 0 for batch in test_loader: batch.to(device) with torch.no_grad(): pred = model(batch) label = batch.y loss = model.loss(pred, label) loss += ( mu / 2.0 * self.__prox_term(model, gconv_names, Wt) if prox else 0.0 ) total_loss += loss.item() * batch.num_graphs total_acc += pred.max(dim=1)[1].eq(label).sum().item() num_graphs += batch.num_graphs return total_loss / num_graphs, total_acc / num_graphs def __prox_term(self, model: Any, gconv_names: Any, Wt: Any) -> torch.tensor: """ Compute the proximal term. Parameters ---------- model: Any The model to be trained gconv_names: Any The names of the gconv layers Wt: Any The target weights Returns ------- prox: torch.tensor The proximal term """ prox = torch.tensor(0.0, requires_grad=True) for name, param in model.named_parameters(): if name in gconv_names: # only add the prox term for sharing layers (gConv) prox = prox + torch.norm(param - Wt[name]).pow( 2 ) # force the weights to be close to the old weights return prox def __calc_grads_norm(self, gconv_names: Any, Ws: Any) -> float: """ Calculate the norm of the gradients of the gconv layers. Parameters ---------- model: Any The model to be trained gconv_names: Any The names of the gconv layers Wt: Any The target weights Returns ------- convGradsNorm: float The norm of the gradients of the gconv layers """ grads_conv = {k: Ws[k].grad for k in gconv_names} convGradsNorm = torch.norm(self.__flatten(grads_conv)).item() return convGradsNorm def __copy_weights( self, target: dict, source: dict, keys: Union[list, None] ) -> None: """ Copy the source weights to the target weights. Parameters ---------- target: dict The target weights source: dict The source weights keys: list, optional The keys to be copied. The default is None. """ if keys is not None: for name in keys: target[name].data = source[name].data.clone() def __subtract_weights(self, target: dict, minuend: dict, subtrahend: dict) -> None: """ Subtract the subtrahend from the minuend and store the result in the target. Parameters ---------- target: dict The target weights minuend: dict The minuend subtrahend: dict The subtrahend """ for name in target: target[name].data = ( minuend[name].data.clone() - subtrahend[name].data.clone() ) def __flatten(self, w: dict) -> torch.tensor: """ Flatten the gradients of a trainer into a 1D tensor. Parameters ---------- w: dict The gradients of a trainer """ return torch.cat([v.flatten() for v in w.values()])
[docs] class Trainer_LP: """ A trainer class specified for graph link prediction tasks, which includes functionalities required for training GNN models on a subset of a distributed dataset, handling local training and testing, parameter updates, and feature aggregation. Parameters ---------- client_id : int The ID of the client. country_code : str The country code of the client. Each client is associated with one country code. user_id_mapping : dict The mapping of user IDs. item_id_mapping : dict The mapping of item IDs. number_of_users : int The number of users. number_of_items : int The number of items. meta_data : tuple The metadata of the dataset. hidden_channels : int, optional The number of hidden channels in the GNN model. The default is 64. """ def __init__( self, client_id: int, country_code: str, user_id_mapping: dict, item_id_mapping: dict, number_of_users: int, number_of_items: int, meta_data: tuple, hidden_channels: int = 64, ): self.client_id = client_id self.country_code = country_code # global user_id and item_id self.data = get_data(self.country_code, user_id_mapping, item_id_mapping) self.model = GNN_LP( number_of_users, number_of_items, meta_data, hidden_channels ) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: '{self.device}'") self.model = self.model.to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
[docs] def get_train_test_data_at_current_time_step( self, start_time_float_format: float, end_time_float_format: float, use_buffer: bool = False, buffer_size: int = 10, ) -> None: """ Get the training and testing data at the current time step. Parameters ---------- start_time_float_format : float The start time in float format. end_time_float_format : float The end time in float format. use_buffer : bool, optional Whether to use the buffer. The default is False. buffer_size : int, optional The size of the buffer. The default is 10. """ print("loading buffer_train_data_list") if use_buffer else print( "loading train_data and test_data" ) load_res = get_data_loaders_per_time_step( self.data, start_time_float_format, end_time_float_format, use_buffer, buffer_size, ) if use_buffer: ( self.global_train_data, self.test_data, self.buffer_train_data_list, ) = load_res else: self.train_data, self.test_data = load_res
[docs] def train(self, local_updates: int, use_buffer: bool = False) -> tuple: """ Perform local training for a specified number of iterations. Parameters ---------- local_updates : int The number of local updates. use_buffer : bool, optional Whether to use the buffer. The default is False. Returns ------- (loss, train_finish_times) : tuple [0] The loss of the model [1] The time taken for each local update """ train_finish_times = [] if use_buffer: probabilities = [1 / len(self.buffer_train_data_list)] * len( self.buffer_train_data_list ) for i in range(local_updates): if use_buffer: train_data = random.choices( self.buffer_train_data_list, weights=probabilities, k=1 )[0].to(self.device) else: train_data = self.train_data.to(self.device) start_train_time = time.time() self.optimizer.zero_grad() pred = self.model(train_data) ground_truth = train_data["user", "select", "item"].edge_label loss = F.binary_cross_entropy_with_logits(pred, ground_truth) loss.backward() self.optimizer.step() train_finish_time = time.time() - start_train_time train_finish_times.append(train_finish_time) print( f"client {self.client_id} local update {i} loss {loss:.4f} train time {train_finish_time:.4f}" ) return loss, train_finish_times
[docs] def test(self, use_buffer: bool = False) -> tuple: """ Test the model on the test data. Parameters ---------- use_buffer : bool, optional Whether to use the buffer. The default is False. Returns ------- (auc, hit_rate_at_2, traveled_user_hit_rate_at_2) : tuple [0] The AUC score [1] The hit rate at 2 [2] The hit rate at 2 for traveled users """ preds, ground_truths = [], [] self.test_data.to(self.device) with torch.no_grad(): if not use_buffer: self.train_data.to(self.device) preds.append(self.model.pred(self.train_data, self.test_data)) else: self.global_train_data.to(self.device) preds.append(self.model.pred(self.global_train_data, self.test_data)) ground_truths.append(self.test_data["user", "select", "item"].edge_label) pred = torch.cat(preds, dim=0) ground_truth = torch.cat(ground_truths, dim=0) auc = retrieval_auroc(pred, ground_truth) hit_rate_evaluator = RetrievalHitRate(top_k=2) hit_rate_at_2 = hit_rate_evaluator( pred, ground_truth, indexes=self.test_data["user", "select", "item"].edge_label_index[0], ) traveled_user_hit_rate_at_2 = hit_rate_evaluator( pred[self.traveled_user_edge_indices], ground_truth[self.traveled_user_edge_indices], indexes=self.test_data["user", "select", "item"].edge_label_index[0][ self.traveled_user_edge_indices ], ) print(f"Test AUC: {auc:.4f}") print(f"Test Hit Rate at 2: {hit_rate_at_2:.4f}") print(f"Test Traveled User Hit Rate at 2: {traveled_user_hit_rate_at_2:.4f}") return auc, hit_rate_at_2, traveled_user_hit_rate_at_2
[docs] def calculate_traveled_user_edge_indices(self, file_path: str) -> None: """ Calculate the indices of the edges of the traveled users. Parameters ---------- file_path : str The path to the file containing the traveled users. """ with open(file_path, "r") as a: traveled_users = torch.tensor( [int(line.split("\t")[0]) for line in a] ) # read the user IDs of the traveled users mask = torch.isin( self.test_data["user", "select", "item"].edge_label_index[0], traveled_users ) # mark the indices of the edges of the traveled users as True or False self.traveled_user_edge_indices = torch.where(mask)[ 0 ] # get the indices of the edges of the traveled users
[docs] def set_model_parameter( self, model_state_dict: dict, gnn_only: bool = False ) -> None: """ Load the model parameters from the global server. Parameters ---------- model_state_dict : dict The model parameters to be loaded. gnn_only : bool, optional Whether to load only the GNN parameters. The default is False. """ if gnn_only: self.model.gnn.load_state_dict(model_state_dict) else: self.model.load_state_dict(model_state_dict)
[docs] def get_model_parameter(self, gnn_only: bool = False) -> dict: """ Get the model parameters. Parameters ---------- gnn_only : bool, optional Whether to get only the GNN parameters. The default is False. Returns ------- dict The model parameters. """ if gnn_only: return self.model.gnn.state_dict() else: return self.model.state_dict()