Integrating Tidymodels and Targets

Star Wars Dataset

We will reuse the Star Wars dataset yet again, and demonstrate how to integrate Tidymodels with targets. Targets allows users to create pipelines for general-purpose workflows.

# Constants

MLFLOW_URL = "http://mlflow:5000"

# Imports

library(carrier)
library(DataExplorer)
library(knitr)
library(reticulate)
use_condaenv("r-mlflow-1.30.0")
library(mlflow)
mlflow::mlflow_set_tracking_uri(MLFLOW_URL)
library(targets)
library(tidymodels)
library(tidyverse)

# Load data

load_sw_data = function() {
  dplyr::starwars |>
    select(c(height, mass)) |>
    mutate_if(is.numeric, ~ replace_na(.,0))
}

data = load_sw_data() 
data |>
  head()
# A tibble: 6 × 2
  height  mass
   <int> <dbl>
1    172    77
2    167    75
3     96    32
4    202   136
5    150    49
6    178   120

We will perform an 80-20 train-test split to evaluate the generalisability of our model.

data_split = initial_split(data, prop = 0.8)
data_train = training(data_split)
data_test = testing(data_split)

Tidymodels

Tidymodels is an R framework for machine learning modelling inspired by the functional programming style adopted by the tidyverse. In contrast with the popular caret package, Tidymodels is an entire framework composed of a collection of packages. Conversely, caret is a single package containing many machine learning methods and tools.

Recipes

The purpose of Tidymodels recipes to create reproducible data preprocessing pipelines. A recipe is composed of a sequence of data preprocessing steps.

preprocess_data_recipe = function(data_train) {
  recipe(data_train) |>
    update_role(everything(), new_role = "support") |> 
    update_role(height, new_role = "outcome") |>
    update_role(mass, new_role = "predictor") |>
    step_impute_mean(mass) |>
    step_normalize(all_numeric(), -all_outcomes())
}

sw_recipe = preprocess_data_recipe(data_train)  

Random Forest

One purpose of Tidymodels is to provide a layer of abstraction between different packages. For instance, there are several packages such as randomForest and ranger that implement the random forest algorithm. With Tidymodels, we can easily switch between these different implementations and specify whether we are performing regression or classification.

get_rf = function() {
  rand_forest(trees = tune()) |>
    set_engine("ranger") |>
    set_mode("regression")
}
sw_model = get_rf()

Workflow

Next, we define a Tidymodel workflow. This allows us to combine the above preprocessing steps with a random forest regressor. Also note that Tidymodels encourages a high degree of modularity. We can save complex preprocessing recipes, and easily switch between different models.

define_workflow = function(sw_recipe, sw_model) {
  workflows::workflow() |>
    add_recipe(sw_recipe) |>
    add_model(sw_model)
}
sw_workflow = define_workflow(sw_recipe, sw_model)

Hyperparameter Tuning

We will tune the optimal number of decision trees to use within the random forest ensemble.

tree_grid = seq(50, 200, by = 50)
sw_grid = expand_grid(trees = tree_grid)

sw_grid_results = sw_workflow |>
  tune_grid(resamples = vfold_cv(data_train, v = 5), grid = sw_grid)

hyperparameters = sw_grid_results |> 
  select_by_pct_loss(metric = "rmse", limit = 5, trees)

MLFlow

We will next demonstrate how to integrate Tidymodels with MLFlow.

Registering Models

We will first create a new model in the model registry.

client = mlflow_client()
tryCatch(
  expr = {mlflow_delete_registered_model("sw_rf", client = client)},
  error = function(x) {}
)
mlflow_create_registered_model("sw_rf", client = client, description = "Perform predictions for Star Wars characters using Random Forest.")
$name
[1] "sw_rf"

$creation_timestamp
[1] 1.668677e+12

$last_updated_timestamp
[1] 1.668677e+12

$description
[1] "Perform predictions for Star Wars characters using Random Forest."

We will next execute an MLFlow run.

MLFlow Run

Metric Tracking

We will log the metrics and parameters for the random forest run.

# See https://mdneuzerling.com/post/tracking-tidymodels-with-mlflow/

log_workflow_parameters = function(workflow, client, run) {
  spec = workflows::extract_spec_parsnip(workflow)
  parameter_names = names(spec$args)
  parameter_values = lapply(spec$args, rlang::get_expr)
  for(i in seq_along(spec$args)) {
    parameter_name = parameter_names[[i]]
    parameter_value = parameter_values[[i]]
    if (!is.null(parameter_value)) {
      mlflow_log_param(parameter_name, parameter_value, client = client, run_id = run$run_uuid)
    }
  }
  workflow
}

