Source code for src.plot_curves

#!/usr/bin/python

######################################
# Imports
######################################

# External
import hydra
from matplotlib import pyplot as plt
from omegaconf import DictConfig
from os.path import join as join_path
import pandas as pd
import seaborn as sns

######################################
# Main
######################################


[docs] def plot_curves( species_code: str, index_df: pd.DataFrame, data_dir: str, infile: str, x: str, y: str, hue: str, col: str, as_cat: list[str] = [], tracking_uri: str = None, enable_experiment_tracking: bool = False ) -> None: """ Plot the growth curves. Args: species_code (str): The shark species code. index_df (pd.DataFrame): The index dataframe containing metadata. data_dir (str): The data directory. infile (str): The input datafile. x (str): The x column. y (str): The y column. hue (str): The grouping column. col (str): The categorical column. as_cat (list[str], optional): Cast the column as a category. Defaults to []. tracking_uri (str, optional): The experiment tracking URI. Defaults to None. enable_experiment_tracking (bool, optional): Enable experiment tracking. Defaults to False. """ species_df = index_df.query("species_code == @species_code") extract_val = lambda key: species_df[key].values.item() species = extract_val("species") class_type = extract_val("class") order = extract_val("order") in_dir = join_path(data_dir, class_type, order, species) datafile = join_path(in_dir, infile) data_df = pd.read_csv(datafile) for cat_col in as_cat: data_df[cat_col] = data_df[cat_col].astype("category") sns.catplot( x=x, y=y, hue=hue, col=col, col_wrap=3, data=data_df, kind="point", height=4, aspect=0.8, ) outfile = join_path(in_dir, f"{y}_{x}.png") plt.savefig(outfile) if enable_experiment_tracking is False or tracking_uri is None: return import mlflow mlflow.set_tracking_uri(tracking_uri) experiment_name = "plot_growth_curves" existing_exp = mlflow.get_experiment_by_name(experiment_name) if not existing_exp: mlflow.create_experiment(experiment_name) mlflow.set_experiment(experiment_name) mlflow.set_tag("species", species) mlflow.log_param("species", species) mlflow.log_param("class", class_type) mlflow.log_param("order", order) def log_series(df, interval=0.01): series = df[y].dropna().sort_values().values series_interval = int(len(series) * interval) for i, value in enumerate(series): if i % series_interval != 0: continue mlflow.log_metric(y, value, step=i) data_df.groupby(hue).apply(log_series) mlflow.log_artifact(outfile) mlflow.end_run()
[docs] @hydra.main(version_base=None, config_path="../conf", config_name="config") def main(config: DictConfig) -> None: """ The main entry point for the plotting pipeline. Args: config (DictConfig): The pipeline configuration. """ # Constants SPECIES_LIST = config["common"]["species"] TRACKING_URI = config["experiment_tracking"]["tracking_uri"] ENABLE_EXPERIMENT_TRACKING = config["experiment_tracking"]["enabled"] data_config = config["data"] DATA_DIR = data_config["dir"] INDEX = data_config["index"] OUTFILE = data_config["out"] plot_config = config["plot"] X = plot_config["x"] Y = plot_config["y"] HUE = plot_config["hue"] COL = plot_config["col"] AS_CAT = plot_config["as_cat"] # Load data index_file = join_path(DATA_DIR, INDEX) index_df = pd.read_csv(index_file) for species_code in SPECIES_LIST: plot_curves( species_code, index_df, DATA_DIR, OUTFILE, X, Y, HUE, COL, AS_CAT, TRACKING_URI, ENABLE_EXPERIMENT_TRACKING )
if __name__ == "__main__": main()