Source code for ax.benchmark.problems.hpo.torchvision

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict

from ax.benchmark.problems.hpo.pytorch_cnn import (
    PyTorchCNNBenchmarkProblem,
    PyTorchCNNRunner,
)
from ax.exceptions.core import UserInputError
from ax.utils.common.typeutils import checked_cast

try:  # We don't require TorchVision by default.
    from torchvision import datasets, transforms

    _REGISTRY = {"MNIST": datasets.MNIST, "FashionMNIST": datasets.FashionMNIST}
except ModuleNotFoundError:
    transforms = None
    datasets = None
    _REGISTRY = {}


[docs]class PyTorchCNNTorchvisionBenchmarkProblem(PyTorchCNNBenchmarkProblem):
[docs] @classmethod def from_dataset_name( cls, name: str, num_trials: int, infer_noise: bool = True, ) -> "PyTorchCNNTorchvisionBenchmarkProblem": if name not in _REGISTRY: raise UserInputError( f"Unrecognized torchvision dataset {name}. Please ensure it is listed" "in PyTorchCNNTorchvisionBenchmarkProblem registry." ) dataset_fn = _REGISTRY[name] train_set = dataset_fn( root="./data", train=True, download=True, transform=transforms.ToTensor(), ) test_set = dataset_fn( root="./data", train=False, download=True, transform=transforms.ToTensor(), ) problem = cls.from_datasets( name=name, num_trials=num_trials, train_set=train_set, test_set=test_set, infer_noise=infer_noise, ) runner = PyTorchCNNTorchvisionRunner( name=name, train_set=train_set, test_set=test_set ) return cls( name=f"HPO_PyTorchCNN_Torchvision::{name}", search_space=problem.search_space, optimization_config=problem.optimization_config, runner=runner, num_trials=num_trials, infer_noise=infer_noise, optimal_value=problem.optimal_value, )
[docs]class PyTorchCNNTorchvisionRunner(PyTorchCNNRunner): """ A subclass to aid in serialization. This allows us to save only the name of the dataset and reload it from TorchVision at deserialization time. """
[docs] @classmethod # pyre-fixme[2]: Parameter annotation cannot be `Any`. def serialize_init_args(cls, obj: Any) -> Dict[str, Any]: pytorch_cnn_runner = checked_cast(PyTorchCNNRunner, obj) return {"name": pytorch_cnn_runner.name}
[docs] @classmethod def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]: name = args["name"] dataset_fn = _REGISTRY[name] train_set = dataset_fn( root="./data", train=True, download=True, transform=transforms.ToTensor(), ) test_set = dataset_fn( root="./data", train=False, download=True, transform=transforms.ToTensor(), ) return {"name": name, "train_set": train_set, "test_set": test_set}