log_metrics = function(metrics, estimator = "standard", client, run) {
  metrics |> 
    filter(.estimator == estimator) |>
    pmap(
      function(.metric, .estimator, .estimate) {
        mlflow_log_metric(.metric, .estimate, client = client, run_id = run$run_uuid)  
      }
    )
  metrics
}

Next, we will initiate the tidymodels run with MLFlow integration.

s3_bucket = "s3://mlflow/sw_rf"
# Begin the run.
experiment = mlflow_set_experiment(experiment_name = "sw_rf", artifact_location = s3_bucket) 
run = mlflow_start_run(client = client)

# Save the model.
train_rf = function(data_train, sw_workflow, hyperparameters, client, run) {
  sw_workflow |>
    finalize_workflow(hyperparameters) |>
    log_workflow_parameters(client = client, run = run) |> 
    fit(data_train)
}
sw_rf = train_rf(data_train, sw_workflow, hyperparameters, client, run)

package_rf = function(sw_rf) {
  carrier::crate(
    function(x) workflows:::predict.workflow(sw_rf, x),
    sw_rf = sw_rf
  )
}
packaged_sw_rf = package_rf(sw_rf)

# Log params and metrics.
get_metrics = function(sw_rf, data_test, client, run) {
  sw_rf |>
    predict(data_test) |>
    metric_set(rmse, mae, rsq)(data_test$height, .pred) |> 
    log_metrics(client = client, run = run)
}
metrics = get_metrics(sw_rf, data_test, client, run)

# Log predictions and actual values
load_pred_actual = function(sw_rf, data_test, client, run) {
  sw_rf |>
    predict(new_data = data_test) |>
    (function(x) x$.pred)() |>
    iwalk(
      ~ mlflow_log_metric("prediction", .x, step = as.numeric(.y), 
                          client = client, run_id = run$run_uuid)
      )
  
  data_test$height |> 
    iwalk(
      ~ mlflow_log_metric("actual",  .x, step = .y, 
                          client = client, run_id = run$run_uuid)
      )
}
load_pred_actual(sw_rf, data_test, client, run)

# Save model to the registry.
crated_model = "/tmp/sw_rf"
saved_model = mlflow_save_model(packaged_sw_rf, crated_model)  
logged_model = mlflow_log_artifact(crated_model, client = client, run_id =  run$run_uuid) 
2022/11/17 09:29:47 INFO mlflow.store.artifact.cli: Logged artifact from local dir /tmp/sw_rf to artifact_path=None
versioned_model = mlflow_create_model_version("sw_rf", run$artifact_uri, run_id = run$run_uuid, client = client)

# Generate report.
generate_sw_report = function(data, client, run) {
 data |>
    select_if(~ !is.list(.x)) |>
    create_report(output_file = "star_wars.html", output_dir = "/tmp", 
                  report_title = "Star Wars Report", quiet = T) 
  logged_report = mlflow_log_artifact("/tmp/star_wars.html", 
                                      client = client, run_id =  run$run_uuid) 
}
sw_report = generate_sw_report(data, client, run)
2022/11/17 09:29:53 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars.html to artifact_path=None
# Save plots.
plot_sw = function(data, client, run) {
  sw_plot = "/tmp/star_wars_characters.png"
  png(filename = sw_plot)
  plot(data$height, data$mass)
  doff = dev.off()
  logged_plot = mlflow_log_artifact(sw_plot, client = client, run_id =  run$run_uuid) 
}
sw_plot = plot_sw(data, client, run)
2022/11/17 09:29:55 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.png to artifact_path=None
# Save tibble.
data_csv = "/tmp/star_wars_characters.csv"
write_csv(data, data_csv)
logged_csv = mlflow_log_artifact(data_csv, client = client, run_id =  run$run_uuid) 
2022/11/17 09:29:56 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.csv to artifact_path=None
# End run.
run_end = mlflow_end_run(run_id =  run$run_uuid, client = client)

Loading and Serving Models

Next, we will load the random forest model from the registry.

# Remove the model from the R environment.
print(packaged_sw_rf)
<crate> 165.57 kB
* function: 55.74 kB
* `sw_rf`: 110.06 kB
function(x) workflows:::predict.workflow(sw_rf, x)
rm(packaged_sw_rf)

