Train a Variational Autoencoder Graph Neural Network

Date published: 26/09/23

class bin.train_vae_gnn.VariationalGCNEncoder(in_channels: int, hidden_channels: int, out_channels: int, n_layers: int = 2, normalize: bool = False, bias: bool = True, aggr: str = 'mean')[source]
forward(x: Tensor, edge_index: Tensor) tuple[Tensor, Tensor][source]

The forward pass.

Args:
x (torch.Tensor):

Input data.

edge_index (torch.Tensor):

The graph edge index.

Returns:
tuple[torch.Tensor, torch.Tensor]:

The convolutional mean and log-standard deviation.

bin.train_vae_gnn.get_graph(db_host: str, db_name: str, db_username: str, db_password: str, collection: str, feature_k: str = 'expression') Graph[source]

Retrieve the graph from the database.

Args:
db_host (str):

The database host.

db_name (str):

The database name.

db_username (str):

The database username.

db_password (str):

The database password.

collection (str):

The database collection.

feature_k (str):

The dictionary key for node features.

Returns:
nx.Graph:

The retrieved graph.

bin.train_vae_gnn.get_model_components(lr: float, in_channels: int, hidden_channels: int, out_channels: int, device: device, n_layers: int, normalize: bool, bias: bool, aggr: str) tuple[source]

Get the components for training the model.

Args:
lr (float):

The learning rate.

in_channels (int):

The number of input channels.

hidden_channels (int):

The number of hidden channels.

out_channels (int):

The number of output channels.

device (torch.device):

The training device.

n_layers (int):

The number of SAGE convolutional layers.

normalize (bool):

Whether to normalize the input tensors.

bias (bool):

Whether to include the bias term.

aggr (str):

The data aggregation method.

Returns:
tuple:

The components for training the model.

bin.train_vae_gnn.get_split(G: Graph, num_val: float, num_test: float, device: device) tuple[Graph, Graph, Graph][source]

Get train-validation-test split.

Args:
G (nx.Graph):

The graph.

num_val (float):

The proportion of validation data.

num_test (float):

The proportion of testing data.

device (torch.device):

The training device.

Returns:
tuple[nx.Graph, nx.Graph, nx.Graph]:

The train-validation-test split.

bin.train_vae_gnn.log_results(tracking_uri: str, experiment_prefix: str, grn_name: str, in_channels: int, config: DictConfig) None[source]

Log experiment results to the experiment tracker.

Args:
tracking_uri (str):

The tracking URI.

experiment_prefix (str):

The experiment name prefix.

grn_name (str):

The name of the GRN.

in_channels (int):

The number of input channels.

config (DictConfig):

The pipeline configuration.

bin.train_vae_gnn.main(config: DictConfig) None[source]

The main entry point for the plotting pipeline.

Args:
config (DictConfig):

The pipeline configuration.

bin.train_vae_gnn.train_model(model: Module, train_data: Graph, val_data: Graph, test_data: Graph, n_epochs: int, optimizer: Module, device: device, enable_tracking: bool) float[source]

Train the graph neural network.

Args:
model (torch.nn.Module):

The graph neural network.

train_data (nx.Graph):

The training data.

val_data (nx.Graph):

The validation data.

test_data (nx.Graph):

The testing data.

n_epochs (int):

The number of epochs.

optimizer (torch.nn.Module):

The model optimiser.

device (torch.device):

The training device.

enable_tracking (bool):

Whether to enable experiment tracking.

Returns:
float:

The final area-under-curve score.

bin.train_vae_gnn.view_embeddings(model: Module, data: Graph, output_dir: str, enable_tracking: bool) str[source]

View the latent embeddings in 2D.

Args:
model (torch.nn.Module):

The variational autoencoder.

data (nx.Graph):

The graph data.

output_dir (str):

The output directory for saving plots.

enable_tracking (bool):

Whether experiment tracking is enabled.

Returns:
str:

The saved visualisation.