#!/usr/bin/env python
######################################
# Imports
######################################
import hydra
import matplotlib.pyplot as plt
import networkx as nx
from omegaconf import DictConfig
from os.path import join as join_path
import pandas as pd
from pathlib import Path
######################################
# Functions
######################################
[docs]
def construct_network(
edge_list: pd.DataFrame, from_col: str, to_col: str, len_component: int = 5
) -> nx.Graph:
"""
Construct a graph from edge list data.
Args:
edge_list (pd.DataFrame):
The edge list.
from_col (str):
The "from" column name.
to_col (str):
The "to" column name.
len_component (int, optional):
The minimum size of a subgraph to filter out. Defaults to 5.
Returns:
nx.Graph:
The constructed graph.
"""
edges = edge_list.sort_values(from_col)
G = nx.from_pandas_edgelist(edges, from_col, to_col, create_using=nx.Graph())
for component in list(nx.connected_components(G)):
if len(component) <= len_component:
for node in component:
G.remove_node(node)
return G
[docs]
def visualize_graph(G: nx.Graph, output_dir: str) -> str:
"""
Visualise the graph.
Args:
G (nx.Graph):
The graph.
output_dir (str):
The output directory for the visualisation.
Returns:
str:
The output directory for the visualisation.
"""
plt.figure(figsize=(7, 7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(
G,
pos=nx.spring_layout(G, seed=42),
with_labels=False,
node_color="blue",
cmap="Set2",
node_size=10,
)
outfile = join_path(output_dir, "graph.png")
plt.savefig(outfile)
return outfile
[docs]
def calculate_metrics(G: nx.Graph, output_dir: str) -> dict[float]:
"""
Calculate graph metrics.
Args:
G (nx.Graph):
The graph.
output_dir (str):
The output directory for the visualisation.
Returns:
dict[float]:
The dictionary of metrics.
"""
metrics = {}
for metric_func in [
nx.diameter,
nx.radius,
nx.average_clustering,
nx.node_connectivity,
nx.degree_assortativity_coefficient,
nx.degree_pearson_correlation_coefficient,
]:
metrics[metric_func.__name__] = metric_func(G)
outfile = join_path(output_dir, "metrics.csv")
pd.DataFrame(metrics, index=[0]).to_csv(outfile, index=False)
return metrics
[docs]
def log_results(
tracking_uri: str,
experiment_prefix: str,
grn_name: str,
edge_list_file: str,
network_plot: str,
metrics: dict[float],
) -> None:
"""
Log experiment results to the experiment tracker.
Args:
tracking_uri (str):
The tracking URI.
experiment_prefix (str):
The experiment name prefix.
grn_name (str):
The name of the GRN.
edge_list_file (str):
The name of the edge list file.
network_plot (str):
The path to the network plot to add as an artifact.
metrics (dict[float]):
The dictionary of metrics.
"""
import mlflow
mlflow.set_tracking_uri(tracking_uri)
experiment_name = f"{experiment_prefix}_eda"
existing_exp = mlflow.get_experiment_by_name(experiment_name)
if not existing_exp:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
mlflow.set_tag("grn", grn_name)
mlflow.log_param("grn", grn_name)
mlflow.log_param("edge_list_file_name", edge_list_file)
for k in metrics:
mlflow.log_metric(k, metrics[k])
mlflow.log_artifact(network_plot)
mlflow.end_run()
######################################
# Main
######################################
[docs]
@hydra.main(version_base=None, config_path="../conf", config_name="config")
def main(config: DictConfig) -> None:
"""
The main entry point for the plotting pipeline.
Args:
config (DictConfig):
The pipeline configuration.
"""
# Constants
EXPERIMENT_PREFIX = config["experiment"]["name"]
DATA_DIR = config["dir"]["data_dir"]
PREPROCESS_DIR = config["dir"]["preprocessed_dir"]
OUT_DIR = config["dir"]["out_dir"]
GRN_NAME = config["grn"]["input_dir"]
EDGE_LIST_FILE = config["grn"]["edge_list"]
FROM_COL = config["grn"]["from_col"]
TO_COL = config["grn"]["to_col"]
TRACKING_URI = config["experiment_tracking"]["tracking_uri"]
ENABLE_TRACKING = config["experiment_tracking"]["enabled"]
input_dir = join_path(DATA_DIR, PREPROCESS_DIR, GRN_NAME)
edge_list = pd.read_csv(join_path(input_dir, EDGE_LIST_FILE))
G = construct_network(edge_list, FROM_COL, TO_COL)
output_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, "eda")
Path(output_dir).mkdir(parents=True, exist_ok=True)
network_plot = visualize_graph(G, output_dir)
metrics = calculate_metrics(G, output_dir)
if ENABLE_TRACKING:
log_results(
TRACKING_URI,
EXPERIMENT_PREFIX,
GRN_NAME,
EDGE_LIST_FILE,
network_plot,
metrics,
)
if __name__ == "__main__":
main()