# Load the model from the registry.
packaged_sw_rf = mlflow_load_model("models:/sw_rf/1")
print(packaged_sw_rf)
<crate> 176.10 kB
* function: 55.74 kB
* `sw_rf`: 120.53 kB
function(x) workflows:::predict.workflow(sw_rf, x)

Finally, we will demonstrate how to deploy the model using a model-as-a-service approach. We will first demonstrate how to launch the model using bash.


export MLFLOW_TRACKING_URI=http://mlflow:5000

# ping http://0.0.0.0:9000/predict 
mlflow models serve -m "models:/sw_rf/1" -h 0.0.0.0 -p 9000 

Targets

We will next replicate the above workflow using the targets package. Pipelines should be defined in a _targets.R file.

# _targets.R

library(targets)

tar_option_set(packages = c(
  "carrier", 
  "DataExplorer",
  "knitr",
  "mlflow",
  "reticulate",
  "tidyverse", 
  "tidymodels"
  )
)

list(
  tar_target(MLFLOW_URL, "http://mlflow:5000"),
  tar_target(conda_active, use_condaenv("r-mlflow-1.30.0")),
  tar_target(mlflow_uri, mlflow::mlflow_set_tracking_uri(MLFLOW_URL)),
  tar_target(data, load_sw_data()),
  tar_target(data_split, initial_split(data, prop = 0.8)),
  tar_target(data_train, training(data_split)),
  tar_target(data_test, testing(data_split)),
  tar_target(sw_recipe, preprocess_data_recipe(data_train)),
  tar_target(sw_model, get_rf()),
  tar_target(sw_workflow, define_workflow(sw_recipe, sw_model)),
  tar_target(tree_grid, seq(50, 200, by = 50)),
  tar_target(sw_grid, expand_grid(trees = tree_grid)),
  tar_target(
    sw_grid_results,
    tune_grid(sw_workflow, resamples = vfold_cv(data_train, v = 5), grid = sw_grid)
  ),
  tar_target(
    hyperparameters, 
    select_by_pct_loss(sw_grid_results, metric = "rmse", limit = 5, trees)
  ),
  tar_target(client, mlflow_client()),
  tar_target(s3_bucket, "s3://mlflow/sw_rf"),
  tar_target(experiment, mlflow_set_experiment(experiment_name = "sw_rf", artifact_location = s3_bucket)),
  tar_target(experiment_id, mlflow_get_experiment(name = "sw_rf", client = client)$experiment_id),
  tar_target(run, mlflow_start_run(client = client, experiment_id = experiment_id)),
  tar_target(sw_rf, train_rf(data_train, sw_workflow, hyperparameters, client, run)),
  tar_target(packaged_sw_rf, package_rf(sw_rf)),
  tar_target(metrics, get_metrics(sw_rf, data_test, client, run)),
  tar_target(pred_actual, load_pred_actual(sw_rf, data_test, client, run)),
  tar_target(crated_model, "/tmp/sw_rf"),
  tar_target(saved_model, mlflow_save_model(packaged_sw_rf, crated_model)),
  tar_target(logged_model, mlflow_log_artifact(crated_model, client = client, 
                                               run_id =  run$run_uuid)),
  tar_target(versioned_model, mlflow_create_model_version("sw_rf", run$artifact_uri, 
                                                          run_id = run$run_uuid,
                                                          client = client)),
  tar_target(sw_report, generate_sw_report(data, client, run)),
  tar_target(sw_plot, plot_sw(data, client, run)),
  tar_target(data_csv, "/tmp/star_wars_characters.csv"),
  tar_target(written_csv, write_csv(data, data_csv)),
  tar_target(logged_csv, mlflow_log_artifact(data_csv, 
                                             client = client, run_id =  run$run_uuid)
  ),
  tar_target(run_end, mlflow_end_run(run_id =  run$run_uuid, client = client))
)

We can visualise the workflow as a network next.

tar_visnetwork()
── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
✔ broom        1.0.1      ✔ recipes      1.0.2 
✔ dials        1.0.0      ✔ rsample      1.1.0 
✔ dplyr        1.0.10     ✔ tibble       3.1.8 
✔ ggplot2      3.3.6      ✔ tidyr        1.2.1 
✔ infer        1.0.3      ✔ tune         1.0.1 
✔ modeldata    1.0.1      ✔ workflows    1.1.0 
✔ parsnip      1.0.2      ✔ workflowsets 1.0.0 
✔ purrr        0.3.5      ✔ yardstick    1.1.0 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter()  masks stats::filter()
✖ dplyr::lag()     masks stats::lag()
✖ recipes::step()  masks stats::step()
• Search for functions across packages at https://www.tidymodels.org/find/
── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
✔ readr   2.1.3     ✔ forcats 0.5.2
✔ stringr 1.4.1     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ readr::col_factor() masks scales::col_factor()
✖ purrr::discard()    masks scales::discard()
✖ dplyr::filter()     masks stats::filter()
✖ stringr::fixed()    masks recipes::fixed()
✖ dplyr::lag()        masks stats::lag()
✖ readr::spec()       masks yardstick::spec()

