Source code for ax.modelbridge.transforms.trial_as_task
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Dict, List, Optional
import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.parameter import ChoiceParameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform
TRIAL_PARAM = "TRIAL_PARAM"
[docs]class TrialAsTask(Transform):
"""Convert trial to one or more task parameters.
How trial is mapped to parameter is specified with a map like
{parameter_name: {trial_index: level name}}.
For example,
{"trial_param1": {0: "level1", 1: "level1", 2: "level2"},}
will create choice parameters "trial_param1" with is_task=True.
Observations with trial 0 or 1 will have "trial_param1" set to "level1",
and those with trial 2 will have "trial_param1" set to "level2". Multiple
parameter names and mappings can be specified in this dict.
The trial level mapping can be specified in config["trial_level_map"]. If
not specified, defaults to a parameter with a level for every trial index.
For the reverse transform, if there are multiple mappings in the transform
the trial will not be set.
Will raise if trial not specified for every point in the training data.
Transform is done in-place.
"""
def __init__(
self,
search_space: SearchSpace,
observation_features: List[ObservationFeatures],
observation_data: List[ObservationData],
config: Optional[TConfig] = None,
) -> None:
# Identify values of trial.
trials = {obsf.trial_index for obsf in observation_features}
if None in trials:
raise ValueError(
"Unable to use trial as task since not all observations have "
"trial specified."
)
# Get trial level map
if config is not None and "trial_level_map" in config:
# pyre: Attribute `trial_level_map` declared in class `ax.
# pyre: modelbridge.transforms.trial_as_task.TrialAsTask` has type
# pyre: `Dict[str, Dict[int, str]]` but is used as type `typing.
# pyre-fixme[8]: Union[float, int, str]`.
self.trial_level_map: Dict[str, Dict[int, str]] = config["trial_level_map"]
# Validate
for _p_name, level_dict in self.trial_level_map.items():
# Check that trials match those in data
level_map = set(level_dict.keys())
if not trials.issubset(level_map):
raise ValueError(
f"Not all trials in data ({trials}) contained "
f"in trial level map for {_p_name} ({level_map})"
)
else:
# Set TRIAL_PARAM for each trial to the corresponding trial_index.
self.trial_level_map = {TRIAL_PARAM: {int(b): str(b) for b in trials}}
if len(self.trial_level_map) == 1:
level_dict = next(iter(self.trial_level_map.values()))
self.inverse_map: Optional[Dict[str, int]] = {
v: k for k, v in level_dict.items()
}
else:
self.inverse_map = None