Train a Graph Neural Network¶
Date published: 26/09/23
- class bin.train_gnn.SAGENet(in_channels: int, hidden_channels: int, out_channels: int, n_layers: int = 5, normalize: bool = False, bias: bool = True, aggr: str = 'mean', dropout_p: float = 0.25)[source]¶
- bin.train_gnn.get_graph(db_host: str, db_name: str, db_username: str, db_password: str, collection: str, feature_k: str = 'expression') Graph [source]¶
Retrieve the graph from the database.
- Args:
- db_host (str):
The database host.
- db_name (str):
The database name.
- db_username (str):
The database username.
- db_password (str):
The database password.
- collection (str):
The database collection.
- feature_k (str):
The dictionary key for node features.
- Returns:
- nx.Graph:
The retrieved graph.
- bin.train_gnn.get_model_components(lr: float, in_channels: int, hidden_channels: int, out_channels: int, device: device, n_layers: int, normalize: bool, bias: bool, aggr: str, dropout_p: float) tuple [source]¶
Get the components for training the model.
- Args:
- lr (float):
The learning rate.
- in_channels (int):
The number of input channels.
- hidden_channels (int):
The number of hidden channels.
- out_channels (int):
The number of output channels.
- device (torch.device):
The training device.
- n_layers (int):
The number of SAGE convolutional layers.
- normalize (bool):
Whether to normalize the input tensors.
- bias (bool):
Whether to include the bias term.
- aggr (str):
The data aggregation method.
- dropout_p (float):
The dropout probability.
- Returns:
- tuple:
The components for training the model.
- bin.train_gnn.get_split(G: Graph, num_val: float, num_test: float, device: device) tuple[Graph, Graph, Graph] [source]¶
Get train-validation-test split.
- Args:
- G (nx.Graph):
The graph.
- num_val (float):
The proportion of validation data.
- num_test (float):
The proportion of testing data.
- device (torch.device):
The training device.
- Returns:
- tuple[nx.Graph, nx.Graph, nx.Graph]:
The train-validation-test split.
- bin.train_gnn.log_results(tracking_uri: str, experiment_prefix: str, grn_name: str, in_channels: int, config: DictConfig) None [source]¶
Log experiment results to the experiment tracker.
- Args:
- tracking_uri (str):
The tracking URI.
- experiment_prefix (str):
The experiment name prefix.
- grn_name (str):
The name of the GRN.
- in_channels (int):
The number of input channels.
- config (DictConfig):
The pipeline configuration.
- bin.train_gnn.main(config: DictConfig) None [source]¶
The main entry point for the plotting pipeline.
- Args:
- config (DictConfig):
The pipeline configuration.
- bin.train_gnn.train_model(model: Module, train_data: Graph, val_data: Graph, test_data: Graph, n_epochs: int, optimizer: Module, criterion: Module, device: device, enable_tracking: bool) float [source]¶
Train the graph neural network.
- Args:
- model (torch.nn.Module):
The graph neural network.
- train_data (nx.Graph):
The training data.
- val_data (nx.Graph):
The validation data.
- test_data (nx.Graph):
The testing data.
- n_epochs (int):
The number of epochs.
- optimizer (torch.nn.Module):
The model optimiser.
- criterion (torch.nn.Module):
The loss criterion.
- device (torch.device):
The training device.
- enable_tracking (bool):
Whether to enable experiment tracking.
- Returns:
- float:
The final area-under-curve score.