Source code for bin.train_vae_gnn

#!/usr/bin/env python

# Imports

from adbnx_adapter import ADBNX_Adapter
from arango import ArangoClient
import hydra
import matplotlib.pyplot as plt
import mlflow
import networkx as nx
from omegaconf import DictConfig
from os.path import join as join_path
import pandas as pd
from pathlib import Path
import torch
from sklearn.decomposition import PCA
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import from_networkx
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, VGAE

# Classes

[docs] class VariationalGCNEncoder(torch.nn.Module): def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, n_layers: int = 2, normalize: bool = False, bias: bool = True, aggr: str = "mean", ) -> None: """ SAGENet constructor. Args: in_channels (int): The number of input channels. hidden_channels (int): The number of hidden channels. out_channels (int): The number of output channels. n_layers (int, optional): The number of SAGE convolutional layers. Defaults to 5. normalize (bool, optional): Whether to apply normalisation. Defaults to False. bias (bool, optional): Whether to include the bias term. Defaults to True. aggr (str, optional): The tensor aggregation type. Defaults to "mean". """ super().__init__() self.layers = nn.ModuleList() self.conv1 = SAGEConv( in_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias ) self.conv2 = SAGEConv( hidden_channels, 2 * out_channels, normalize=normalize, aggr=aggr, bias=bias ) self.conv_mu = SAGEConv( 2 * out_channels, out_channels, normalize=normalize, aggr=aggr, bias=bias ) self.conv_logstd = SAGEConv( 2 * out_channels, out_channels, normalize=normalize, aggr=aggr, bias=bias ) self.layers.append(self.conv1) for _ in range(n_layers): self.layers.append( SAGEConv( hidden_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias, ) ) self.activation = F.leaky_relu
[docs] def forward( self, x: torch.Tensor, edge_index: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ The forward pass. Args: x (torch.Tensor): Input data. edge_index (torch.Tensor): The graph edge index. Returns: tuple[torch.Tensor, torch.Tensor]: The convolutional mean and log-standard deviation. """ for layer in self.layers: x = layer(x, edge_index) x = self.activation(x) x = self.conv2(x, edge_index) return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
###################################### # Functions ######################################
[docs] def log_results( tracking_uri: str, experiment_prefix: str, grn_name: str, in_channels: int, config: DictConfig, ) -> None: """ Log experiment results to the experiment tracker. Args: tracking_uri (str): The tracking URI. experiment_prefix (str): The experiment name prefix. grn_name (str): The name of the GRN. in_channels (int): The number of input channels. config (DictConfig): The pipeline configuration. """ mlflow.set_tracking_uri(tracking_uri) experiment_name = f"{experiment_prefix}_train_vae_gnn" existing_exp = mlflow.get_experiment_by_name(experiment_name) if not existing_exp: mlflow.create_experiment(experiment_name) mlflow.set_experiment(experiment_name) mlflow.set_tag("grn", grn_name) mlflow.set_tag("gnn", "VAE") mlflow.log_param("grn", grn_name) mlflow.log_param("in_channels", in_channels) for k in config["gnn"]: mlflow.log_param(k, config["gnn"][k])
[docs] def get_graph( db_host: str, db_name: str, db_username: str, db_password: str, collection: str, feature_k: str = "expression", ) -> nx.Graph: """ Retrieve the graph from the database. Args: db_host (str): The database host. db_name (str): The database name. db_username (str): The database username. db_password (str): The database password. collection (str): The database collection. feature_k (str): The dictionary key for node features. Returns: nx.Graph: The retrieved graph. """ db = ArangoClient(hosts=db_host).db( db_name, username=db_username, password=db_password ) adapter = ADBNX_Adapter(db) db_G = adapter.arangodb_graph_to_networkx(collection) db_G = nx.Graph(db_G) db_G = nx.convert_node_labels_to_integers(db_G) G = nx.Graph() G.add_edges_from(db_G.edges) for node_id, node_features in list(db_G.nodes(data=True)): features = list(node_features[feature_k].values()) G.nodes[node_id][feature_k] = features return G
[docs] def get_split( G: nx.Graph, num_val: float, num_test: float, device: torch.device ) -> tuple[nx.Graph, nx.Graph, nx.Graph]: """ Get train-validation-test split. Args: G (nx.Graph): The graph. num_val (float): The proportion of validation data. num_test (float): The proportion of testing data. device (torch.device): The training device. Returns: tuple[nx.Graph, nx.Graph, nx.Graph]: The train-validation-test split. """ transform = T.Compose( [ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit( num_val=num_val, num_test=num_test, is_undirected=True, add_negative_train_samples=False, split_labels=True, ), ] ) train_data, val_data, test_data = transform(G) return train_data, val_data, test_data
[docs] def get_model_components( lr: float, in_channels: int, hidden_channels: int, out_channels: int, device: torch.device, n_layers: int, normalize: bool, bias: bool, aggr: str, ) -> tuple: """ Get the components for training the model. Args: lr (float): The learning rate. in_channels (int): The number of input channels. hidden_channels (int): The number of hidden channels. out_channels (int): The number of output channels. device (torch.device): The training device. n_layers (int): The number of SAGE convolutional layers. normalize (bool): Whether to normalize the input tensors. bias (bool): Whether to include the bias term. aggr (str): The data aggregation method. Returns: tuple: The components for training the model. """ model = VGAE( VariationalGCNEncoder( in_channels, hidden_channels, out_channels, n_layers, normalize, bias, aggr, ) ).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "max", factor=0.05 ) return model, optimizer, scheduler
[docs] def train_model( model: torch.nn.Module, train_data: nx.Graph, val_data: nx.Graph, test_data: nx.Graph, n_epochs: int, optimizer: torch.nn.Module, device: torch.device, enable_tracking: bool, ) -> float: """ Train the graph neural network. Args: model (torch.nn.Module): The graph neural network. train_data (nx.Graph): The training data. val_data (nx.Graph): The validation data. test_data (nx.Graph): The testing data. n_epochs (int): The number of epochs. optimizer (torch.nn.Module): The model optimiser. device (torch.device): The training device. enable_tracking (bool): Whether to enable experiment tracking. Returns: float: The final area-under-curve score. """ def train(): model.train() optimizer.zero_grad() z = model.encode(train_data.expression, train_data.edge_index) loss = model.recon_loss(z, train_data.pos_edge_label_index) loss = loss + (1 / train_data.num_nodes) * model.kl_loss() loss.backward() optimizer.step() return float(loss) @torch.no_grad() def test(data): model.eval() z = model.encode(data.expression, data.edge_index) return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) for epoch in range(n_epochs): loss = train() val_auc, val_ap = test(val_data) test_auc, test_ap = test(test_data) if epoch % int(n_epochs * 0.05) == 0: if enable_tracking: mlflow.log_metric("train_loss", loss, step=epoch) mlflow.log_metric("val_auc", val_auc, step=epoch) mlflow.log_metric("val_ap", val_ap, step=epoch) mlflow.log_metric("test_auc", test_auc, step=epoch) mlflow.log_metric("test_ap", test_ap, step=epoch) print( f"Epoch: {epoch:03d}, loss {loss:.4f}", f"Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}", f"Test AUC: {test_auc:.4f}, Test AP: {test_ap:.4f}", ) final_test_auc, final_test_ap = test(test_data) print(f"Final Test AUC: {final_test_auc:.4f}, Final Test AP: {final_test_ap:.4f}") if enable_tracking: mlflow.log_metric("final_test_auc", final_test_auc) mlflow.log_metric("final_test_ap", final_test_ap) return model
[docs] def view_embeddings( model: torch.nn.Module, data: nx.Graph, output_dir: str, enable_tracking: bool ) -> str: """ View the latent embeddings in 2D. Args: model (torch.nn.Module): The variational autoencoder. data (nx.Graph): The graph data. output_dir (str): The output directory for saving plots. enable_tracking (bool): Whether experiment tracking is enabled. Returns: str: The saved visualisation. """ embeddings = model.encode(data.expression, data.edge_index).detach().cpu().numpy() transformer = PCA(n_components=2) emb_transformed = pd.DataFrame( transformer.fit_transform(embeddings), columns=["x", "y"] ) emb_transformed.plot.scatter("x", "y") Path(output_dir).mkdir(parents=True, exist_ok=True) outfile = join_path(output_dir, "graph.png") plt.savefig(outfile) if enable_tracking: mlflow.log_artifact(outfile) return outfile
###################################### # Main ######################################
[docs] @hydra.main(version_base=None, config_path="../conf", config_name="config") def main(config: DictConfig) -> None: """ The main entry point for the plotting pipeline. Args: config (DictConfig): The pipeline configuration. """ EXPERIMENT_PREFIX = config["experiment"]["name"] DATA_DIR = config["dir"]["data_dir"] OUT_DIR = config["dir"]["out_dir"] GRN_NAME = config["grn"]["input_dir"] DB_HOST = config["db"]["host"] DB_NAME = config["db"]["name"] DB_USERNAME = config["db"]["username"] DB_PASSWORD = config["db"]["password"] NUM_VAL = config["gnn"]["num_val"] NUM_TEST = config["gnn"]["num_test"] HIDDEN_CHANNELS = config["gnn"]["hidden_channels"] OUT_CHANNELS = config["gnn"]["out_channels"] LR = config["gnn"]["lr"] N_EPOCHS = config["gnn"]["n_epochs"] N_LAYERS = config["gnn"]["n_layers"] NORMALIZE = config["gnn"]["normalize"] BIAS = config["gnn"]["bias"] AGGR = config["gnn"]["aggr"] TRACKING_URI = config["experiment_tracking"]["tracking_uri"] ENABLE_TRACKING = config["experiment_tracking"]["enabled"] G = get_graph(DB_HOST, DB_NAME, DB_USERNAME, DB_PASSWORD, GRN_NAME) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G = from_networkx(G) train_data, val_data, test_data = get_split(G, NUM_VAL, NUM_TEST, device) in_channels = G.expression.shape[1] model, optimizer, scheduler = get_model_components( LR, in_channels, HIDDEN_CHANNELS, OUT_CHANNELS, device, N_LAYERS, NORMALIZE, BIAS, AGGR, ) if ENABLE_TRACKING: log_results(TRACKING_URI, EXPERIMENT_PREFIX, GRN_NAME, in_channels, config) model = train_model( model, train_data, val_data, test_data, N_EPOCHS, optimizer, device, ENABLE_TRACKING, ) output_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, "vae_gnn") view_embeddings(model, train_data, output_dir, ENABLE_TRACKING) if ENABLE_TRACKING: mlflow.end_run()
if __name__ == "__main__": main()