# Model hyperparameter tuning with scVI

```{warning}
`scvi.autotune` development is still in progress. The API is subject to change.
```

Finding an effective set of model hyperparameters (e.g. learning rate, number of hidden layers, etc.) is an important component in training a model as its performance can be highly dependent on these non-trainable parameters. Manually tuning a model often involves picking a set of hyperparameters to search over and then evaluating different configurations over a validation set for a desired metric. This process can be time consuming and can require some prior intuition about a model and dataset pair, which is not always feasible.

In this tutorial, we show how to use `scvi`'s [`autotune`](https://docs.scvi-tools.org/en/latest/api/user.html#model-hyperparameter-autotuning) module, which allows us to automatically find a good set of model hyperparameters using [Ray Tune](https://docs.ray.io/en/latest/tune/index.html). We will use `SCVI` and a subsample of the [heart cell atlas](https://www.heartcellatlas.org/#DataSources) for the task of batch correction, but the principles outlined here can be applied to any model and dataset. In particular, we will go through the following steps:

1. Installing required packages
1. Loading and preprocessing the dataset
1. Defining the tuner and discovering hyperparameters
1. Running the tuner
1. Comparing latent spaces
1. Optional: Monitoring progress with Tensorboard
1. Optional: Tuning over integration metrics with `scib-metrics`

## Installing required packages

```{note}
Running the following cell will install tutorial dependencies on Google Colab only. It will have no effect on environments other than Google Colab.
```

In [1]:
!pip install --quiet scvi-colab
from scvi_colab import install

install()

