ax.utils

Common

Base

class ax.utils.common.base.Base[source]

Bases: object

Metaclass for core Ax classes. Provides an equality check and db_id property for SQA storage.

property db_id
class ax.utils.common.base.SortableBase[source]

Bases: ax.utils.common.base.Base

Extension to the base class that also provides an inequality check.

Constants

class ax.utils.common.constants.Keys(value)[source]

Bases: str, enum.Enum

Enum of reserved keys in options dicts etc, alphabetized.

NOTE: Useful for keys in dicts that correspond to kwargs to classes or functions and/or are used in multiple places.

ACQF_KWARGS = 'acquisition_function_kwargs'
BATCH_INIT_CONDITIONS = 'batch_initial_conditions'
CANDIDATE_SET = 'candidate_set'
CANDIDATE_SIZE = 'candidate_size'
COST_AWARE_UTILITY = 'cost_aware_utility'
COST_INTERCEPT = 'cost_intercept'
CURRENT_VALUE = 'current_value'
EXPAND = 'expand'
EXPECTED_ACQF_VAL = 'expected_acquisition_value'
FIDELITY_FEATURES = 'fidelity_features'
FIDELITY_WEIGHTS = 'fidelity_weights'
FRAC_RANDOM = 'frac_random'
FULL_PARAMETERIZATION = 'full_parameterization'
IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = 'immutable_search_space_and_opt_config'
MAXIMIZE = 'maximize'
METADATA = 'metadata'
METRIC_NAMES = 'metric_names'
NUM_FANTASIES = 'num_fantasies'
NUM_INNER_RESTARTS = 'num_inner_restarts'
NUM_RESTARTS = 'num_restarts'
NUM_TRACE_OBSERVATIONS = 'num_trace_observations'
OBJECTIVE = 'objective'
OPTIMIZER_KWARGS = 'optimizer_kwargs'
PREFERENCE_DATA = 'preference_data'
PROJECT = 'project'
QMC = 'qmc'
RAW_INNER_SAMPLES = 'raw_inner_samples'
RAW_SAMPLES = 'raw_samples'
REFIT_ON_UPDATE = 'refit_on_update'
SAMPLER = 'sampler'
SEED_INNER = 'seed_inner'
SEQUENTIAL = 'sequential'
STATE_DICT = 'state_dict'
SUBCLASS = 'subclass'
SUBSET_MODEL = 'subset_model'
TASK_FEATURES = 'task_features'
TRIAL_COMPLETION_TIMESTAMP = 'trial_completion_timestamp'
WARM_START_REFITTING = 'warm_start_refitting'
X_BASELINE = 'X_baseline'

Docutils

Support functions for sphinx et. al

ax.utils.common.docutils.copy_doc(src: Callable[[], Any])Callable[[_T], _T][source]

A decorator that copies the docstring of another object

Since sphinx actually loads the python modules to grab the docstrings this works with both sphinx and the help function.

class Cat(Mamal):

  @property
  @copy_doc(Mamal.is_feline)
  def is_feline(self) -> true:
      ...

Equality

ax.utils.common.equality.dataframe_equals(df1: pandas.DataFrame, df2: pandas.DataFrame)bool[source]

Compare equality of two pandas dataframes.

ax.utils.common.equality.datetime_equals(dt1: Optional[datetime.datetime], dt2: Optional[datetime.datetime])bool[source]

Compare equality of two datetimes, ignoring microseconds.

ax.utils.common.equality.equality_typechecker(eq_func: Callable)Callable[source]

A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type.

ax.utils.common.equality.object_attribute_dicts_equal(one_dict: Dict[str, Any], other_dict: Dict[str, Any])bool[source]

Utility to check if all items in attribute dicts of two Ax objects are the same.

NOTE: Special-cases some Ax object attributes, like “_experiment” or “_model”, where full equality is hard to check.

ax.utils.common.equality.object_attribute_dicts_find_unequal_fields(one_dict: Dict[str, Any], other_dict: Dict[str, Any], fast_return: bool = True, skip_db_id_check: bool = False)Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Tuple[Any, Any]]][source]

Utility for finding out what attributes of two objects’ attribute dicts are unequal.

Parameters
  • one_dict – First object’s attribute dict (obj.__dict__).

  • other_dict – Second object’s attribute dict (obj.__dict__).

  • fast_return – Boolean representing whether to return as soon as a single unequal attribute was found or to iterate over all attributes and collect all unequal ones.

Returns

  • attribute name to attribute values of unequal type (as a tuple),

  • attribute name to attribute values of unequal value (as a tuple).

Return type

Two dictionaries

ax.utils.common.equality.same_elements(list1: List[Any], list2: List[Any])bool[source]

Compare equality of two lists of core Ax objects.

Assumptions:

– The contents of each list are types that implement __eq__ – The lists do not contain duplicates

Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other.

Equality

ax.utils.common.executils.handle_exceptions_in_retries(no_retry_exceptions: Tuple[Type[Exception], ], retry_exceptions: Tuple[Type[Exception], ], suppress_errors: bool, check_message_contains: Optional[str], last_retry: bool, logger: Optional[logging.Logger], wrap_error_message_in: Optional[str])Generator[None, None, None][source]
ax.utils.common.executils.retry_on_exception(exception_types: Optional[Tuple[Type[Exception], ]] = None, no_retry_on_exception_types: Optional[Tuple[Type[Exception], ]] = None, check_message_contains: Optional[List[str]] = None, retries: int = 3, suppress_all_errors: bool = False, logger: Optional[logging.Logger] = None, default_return_on_suppression: Optional[Any] = None, wrap_error_message_in: Optional[str] = None, initial_wait_seconds: Optional[int] = None)Optional[Any][source]

A decorator for instance methods or standalone functions that makes them retry on failure and allows to specify on which types of exceptions the function should and should not retry.

NOTE: If the argument suppress_all_errors is supplied and set to True, the error will be suppressed and default value returned.

Parameters
  • exception_types – A tuple of exception(s) types to catch in the decorated function. If none is provided, baseclass Exception will be used.

  • no_retry_on_exception_types – Exception types to consider non-retryable even if their supertype appears in exception_types or the only exceptions to not retry on if no exception_types are specified.

  • check_message_contains – A list of strings, against which to match error messages. If the error message contains any one of these strings, the exception will cause a retry. NOTE: This argument works in addition to exception_types; if those are specified, only the specified types of exceptions will be caught and retried on if they contain the strings provided as check_message_contains.

  • retries – Number of retries to perform.

  • suppress_all_errors – If true, after all the retries are exhausted, the error will still be suppressed and default_return_on_suppresion will be returned from the function. NOTE: If using this argument, the decorated function may not actually get fully executed, if it consistently raises an exception.

  • logger – A handle for the logger to be used.

  • default_return_on_suppression – If the error is suppressed after all the retries, then this default value will be returned from the function. Defaults to None.

  • wrap_error_message_in – If raising the error message after all the retries, a string wrapper for the error message (useful for making error messages more user-friendly). NOTE: Format of resulting error will be: “<wrap_error_message_in>: <original_error_type>: <original_error_msg>”, with the stack trace of the original message.

  • initial_wait_seconds – Initial length of time to wait between failures, doubled after each failure up to a maximum of 10 minutes. If unspecified then there is no wait between retries.

Kwargs