We can execute the pipeline using the tar_make() function.

tar_make()

Session Information

R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] forcats_0.5.2      stringr_1.4.1      readr_2.1.3        tidyverse_1.3.2   
 [5] yardstick_1.1.0    workflowsets_1.0.0 workflows_1.1.0    tune_1.0.1        
 [9] tidyr_1.2.1        tibble_3.1.8       rsample_1.1.0      recipes_1.0.2     
[13] purrr_0.3.5        parsnip_1.0.2      modeldata_1.0.1    infer_1.0.3       
[17] ggplot2_3.3.6      dplyr_1.0.10       dials_1.0.0        scales_1.2.1      
[21] broom_1.0.1        tidymodels_1.0.0   targets_0.13.5     mlflow_1.30.0     
[25] reticulate_1.26    knitr_1.40         DataExplorer_0.8.2 carrier_0.1.0     

loaded via a namespace (and not attached):
  [1] googledrive_2.0.0   colorspace_2.0-3    ellipsis_0.3.2     
  [4] class_7.3-20        fs_1.5.2            base64enc_0.1-3    
  [7] rstudioapi_0.14     listenv_0.8.0       furrr_0.3.1        
 [10] prodlim_2019.11.13  fansi_1.0.3         lubridate_1.8.0    
 [13] xml2_1.3.3          codetools_0.2-18    splines_4.2.1      
 [16] zeallot_0.1.0       jsonlite_1.8.3      dbplyr_2.2.1       
 [19] png_0.1-7           compiler_4.2.1      httr_1.4.4         
 [22] backports_1.4.1     assertthat_0.2.1    Matrix_1.4-1       
 [25] fastmap_1.1.0       gargle_1.2.1        cli_3.4.1          
 [28] later_1.3.0         htmltools_0.5.3     tools_4.2.1        
 [31] igraph_1.3.5        gtable_0.3.1        glue_1.6.2         
 [34] rappdirs_0.3.3      Rcpp_1.0.9          cellranger_1.1.0   
 [37] DiceDesign_1.9      vctrs_0.5.0         iterators_1.0.14   
 [40] timeDate_4021.106   gower_1.0.0         xfun_0.34          
 [43] globals_0.16.1      networkD3_0.4       ps_1.7.2           
 [46] rvest_1.0.3         lifecycle_1.0.3     googlesheets4_1.0.1
 [49] future_1.28.0       MASS_7.3-57         ipred_0.9-13       
 [52] hms_1.1.2           promises_1.2.0.1    parallel_4.2.1     
 [55] swagger_3.33.1      yaml_2.3.6          forge_0.2.0        
 [58] gridExtra_2.3       rpart_4.1.16        stringi_1.7.8      
 [61] foreach_1.5.2       lhs_1.1.5           hardhat_1.2.0      
 [64] lava_1.7.0          rlang_1.0.6         pkgconfig_2.0.3    
 [67] evaluate_0.17       lattice_0.20-45     htmlwidgets_1.5.4  
 [70] processx_3.8.0      tidyselect_1.2.0    parallelly_1.32.1  
 [73] magrittr_2.0.3      R6_2.5.1            generics_0.1.3     
 [76] base64url_1.4       ini_0.3.1           DBI_1.1.3          
 [79] haven_2.5.1         pillar_1.8.1        withr_2.5.0        
 [82] survival_3.3-1      nnet_7.3-17         future.apply_1.9.1 
 [85] crayon_1.5.2        modelr_0.1.9        utf8_1.2.2         
 [88] tzdb_0.3.0          rmarkdown_2.17      readxl_1.4.1       
 [91] grid_4.2.1          data.table_1.14.4   callr_3.7.2        
 [94] reprex_2.0.2        digest_0.6.30       httpuv_1.6.6       
 [97] openssl_2.0.4       munsell_0.5.0       GPfit_1.0-8        
[100] askpass_1.1