Source code for ax.benchmark.problems.hpo.pytorch_cnn

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Iterable, Set

import pandas as pd
import torch
from ax.benchmark.benchmark_problem import SingleObjectiveBenchmarkProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.result import Err, Ok
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset


[docs]class PyTorchCNNBenchmarkProblem(SingleObjectiveBenchmarkProblem): @equality_typechecker def __eq__(self, other: Base) -> bool: if not isinstance(other, PyTorchCNNBenchmarkProblem): return False # Checking the whole datasets' equality here would be too expensive to be # worth it; just check names instead return self.name == other.name
[docs] @classmethod def from_datasets( cls, name: str, num_trials: int, train_set: Dataset, test_set: Dataset ) -> "PyTorchCNNBenchmarkProblem": optimal_value = 1 search_space = SearchSpace( parameters=[ RangeParameter( name="lr", parameter_type=ParameterType.FLOAT, lower=1e-6, upper=0.4 ), RangeParameter( name="momentum", parameter_type=ParameterType.FLOAT, lower=0, upper=1, ), RangeParameter( name="weight_decay", parameter_type=ParameterType.FLOAT, lower=0, upper=1, ), RangeParameter( name="step_size", parameter_type=ParameterType.INT, lower=1, upper=100, ), RangeParameter( name="gamma", parameter_type=ParameterType.FLOAT, lower=0, upper=1, ), ] ) optimization_config = OptimizationConfig( objective=Objective( metric=PyTorchCNNMetric(), minimize=False, ) ) runner = PyTorchCNNRunner(name=name, train_set=train_set, test_set=test_set) return cls( name=f"HPO_PyTorchCNN_{name}", optimal_value=optimal_value, search_space=search_space, optimization_config=optimization_config, runner=runner, num_trials=num_trials, )
[docs]class PyTorchCNNMetric(Metric): def __init__(self) -> None: super().__init__(name="accuracy")
[docs] def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: try: accuracy = [ trial.run_metadata["accuracy"][name] for name, arm in trial.arms_by_name.items() ] df = pd.DataFrame( { "arm_name": [name for name, _ in trial.arms_by_name.items()], "metric_name": self.name, "mean": accuracy, "sem": 0, "trial_index": trial.index, } ) return Ok(value=Data(df=df)) except Exception as e: return Err( value=MetricFetchE( message=f"Failed to fetch {self.name} for trial {trial}", exception=e, ) )
[docs]class PyTorchCNNRunner(Runner): def __init__(self, name: str, train_set: Dataset, test_set: Dataset) -> None: self.name = name # pyre-fixme[4]: Attribute must be annotated. self.train_loader = DataLoader(train_set) # pyre-fixme[4]: Attribute must be annotated. self.test_loader = DataLoader(test_set) self.results: Dict[int, float] = {} self.statuses: Dict[int, TrialStatus] = {} self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs] class CNN(nn.Module): # pyre-fixme[3]: Return type must be annotated. def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1) self.fc1 = nn.Linear(8 * 8 * 20, 64) self.fc2 = nn.Linear(64, 10) # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated.
[docs] def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 3, 3) x = x.view(-1, 8 * 8 * 20) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=-1)
[docs] def train_and_evaluate( self, lr: float, momentum: float, weight_decay: float, step_size: int, gamma: float, ) -> float: net = self.CNN() net.to(device=self.device) # Train net.train() criterion = nn.NLLLoss(reduction="sum") optimizer = optim.SGD( net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, ) scheduler = optim.lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=gamma ) for inputs, labels in self.train_loader: inputs = inputs.to(device=self.device) labels = labels.to(device=self.device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() scheduler.step() # Evaluate net.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in self.test_loader: outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct / total
[docs] def run(self, trial: BaseTrial) -> Dict[str, Any]: self.statuses[trial.index] = TrialStatus.RUNNING self.statuses[trial.index] = TrialStatus.COMPLETED return { "accuracy": { arm.name: self.train_and_evaluate( lr=arm.parameters["lr"], # pyre-ignore[6] momentum=arm.parameters["momentum"], # pyre-ignore[6] weight_decay=arm.parameters["weight_decay"], # pyre-ignore[6] step_size=arm.parameters["step_size"], # pyre-ignore[6] gamma=arm.parameters["gamma"], # pyre-ignore[6] ) for arm in trial.arms } }
[docs] def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> Dict[TrialStatus, Set[int]]: return {TrialStatus.COMPLETED: {t.index for t in trials}}