Source code for ax.modelbridge.transforms.search_space_to_choice

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import List, Optional

from ax.core.arm import Arm
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform


[docs]class SearchSpaceToChoice(Transform): """Replaces the search space with a single choice parameter, whose values are the signatures of the arms observed in the data. This transform is meant to be used with ThompsonSampler. Transform is done in-place. """ def __init__( self, search_space: SearchSpace, observation_features: List[ObservationFeatures], observation_data: List[ObservationData], config: Optional[TConfig] = None, ) -> None: self.parameter_name = "arms" self.signature_to_parameterization = { Arm(parameters=obsf.parameters).signature: obsf.parameters for obsf in observation_features }
[docs] def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: values = list(self.signature_to_parameterization.keys()) if len(values) > 1: parameter = ChoiceParameter( name=self.parameter_name, parameter_type=ParameterType.STRING, values=values, ) else: parameter = FixedParameter( name=self.parameter_name, parameter_type=ParameterType.STRING, value=values[0], ) return SearchSpace(parameters=[parameter])
[docs] def transform_observation_features( self, observation_features: List[ObservationFeatures] ) -> List[ObservationFeatures]: for obsf in observation_features: obsf.parameters = { self.parameter_name: Arm(parameters=obsf.parameters).signature } return observation_features
[docs] def untransform_observation_features( self, observation_features: List[ObservationFeatures] ) -> List[ObservationFeatures]: for obsf in observation_features: signature = obsf.parameters[self.parameter_name] obsf.parameters = self.signature_to_parameterization[signature] return observation_features