--- title: "Saving and Using Saved Models" date: "`r format(Sys.Date(), '%Y-%m-%d')`" output: rmarkdown::html_vignette: toc: true toc_depth: 2 fig_width: 7 fig_height: 5 dpi: 600 vignette: > %\VignetteIndexEntry{Saving and Using Saved Models} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- This vignette shows how to save a trained CISS-VAE model, reload it in a later R session, and use it to impute new data. As of rCISSVAE version 0.0.6, the recommended workflow is to save the model as a PyTorch `state_dict` together with an automatically generated architecture config file. This is more portable across Python and PyTorch and more robust than saving the full live Python object. ## Saving a trained CISS-VAE model ```{r} #| eval: false library(reticulate) library(rCISSVAE) # Train a model res <- run_cissvae(data, return_model = TRUE) # Save the trained model to disk # by default, this writes two files: # - trained_vae.pt (the model weights) # - trained_vae.pt.config.rds (the saved model configuration/architecture needed to rebuild the model when loading. ) save_cissvae_model(res$model, "trained_vae.pt", method="state_dict") # IMPORTANT # The Python environment must be active so 'torch' can be imported. ``` There is still the option to save the full model as a .pt file if desired by setting `method='full'`. ## Loading a saved model and imputing data ```{.r} library(rCISSVAE) library(reticulate) ## Activate your Python environment reticulate::use_virtualenv("cissvae_environment", required = TRUE) ## Load full model object model <- load_cissvae_model( file = "trained_vae.pt", method = "state_dict", ## (or 'full' if full was used when saving) device = "cpu" ) ## Perform imputation on new data ## Make sure your `data` has valid NAs and `clusters` vector is ready ## `val_proportion`, `categorical_column_map` and `replacement_value` are not needed because we are just imputing imputed_df <- impute_with_cissvae( model_py = model, data = data, index_col = "index", columns_ignore = NULL, clusters = clusters, imputable_matrix = NULL, binary_feature_mask = NULL, batch_size = 4000L, seed = 42 ) # `imputed_df` is returned to R as a data.frame ``` If you have binary variables in your dataset, make sure to define the binary_feature_mask in your `impute_with_cissvae()` call and convert the probabilities for the binary variables into {0, 1} values after imputation using desired thresholding.