MLFlow Linear Regression

Star Wars Dataset

We will demonstrate how to fit a linear regression model package to the Star Wars dataset.

Subsequently, we will track the experiment and save the model to an MLFlow server. We use the carrier package to serialize the model (write an in-memory object to file).

# Constants

MLFLOW_URL = "http://mlflow:5000"

# Imports

library(carrier)
library(DataExplorer)
library(knitr)
library(reticulate)
use_condaenv("r-mlflow-1.30.0")
library(mlflow)
mlflow::mlflow_set_tracking_uri(MLFLOW_URL)
library(tidyverse)

# Load data

data = dplyr::starwars |>
   mutate_if(is.numeric, ~ replace_na(.,0))

data |>
  head()
# A tibble: 6 × 14
  name         height  mass hair_…¹ skin_…² eye_c…³ birth…⁴ sex   gender homew…⁵
  <chr>         <int> <dbl> <chr>   <chr>   <chr>     <dbl> <chr> <chr>  <chr>  
1 Luke Skywal…    172    77 blond   fair    blue       19   male  mascu… Tatooi…
2 C-3PO           167    75 <NA>    gold    yellow    112   none  mascu… Tatooi…
3 R2-D2            96    32 <NA>    white,… red        33   none  mascu… Naboo  
4 Darth Vader     202   136 none    white   yellow     41.9 male  mascu… Tatooi…
5 Leia Organa     150    49 brown   light   brown      19   fema… femin… Aldera…
6 Owen Lars       178   120 brown,… light   blue       52   male  mascu… Tatooi…
# … with 4 more variables: species <chr>, films <list>, vehicles <list>,
#   starships <list>, and abbreviated variable names ¹​hair_color, ²​skin_color,
#   ³​eye_color, ⁴​birth_year, ⁵​homeworld

Minio

Minio is an object database that is used to store files within a centralised location.

We need to create a minio bucket. All of our files will be saved into this bucket.


from minio import Minio
import json
import os

minioClient = Minio(
  os.environ['MLFLOW_S3_ENDPOINT_URL'].split('//')[1],
  access_key=os.environ['AWS_ACCESS_KEY_ID'],
  secret_key=os.environ['AWS_SECRET_ACCESS_KEY'],
  secure = False
)

mlflow_names = [ bucket.name for bucket in minioClient.list_buckets() ]
if 'mlflow' not in mlflow_names:
  minioClient.make_bucket('mlflow')

Next, we set the bucket policy.


policy = {
  "Version":"2012-10-17",
  "Statement":[
    {
      "Sid":"",
      "Effect":"Allow",
      "Principal":{"AWS":"*"},
      "Action":"s3:GetBucketLocation",
      "Resource":"arn:aws:s3:::mlflow"
    },
    {
      "Sid":"",
      "Effect":"Allow",
      "Principal":{"AWS":"*"},
      "Action":"s3:ListBucket",
      "Resource":"arn:aws:s3:::mlflow"
    },
    {
      "Sid":"",
      "Effect":"Allow",
      "Principal":{"AWS":"*"},
      "Action":"s3:GetObject",
      "Resource":"arn:aws:s3:::mlflow/*"
    },
    {
      "Sid":"",
      "Effect":"Allow",
      "Principal":{"AWS":"*"},
      "Action":"s3:PutObject",
      "Resource":"arn:aws:s3:::mlflow/*"
    }
  ]}

minioClient.set_bucket_policy('mlflow', json.dumps(policy))

MLFlow

Registering Models

We will first create a new model in the model registry.

client = mlflow_client()
tryCatch(
  expr = {mlflow_delete_registered_model("sw_lm", client = client)},
  error = function(x) {}
)
mlflow_create_registered_model("sw_lm", client = client, description = "Perform predictions for Star Wars characters using linear regression.")
$name
[1] "sw_lm"

$creation_timestamp
[1] 1.668548e+12

