#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Dict, List, Optional, TYPE_CHECKING
import numpy as np
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization, TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.rounding import (
randomized_onehot_round,
strict_onehot_round,
)
from ax.modelbridge.transforms.utils import construct_new_search_space
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
OH_PARAM_INFIX = "_OH_PARAM_"
[docs]class OneHotEncoder:
"""OneHot encodes a list of labels."""
def __init__(self, values: List[TParamValue]) -> None:
assert len(values) >= 2
self.values: List[TParamValue] = values
self.encoded_len: int = 1 if len(values) == 2 else len(values)
[docs]class OneHot(Transform):
"""Convert categorical parameters (unordered ChoiceParameters) to
one-hot-encoded parameters.
Does not convert task parameters.
Parameters will be one-hot-encoded, yielding a set of RangeParameters,
of type float, on [0, 1]. If there are two values, one single RangeParameter
will be yielded, otherwise there will be a new RangeParameter for each
ChoiceParameter value.
In the reverse transform, floats can be converted to a one-hot encoded vector
using one of two methods:
Strict rounding: Choose the maximum value. With levels ['a', 'b', 'c'] and
float values [0.2, 0.4, 0.3], the restored parameter would be set to 'b'.
Ties are broken randomly, so values [0.2, 0.4, 0.4] is randomly set to 'b'
or 'c'.
Randomized rounding: Sample from the distribution. Float values
[0.2, 0.4, 0.3] are transformed to 'a' w.p.
0.2/0.9, 'b' w.p. 0.4/0.9, or 'c' w.p. 0.3/0.9.
Type of rounding can be set using transform_config['rounding'] to either
'strict' or 'randomized'. Defaults to strict.
Transform is done in-place.
"""
def __init__(
self,
search_space: Optional[SearchSpace] = None,
observations: Optional[List[Observation]] = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: Optional[TConfig] = None,
) -> None:
assert search_space is not None, "OneHot requires search space"
# Identify parameters that should be transformed
# pyre-fixme[4]: Attribute must be annotated.
self.rounding = "strict"
if config is not None:
self.rounding = config.get("rounding", "strict")
self.encoder: Dict[str, OneHotEncoder] = {}
self.encoded_parameters: Dict[str, List[str]] = {}
self.encoded_values: Dict[str, List[TParamValue]] = {}
for p in search_space.parameters.values():
if isinstance(p, ChoiceParameter) and not p.is_ordered and not p.is_task:
self.encoded_values[p.name] = p.values
self.encoder[p.name] = OneHotEncoder(p.values)
encoded_len = self.encoder[p.name].encoded_len
if encoded_len == 1:
# Two levels handled in one parameter
self.encoded_parameters[p.name] = [p.name + OH_PARAM_INFIX]
else:
self.encoded_parameters[p.name] = [
"{}{}_{}".format(p.name, OH_PARAM_INFIX, i)
for i in range(encoded_len)
]
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: Dict[str, Parameter] = {}
for p_name, p in search_space.parameters.items():
if p_name in self.encoded_parameters:
p = checked_cast(ChoiceParameter, p)
if p.is_fidelity:
raise ValueError(
f"Cannot one-hot-encode fidelity parameter {p_name}"
)
if not set(p.values).issubset(self.encoded_values[p_name]):
raise ValueError(
f"{p_name} has values {p.values} which are not a subset of "
f"the original values {self.encoded_values[p_name]} used to "
"initialize the transform."
)
encoded_p = self.encoded_parameters[p_name]
if len(encoded_p) > 1:
# Remove any parameters that are not in the search space being
# transformed. This is necessary if the search space used to
# initialize the transform is larger than the search space
# being transformed, to ensure that the missing parameters
# do not get selected.
encoded_p = [
encoded_p[self.encoded_values[p_name].index(v)]
for v in p.values
]
for new_p_name in encoded_p:
transformed_parameters[new_p_name] = RangeParameter(
name=new_p_name,
parameter_type=ParameterType.FLOAT,
lower=0,
upper=1,
)
else:
transformed_parameters[p_name] = p
return construct_new_search_space(
search_space=search_space,
parameters=list(transformed_parameters.values()),
parameter_constraints=[
pc.clone_with_transformed_parameters(
transformed_parameters=transformed_parameters
)
for pc in search_space.parameter_constraints
],
)