Source code for bin.train_gnn

#!/usr/bin/env python

######################################
# Imports
######################################

from adbnx_adapter import ADBNX_Adapter
from arango import ArangoClient
import hydra
import mlflow
import networkx as nx
from omegaconf import DictConfig
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import from_networkx, negative_sampling
from sklearn.metrics import roc_auc_score
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv

######################################
# Classes
######################################


[docs] class SAGENet(torch.nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, n_layers: int = 5, normalize: bool = False, bias: bool = True, aggr: str = "mean", dropout_p: float = 0.25, ) -> None: """ SAGENet constructor. Args: in_channels (int): The number of input channels. hidden_channels (int): The number of hidden channels. out_channels (int): The number of output channels. n_layers (int, optional): The number of SAGE convolutional layers. Defaults to 5. normalize (bool, optional): Whether to apply normalisation. Defaults to False. bias (bool, optional): Whether to include the bias term. Defaults to True. aggr (str, optional): The tensor aggregation type. Defaults to "mean". dropout_p (float, optional): The dropout layer probability. Defaults to 0.25. """ super().__init__() self.layers = nn.ModuleList() self.conv1 = SAGEConv( in_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias ) self.conv2 = SAGEConv( hidden_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias ) self.conv3 = SAGEConv( hidden_channels, out_channels, normalize=normalize, aggr=aggr, bias=bias ) self.layers.append(self.conv1) for _ in range(n_layers): self.layers.append( SAGEConv( hidden_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias, ) ) self.activation = F.leaky_relu self.dropout = F.dropout self.dropout_p = dropout_p
[docs] def predict( self, x: torch.Tensor, edge_index: torch.Tensor, edge_label_index: torch.Tensor ) -> torch.Tensor: """ The forward pass. Args: x (torch.Tensor): Input data. edge_index (torch.Tensor): The graph edge index. edge_label_index (torch.Tensor): The graph edge label indices. Returns: torch.Tensor: The logits for edge membership. """ for layer in self.layers: x = layer(x, edge_index) x = self.activation(x) x = self.dropout(x, p=self.dropout_p) x = self.conv2(x, edge_index) x = self.conv3(x, edge_index) x = x[edge_label_index[0]] * x[edge_label_index[1]] x = x.sum(dim=-1).view(-1) return x
###################################### # Functions ######################################
[docs] def log_results( tracking_uri: str, experiment_prefix: str, grn_name: str, in_channels: int, config: DictConfig, ) -> None: """ Log experiment results to the experiment tracker. Args: tracking_uri (str): The tracking URI. experiment_prefix (str): The experiment name prefix. grn_name (str): The name of the GRN. in_channels (int): The number of input channels. config (DictConfig): The pipeline configuration. """ mlflow.set_tracking_uri(tracking_uri) experiment_name = f"{experiment_prefix}_train_gnn" existing_exp = mlflow.get_experiment_by_name(experiment_name) if not existing_exp: mlflow.create_experiment(experiment_name) mlflow.set_experiment(experiment_name) mlflow.set_tag("grn", grn_name) mlflow.set_tag("gnn", "SAGE") mlflow.log_param("grn", grn_name) mlflow.log_param("in_channels", in_channels) for k in config["gnn"]: mlflow.log_param(k, config["gnn"][k])
[docs] def get_graph( db_host: str, db_name: str, db_username: str, db_password: str, collection: str, feature_k: str = "expression", ) -> nx.Graph: """ Retrieve the graph from the database. Args: db_host (str): The database host. db_name (str): The database name. db_username (str): The database username. db_password (str): The database password. collection (str): The database collection. feature_k (str): The dictionary key for node features. Returns: nx.Graph: The retrieved graph. """ db = ArangoClient(hosts=db_host).db( db_name, username=db_username, password=db_password ) adapter = ADBNX_Adapter(db) db_G = adapter.arangodb_graph_to_networkx(collection) db_G = nx.Graph(db_G) db_G = nx.convert_node_labels_to_integers(db_G) G = nx.Graph() G.add_edges_from(db_G.edges) for node_id, node_features in list(db_G.nodes(data=True)): features = list(node_features[feature_k].values()) G.nodes[node_id][feature_k] = features return G
[docs] def get_split( G: nx.Graph, num_val: float, num_test: float, device: torch.device ) -> tuple[nx.Graph, nx.Graph, nx.Graph]: """ Get train-validation-test split. Args: G (nx.Graph): The graph. num_val (float): The proportion of validation data. num_test (float): The proportion of testing data. device (torch.device): The training device. Returns: tuple[nx.Graph, nx.Graph, nx.Graph]: The train-validation-test split. """ transform = T.Compose( [ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit( num_val=num_val, num_test=num_test, is_undirected=True, add_negative_train_samples=False, ), ] ) train_data, val_data, test_data = transform(G) return train_data, val_data, test_data
[docs] def get_model_components( lr: float, in_channels: int, hidden_channels: int, out_channels: int, device: torch.device, n_layers: int, normalize: bool, bias: bool, aggr: str, dropout_p: float, ) -> tuple: """ Get the components for training the model. Args: lr (float): The learning rate. in_channels (int): The number of input channels. hidden_channels (int): The number of hidden channels. out_channels (int): The number of output channels. device (torch.device): The training device. n_layers (int): The number of SAGE convolutional layers. normalize (bool): Whether to normalize the input tensors. bias (bool): Whether to include the bias term. aggr (str): The data aggregation method. dropout_p (float): The dropout probability. Returns: tuple: The components for training the model. """ model = SAGENet( in_channels, hidden_channels, out_channels, n_layers, normalize, bias, aggr, dropout_p, ).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "max", factor=0.05 ) criterion = torch.nn.BCEWithLogitsLoss() return model, optimizer, scheduler, criterion
[docs] def train_model( model: torch.nn.Module, train_data: nx.Graph, val_data: nx.Graph, test_data: nx.Graph, n_epochs: int, optimizer: torch.nn.Module, criterion: torch.nn.Module, device: torch.device, enable_tracking: bool, ) -> float: """ Train the graph neural network. Args: model (torch.nn.Module): The graph neural network. train_data (nx.Graph): The training data. val_data (nx.Graph): The validation data. test_data (nx.Graph): The testing data. n_epochs (int): The number of epochs. optimizer (torch.nn.Module): The model optimiser. criterion (torch.nn.Module): The loss criterion. device (torch.device): The training device. enable_tracking (bool): Whether to enable experiment tracking. Returns: float: The final area-under-curve score. """ def train(): model.train() optimizer.zero_grad() neg_edge_index = negative_sampling( edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, num_neg_samples=train_data.edge_label_index.size(1), ) edge_label_index = torch.cat( [train_data.edge_label_index, neg_edge_index], dim=-1, ) edge_label = torch.cat( [ train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1)), ], dim=0, ) out = model.predict( train_data.expression, train_data.edge_index, edge_label_index ).to(device) loss = criterion(out, edge_label) loss.backward() optimizer.step() return loss @torch.no_grad() def test(data): model.eval() out = model.predict( data.expression, data.edge_index, data.edge_label_index ).sigmoid() return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) for epoch in range(n_epochs): loss = train() val_auc = test(val_data) test_auc = test(test_data) if epoch % int(n_epochs * 0.05) == 0: if enable_tracking: mlflow.log_metric("train_loss", loss, step=epoch) mlflow.log_metric("val_auc", val_auc, step=epoch) mlflow.log_metric("test_auc", test_auc, step=epoch) print( f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, " f"Test: {test_auc:.4f}" ) final_test_auc = test(test_data) print(f"Final Test: {final_test_auc:.4f}") return final_test_auc
###################################### # Main ######################################
[docs] @hydra.main(version_base=None, config_path="../conf", config_name="config") def main(config: DictConfig) -> None: """ The main entry point for the plotting pipeline. Args: config (DictConfig): The pipeline configuration. """ EXPERIMENT_PREFIX = config["experiment"]["name"] GRN_NAME = config["grn"]["input_dir"] DB_HOST = config["db"]["host"] DB_NAME = config["db"]["name"] DB_USERNAME = config["db"]["username"] DB_PASSWORD = config["db"]["password"] NUM_VAL = config["gnn"]["num_val"] NUM_TEST = config["gnn"]["num_test"] HIDDEN_CHANNELS = config["gnn"]["hidden_channels"] OUT_CHANNELS = config["gnn"]["out_channels"] LR = config["gnn"]["lr"] N_EPOCHS = config["gnn"]["n_epochs"] N_LAYERS = config["gnn"]["n_layers"] NORMALIZE = config["gnn"]["normalize"] BIAS = config["gnn"]["bias"] AGGR = config["gnn"]["aggr"] DROPOUT_P = config["gnn"]["dropout_p"] TRACKING_URI = config["experiment_tracking"]["tracking_uri"] ENABLE_TRACKING = config["experiment_tracking"]["enabled"] G = get_graph(DB_HOST, DB_NAME, DB_USERNAME, DB_PASSWORD, GRN_NAME) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G = from_networkx(G) train_data, val_data, test_data = get_split(G, NUM_VAL, NUM_TEST, device) in_channels = G.expression.shape[1] model, optimizer, scheduler, criterion = get_model_components( LR, in_channels, HIDDEN_CHANNELS, OUT_CHANNELS, device, N_LAYERS, NORMALIZE, BIAS, AGGR, DROPOUT_P, ) if ENABLE_TRACKING: log_results(TRACKING_URI, EXPERIMENT_PREFIX, GRN_NAME, in_channels, config) final_test_auc = train_model( model, train_data, val_data, test_data, N_EPOCHS, optimizer, criterion, device, ENABLE_TRACKING, ) if ENABLE_TRACKING: mlflow.log_metric("final_test_auc", final_test_auc) mlflow.end_run()
if __name__ == "__main__": main()