$last_updated_timestamp
[1] 1.668548e+12

$description
[1] "Perform predictions for Star Wars characters using linear regression."

We will next execute an MLFlow run.

MLFlow Run

s3_bucket = "s3://mlflow/sw_lm"
# Begin the run.
experiment = mlflow_set_experiment(experiment_name = "sw_lm", artifact_location = s3_bucket) 
run = mlflow_start_run(client = client)

# Save the model.
sw_lm = lm(height ~ mass, data = data)
packaged_sw_lm = carrier::crate(
    function(x) {
      stats::predict.lm(sw_lm, newdata = x)
    },
    sw_lm = sw_lm
)

# Log params and metrics.
mlflow_log_param("Intercept", sw_lm$coefficients["(Intercept)"], client = client, run_id = run$run_uuid)
mlflow_log_param("mass", sw_lm$coefficients["mass"], client = client, run_id = run$run_uuid)
mlflow_log_metric("MSE", mean(sw_lm$residuals^2), client = client, run_id = run$run_uuid)

# Log predictions and actual values
sw_lm |>
  predict() |>
  iwalk(
    ~ mlflow_log_metric("prediction", .x, step = as.numeric(.y), client = client, run_id = run$run_uuid)
    )

data$height |> 
  iwalk(
    ~ mlflow_log_metric("actual",  .x, step = .y, client = client, run_id = run$run_uuid)
    )

# Save model to the registry.
crated_model = "/tmp/sw_lm"
saved_model = mlflow_save_model(packaged_sw_lm, crated_model)  
logged_model = mlflow_log_artifact(crated_model, client = client, run_id =  run$run_uuid) 
2022/11/15 21:36:42 INFO mlflow.store.artifact.cli: Logged artifact from local dir /tmp/sw_lm to artifact_path=None
versioned_model = mlflow_create_model_version("sw_lm", run$artifact_uri, run_id = run$run_uuid, client = client)

# Generate report.
sw_report = data |>
  select_if(~ !is.list(.x)) |>
  create_report(output_file = "star_wars.html", output_dir = "/tmp", report_title = "Star Wars Report", quiet = T)
logged_report = mlflow_log_artifact("/tmp/star_wars.html", client = client, run_id =  run$run_uuid) 
2022/11/15 21:36:54 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars.html to artifact_path=None
# Save plots.
sw_plot = "/tmp/star_wars_characters.png"
png(filename = sw_plot)
plot(data$height, data$mass)
doff = dev.off()
logged_plot = mlflow_log_artifact(sw_plot, client = client, run_id =  run$run_uuid) 
2022/11/15 21:36:55 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.png to artifact_path=None
# Save tibble.
data_csv = "/tmp/star_wars_characters.csv"
write_csv(data, data_csv)
logged_csv = mlflow_log_artifact(data_csv, client = client, run_id =  run$run_uuid) 
2022/11/15 21:36:56 INFO mlflow.store.artifact.cli: Logged artifact from local file /tmp/star_wars_characters.csv to artifact_path=None
# End run.
run_end = mlflow_end_run(run_id =  run$run_uuid, client = client)

Loading and Serving Models

Next, we will load the model from the registry.

# Remove the model from the R environment.
print(packaged_sw_lm)
<crate> 39.87 kB
* function: 26.64 kB
* `sw_lm`: 13.19 kB
function(x) stats::predict.lm(sw_lm)
rm(packaged_sw_lm)

# Load the model from the registry.
packaged_sw_lm = mlflow_load_model("models:/sw_lm/1")
print(packaged_sw_lm)
<crate> 43.86 kB
* function: 26.64 kB
* `sw_lm`: 17.18 kB
function(x) stats::predict.lm(sw_lm)

Finally, we will demonstrate how to deploy the model using a model-as-a-service approach. Note that the mlflow::mlflow_rfunc_serve function can be used. Instead, we will launch the model using bash.


