Train a Graph Neural Network

Date published: 26/09/23

class bin.train_gnn.SAGENet(in_channels: int, hidden_channels: int, out_channels: int, n_layers: int = 5, normalize: bool = False, bias: bool = True, aggr: str = 'mean', dropout_p: float = 0.25)[source]
predict(x: Tensor, edge_index: Tensor, edge_label_index: Tensor) Tensor[source]

The forward pass.

Args:
x (torch.Tensor):

Input data.

edge_index (torch.Tensor):

The graph edge index.

edge_label_index (torch.Tensor):

The graph edge label indices.

Returns:
torch.Tensor:

The logits for edge membership.

bin.train_gnn.get_graph(db_host: str, db_name: str, db_username: str, db_password: str, collection: str, feature_k: str = 'expression') Graph[source]

Retrieve the graph from the database.

Args:
db_host (str):

The database host.

db_name (str):

The database name.

db_username (str):

The database username.

db_password (str):

The database password.

collection (str):

The database collection.

feature_k (str):

The dictionary key for node features.

Returns:
nx.Graph:

The retrieved graph.

bin.train_gnn.get_model_components(lr: float, in_channels: int, hidden_channels: int, out_channels: int, device: device, n_layers: int, normalize: bool, bias: bool, aggr: str, dropout_p: float) tuple[source]

Get the components for training the model.

Args:
lr (float):

The learning rate.

in_channels (int):

The number of input channels.

hidden_channels (int):

The number of hidden channels.

out_channels (int):

The number of output channels.

device (torch.device):

The training device.

n_layers (int):

The number of SAGE convolutional layers.

normalize (bool):

Whether to normalize the input tensors.

bias (bool):

Whether to include the bias term.

aggr (str):

The data aggregation method.

dropout_p (float):

The dropout probability.

Returns:
tuple:

The components for training the model.

bin.train_gnn.get_split(G: Graph, num_val: float, num_test: float, device: device) tuple[Graph, Graph, Graph][source]

Get train-validation-test split.

Args:
G (nx.Graph):

The graph.

num_val (float):

The proportion of validation data.

num_test (float):

The proportion of testing data.

device (torch.device):

The training device.

Returns:
tuple[nx.Graph, nx.Graph, nx.Graph]:

The train-validation-test split.

bin.train_gnn.log_results(tracking_uri: str, experiment_prefix: str, grn_name: str, in_channels: int, config: DictConfig) None[source]

Log experiment results to the experiment tracker.

Args:
tracking_uri (str):

The tracking URI.

experiment_prefix (str):

The experiment name prefix.

grn_name (str):

The name of the GRN.

in_channels (int):

The number of input channels.

config (DictConfig):

The pipeline configuration.

bin.train_gnn.main(config: DictConfig) None[source]

The main entry point for the plotting pipeline.

Args:
config (DictConfig):

The pipeline configuration.

bin.train_gnn.train_model(model: Module, train_data: Graph, val_data: Graph, test_data: Graph, n_epochs: int, optimizer: Module, criterion: Module, device: device, enable_tracking: bool) float[source]

Train the graph neural network.

Args:
model (torch.nn.Module):

The graph neural network.

train_data (nx.Graph):

The training data.

val_data (nx.Graph):

The validation data.

test_data (nx.Graph):

The testing data.

n_epochs (int):

The number of epochs.

optimizer (torch.nn.Module):

The model optimiser.

criterion (torch.nn.Module):

The loss criterion.

device (torch.device):

The training device.

enable_tracking (bool):

Whether to enable experiment tracking.

Returns:
float:

The final area-under-curve score.