ax.utils.common.kwargs.consolidate_kwargs(kwargs_iterable: Iterable[Optional[Dict[str, Any]]], keywords: Iterable[str])Dict[str, Any][source]

Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the ones that appear earlier and only kwargs referenced in keywords will be used. This allows to combine somewhat redundant sets of kwargs, where a user-set kwarg, for instance, needs to override a default kwarg.

>>> consolidate_kwargs(
...     kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}],
...     keywords=['a', 'b', 'd']
... )
{'a': 1, 'b': 3, 'd': 5}
ax.utils.common.kwargs.filter_kwargs(function: Callable, **kwargs: Any)Any[source]

Filter out kwargs that are not applicable for a given function. Return a copy of given kwargs dict with only the required kwargs.

ax.utils.common.kwargs.get_function_argument_names(function: Callable, omit: Optional[List[str]] = None)List[str][source]

Extract parameter names from function signature.

ax.utils.common.kwargs.get_function_default_arguments(function: Callable)Dict[str, Any][source]

Extract default arguments from function signature.

ax.utils.common.kwargs.validate_kwarg_typing(typed_callables: List[Callable], **kwargs: Any)None[source]

Check if keywords in kwargs exist in any of the typed_callables and if the type of each keyword value matches the type of corresponding arg in one of the callables

Note: this function expects the typed callables to have unique keywords for the arguments and will raise an error if repeat keywords are found.

ax.utils.common.kwargs.warn_on_kwargs(callable_with_kwargs: Callable, **kwargs: Any)None[source]

Log a warning when a decoder function receives unexpected kwargs.

NOTE: This mainly caters to the use case where an older version of Ax is used to decode objects, serialized to JSON by a newer version of Ax (and therefore potentially containing new fields). In that case, the decoding function should not fail when encountering those additional fields, but rather just ignore them and log a warning using this function.

Logger

class ax.utils.common.logger.AxOutputNameFilter(name='')[source]

Bases: logging.Filter

This is a filter which sets the record’s output_name, if not configured

filter(record: logging.LogRecord)bool[source]

Determine if the specified record is to be logged.

Is the specified record to be logged? Returns 0 for no, nonzero for yes. If deemed appropriate, the record may be modified in-place.

ax.utils.common.logger.build_file_handler(filepath: str, level: int = 20)logging.StreamHandler[source]

Build a file handle that logs entries to the given file, using the same formatting as the stream handler.

Parameters
  • filepath – Location of the file to log output to. If the file exists, output will be appended. If it does not exist, a new file will be created.

  • level – The log level. By default, sets level to INFO

Returns

A logging.FileHandler instance

ax.utils.common.logger.build_stream_handler(level: int = 20)logging.StreamHandler[source]

Build the default stream handler used for most Ax logging. Sets default level to INFO, instead of WARNING.

Parameters

level – The log level. By default, sets level to INFO

Returns

A logging.StreamHandler instance

class ax.utils.common.logger.disable_logger(name: str, level: int = 40)[source]

Bases: object

decorate_callable(func: Callable[[], T])Callable[[], T][source]
decorate_class(klass: T)T[source]
ax.utils.common.logger.get_logger(name: str)logging.Logger[source]

Get an Axlogger.

To set a human-readable “output_name” that appears in logger outputs, add {“output_name”: “[MY_OUTPUT_NAME]”} to the logger’s contextual information. By default, we use the logger’s name

NOTE: To change the log level on particular outputs (e.g. STDERR logs), set the proper log level on the relevant handler, instead of the logger e.g. logger.handers[0].setLevel(INFO)

Parameters

name – The name of the logger.

Returns

The logging.Logger object.

ax.utils.common.logger.make_indices_str(indices: Iterable[int])str[source]

Generate a string representation of an iterable of indices; if indices are contiguous, returns a string formatted like like ‘<min_idx> - <max_idx>’, otherwise a string formatted like ‘[idx_1, idx_2, …, idx_n’].

ax.utils.common.logger.set_stderr_log_level(level: int)None[source]

Set the log level for stream handler, such that logs of given level are printed to STDERR by the root logger

Serialization

ax.utils.common.serialization.callable_from_reference(path: str)Callable[source]

Retrieves a callable by its path.

ax.utils.common.serialization.callable_to_reference(callable: Callable)str[source]

Obtains path to the callable of form <module>.<name>.

ax.utils.common.serialization.extract_init_args(args: Dict[str, Any], class_: Type)Dict[str, Any][source]

Given a dictionary, extract the arguments required for the given class’s constructor.

ax.utils.common.serialization.named_tuple_to_dict(data: Any)Any[source]

Recursively convert NamedTuples to dictionaries.

ax.utils.common.serialization.serialize_init_args(object: Any, exclude_fields: Optional[List[str]] = None)Dict[str, Any][source]

Given an object, return a dictionary of the arguments that are needed by its constructor.

Testutils

Support functions for tests

class ax.utils.common.testutils.TestCase(methodName: str = 'runTest')[source]

Bases: unittest.case.TestCase

The base Ax test case, contains various helper functions to write unittests.

assertEqual(first: Any, second: Any, msg: Optional[str] = None)None[source]

Fail if the two objects are unequal as determined by the ‘==’ operator.

assertRaisesOn(exc: Type[Exception], line: Optional[str] = None, regex: Optional[str] = None)AbstractContextManager[None][source]

Assert that an exception is raised on a specific line.

static silence_stderr()Generator[None, None, None][source]

A context manager that silences stderr for part of a test.

If any exception passes through this context manager the stderr will be printed, otherwise it will be discarded.

Timeutils

ax.utils.common.timeutils.current_timestamp_in_millis()int[source]

Grab current timestamp in milliseconds as an int.

ax.utils.common.timeutils.timestamps_in_range(start: datetime.datetime, end: datetime.datetime, delta: datetime.timedelta)Generator[datetime.datetime, None, None][source]

Generator of timestamps in range [start, end], at intervals delta.

ax.utils.common.timeutils.to_ds(ts: datetime.datetime)str[source]

Convert a datetime to a DS string.

ax.utils.common.timeutils.to_ts(ds: str)datetime.datetime[source]

Convert a DS string to a datetime.

Typeutils

ax.utils.common.typeutils.checked_cast(typ: Type[T], val: V)T[source]

Cast a value to a type (with a runtime safety check).

Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type.

Like typing.cast check_cast performs no runtime conversion on its argument, but, unlike typing.cast, checked_cast will throw an error if the value is not of the expected type. The type passed as an argument should be a python class.

Parameters
  • typ – the type to cast to

  • val – the value that we are casting

Returns

the val argument, unchanged

ax.utils.common.typeutils.checked_cast_complex(typ: Type[T], val: V, message: Optional[str] = None)T[source]

Cast a value to a type (with a runtime safety check). Used for subscripted generics which isinstance cannot run against.

Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type.

Like typing.cast check_cast performs no runtime conversion on its argument, but, unlike typing.cast, checked_cast will throw an error if the value is not of the expected type.

Parameters
  • typ – the type to cast to

  • val – the value that we are casting

  • message – message to print on error

Returns

the val argument casted to typ

ax.utils.common.typeutils.checked_cast_dict(key_typ: Type[K], value_typ: Type[V], d: Dict[X, Y])Dict[K, V][source]

Calls checked_cast on all keys and values in the dictionary.

ax.utils.common.typeutils.checked_cast_list(typ: Type[T], old_l: List[V])List[T][source]

Calls checked_cast on all items in a list.

ax.utils.common.typeutils.checked_cast_optional(typ: Type[T], val: Optional[V])Optional[T][source]

