Source code for ax.telemetry.generation_strategy
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from dataclasses import dataclass
from math import inf
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS
[docs]@dataclass(frozen=True)
class GenerationStrategyCreatedRecord:
"""
Record of the GenerationStrategy creation event. This can be used for telemetry in
settings where many GenerationStrategy are being created either manually or
programatically. In order to facilitate easy serialization only include simple
types: numbers, strings, bools, and None.
"""
generation_strategy_name: str
# -1 indicates unlimited trials requested, 0 indicates no trials requested
num_requested_initialization_trials: int # Typically the number of Sobol trials
num_requested_bayesopt_trials: int
num_requested_other_trials: int
# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
max_parallelism: int
[docs] @classmethod
def from_generation_strategy(
cls, generation_strategy: GenerationStrategy
) -> GenerationStrategyCreatedRecord:
# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
true_max_parallelism = min(
step.max_parallelism or inf for step in generation_strategy._steps
)
return cls(
generation_strategy_name=generation_strategy.name,
num_requested_initialization_trials=sum(
step.num_trials
for step in generation_strategy._steps
if step.model in INITIALIZATION_MODELS
),
num_requested_bayesopt_trials=sum(
step.num_trials
for step in generation_strategy._steps
if step.model not in INITIALIZATION_MODELS + OTHER_MODELS
),
num_requested_other_trials=sum(
step.num_trials
for step in generation_strategy._steps
if step.model in OTHER_MODELS
),
max_parallelism=(
true_max_parallelism if isinstance(true_max_parallelism, int) else -1
),
)