Source code for ax.service.utils.best_point_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.core.experiment import Experiment
from pyre_extensions import none_throws

BASELINE_ARM_NAME = "baseline_arm"


[docs] def select_baseline_name_default_first_trial( experiment: Experiment, baseline_arm_name: str | None ) -> tuple[str, bool]: """ Choose a baseline arm from arms on the experiment. Logic: 1. If ``baseline_arm_name`` provided, validate that arm exists and return that arm name. 2. If ``experiment.status_quo`` is set, return its arm name. 3. If there is at least one trial on the experiment, use the first trial's first arm as the baseline. 4. Error if 1-3 all don't apply. Returns: Tuple: baseline arm name (str) true when baseline selected from first arm of experiment (bool) raise ValueError if no valid baseline found """ arms_dict = experiment.arms_by_name if baseline_arm_name: if baseline_arm_name not in arms_dict: raise ValueError(f"Arm by name {baseline_arm_name=} not found.") return baseline_arm_name, False if experiment.status_quo and none_throws(experiment.status_quo).name in arms_dict: baseline_arm_name = none_throws(experiment.status_quo).name return baseline_arm_name, False if ( experiment.trials and experiment.trials[0].arms and experiment.trials[0].arms[0].name in arms_dict ): baseline_arm_name = experiment.trials[0].arms[0].name return baseline_arm_name, True else: raise ValueError("Could not find valid baseline arm.")