#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from copy import deepcopy
from enum import Enum
from functools import lru_cache
from math import sqrt
from typing import Any
import numpy.typing as npt
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from ax.utils.common.typeutils import checked_cast
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import cross_val_score
from sklearn.neural_network import MLPClassifier, MLPRegressor
[docs]
class SklearnModelType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
RF: str = "rf"
# pyre-fixme[35]: Target cannot be annotated.
NN: str = "nn"
[docs]
class SklearnDataset(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DIGITS: str = "digits"
# pyre-fixme[35]: Target cannot be annotated.
BOSTON: str = "boston"
# pyre-fixme[35]: Target cannot be annotated.
CANCER: str = "cancer"
@lru_cache(maxsize=8)
# pyre-fixme[2]: Parameter must be annotated.
def _get_data(dataset) -> dict[str, npt.NDArray]:
"""Return sklearn dataset, loading and caching if necessary."""
if dataset is SklearnDataset.DIGITS:
return datasets.load_digits()
elif dataset is SklearnDataset.BOSTON:
return datasets.load_boston()
elif dataset is SklearnDataset.CANCER:
return datasets.load_breast_cancer()
else:
raise NotImplementedError(
f"{dataset.value} is not a currently supported {dataset.name}."
)
[docs]
class SklearnMetric(Metric):
"""A metric that trains and evaluates an sklearn model.
The evaluation metric is the k-fold "score". The scoring function
depends on the model type and task type (e.g. classification/regression),
but higher scores are better.
See sklearn documentation for supported parameters.
In addition, this metric supports tuning the hidden_layer_size
and the number of hidden layers (num_hidden_layers) of a NN model.
"""
def __init__(
self,
name: str,
lower_is_better: bool = False,
model_type: SklearnModelType = SklearnModelType.RF,
dataset: SklearnDataset = SklearnDataset.DIGITS,
observed_noise: bool = False,
num_folds: int = 5,
) -> None:
"""Initialize metric.
Args:
name: Name of the metric.
lower_is_better: Flag for metrics which should be minimized.
model_type: Sklearn model type
dataset: Sklearn Dataset for training/evaluating the model
observed_noise: A boolean indicating whether to return the SE
of the mean k-fold cross-validation score.
num_folds: The number of cross-validation folds.
"""
super().__init__(name=name, lower_is_better=lower_is_better)
self.dataset = dataset
self.model_type = model_type
self.num_folds = num_folds
self.observed_noise = observed_noise
if self.dataset is SklearnDataset.DIGITS:
regression = False
elif self.dataset is SklearnDataset.BOSTON:
regression = True
elif self.dataset is SklearnDataset.CANCER:
regression = False
else:
raise NotImplementedError(
f"{self.dataset.value} is not a currently supported {dataset.name}."
)
if model_type is SklearnModelType.NN:
if regression:
# pyre-fixme[4]: Attribute must be annotated.
self._model_cls = MLPRegressor
else:
self._model_cls = MLPClassifier
elif model_type is SklearnModelType.RF:
if regression:
self._model_cls = RandomForestRegressor
else:
self._model_cls = RandomForestClassifier
else:
raise NotImplementedError(
f"{model_type.value} is not a currently supported {model_type.name}."
)
[docs]
def clone(self) -> SklearnMetric:
return self.__class__(
name=self._name,
lower_is_better=checked_cast(bool, self.lower_is_better),
model_type=self.model_type,
dataset=self.dataset,
)
[docs]
def fetch_trial_data(
self, trial: BaseTrial, noisy: bool = True, **kwargs: Any
) -> MetricFetchResult:
try:
arm_names = []
means = []
sems = []
for name, arm in trial.arms_by_name.items():
arm_names.append(name)
# TODO: Consider parallelizing evaluation of large batches
# (e.g. via ProcessPoolExecutor)
mean, sem = self.train_eval(arm=arm)
means.append(mean)
sems.append(sem)
df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self._name,
"mean": means,
"sem": sems,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
[docs]
def train_eval(self, arm: Arm) -> tuple[float, float]:
"""Train and evaluate model.
Args:
arm: An arm specifying the parameters to evaluate.
Returns:
A two-element tuple containing:
- The average k-fold CV score
- The SE of the mean k-fold CV score if observed_noise is True
and 'nan' otherwise
"""
data = _get_data(self.dataset) # cached
X, y = data["data"], data["target"]
params: dict[str, Any] = deepcopy(arm.parameters)
if self.model_type == SklearnModelType.NN:
hidden_layer_size = params.pop("hidden_layer_size", None)
if hidden_layer_size is not None:
hidden_layer_size = checked_cast(int, hidden_layer_size)
num_hidden_layers = checked_cast(
int, params.pop("num_hidden_layers", 1)
)
params["hidden_layer_sizes"] = [hidden_layer_size] * num_hidden_layers
model = self._model_cls(**params)
cv_scores = cross_val_score(model, X, y, cv=self.num_folds)
mean = cv_scores.mean()
sem = (
cv_scores.std() / sqrt(cv_scores.shape[0])
if self.observed_noise
else float("nan")
)
return mean, sem