Calls checked_cast only if value is not None.

ax.utils.common.typeutils.checked_cast_to_tuple(typ: Tuple[Type[V], ], val: V)T[source]

Cast a value to a union of multiple types (with a runtime safety check). This function is similar to checked_cast, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple.

Parameters
  • typ – the tuple of types to cast to

  • val – the value that we are casting

Returns

the val argument, unchanged

ax.utils.common.typeutils.not_none(val: Optional[T], message: Optional[str] = None)T[source]

Unbox an optional type.

Parameters
  • val – the value to cast to a non None type

  • message – optional override of the default error message

Returns

val when val is not None

Return type

V

Throws:

ValueError if val is None

ax.utils.common.typeutils.numpy_type_to_python_type(value: Any)Any[source]

If value is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values.

ax.utils.common.typeutils.torch_type_from_str(identifier: str, type_name: str)Union[torch.dtype, torch.device][source]
ax.utils.common.typeutils.torch_type_to_str(value: Any)str[source]

Converts torch types, commonly used in Ax, to string representations.

Flake8 Plugins

Docstring Checker

ax.utils.flake8_plugins.docstring_checker.A000(node: _ast.AST)ax.utils.flake8_plugins.docstring_checker.Error
class ax.utils.flake8_plugins.docstring_checker.DocstringChecker(tree, filename)[source]

Bases: object

A flake8 plug-in that makes sure all public functions have a docstring

fikename: str
name: str = 'docstring checker'
run()[source]
tree: _ast.Module
version: str = '1.0.0'
class ax.utils.flake8_plugins.docstring_checker.DocstringCheckerVisitor[source]

Bases: ast.NodeVisitor

check_A000(node: _ast.AST)None[source]
errors: List[ax.utils.flake8_plugins.docstring_checker.Error]
visit_AsyncFunctionDef(node: _ast.ClassDef)None[source]
visit_ClassDef(node: _ast.ClassDef)None[source]
visit_FunctionDef(node: _ast.FunctionDef)None[source]
class ax.utils.flake8_plugins.docstring_checker.Error(lineno, col, message, type)[source]

Bases: tuple

property col

Alias for field number 1

property lineno

Alias for field number 0

property message

Alias for field number 2

property type

Alias for field number 3

ax.utils.flake8_plugins.docstring_checker.is_copy_doc_call(c)[source]

Tries to guess if this is a call to the copy_doc decorator. This is a purely syntactic check so if the decorator was aliased as another name] or wrapped in another function we will fail.

ax.utils.flake8_plugins.docstring_checker.new_error(errorid: str, msg: str)Callable[[_ast.AST], ax.utils.flake8_plugins.docstring_checker.Error][source]
ax.utils.flake8_plugins.docstring_checker.should_check(filename)[source]

Measurement

Synthetic Functions

