Linear Regression with Tidymodels

Star Wars Dataset

We will reuse the Star Wars dataset, and demonstrate how to use Tidymodels.

# Constants

MLFLOW_URL = "http://mlflow:5000"

# Imports

library(carrier)
library(DataExplorer)
library(knitr)
library(reticulate)
use_condaenv("r-mlflow-1.30.0")
library(mlflow)
mlflow::mlflow_set_tracking_uri(MLFLOW_URL)
library(tidymodels)
library(tidyverse)

# Load data
data = dplyr::starwars |>
  select(c(height, mass)) |>
   mutate_if(is.numeric, ~ replace_na(.,0))

data |>
  head()
# A tibble: 6 × 2
  height  mass
   <int> <dbl>
1    172    77
2    167    75
3     96    32
4    202   136
5    150    49
6    178   120

We will perform an 80-20 train-test split to evaluate the generalisability of our model.

data_split = initial_split(data, prop = 0.8)
data_train = training(data_split)
data_test = testing(data_split)

Tidymodels

Tidymodels is an R framework for machine learning modelling inspired by the functional programming style adopted by the tidyverse. In contrast with the popular caret package, Tidymodels is an entire framework composed of a collection of packages. Conversely, caret is a single package containing many machine learning methods and tools.

Recipes

The purpose of Tidymodels recipes to create reproducible data preprocessing pipelines. A recipe is composed of a sequence of data preprocessing steps.

sw_recipe = recipe(data_train) |>
  update_role(everything(), new_role = "support") |> 
  update_role(height, new_role = "outcome") |>
  update_role(mass, new_role = "predictor") |>
  step_impute_mean(mass) |>
  step_normalize(all_numeric(), -all_outcomes())

Random Forest

One purpose of Tidymodels is to provide a layer of abstraction between different packages. For instance, there are several packages such as randomForest and ranger that implement the random forest algorithm. With Tidymodels, we can easily switch between these different implementations and specify whether we are performing regression or classification.

sw_model = rand_forest(trees = tune()) |>
  set_engine("ranger") |>
  set_mode("regression")

Workflow

Next, we define a Tidymodel workflow. This allows us to combine the above preprocessing steps with a random forest regressor. Also note that Tidymodels encourages a high degree of modularity. We can save complex preprocessing recipes, and easily switch between different models.

sw_workflow = workflows::workflow() |>
  add_recipe(sw_recipe) |>
  add_model(sw_model)

Hyperparameter Tuning

We will tune the optimal number of decision trees to use within the random forest ensemble.

tree_grid = seq(50, 200, by = 50)
sw_grid = expand_grid(trees = tree_grid)

sw_grid_results = sw_workflow |>
  tune_grid(resamples = vfold_cv(data_train, v = 5), grid = sw_grid)

hyperparameters = sw_grid_results |> 
  select_by_pct_loss(metric = "rmse", limit = 5, trees)

autoplot(sw_grid_results, metric = "rmse") 

MLFlow

We will next demonstrate how to integrate Tidymodels with MLFlow.

Registering Models

We will first create a new model in the model registry.

client = mlflow_client()
tryCatch(
  expr = {mlflow_delete_registered_model("sw_rf", client = client)},
  error = function(x) {}
)
mlflow_create_registered_model("sw_rf", client = client, description = "Perform predictions for Star Wars characters using Random Forest.")
$name
[1] "sw_rf"

$creation_timestamp
[1] 1.668548e+12

$last_updated_timestamp
[1] 1.668548e+12

$description
[1] "Perform predictions for Star Wars characters using Random Forest."

We will next execute an MLFlow run.

MLFlow Run

Metric Tracking

We will log the metrics and parameters for the random forest run.

# See https://mdneuzerling.com/post/tracking-tidymodels-with-mlflow/