[0m

                Not currently in Google Colab environment.

                Please run with `run_outside_colab=True` to override.

                Returning with no further action.
                
  warn(


In [2]:
import tempfile

import ray
import scanpy as sc
import scvi
import seaborn as sns
import torch
from ray import tune
from scvi import autotune

In [3]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

Seed set to 0


Last run with scvi-tools version: 1.1.0


```{note}
You can modify `save_dir` below to change where the data files for this tutorial are saved.
```

In [4]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()
scvi.settings.logging_dir = save_dir.name

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

## Loading and preprocessing the dataset

In [5]:
adata = scvi.data.heart_cell_atlas_subsampled(save_path=save_dir.name)
adata

[34mINFO    [0m Downloading file at [35m/tmp/tmp6xxk_1x6/[0m[95mhca_subsampled_20k.h5ad[0m                                              


Downloading...:   0%|          | 0/65714.0 [00:00<?, ?it/s]

Downloading...:   6%|▋         | 4257/65714.0 [00:00<00:01, 42352.04it/s]

Downloading...:  25%|██▌       | 16445/65714.0 [00:00<00:00, 88999.31it/s]

Downloading...:  41%|████      | 26816/65714.0 [00:00<00:00, 95706.36it/s]

Downloading...:  57%|█████▋    | 37316/65714.0 [00:00<00:00, 99334.82it/s]

Downloading...:  73%|███████▎  | 47862/65714.0 [00:00<00:00, 101523.09it/s]

Downloading...:  89%|████████▉ | 58480/65714.0 [00:00<00:00, 103095.62it/s]

Downloading...: 100%|██████████| 65714/65714.0 [00:00<00:00, 98281.68it/s] 




AnnData object with n_obs × n_vars = 18641 × 26662
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts'
    uns: 'cell_type_colors'

The only preprocessing step we will perform in this case will be to subsample the dataset for 2000 highly variable genes using `scanpy` for faster model training.

In [6]:
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
adata

AnnData object with n_obs × n_vars = 18641 × 2000
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'cell_type_colors', 'hvg'

## Defining the tuner and discovering hyperparameters

The first part of our workflow is the same as the standard `scvi-tools` workflow: we start with our desired model class, and we register our dataset with it using `setup_anndata`. All datasets must be registered prior to hyperparameter tuning.

In [7]:
model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata)


For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead.

For creation, use `anndata.experimental.sparse_dataset(X)` instead.



Our main entry point to the `autotune` module is the `ModelTuner` class, a wrapper around [`ray.tune.Tuner`](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tuner) with additional functionality specific to `scvi-tools`. We can define a new `ModelTuner` by providing it with our model class.

In [8]:
scvi_tuner = autotune.ModelTuner(model_cls)

`ModelTuner` will register all tunable hyperparameters in `SCVI` -- these can be viewed by calling `info()`. By default, this method will display three tables:

1. **Tunable hyperparameters**: The names of hyperparameters that can be tuned, their default values, and the internal classes they are defined in.
1. **Available metrics**: The metrics that can be used to evaluate the performance of the model. One of these must be provided when running the tuner.
1. **Default search space**: The default search space for the model class, which will be used if no search space is provided by the user.

In [9]:
scvi_tuner.info()

## Running the tuner

Now that we know what hyperparameters are available to us, we can define a search space using the [search space API](https://docs.ray.io/en/latest/tune/api/search_space.html) in `ray.tune`. For this tutorial, we choose a simple search space with two model hyperparameters and one training plan hyperparameter. These can all be combined into a single dictionary that we pass into the `fit` method.

In [10]:
search_space = {
    "n_hidden": tune.choice([64, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
}

There are a couple more arguments we should be aware of before fitting the tuner:

- `num_samples`: The total number of hyperparameter sets to sample from `search_space`. This is the total number of models that will be trained.

  For example, if we set `num_samples=2`, we might sample two models with the following hyperparameter configurations:

  ```python
  model1 = {
      "n_hidden": 64,
      "n_layers": 1,
      "lr": 0.001,
  }
  model2 = {
      "n_hidden": 64,
      "n_layers": 3,
      "lr": 0.0001,
  }
  ```

- `max_epochs`: The maximum number of epochs to train each model for.

  Note: This does not mean that each model will be trained for `max_epochs`. Depending on the scheduler used, some trials are likely to be stopped early.

- `resources`: A dictionary of maximum resources to allocate for the whole experiment. This allows us to run concurrent trials on limited hardware.

Now, we can call `fit` on the tuner to start the hyperparameter sweep. This will return a `TuneAnalysis` dataclass, which will contain the best set of hyperparameters, as well as other information.

In [11]:
ray.init(log_to_driver=False)
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    num_samples=5,
    max_epochs=100,
    resources={"cpu": 10, "gpu": 1},
)

0,1
Current time:,2024-02-12 22:37:57
Running for:,00:02:21.96
Memory:,8.0/125.7 GiB

Trial name,status,loc,n_hidden,n_layers,lr,validation_loss
_trainable_8f41734d,TERMINATED,172.29.0.2:1469,128,3,0.00148765,457.597
_trainable_f42dd38f,TERMINATED,172.29.0.2:1469,64,3,0.000455217,748.298
_trainable_a19ea641,TERMINATED,172.29.0.2:1469,256,1,0.000388237,525.478
_trainable_619aa78d,TERMINATED,172.29.0.2:1469,64,2,0.000326625,758.515
_trainable_681ab3ba,TERMINATED,172.29.0.2:1469,256,3,0.00392304,457.508


2024-02-12 22:35:35,776	INFO tune.py:583 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


2024-02-12 22:37:57,784	INFO tune.py:1042 -- Total run time: 142.01 seconds (141.96 seconds for the tuning loop).


In [12]:
print(results.model_kwargs)
print(results.train_kwargs)

{'n_hidden': 256, 'n_layers': 3}
{'plan_kwargs': {'lr': 0.003923035750529307}}


## Comparing latent spaces

Work in progress: please check back in the next release!

## Optional: Monitoring progress with Tensorboard

Work in progress: please check back in the next release!

## Optional: Tuning over integration metrics with `scib-metrics`

Work in progress: please check back in the next release!