export MLFLOW_TRACKING_URI=http://mlflow:5000

# ping http://0.0.0.0:9000/predict 
mlflow models serve -m "models:/sw_lm/1" -h 0.0.0.0 -p 9000 

You can also run the following command to deploy the model in a Docker container: docker compose restart rstudio_mlflow_serve_lm

Grafana

Next, access the Grafana home page. This application will allow you to build your own dashboards.

Go to Configuration -> Data sources -> Add data source.

Grafana Configuration Page

Select PostgreSQL as the data source. Enter the following values into the web form.

  • Host: postgres:5432

  • Database: docker_r_mlops

  • User: user (Default)

  • Password: pass (Default)

  • TLS/SSL Mode: Disable

Grafana Data Source Web Form

Click Save & test.

Next, go to Create -> Dashboard -> Add a new panel. Create the following query.

  • Database: metrics

  • Time column: step

  • Select: Column:value

  • Where: Remove the Macro: $__unixEpochFilter

Click Zoom to data. Your dashboard should look like the below. Click Apply to save the dashboard.

Grafana Dashboard

View the dashboard again. Click Table view.

Table View

Congratulations. Feel free to add new panels and experiment with the various plot types supported by Grafana.

Session Information

R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] forcats_0.5.2      stringr_1.4.1      dplyr_1.0.10       purrr_0.3.5       
 [5] readr_2.1.3        tidyr_1.2.1        tibble_3.1.8       ggplot2_3.3.6     
 [9] tidyverse_1.3.2    mlflow_1.30.0      reticulate_1.26    knitr_1.40        
[13] DataExplorer_0.8.2 carrier_0.1.0     

loaded via a namespace (and not attached):
 [1] httr_1.4.4          jsonlite_1.8.3      modelr_0.1.9       
 [4] assertthat_0.2.1    askpass_1.1         googlesheets4_1.0.1
 [7] cellranger_1.1.0    yaml_2.3.6          pillar_1.8.1       
[10] backports_1.4.1     lattice_0.20-45     glue_1.6.2         
[13] digest_0.6.30       promises_1.2.0.1    rvest_1.0.3        
[16] colorspace_2.0-3    htmltools_0.5.3     httpuv_1.6.6       
[19] Matrix_1.4-1        pkgconfig_2.0.3     broom_1.0.1        
[22] haven_2.5.1         scales_1.2.1        processx_3.8.0     
[25] later_1.3.0         tzdb_0.3.0          openssl_2.0.4      
[28] googledrive_2.0.0   generics_0.1.3      ellipsis_0.3.2     
[31] swagger_3.33.1      withr_2.5.0         cli_3.4.1          
[34] crayon_1.5.2        readxl_1.4.1        magrittr_2.0.3     
[37] evaluate_0.17       ps_1.7.2            fs_1.5.2           
[40] fansi_1.0.3         xml2_1.3.3          tools_4.2.1        
[43] data.table_1.14.4   hms_1.1.2           gargle_1.2.1       
[46] lifecycle_1.0.3     reprex_2.0.2        munsell_0.5.0      
[49] networkD3_0.4       compiler_4.2.1      forge_0.2.0        
[52] rlang_1.0.6         grid_4.2.1          rstudioapi_0.14    
[55] rappdirs_0.3.3      htmlwidgets_1.5.4   igraph_1.3.5       
[58] base64enc_0.1-3     rmarkdown_2.17      gtable_0.3.1       
[61] DBI_1.1.3           R6_2.5.1            ini_0.3.1          
[64] gridExtra_2.3       lubridate_1.8.0     fastmap_1.1.0      
[67] utf8_1.2.2          zeallot_0.1.0       stringi_1.7.8      
[70] parallel_4.2.1      Rcpp_1.0.9          vctrs_0.5.0        
[73] png_0.1-7           dbplyr_2.2.1        tidyselect_1.2.0   
[76] xfun_0.34