"""Contains MLflow compatible models for versioning and deployment.
This module defines MLflow compatible models for versioning and deployment as microservices.
"""
from typing import Any
import bentoml
import gpytorch
import mlflow
import numpy as np
import pandas as pd
import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
import zuko
from gpytorch.models import ApproximateGP
from lampe.inference import NPE
from mlflow.client import MlflowClient
from mlflow.models import infer_signature
from mlflow.pyfunc.context import Context
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.pool import global_max_pool
from ..data_model import RootCalibrationModel
from .surrogates import SingleTaskVariationalGPModel
[docs]
def log_model(
task: str,
input_parameters: RootCalibrationModel,
calibration_model: mlflow.pyfunc.PythonModel,
artifacts: dict,
simulation_uuid: str,
signature_x: pd.DataFrame | np.ndarray | list | None = None,
signature_y: pd.DataFrame | np.ndarray | list | None = None,
model_config: dict = None, # type: ignore [assignment]
) -> None:
"""Log the calibrator model to the registry.
Args:
task (str):
The name of the current task for the experiment.
input_parameters (RootCalibrationModel):
The root calibration data model.
calibration_model (mlflow.pyfunc.PythonModel):
The calibrator to log.
artifacts (dict):
Experiment artifacts to log.
simulation_uuid (str):
The simulation uuid.
signature_x (pd.DataFrame | np.ndarray | list | None, optional):
The signature for data inputs. Defaults to None.
signature_y (pd.DataFrame | np.ndarray | list | None, optional):
The signature for data outputs. Defaults to None.
model_config (dict, optional):
The model configuration. Defaults to None.
"""
if signature_x is None and signature_y is None:
signature = None
else:
signature = infer_signature(signature_x, signature_y)
logged_model = mlflow.pyfunc.log_model(
python_model=calibration_model,
artifact_path=task,
artifacts=artifacts,
signature=signature,
model_config=model_config,
)
model_uri = logged_model.model_uri
model_version = mlflow.register_model(
model_uri=model_uri,
name=task,
tags=dict(
task=task,
simulation_uuid=simulation_uuid,
simulation_tag=input_parameters.simulation_tag,
),
)
client = MlflowClient(mlflow.get_tracking_uri())
client.update_model_version(
name=task,
version=model_version.version,
description=f"A root model calibrator for performing the following task: {task}",
)
bentoml.mlflow.import_model(
task,
model_uri,
labels=mlflow.active_run().data.tags,
metadata={
"metrics": mlflow.active_run().data.metrics,
"params": mlflow.active_run().data.params,
},
)
[docs]
class OptimisationModel(mlflow.pyfunc.PythonModel):
"""An optimisation calibration model."""
def __init__(self) -> None:
"""The OptimisationModel constructor."""
self.task = "optimisation"
self.calibrator = None
[docs]
def load_context(self, context: Context) -> None:
"""Load the model context.
Args:
context (Context):
The model context.
"""
import joblib
calibrator_data = context.artifacts["calibrator"]
self.calibrator = joblib.load(calibrator_data)
[docs]
def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
) -> pd.DataFrame:
"""Make a model prediction.
Args:
context (Context):
The model context.
model_input (pd.DataFrame):
The model input data.
params (dict, optional):
Optional model parameters. Defaults to None.
Raises:
ValueError:
Error raised when the calibrator has not been loaded.
Returns:
pd.DataFrame:
The model prediction.
"""
if self.calibrator is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")
n_trials = model_input["n_trials"].item()
trials_df: pd.DataFrame = self.calibrator.trials_dataframe().sort_values(
"value", ascending=True
)
return trials_df.head(n_trials)
[docs]
class SensitivityAnalysisModel(mlflow.pyfunc.PythonModel):
"""A sensitivity analysis calibration model."""
def __init__(self) -> None:
"""The SensitivityAnalysisModel constructor."""
self.task = "sensitivity_analysis"
self.calibrator = None
[docs]
def load_context(self, context: Context) -> None:
"""Load the model context.
Args:
context (Context):
The model context.
"""
import joblib
calibrator_data = context.artifacts["calibrator"]
self.calibrator = joblib.load(calibrator_data)
[docs]
def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
) -> pd.DataFrame:
"""Make a model prediction.
Args:
context (Context):
The model context.
model_input (pd.DataFrame):
The model input data.
params (dict, optional):
Optional model parameters. Defaults to None.
Raises:
ValueError:
Error raised when the calibrator has not been loaded.
Returns:
pd.DataFrame:
The model prediction.
"""
if self.calibrator is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")
names = model_input["name"].values # noqa: F841
si_df = self.calibrator.total_si_df
return si_df.query("name in @names")
[docs]
class AbcModel(mlflow.pyfunc.PythonModel):
"""An Approximate Bayesian Computation calibration model."""
def __init__(self) -> None:
"""The AbcModel constructor."""
self.task = "abc"
self.calibrator = None
[docs]
def load_context(self, context: Context) -> None:
"""Load the model context.
Args:
context (Context):
The model context.
"""
import joblib
calibrator_data = context.artifacts["calibrator"]
self.calibrator = joblib.load(calibrator_data)
[docs]
def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
) -> pd.DataFrame:
"""Make a model prediction.
Args:
context (Context):
The model context.
model_input (pd.DataFrame):
The model input data.
params (dict, optional):
Optional model parameters. Defaults to None.
Raises:
ValueError:
Error raised when the calibrator has not been loaded.
Returns:
pd.DataFrame:
The model prediction.
"""
if self.calibrator is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")
t: list[int] = model_input["t"].values
sampling_df = self.calibrator
if len(t) == 0 or t[0] == -1:
return sampling_df
else:
return sampling_df.query("t in @t")
[docs]
class SnpeModel(mlflow.pyfunc.PythonModel):
"""A Sequential neural posterior estimation calibration model."""
def __init__(self) -> None:
"""The SnpeModel constructor."""
self.task = "snpe"
self.inference = None
self.posterior = None
self.parameter_intervals = None
self.statistics_df = None
[docs]
def load_context(self, context: Context) -> None:
"""Load the model context.
Args:
context (Context):
The model context.
"""
import joblib
def load_data(k: str) -> Any:
artifact = context.artifacts[k]
return joblib.load(artifact)
self.inference = load_data("inference")
self.posterior = load_data("posterior")
self.parameter_intervals = load_data("parameter_intervals")
statistics_list = load_data("statistics_list")
self.statistics_df = pd.DataFrame(statistics_list)
[docs]
def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
) -> pd.DataFrame:
"""Make a model prediction.
Args:
context (Context):
The model context.
model_input (pd.DataFrame):
The model input data.
params (dict, optional):
Optional model parameters. Defaults to None.
Raises:
ValueError:
Error raised when the calibrator has not been loaded.
Returns:
pd.DataFrame:
The model prediction.
"""
for prop in [
self.inference,
self.posterior,
self.parameter_intervals,
self.statistics_df,
]:
if prop is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")
if context.model_config["inference_type"] == "summary_statistics":
statistic_names = self.statistics_df.statistic_name.unique()
filtered_inputs = model_input.query("statistic_name in @statistic_names")
if len(filtered_inputs) == 0:
return filtered_inputs
filtered_inputs = filtered_inputs.set_index("statistic_name")
filtered_inputs = filtered_inputs.loc[statistic_names]
observed_values = filtered_inputs["statistic_value"].values
posterior_samples = self.posterior.sample((100,), x=observed_values)
names = []
for name in self.parameter_intervals:
names.append(name)
else:
raise NotImplementedError("Inference for outputs unsupported.")
df = pd.DataFrame(posterior_samples, columns=names)
return df
[docs]
class SurrogateModel(mlflow.pyfunc.PythonModel):
"""A surrogate calibration model."""
def __init__(self) -> None:
"""The SurrogateModel constructor."""
self.task = "surrogate"
self.state_dict = None
self.X_scaler = None
self.Y_scaler = None
self.model = None
self.likelihood = None
self.column_names = None
[docs]
def load_context(self, context: Context) -> None:
"""Load the model context.
Args:
context (Context):
The model context.
"""
import joblib
def load_data(k: str) -> Any:
artifact = context.artifacts[k]
return joblib.load(artifact)
state_dict_path = context.artifacts["state_dict"]
self.state_dict = torch.load(state_dict_path)
if context.model_config["surrogate_type"] == "cost_emulator":
inducing_points_path = context.artifacts["inducing_points"]
inducing_points = torch.load(inducing_points_path).double()
self.model = SingleTaskVariationalGPModel(inducing_points).double()
self.likelihood = gpytorch.likelihoods.GaussianLikelihood().double()
self.model.load_state_dict(self.state_dict)
self.model.eval()
self.X_scaler = load_data("X_scaler")
self.Y_scaler = load_data("Y_scaler")
self.column_names = load_data("column_names")
[docs]
def predict(
self, context: Context, model_input: pd.DataFrame, params: dict | None = None
) -> pd.DataFrame:
"""Make a model prediction.
Args:
context (Context):
The model context.
model_input (pd.DataFrame):
The model input data.
params (dict, optional):
Optional model parameters. Defaults to None.
Raises:
ValueError:
Error raised when the calibrator has not been loaded.
Returns:
pd.DataFrame:
The model prediction.
"""
for prop in [
self.state_dict,
self.X_scaler,
self.Y_scaler,
self.model,
self.likelihood,
self.column_names,
]:
if prop is None:
raise ValueError(f"The {self.task} calibrator has not been loaded.")
filtered_df = model_input[self.column_names]
X = self.X_scaler.transform(filtered_df.values)
X = torch.Tensor(X).double()
predictions = self.likelihood(self.model(X))
mean = predictions.mean.detach().cpu().numpy()
lower, upper = predictions.confidence_region()
lower, upper = lower.detach().cpu().numpy(), upper.detach().cpu().numpy()
if context.model_config["surrogate_type"] == "cost_emulator":
mean = self.Y_scaler.inverse_transform(mean.reshape(-1, 1)).flatten()
lower = self.Y_scaler.inverse_transform(lower.reshape(-1, 1)).flatten()
upper = self.Y_scaler.inverse_transform(upper.reshape(-1, 1)).flatten()
df = pd.DataFrame(
{"discrepancy": mean, "lower_bound": lower, "upper_bound": upper}
)
return df