Source code for ax.modelbridge.transforms.ordered_choice_encode
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Dict, List, Optional
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TParamValue
from ax.modelbridge.transforms.base import Transform
[docs]class OrderedChoiceEncode(Transform):
"""Convert ordered ChoiceParameters to unit length RangeParameters.
Parameters will be transformed to an integer RangeParameter,
mapped from the original choice domain to a contiguous range from [0, n_choices].
Does not transform task parameters.
In the inverse transform, parameters will be mapped back onto the original domain.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
config: Optional[TConfig] = None,
) -> None:
# Identify parameters that should be transformed
self.encoded_parameters: Dict[str, Dict[TParamValue, int]] = {}
for p in search_space.parameters.values():
if isinstance(p, ChoiceParameter) and p.is_ordered and not p.is_task:
self.encoded_parameters[p.name] = {
original_value: transformed_value
for transformed_value, original_value in enumerate(p.values)
}
self.encoded_parameters_inverse: Dict[str, Dict[int, TParamValue]] = {
p_name: {
transformed_value: original_value
for original_value, transformed_value in transforms.items()
}
for p_name, transforms in self.encoded_parameters.items()
}
[docs] def transform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.encoded_parameters:
if p_name in obsf.parameters:
obsf.parameters[p_name] = self.encoded_parameters[p_name][
obsf.parameters[p_name]
]
return observation_features
[docs] def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters: Dict[str, Parameter] = {}
for p in search_space.parameters.values():
if p.name in self.encoded_parameters:
# TypeAssert. Only ChoiceParameters present here.
# pyre: p_ is declared to have type `ChoiceParameter` but is
# pyre-fixme[9]: used as type `Parameter`.
p_: ChoiceParameter = p
# Choice(|K|) => Range(0, K-1)
transformed_parameters[p.name] = RangeParameter(
name=p_.name,
parameter_type=ParameterType.INT,
lower=0,
upper=len(p_.values) - 1,
)
else:
transformed_parameters[p.name] = p
return SearchSpace(
parameters=list(transformed_parameters.values()),
parameter_constraints=[
pc.clone() for pc in search_space.parameter_constraints
],
)
[docs] def untransform_observation_features(
self, observation_features: List[ObservationFeatures]
) -> List[ObservationFeatures]:
for obsf in observation_features:
for p_name, reverse_transform in self.encoded_parameters_inverse.items():
# pyre: pval is declared to have type `int` but is used as
# pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`.
pval: int = obsf.parameters[p_name]
obsf.parameters[p_name] = reverse_transform[pval]
return observation_features