log_workflow_parameters = function(workflow, client, run) {
  spec = workflows::extract_spec_parsnip(workflow)
  parameter_names = names(spec$args)
  parameter_values = lapply(spec$args, rlang::get_expr)
  for(i in seq_along(spec$args)) {
    parameter_name = parameter_names[[i]]
    parameter_value = parameter_values[[i]]
    if (!is.null(parameter_value)) {
      mlflow_log_param(parameter_name, parameter_value, client = client, run_id = run$run_uuid)
    }
  }
  workflow
}

log_metrics = function(metrics, estimator = "standard", client, run) {
  metrics |> 
    filter(.estimator == estimator) |>
    pmap(
      function(.metric, .estimator, .estimate) {
        mlflow_log_metric(.metric, .estimate, client = client, run_id = run$run_uuid)  
      }
    )
  metrics
}

Next, we will initiate the tidymodels run with MLFlow integration.

s3_bucket = "s3://mlflow/sw_rf"
# Begin the run.
experiment = mlflow_set_experiment(experiment_name = "sw_rf", artifact_location = s3_bucket) 
run = mlflow_start_run(client = client)

# Save the model.
sw_rf = sw_workflow |>
    finalize_workflow(hyperparameters) |>
    log_workflow_parameters(client = client, run = run) |> 
    fit(data_train)

packaged_sw_rf = carrier::crate(
  function(x) workflows:::predict.workflow(sw_rf, x),
  sw_rf = sw_rf
)

# Log params and metrics.
metrics = sw_rf |>
    predict(data_test) |>
    metric_set(rmse, mae, rsq)(data_test$height, .pred) |> 
    log_metrics(client = client, run = run)

# Log predictions and actual values
sw_rf |>
  predict(new_data = data_test) |>
  (function(x) x$.pred)() |>
  iwalk(
    ~ mlflow_log_metric("prediction", .x, step = as.numeric(.y), client = client, run_id = run$run_uuid)
    )

data_test$height |> 
  iwalk(
    ~ mlflow_log_metric("actual",  .x, step = .y, client = client, run_id = run$run_uuid)
    )

# Save model to the registry.
crated_model = "/tmp/sw_rf"
saved_model = mlflow_save_model(packaged_sw_rf, crated_model)  
logged_model = mlflow_log_artifact(crated_model, client = client, run_id =  run$run_uuid) 
2022/11/17 09:31:44 INFO mlflow.store.artifact.cli: Logged artifact from local dir /tmp/sw_rf to artifact_path=None
versioned_model = mlflow_create_model_version("sw_rf", run$artifact_uri, run_id = run$run_uuid, client = client)

# Generate report.
sw_report = data |>
  select_if(~ !is.list(.x)) |>
  create_report(output_file = "star_wars.html", output_dir = "/tmp", report_title = "Star Wars Report", quiet = T)
logged_report = mlflow_log_artifact("/tmp/star_wars.html", client = client, run_id =  run$run_uuid) 
2022/11/17 09:31:49 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars.html to artifact_path=None
# Save plots.
sw_plot = "/tmp/star_wars_characters.png"
png(filename = sw_plot)
plot(data$height, data$mass)
doff = dev.off()
logged_plot = mlflow_log_artifact(sw_plot, client = client, run_id =  run$run_uuid) 
2022/11/17 09:31:50 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.png to artifact_path=None
# Save tibble.
data_csv = "/tmp/star_wars_characters.csv"
write_csv(data, data_csv)
logged_csv = mlflow_log_artifact(data_csv, client = client, run_id =  run$run_uuid) 
2022/11/17 09:31:51 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.csv to artifact_path=None
# End run.
run_end = mlflow_end_run(run_id =  run$run_uuid, client = client)

Session Information

R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] ranger_0.14.1      rmarkdown_2.17     data.table_1.14.4  forcats_0.5.2     
 [5] stringr_1.4.1      readr_2.1.3        tidyverse_1.3.2    yardstick_1.1.0   
 [9] workflowsets_1.0.0 workflows_1.1.0    tune_1.0.1         tidyr_1.2.1       