class ax.utils.measurement.synthetic_functions.Aug_Branin[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Augmented Branin function (3-dimensional with infinitely many global minima).

class ax.utils.measurement.synthetic_functions.Aug_Hartmann6[source]

Bases: ax.utils.measurement.synthetic_functions.Hartmann6

Augmented Hartmann6 function (7-dimensional with 1 global minimum).

class ax.utils.measurement.synthetic_functions.Branin[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Branin function (2-dimensional with 3 global minima).

class ax.utils.measurement.synthetic_functions.FromBotorch(botorch_synthetic_function: botorch.test_functions.synthetic.SyntheticTestFunction)[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

property name
class ax.utils.measurement.synthetic_functions.Hartmann6[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Hartmann6 function (6-dimensional with 1 global minimum).

class ax.utils.measurement.synthetic_functions.SyntheticFunction[source]

Bases: abc.ABC

property domain
f(X: numpy.ndarray)Union[float, numpy.ndarray][source]

Synthetic function implementation.

Parameters

X (numpy.ndarray) – an n by d array, where n represents the number of observations and d is the dimensionality of the inputs.

Returns

an n-dimensional array.

Return type

numpy.ndarray

property fmax
property fmin
property maximums
property minimums
property name
property required_dimensionality
ax.utils.measurement.synthetic_functions.from_botorch(botorch_synthetic_function: botorch.test_functions.synthetic.SyntheticTestFunction)ax.utils.measurement.synthetic_functions.SyntheticFunction[source]

Utility to generate Ax synthetic functions from BoTorch synthetic functions.

ax.utils.measurement.synthetic_functions.informative_failure_on_none(func: Callable)Any[source]

Notebook

Plotting

ax.utils.notebook.plotting.init_notebook_plotting(offline=False)[source]

Initialize plotting in notebooks, either in online or offline mode.

ax.utils.notebook.plotting.render(plot_config: ax.plot.base.AxPlotConfig, inject_helpers=False)None[source]

Render plot config.

Report

Render

ax.utils.report.render.h2_html(text: str)str[source]

Embed text in subheading tag.

ax.utils.report.render.h3_html(text: str)str[source]

Embed text in subsubheading tag.

Embed text and reference address into link tag.

ax.utils.report.render.list_item_html(text: str)str[source]

Embed text in list element tag.

ax.utils.report.render.p_html(text: str)str[source]

Embed text in paragraph tag.

ax.utils.report.render.render_report_elements(experiment_name: str, html_elements: List[str], header: bool = True, offline: bool = False, notebook_env: bool = False)str[source]

Generate Ax HTML report for a given experiment from HTML elements.

Uses Jinja2 for template. Injects Plotly JS for graph rendering.

Example:

html_elements = [
    h2_html("Subsection with plot"),
    p_html("This is an example paragraph."),
    plot_html(plot_fitted(gp_model, 'perf_metric')),
    h2_html("Subsection with table"),
    pandas_html(data.df),
]
html = render_report_elements('My experiment', html_elements)
Parameters
  • experiment_name – the name of the experiment to use for title.

  • html_elements – list of HTML strings to render in report body.

  • header – if True, render experiment title as a header. Meant to be used for standalone reports (e.g. via email), as opposed to served on the front-end.

  • offline – if True, entire Plotly library is bundled with report.

  • notebook_env – if True, caps the report width to 700px for viewing in a notebook environment.

Returns

HTML string.

Return type

str

ax.utils.report.render.table_cell_html(text: str, width: Optional[str] = None)str[source]

Embed text or an HTML element into table cell tag.

ax.utils.report.render.table_heading_cell_html(text: str)str[source]

Embed text or an HTML element into table heading cell tag.

ax.utils.report.render.table_html(table_rows: List[str])str[source]

Embed list of HTML elements into table tag.

ax.utils.report.render.table_row_html(table_cells: List[str])str[source]

Embed list of HTML elements into table row tag.

ax.utils.report.render.unordered_list_html(list_items: List[str])str[source]

Embed list of html elements into an unordered list tag.

Stats

Statstools

ax.utils.stats.statstools.agresti_coull_sem(n_numer: Union[pandas.Series, numpy.ndarray, int], n_denom: Union[pandas.Series, numpy.ndarray, int], prior_successes: int = 2, prior_failures: int = 2)Union[numpy.ndarray, float][source]

Compute the Agresti-Coull style standard error for a binomial proportion.

Reference: Agresti, Alan, and Brent A. Coull. Approximate Is Better than ‘Exact’ for Interval Estimation of Binomial Proportions.” The American Statistician, vol. 52, no. 2, 1998, pp. 119-126. JSTOR, www.jstor.org/stable/2685469.

ax.utils.stats.statstools.inverse_variance_weight(means: numpy.ndarray, variances: numpy.ndarray, conflicting_noiseless: str = 'warn')Tuple[float, float][source]

Perform inverse variance weighting.

Parameters
  • means – The means of the observations.

  • variances – The variances of the observations.

  • conflicting_noiseless – How to handle the case of multiple observations with zero variance but different means. Options are “warn” (default), “ignore” or “raise”.

ax.utils.stats.statstools.marginal_effects(df: pandas.DataFrame)pandas.DataFrame[source]

This method calculates the relative (in %) change in the outcome achieved by using any individual factor level versus randomizing across all factor levels. It does this by estimating a baseline under the experiment by marginalizing over all factors/levels. For each factor level, then, it conditions on that level for the individual factor and then marginalizes over all levels for all other factors.

Parameters

df – Dataframe containing columns named mean and sem. All other columns are assumed to be factors for which to calculate marginal effects.

Returns

A dataframe containing columns “Name”, “Level”, “Beta” and “SE”

corresponding to the factor, level, effect and standard error. Results are relativized as percentage changes.

ax.utils.stats.statstools.positive_part_james_stein(means: Union[numpy.ndarray, List[float]], sems: Union[numpy.ndarray, List[float]])Tuple[numpy.ndarray, numpy.ndarray][source]

Estimation method for Positive-part James-Stein estimator.

This method takes a vector of K means (y_i) and standard errors (sigma_i) and calculates the positive-part James Stein estimator.

Resulting estimates are the shrunk means and standard errors. The positive part James-Stein estimator shrinks each constituent average to the grand average:

y_i - phi_i * y_i + phi_i * ybar

The variable phi_i determines the amount of shrinkage. For phi_i = 1, mu_hat is equal to ybar (the mean of all y_i), while for phi_i = 0, mu_hat is equal to y_i. It can be shown that restricting phi_i <= 1 dominates the unrestricted estimator, so this method restricts phi_i in this manner. The amount of shrinkage, phi_i, is determined by:

(K - 3) * sigma2_i / s2

That is, less shrinkage is applied when individual means are estimated with greater precision, and more shrinkage is applied when individual means are very tightly clustered together. We also restrict phi_i to never be larger than 1.

The variance of the mean estimator is:

(1 - phi_i) * sigma2_i + phi * sigma2_i / K + 2 * phi_i ** 2 * (y_i - ybar)^2 / (K - 3)

The first term is the variance component from y_i, the second term is the contribution from the mean of all y_i, and the third term is the contribution from the uncertainty in the sum of squared deviations of y_i from the mean of all y_i.

For more information, see https://ax.dev/docs/models.html#empirical-bayes-and-thompson-sampling.

Parameters
  • means – Means of each arm

  • sems – Standard errors of each arm

Returns

Empirical Bayes estimate of each arm’s mean sem_i: Empirical Bayes estimate of each arm’s sem

Return type

mu_hat_i

ax.utils.stats.statstools.relativize(means_t: Union[numpy.ndarray, List[float], float], sems_t: Union[numpy.ndarray, List[float], float], mean_c: float, sem_c: float, bias_correction: bool = True, cov_means: Union[numpy.ndarray, List[float], float] = 0.0, as_percent: bool = False)Tuple[numpy.ndarray, numpy.ndarray][source]

Ratio estimator based on the delta method.

This uses the delta method (i.e. a Taylor series approximation) to estimate the mean and standard deviation of the sampling distribution of the ratio between test and control – that is, the sampling distribution of an estimator of the true population value under the assumption that the means in test and control have a known covariance:

(mu_t / mu_c) - 1.

Under a second-order Taylor expansion, the sampling distribution of the relative change in empirical means, which is m_t / m_c - 1, is approximately normally distributed with mean

[(mu_t - mu_c) / mu_c] - [(sigma_c)^2 * mu_t] / (mu_c)^3

and variance

(sigma_t / mu_c)^2 - 2 * mu_t _ sigma_tc / mu_c^3 + [(sigma_c * mu_t)^2 / (mu_c)^4]

as the higher terms are assumed to be close to zero in the full Taylor series. To estimate these parameters, we plug in the empirical means and standard errors. This gives us the estimators:

[(m_t - m_c) / m_c] - [(s_c)^2 * m_t] / (m_c)^3

and

(s_t / m_c)^2 - 2 * m_t * s_tc / m_c^3 + [(s_c * m_t)^2 / (m_c)^4]

Note that the delta method does NOT take as input the empirical standard deviation of a metric, but rather the standard error of the mean of that metric – that is, the standard deviation of the metric after division by the square root of the total number of observations.

Parameters
  • means_t – Sample means (test)

  • sems_t – Sample standard errors of the means (test)

  • mean_c – Sample mean (control)

  • sem_c – Sample standard error of the mean (control)

  • cov_means – Sample covariance between test and control

  • as_percent – If true, return results in percent (* 100)

Returns

Inferred means of the sampling distribution of

the relative change (mean_t / mean_c) - 1

sem_hat: Inferred standard deviation of the sampling

distribution of rel_hat – i.e. the standard error.

Return type

rel_hat

ax.utils.stats.statstools.relativize_data(data: ax.core.data.Data, status_quo_name: str = 'status_quo', as_percent: bool = False, include_sq: bool = False)ax.core.data.Data[source]

Relativize a data object w.r.t. a status_quo arm.

Parameters
  • data – The data object to be relativized.

  • status_quo_name – The name of the status_quo arm.

  • as_percent – If True, return results as percentage change.

  • include_sq – Include status quo in final df.

Returns

The new data object with the relativized metrics (excluding the

status_quo arm)

ax.utils.stats.statstools.total_variance(means: numpy.ndarray, variances: numpy.ndarray, sample_sizes: numpy.ndarray)float[source]

Compute total variance.

Storage

Deletion

Testing

Backend Scheduler

Backend Simulator

class ax.utils.testing.backend_simulator.BackendSimulator(options: Optional[ax.utils.testing.backend_simulator.BackendSimulatorOptions] = None, queued: Optional[List[ax.utils.testing.backend_simulator.SimTrial]] = None, running: Optional[List[ax.utils.testing.backend_simulator.SimTrial]] = None, failed: Optional[List[ax.utils.testing.backend_simulator.SimTrial]] = None, completed: Optional[List[ax.utils.testing.backend_simulator.SimTrial]] = None, verbose_logging: bool = True)[source]

Bases: object

Simulator for a backend deployment with concurrent dispatch and a queue.

property all_trials

All trials on the simulator.

classmethod from_state(state: ax.utils.testing.backend_simulator.BackendSimulatorState)[source]

Construct a simulator from a state.

Parameters

state – A BackendSimulatorState to set the simulator to.

Returns

A BackendSimulator with the desired state.

get_sim_trial_by_index(trial_index: int)Optional[ax.utils.testing.backend_simulator.SimTrial][source]

Get a SimTrial by trial_index.

Parameters

trial_index – The index of the trial to return.

Returns

A SimTrial with the index trial_index or None if not found.

lookup_trial_index_status(trial_index: int)Optional[ax.core.base_trial.TrialStatus][source]

Lookup the trial status of a trial_index.

Parameters

trial_index – The index of the trial to check.

Returns

A TrialStatus.

new_trial(trial: ax.utils.testing.backend_simulator.SimTrial, status: ax.core.base_trial.TrialStatus)None[source]

Register a trial into the simulator.

Parameters
  • trial – A new trial to add.

  • status – The status of the new trial, either STAGED (add to self._queued) or RUNNING (add to self._running).

property num_completed

The number of completed trials.

property num_failed

The number of failed trials.

property num_queued

The number of queued trials (to run as soon as capacity is available).

property num_running

The number of currently running trials.

reset()None[source]

Reset the simulator.

run_trial(trial_index: int, runtime: float)None[source]

Run a simulated trial.

Parameters
  • trial_index – The index of the trial (usually the Ax trial index)

  • runtime – The runtime of the simulation. Typically sampled from the runtime model of a simulation model.

Internally, the runtime is scaled by the time_scaling factor, so that the simulation can run arbitrarily faster than the underlying evaluation.

state()ax.utils.testing.backend_simulator.BackendSimulatorState[source]

Return a BackendSimulatorState containing the state of the simulator.

status()ax.utils.testing.backend_simulator.SimStatus[source]

Return the internal status of the simulator.

Returns

A SimStatus object representing the current simulator state.

stop_trial(trial_index: int)None[source]

Stop a simulated trial by setting the completed time to the current time.

Parameters

trial_index – The index of the trial to stop.

property time

The current time.

update()None[source]

Update the state of the simulator.

property use_internal_clock

Whether or not we are using the internal clock.

class ax.utils.testing.backend_simulator.BackendSimulatorOptions(max_concurrency: int = 1, time_scaling: float = 1.0, failure_rate: float = 0.0, internal_clock: Optional[float] = None, use_update_as_start_time: bool = False)[source]

Bases: object

Settings for the BackendSimulator.

Parameters
  • max_concurrency – The maximum number of trials that can be run in parallel.

  • time_scaling – The factor to scale down the runtime of the tasks by. If runtime is the actual runtime of a trial, the simulation time will be runtime / time_scaling.

  • failure_rate – The rate at which the trials are failing. For now, trials fail independently with at coin flip based on that rate.

  • internal_clock – The initial state of the internal clock. If None, the simulator uses time.time() as the clock.

  • use_update_as_start_time – Whether the start time of a new trial should be logged as the current time (at time of update) or end time of previous trial. This makes sense when using the internal clock and the BackendSimulator is simulated forward by an external process (such as Scheduler).

failure_rate: float = 0.0
internal_clock: Optional[float] = None
max_concurrency: int = 1
time_scaling: float = 1.0
use_update_as_start_time: bool = False
class ax.utils.testing.backend_simulator.BackendSimulatorState(options: ax.utils.testing.backend_simulator.BackendSimulatorOptions, verbose_logging: bool, queued: List[Dict[str, Optional[float]]], running: List[Dict[str, Optional[float]]], failed: List[Dict[str, Optional[float]]], completed: List[Dict[str, Optional[float]]])[source]

Bases: object

State of the BackendSimulator.

Parameters
  • options – The BackendSimulatorOptions associated with this simulator.

  • verbose_logging – Whether the simulator is using verbose logging.

  • queued – Currently queued trials.

  • running – Currently running trials.

  • failed – Currently failed trials.

  • completed – Currently completed trials.

completed: List[Dict[str, Optional[float]]]
failed: List[Dict[str, Optional[float]]]
options: ax.utils.testing.backend_simulator.BackendSimulatorOptions
queued: List[Dict[str, Optional[float]]]
running: List[Dict[str, Optional[float]]]
verbose_logging: bool
class ax.utils.testing.backend_simulator.SimStatus(queued: List[int], running: List[int], failed: List[int], time_remaining: List[float], completed: List[int])[source]

Bases: object

Container for status of the simulation.

queued

List of indices of queued trials.

Type

List[int]

running

List of indices of running trials.

Type

List[int]

failed

List of indices of failed trials.

Type

List[int]

time_remaining

List of sim time remaining for running trials.

Type

List[float]

completed

List of indicies of completed trials.

Type

List[int]

completed: List[int]
failed: List[int]
queued: List[int]
running: List[int]
time_remaining: List[float]
class ax.utils.testing.backend_simulator.SimTrial(trial_index: int, sim_runtime: float, sim_start_time: Optional[float] = None, sim_queued_time: Optional[float] = None, sim_completed_time: Optional[float] = None)[source]

Bases: object

Container for the simulation tasks.

trial_index

The index of the trial (should match Ax trial index).

Type

int

sim_runtime

The runtime of the trial (sampled at creation).

Type

float

sim_start_time

When the trial started running (or exits queued state).

Type

Optional[float]

sim_queued_time

When the trial was initially queued.

Type

Optional[float]

sim_completed_time

When the trial was marked as completed. Currently, this is used by an early-stopper via stop_trial.

Type

Optional[float]

sim_completed_time: Optional[float] = None
sim_queued_time: Optional[float] = None
sim_runtime: float
sim_start_time: Optional[float] = None
trial_index: int
ax.utils.testing.backend_simulator.format(trial_list: List[Dict[str, Optional[float]]])str[source]

Helper function for formatting a list.

Benchmark Stubs

Core Stubs

class ax.utils.testing.core_stubs.DummyEarlyStoppingStrategy(early_stop_trials: Optional[Dict[int, str]] = None)[source]

Bases: ax.early_stopping.strategies.base.BaseEarlyStoppingStrategy

should_stop_trials_early(trial_indices: Set[int], experiment: ax.core.experiment.Experiment, **kwargs: Dict[str, Any])Dict[int, Optional[str]][source]

Decide whether to complete trials before evaluation is fully concluded.

Typical examples include stopping a machine learning model’s training, or halting the gathering of samples before some planned number are collected.

Parameters
  • trial_indices – Indices of candidate trials to stop early.

  • experiment – Experiment that contains the trials and other contextual data.

Returns

A dictionary mapping trial indices that should be early stopped to (optional) messages with the associated reason.

class ax.utils.testing.core_stubs.DummyGlobalStoppingStrategy(min_trials: int, trial_to_stop: int)[source]

Bases: ax.global_stopping.strategies.base.BaseGlobalStoppingStrategy

A dummy Global Stopping Strategy which stops the optimization after a pre-specified number of trials are completed.

should_stop_optimization(experiment: ax.core.experiment.Experiment, **kwargs: Dict[str, Any])[source]

Decide whether to stop optimization.

Typical examples include stopping the optimization loop when the objective appears to not improve anymore.

Parameters

experiment – Experiment that contains the trials and other contextual data.

Returns

A Tuple with a boolean determining whether the optimization should stop, and a str declaring the reason for stopping.

class ax.utils.testing.core_stubs.TestTrial(experiment: core.experiment.Experiment, trial_type: Optional[str] = None, ttl_seconds: Optional[int] = None, index: Optional[int] = None)[source]

Bases: ax.core.base_trial.BaseTrial

Trial class to test unsupported trial type error

abandoned_arms()str[source]

All abandoned arms, associated with this trial.

property arms
arms_by_name()str[source]
generator_runs()str[source]

All generator runs associated with this trial.

ax.utils.testing.core_stubs.get_abandoned_arm()ax.core.batch_trial.AbandonedArm[source]
ax.utils.testing.core_stubs.get_acquisition_function_type()Type[botorch.acquisition.acquisition.AcquisitionFunction][source]
ax.utils.testing.core_stubs.get_acquisition_type()Type[ax.models.torch.botorch_modular.acquisition.Acquisition][source]
ax.utils.testing.core_stubs.get_and_early_stopping_strategy()ax.early_stopping.strategies.logical.AndEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_arm()ax.core.arm.Arm[source]
ax.utils.testing.core_stubs.get_arm_weights1()MutableMapping[ax.core.arm.Arm, float][source]
ax.utils.testing.core_stubs.get_arm_weights2()MutableMapping[ax.core.arm.Arm, float][source]
ax.utils.testing.core_stubs.get_arms()List[ax.core.arm.Arm][source]
ax.utils.testing.core_stubs.get_arms_from_dict(arm_weights_dict: MutableMapping[ax.core.arm.Arm, float])List[ax.core.arm.Arm][source]
ax.utils.testing.core_stubs.get_augmented_branin_metric(name='aug_branin')ax.metrics.branin.AugmentedBraninMetric[source]
ax.utils.testing.core_stubs.get_augmented_branin_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_augmented_branin_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_augmented_hartmann_metric(name='aug_hartmann')ax.metrics.hartmann6.AugmentedHartmann6Metric[source]
ax.utils.testing.core_stubs.get_augmented_hartmann_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_augmented_hartmann_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_batch_trial(abandon_arm: bool = True, experiment: Optional[ax.core.experiment.Experiment] = None)ax.core.batch_trial.BatchTrial[source]
ax.utils.testing.core_stubs.get_batch_trial_with_repeated_arms(num_repeated_arms: int)ax.core.batch_trial.BatchTrial[source]

Create a batch that contains both new arms and N arms from the last existed trial in the experiment. Where N is equal to the input argument ‘num_repeated_arms’.

ax.utils.testing.core_stubs.get_botorch_model()ax.models.torch.botorch_modular.model.BoTorchModel[source]
ax.utils.testing.core_stubs.get_botorch_model_with_default_acquisition_class()ax.models.torch.botorch_modular.model.BoTorchModel[source]
ax.utils.testing.core_stubs.get_branin_arms(n: int, seed: int)List[ax.core.arm.Arm][source]
ax.utils.testing.core_stubs.get_branin_data(trial_indices: Optional[Iterable[int]] = None, trials: Optional[Iterable[ax.core.trial.Trial]] = None)ax.core.data.Data[source]
ax.utils.testing.core_stubs.get_branin_data_batch(batch: ax.core.batch_trial.BatchTrial)ax.core.data.Data[source]
ax.utils.testing.core_stubs.get_branin_data_multi_objective(trial_indices: Optional[Iterable[int]] = None, num_objectives: int = 2)ax.core.data.Data[source]
ax.utils.testing.core_stubs.get_branin_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_trial: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False, search_space: Optional[ax.core.search_space.SearchSpace] = None, minimize: bool = False, named: bool = True, with_completed_trial: bool = False)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_branin_experiment_with_multi_objective(has_optimization_config: bool = True, has_objective_thresholds: bool = False, with_batch: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False, num_objectives: int = 2)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_branin_experiment_with_timestamp_map_metric(with_status_quo: bool = False, rate: Optional[float] = None)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_branin_metric(name='branin')ax.metrics.branin.BraninMetric[source]
ax.utils.testing.core_stubs.get_branin_multi_objective(num_objectives: int = 2)ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_branin_multi_objective_optimization_config(has_objective_thresholds: bool = False, num_objectives: int = 2)ax.core.optimization_config.MultiObjectiveOptimizationConfig[source]
ax.utils.testing.core_stubs.get_branin_objective(minimize: bool = False)ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_branin_optimization_config(minimize: bool = False)ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_branin_outcome_constraint()ax.core.outcome_constraint.OutcomeConstraint[source]
ax.utils.testing.core_stubs.get_branin_search_space(with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, with_str_choice_param: bool = False)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_branin_with_multi_task(with_multi_objective: bool = False)[source]
ax.utils.testing.core_stubs.get_choice_parameter()ax.core.parameter.ChoiceParameter[source]
ax.utils.testing.core_stubs.get_data(metric_name: str = 'ax_test_metric', trial_index: int = 0, num_non_sq_arms: int = 4, include_sq: bool = True)ax.core.data.Data[source]
ax.utils.testing.core_stubs.get_default_scheduler_options()ax.service.utils.scheduler_options.SchedulerOptions[source]
ax.utils.testing.core_stubs.get_discrete_search_space()ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_experiment(with_status_quo: bool = True)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_batch_and_single_trial()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_batch_trial()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_data()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_map_data()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_map_data_type()[source]
ax.utils.testing.core_stubs.get_experiment_with_multi_objective()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_repeated_arms(num_repeated_arms: int)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_scalarized_objective_and_outcome_constraint()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_experiment_with_trial_with_ttl()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_factorial_experiment(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False)ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_factorial_metric(name: str = 'success_metric')ax.metrics.factorial.FactorialMetric[source]
ax.utils.testing.core_stubs.get_factorial_search_space()ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_fixed_parameter()ax.core.parameter.FixedParameter[source]
ax.utils.testing.core_stubs.get_gamma_prior()gpytorch.priors.torch_priors.GammaPrior[source]
ax.utils.testing.core_stubs.get_generator_run()ax.core.generator_run.GeneratorRun[source]
ax.utils.testing.core_stubs.get_generator_run2()ax.core.generator_run.GeneratorRun[source]
ax.utils.testing.core_stubs.get_hartmann_metric(name='hartmann')ax.metrics.hartmann6.Hartmann6Metric[source]
ax.utils.testing.core_stubs.get_hartmann_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_hartmann_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_hartmann_search_space(with_fidelity_parameter: bool = False)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_hierarchical_search_space()ax.core.search_space.HierarchicalSearchSpace[source]
ax.utils.testing.core_stubs.get_hierarchical_search_space_experiment()ax.core.experiment.Experiment[source]
ax.utils.testing.core_stubs.get_interval()gpytorch.constraints.constraints.Interval[source]
ax.utils.testing.core_stubs.get_l2_reg_weight_parameter()ax.core.parameter.RangeParameter[source]
ax.utils.testing.core_stubs.get_large_factorial_search_space()ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_large_ordinal_search_space(n_ordinal_choice_parameters, n_continuous_range_parameters)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_list_surrogate()ax.models.torch.botorch_modular.surrogate.Surrogate[source]
ax.utils.testing.core_stubs.get_lr_parameter()ax.core.parameter.RangeParameter[source]
ax.utils.testing.core_stubs.get_map_data(trial_index: int = 0)ax.core.map_data.MapData[source]
ax.utils.testing.core_stubs.get_map_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_map_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_metric()ax.core.metric.Metric[source]
ax.utils.testing.core_stubs.get_mll_type()Type[gpytorch.mlls.marginal_log_likelihood.MarginalLogLikelihood][source]
ax.utils.testing.core_stubs.get_model_covariance()Dict[str, Dict[str, List[float]]][source]
ax.utils.testing.core_stubs.get_model_mean()Dict[str, List[float]][source]
ax.utils.testing.core_stubs.get_model_parameter()ax.core.parameter.ChoiceParameter[source]
ax.utils.testing.core_stubs.get_model_predictions()Tuple[Dict[str, List[float]], Dict[str, Dict[str, List[float]]]][source]
ax.utils.testing.core_stubs.get_model_predictions_per_arm()Dict[str, Tuple[Dict[str, float], Optional[Dict[str, Dict[str, float]]]]][source]
ax.utils.testing.core_stubs.get_model_type()Type[botorch.models.model.Model][source]
ax.utils.testing.core_stubs.get_multi_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_multi_objective_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_multi_type_experiment(add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10)ax.core.multi_type_experiment.MultiTypeExperiment[source]
ax.utils.testing.core_stubs.get_multi_type_experiment_with_multi_objective(add_trials: bool = False)ax.core.multi_type_experiment.MultiTypeExperiment[source]
ax.utils.testing.core_stubs.get_non_monolithic_branin_moo_data()ax.core.data.Data[source]
ax.utils.testing.core_stubs.get_num_boost_rounds_parameter()ax.core.parameter.RangeParameter[source]
ax.utils.testing.core_stubs.get_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_objective_threshold(metric_name: str = 'm1', bound=-0.25, comparison_op: ax.core.types.ComparisonOp = <ComparisonOp.GEQ: 0>)ax.core.outcome_constraint.ObjectiveThreshold[source]
ax.utils.testing.core_stubs.get_optimization_config()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_optimization_config_no_constraints()ax.core.optimization_config.OptimizationConfig[source]
ax.utils.testing.core_stubs.get_or_early_stopping_strategy()ax.early_stopping.strategies.logical.OrEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_order_constraint()ax.core.parameter_constraint.OrderConstraint[source]
ax.utils.testing.core_stubs.get_ordered_choice_parameter()ax.core.parameter.ChoiceParameter[source]
ax.utils.testing.core_stubs.get_outcome_constraint()ax.core.outcome_constraint.OutcomeConstraint[source]
ax.utils.testing.core_stubs.get_parameter_constraint(param_x: str = 'x', param_y: str = 'w')ax.core.parameter_constraint.ParameterConstraint[source]
ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy()ax.early_stopping.strategies.percentile.PercentileEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy_with_non_objective_metric_name()ax.early_stopping.strategies.percentile.PercentileEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_percentile_early_stopping_strategy_with_true_objective_metric_name()ax.early_stopping.strategies.percentile.PercentileEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_range_parameter()ax.core.parameter.RangeParameter[source]
ax.utils.testing.core_stubs.get_range_parameter2()ax.core.parameter.RangeParameter[source]
ax.utils.testing.core_stubs.get_robust_search_space(lb: float = 0.0, ub: float = 5.0, multivariate: bool = False, use_discrete: bool = False)ax.core.search_space.RobustSearchSpace[source]
ax.utils.testing.core_stubs.get_scalarized_objective()ax.core.objective.Objective[source]
ax.utils.testing.core_stubs.get_scalarized_outcome_constraint()ax.core.outcome_constraint.ScalarizedOutcomeConstraint[source]
ax.utils.testing.core_stubs.get_scheduler_options_batch_trial()ax.service.utils.scheduler_options.SchedulerOptions[source]
ax.utils.testing.core_stubs.get_search_space()ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_search_space_for_range_value(min: float = 3.0, max: float = 6.0)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_search_space_for_range_values(min: float = 3.0, max: float = 6.0)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_search_space_for_value(val: float = 3.0)ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_small_discrete_search_space()ax.core.search_space.SearchSpace[source]
ax.utils.testing.core_stubs.get_status_quo()ax.core.arm.Arm[source]
ax.utils.testing.core_stubs.get_sum_constraint1()ax.core.parameter_constraint.SumConstraint[source]
ax.utils.testing.core_stubs.get_sum_constraint2()ax.core.parameter_constraint.SumConstraint[source]
ax.utils.testing.core_stubs.get_surrogate()ax.models.torch.botorch_modular.surrogate.Surrogate[source]
ax.utils.testing.core_stubs.get_synthetic_runner()ax.runners.synthetic.SyntheticRunner[source]
ax.utils.testing.core_stubs.get_task_choice_parameter()ax.core.parameter.ChoiceParameter[source]
ax.utils.testing.core_stubs.get_threshold_early_stopping_strategy()ax.early_stopping.strategies.threshold.ThresholdEarlyStoppingStrategy[source]
ax.utils.testing.core_stubs.get_trial()ax.core.trial.Trial[source]
ax.utils.testing.core_stubs.get_weights()List[float][source]
ax.utils.testing.core_stubs.get_weights_from_dict(arm_weights_dict: MutableMapping[ax.core.arm.Arm, float])List[float][source]
ax.utils.testing.core_stubs.get_winsorization_config()ax.modelbridge.transforms.winsorize.WinsorizationConfig[source]

Modeling Stubs

ax.utils.testing.modeling_stubs.get_experiment_for_value()ax.core.experiment.Experiment[source]
ax.utils.testing.modeling_stubs.get_generation_strategy(with_experiment: bool = False, with_callable_model_kwarg: bool = True)ax.modelbridge.generation_strategy.GenerationStrategy[source]
ax.utils.testing.modeling_stubs.get_observation(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation1(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation1trans(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation2(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation2trans(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation_features()ax.core.observation.ObservationFeatures[source]
ax.utils.testing.modeling_stubs.get_observation_status_quo0(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_observation_status_quo1(first_metric_name: str = 'a', second_metric_name='b')ax.core.observation.Observation[source]
ax.utils.testing.modeling_stubs.get_transform_type()Type[ax.modelbridge.transforms.base.Transform][source]
class ax.utils.testing.modeling_stubs.transform_1(search_space: Optional[SearchSpace], observation_features: List[ObservationFeatures], observation_data: List[ObservationData], modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, config: Optional[TConfig] = None)[source]

Bases: ax.modelbridge.transforms.base.Transform

config: TConfig
modelbridge: Optional[modelbridge_module.base.ModelBridge]
transform_observation_data(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationData][source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.

Parameters
  • observation_data – Observation data

  • observation_features – Corresponding observation features

Returns: transformed observation data

transform_observation_features(observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationFeatures][source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features – Observation features

Returns: transformed observation features

transform_optimization_config(optimization_config: ax.core.optimization_config.OptimizationConfig, modelbridge: Optional[ax.modelbridge.base.ModelBridge], fixed_features: ax.core.observation.ObservationFeatures)ax.core.optimization_config.OptimizationConfig[source]

Transform optimization config.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

optimization_config – The optimization config

Returns: transformed optimization config.

transform_search_space(search_space: ax.core.search_space.SearchSpace)ax.core.search_space.SearchSpace[source]

Transform search space.

The transforms are typically done in-place. This calls two private methods, _transform_search_space, which transforms the core search space attributes, and _transform_parameter_distributions, which transforms the distributions when using a RobustSearchSpace.

Parameters

search_space – The search space

Returns: transformed search space.

untransform_observation_data(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationData][source]

Untransform observation data.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters
  • observation_data – Observation data, in transformed space

  • observation_features – Corresponding observation features, in same space.

Returns: observation data in original space.

untransform_observation_features(observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationFeatures][source]

Untransform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features – Observation features in the transformed space

Returns: observation features in the original space

class ax.utils.testing.modeling_stubs.transform_2(search_space: Optional[SearchSpace], observation_features: List[ObservationFeatures], observation_data: List[ObservationData], modelbridge: Optional[modelbridge_module.base.ModelBridge] = None, config: Optional[TConfig] = None)[source]

Bases: ax.modelbridge.transforms.base.Transform

config: TConfig
modelbridge: Optional[modelbridge_module.base.ModelBridge]
transform_observation_data(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationData][source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.

Parameters
  • observation_data – Observation data

  • observation_features – Corresponding observation features

Returns: transformed observation data

transform_observation_features(observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationFeatures][source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features – Observation features

Returns: transformed observation features

transform_optimization_config(optimization_config: ax.core.optimization_config.OptimizationConfig, modelbridge: Optional[ax.modelbridge.base.ModelBridge], fixed_features: ax.core.observation.ObservationFeatures)ax.core.optimization_config.OptimizationConfig[source]

Transform optimization config.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

optimization_config – The optimization config

Returns: transformed optimization config.

transform_search_space(search_space: ax.core.search_space.SearchSpace)ax.core.search_space.SearchSpace[source]

Transform search space.

The transforms are typically done in-place. This calls two private methods, _transform_search_space, which transforms the core search space attributes, and _transform_parameter_distributions, which transforms the distributions when using a RobustSearchSpace.

Parameters

search_space – The search space

Returns: transformed search space.

untransform_observation_data(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationData][source]

Untransform observation data.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters
  • observation_data – Observation data, in transformed space

  • observation_features – Corresponding observation features, in same space.

Returns: observation data in original space.

untransform_observation_features(observation_features: List[ax.core.observation.ObservationFeatures])List[ax.core.observation.ObservationFeatures][source]

Untransform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features – Observation features in the transformed space

Returns: observation features in the original space

Mocking

ax.utils.testing.mock.fast_botorch_optimize(f: Callable)Callable[source]

Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.

ax.utils.testing.mock.fast_botorch_optimize_context_manager()Generator[None, None, None][source]

A context manager to force botorch to speed up optimization. Currently, the primary tactic is to force the underlying scipy methods to stop after just one iteration.

Test Init Files

class ax.utils.testing.test_init_files.InitTest(methodName: str = 'runTest')[source]

Bases: ax.utils.common.testutils.TestCase

testInitFiles()None[source]

__init__.py files are necessary when not using buck targets

Torch Stubs

ax.utils.testing.torch_stubs.get_optimizer_kwargs()Dict[str, int][source]
ax.utils.testing.torch_stubs.get_torch_test_data(dtype=torch.float32, cuda: bool = False, constant_noise: bool = True, task_features=None, offset: float = 0.0)[source]

Unittest Conventions

class ax.utils.testing.unittest_conventions.TestUnittestConventions(methodName: str = 'runTest')[source]

Bases: ax.utils.common.testutils.TestCase

test_uses_ae_unittest()[source]

Check that all of our tests are inheriting from our own base class

Our base class does a bit more (like making sure we don’t use any of python’s deprecated assert functions) so we want to enforce its usage everywhere.

ax.utils.testing.unittest_conventions.get_all_subclasses(cls)[source]

Reccursively get all the subclasses of cls

Test Metrics

Backend Simulator Map

class ax.utils.testing.metrics.backend_simulator_map.BackendSimulatorTimestampMapMetric(name: str, param_names: Iterable[str], map_key_infos: Iterable[ax.core.map_data.MapKeyInfo], noise_sd: float = 0.0, lower_is_better: Optional[bool] = None, cache_evaluations: bool = True)[source]

Bases: ax.metrics.noisy_function_map.NoisyFunctionMapMetric

A metric that interfaces with an underlying BackendSimulator and returns timestamp map data.

convert_to_timestamps(start_time: float, end_time: float)List[float][source]

Given a starting and current time, get the list of intermediate timestamps at which we have observations.

fetch_trial_data(trial: ax.core.base_trial.BaseTrial, noisy: bool = True, **kwargs: Any)ax.core.map_data.MapData[source]

Fetch data for one trial.

Branin Backend Map

class ax.utils.testing.metrics.branin_backend_map.BraninBackendMapMetric(name: str, param_names: List[str], map_key_infos: Optional[Iterable[ax.core.map_data.MapKeyInfo]] = None, noise_sd: float = 0.0, lower_is_better: Optional[bool] = True, cache_evaluations: bool = True, rate: float = 0.5, delta_t: float = 1.0)[source]

Bases: ax.utils.testing.metrics.backend_simulator_map.BackendSimulatorTimestampMapMetric, ax.metrics.branin_map.BraninTimestampMapMetric

A Branin BackendSimulatorTimestampMapMetric with a multiplicative factor of 1 - exp(-rate * t) where t is the runtime of the trial.

convert_to_timestamps(start_time: Optional[float], end_time: float)List[float][source]

Given a starting and current time, get the list of intermediate timestamps at which we have observations.

Tutorials

Neural Net

class ax.utils.tutorials.cnn_utils.CNN[source]

Bases: torch.nn.modules.module.Module

Convolutional Neural Network.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
ax.utils.tutorials.cnn_utils.evaluate(net: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader, dtype: torch.dtype, device: torch.device)float[source]

Compute classification accuracy on provided dataset.

Parameters
  • net – trained model

  • data_loader – DataLoader containing the evaluation set

  • dtype – torch dtype

  • device – torch device

Returns

classification accuracy

Return type

float

ax.utils.tutorials.cnn_utils.get_partition_data_loaders(train_valid_set: torch.utils.data.dataset.Dataset, test_set: torch.utils.data.dataset.Dataset, downsample_pct: float = 0.5, train_pct: float = 0.8, batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None)Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader][source]
Helper function for partitioning training data into training and validation sets,

downsampling data, and initializing DataLoaders for each partition.

Parameters
  • train_valid_set – torch.dataset

  • downsample_pct – the proportion of the dataset to use for training, and validation

  • train_pct – the proportion of the downsampled data to use for training

  • batch_size – how many samples per batch to load

  • num_workers – number of workers (subprocesses) for loading data

  • deterministic_partitions – whether to partition data in a deterministic fashion

  • downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct

Returns

training data DataLoader: validation data DataLoader: test data

Return type

DataLoader

ax.utils.tutorials.cnn_utils.load_mnist(downsample_pct: float = 0.5, train_pct: float = 0.8, data_path: str = './data', batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None)Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader][source]
Load MNIST dataset (download if necessary) and split data into training,

validation, and test sets.

Parameters
  • downsample_pct – the proportion of the dataset to use for training, validation, and test

  • train_pct – the proportion of the downsampled data to use for training

  • data_path – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.

  • batch_size – how many samples per batch to load

  • num_workers – number of workers (subprocesses) for loading data

  • deterministic_partitions – whether to partition data in a deterministic fashion

  • downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct

Returns

training data DataLoader: validation data DataLoader: test data

Return type

DataLoader

ax.utils.tutorials.cnn_utils.split_dataset(dataset: torch.utils.data.dataset.Dataset, lengths: List[int], deterministic_partitions: bool = False)List[torch.utils.data.dataset.Dataset][source]

Split a dataset either randomly or deterministically.

Parameters
  • dataset – the dataset to split

  • lengths – the lengths of each partition

  • deterministic_partitions – deterministic_partitions: whether to partition data in a deterministic fashion

Returns

split datasets

Return type

List[Dataset]

ax.utils.tutorials.cnn_utils.train(net: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, parameters: Dict[str, float], dtype: torch.dtype, device: torch.device)torch.nn.modules.module.Module[source]

Train CNN on provided data set.

Parameters
  • net – initialized neural network

  • train_loader – DataLoader containing training set

  • parameters – dictionary containing parameters to be passed to the optimizer. - lr: default (0.001) - momentum: default (0.0) - weight_decay: default (0.0) - num_epochs: default (1)

  • dtype – torch dtype

  • device – torch device

Returns

trained CNN.

Return type

nn.Module