[13] tibble_3.1.8       rsample_1.1.0      recipes_1.0.2      purrr_0.3.5       
[17] parsnip_1.0.2      modeldata_1.0.1    infer_1.0.3        ggplot2_3.3.6     
[21] dplyr_1.0.10       dials_1.0.0        scales_1.2.1       broom_1.0.1       
[25] tidymodels_1.0.0   mlflow_1.30.0      reticulate_1.26    knitr_1.40        
[29] DataExplorer_0.8.2 carrier_0.1.0     

loaded via a namespace (and not attached):
  [1] readxl_1.4.1        backports_1.4.1     plyr_1.8.7         
  [4] igraph_1.3.5        splines_4.2.1       listenv_0.8.0      
  [7] digest_0.6.30       foreach_1.5.2       htmltools_0.5.3    
 [10] fansi_1.0.3         magrittr_2.0.3      googlesheets4_1.0.1
 [13] tzdb_0.3.0          globals_0.16.1      modelr_0.1.9       
 [16] gower_1.0.0         vroom_1.6.0         askpass_1.1        
 [19] hardhat_1.2.0       colorspace_2.0-3    rvest_1.0.3        
 [22] rappdirs_0.3.3      haven_2.5.1         xfun_0.34          
 [25] crayon_1.5.2        jsonlite_1.8.3      zeallot_0.1.0      
 [28] survival_3.3-1      iterators_1.0.14    glue_1.6.2         
 [31] gtable_0.3.1        gargle_1.2.1        ipred_0.9-13       
 [34] future.apply_1.9.1  DBI_1.1.3           Rcpp_1.0.9         
 [37] bit_4.0.4           GPfit_1.0-8         lava_1.7.0         
 [40] prodlim_2019.11.13  htmlwidgets_1.5.4   httr_1.4.4         
 [43] ellipsis_0.3.2      pkgconfig_2.0.3     farver_2.1.1       
 [46] nnet_7.3-17         sass_0.4.2          dbplyr_2.2.1       
 [49] utf8_1.2.2          reshape2_1.4.4      tidyselect_1.2.0   
 [52] labeling_0.4.2      rlang_1.0.6         DiceDesign_1.9     
 [55] later_1.3.0         munsell_0.5.0       cellranger_1.1.0   
 [58] tools_4.2.1         cachem_1.0.6        cli_3.4.1          
 [61] generics_0.1.3      evaluate_0.17       fastmap_1.1.0      
 [64] yaml_2.3.6          bit64_4.0.5         processx_3.8.0     
 [67] fs_1.5.2            forge_0.2.0         future_1.28.0      
 [70] xml2_1.3.3          compiler_4.2.1      rstudioapi_0.14    
 [73] curl_4.3.3          png_0.1-7           reprex_2.0.2       
 [76] lhs_1.1.5           bslib_0.4.0         stringi_1.7.8      
 [79] highr_0.9           ps_1.7.2            lattice_0.20-45    
 [82] Matrix_1.4-1        vctrs_0.5.0         pillar_1.8.1       
 [85] lifecycle_1.0.3     networkD3_0.4       furrr_0.3.1        
 [88] jquerylib_0.1.4     ini_0.3.1           httpuv_1.6.6       
 [91] R6_2.5.1            promises_1.2.0.1    gridExtra_2.3      
 [94] parallelly_1.32.1   codetools_0.2-18    MASS_7.3-57        
 [97] assertthat_0.2.1    openssl_2.0.4       withr_2.5.0        
[100] swagger_3.33.1      parallel_4.2.1      hms_1.1.2          
[103] grid_4.2.1          rpart_4.1.16        timeDate_4021.106  
[106] class_7.3-20        googledrive_2.0.0   lubridate_1.8.0    
[109] base64enc_0.1-3