ax.utils

Common

Docutils

Support functions for sphinx et. al

ax.utils.common.docutils.copy_doc(src)[source]

A decorator that copies the docstring of another object

Since sphinx actually loads the python modules to grab the docstrings this works with both sphinx and the help function.

class Cat(Mamal):

  @property
  @copy_doc(Mamal.is_feline)
  def is_feline(self) -> true:
      ...
Return type

~_T

Equality

ax.utils.common.equality.dataframe_equals(df1, df2)[source]

Compare equality of two pandas dataframes.

Return type

bool

ax.utils.common.equality.datetime_equals(dt1, dt2)[source]

Compare equality of two datetimes, ignoring microseconds.

Return type

bool

ax.utils.common.equality.equality_typechecker(eq_func)[source]

A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type.

Return type

Callable

ax.utils.common.equality.same_elements(list1, list2)[source]

Compare equality of two lists of core Ax objects.

Assumptions:

– The contents of each list are types that implement __eq__ – The lists do not contain duplicates

Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other.

Return type

bool

Kwargs

ax.utils.common.kwargs.consolidate_kwargs(kwargs_iterable, keywords)[source]

Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the ones that appear earlier and only kwargs referenced in keywords will be used. This allows to combine somewhat redundant sets of kwargs, where a user-set kwarg, for instance, needs to override a default kwarg.

>>> consolidate_kwargs(
...     kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}],
...     keywords=['a', 'b', 'd']
... )
{'a': 1, 'b': 3, 'd': 5}
Return type

Dict[str, Any]

ax.utils.common.kwargs.get_function_argument_names(function, omit=None)[source]

Extract parameter names from function signature.

Return type

List[str]

ax.utils.common.kwargs.get_function_default_arguments(function)[source]

Extract default arguments from function signature.

Return type

Dict[str, Any]

ax.utils.common.kwargs.validate_kwarg_typing(typed_callables, **kwargs)[source]

Raises a value error if some of the keyword argument types do not match the signatures of the specified typed callables.

Note: this function expects the typed callables to have unique keywords for the arguments and will raise an error if repeat keywords are found.

Return type

None

Logger

ax.utils.common.logger.get_logger(name, filepath=None, level=20)[source]

Get an Axlogger.

Sets default level to INFO, instead of WARNING. Adds timestamps to logger messages.

Return type

Logger

Serialization

ax.utils.common.serialization.callable_from_reference(path)[source]

Retrieves a callable by its path.

Return type

Callable

ax.utils.common.serialization.callable_to_reference(callable)[source]

Obtains path to the callable of form <module>.<name>.

Return type

str

ax.utils.common.serialization.named_tuple_to_dict(data)[source]

Recursively convert NamedTuples to dictionaries.

Return type

Any

Testutils

Support functions for tests

class ax.utils.common.testutils.TestCase(methodName='runTest')[source]

Bases: unittest.case.TestCase

The base test case for Ax, contains various helper functions to write unittest.

assertRaisesOn(exc, line=None, regex=None)[source]

Assert that an exception is raised on a specific line.

Return type

AbstractContextManager[None]

static silence_stderr()[source]

A context manager that silences stderr for part of a test.

If any exception passes through this context manager the stderr will be printed, otherwise it will be discarded.

Return type

Generator[None, None, None]

Timeutils

ax.utils.common.timeutils.current_timestamp_in_millis()[source]

Grab current timestamp in milliseconds as an int.

Return type

int

ax.utils.common.timeutils.to_ds(ts)[source]

Convert a datetime to a DS string.

Return type

str

ax.utils.common.timeutils.to_ts(ds)[source]

Convert a DS string to a datetime.

Return type

datetime

Typeutils

ax.utils.common.typeutils.checked_cast(typ, val)[source]

Cast a value to a type (with a runtime safety check).

Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type.

Like typing.cast check_cast performs no runtime conversion on its argument, but, unlike typing.cast, checked_cast will throw an error if the value is not of the expected type. The type passed as an argument should be a python class.

Parameters
  • typ (Type[~T]) – the type to cast to

  • val (~V) – the value that we are casting

Return type

~T

Returns

the val argument, unchanged

ax.utils.common.typeutils.checked_cast_dict(key_typ, value_typ, d)[source]

Calls checked_cast on all keys and values in the dictionary.

Return type

Dict[~K, ~V]

ax.utils.common.typeutils.checked_cast_list(typ, l)[source]

Calls checked_cast on all items in a list.

Return type

List[~T]

ax.utils.common.typeutils.checked_cast_optional(typ, val)[source]

Calls checked_cast only if value is not None.

Return type

Optional[~T]

ax.utils.common.typeutils.checked_cast_to_tuple(typ, val)[source]

Cast a value to a union of multiple types (with a runtime safety check). This function is similar to checked_cast, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple.

Parameters
  • typ (Tuple[Type[~V], …]) – the tuple of types to cast to

  • val (~V) – the value that we are casting

Return type

~T

Returns

the val argument, unchanged

ax.utils.common.typeutils.not_none(val)[source]

Unbox an optional type.

Parameters

val (Optional[~T]) – the value to cast to a non None type

Retruns:

V: val when val is not None

Throws:

ValueError if val is None

Return type

~T

ax.utils.common.typeutils.numpy_type_to_python_type(value)[source]

If value is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values.

Return type

Any

ax.utils.common.typeutils.torch_type_from_str(identifier, type_name)[source]
Return type

Union[dtype, device]

ax.utils.common.typeutils.torch_type_to_str(value)[source]

Converts torch types, commonly used in Ax, to string representations.

Return type

str

Measurement

Synthetic Functions

class ax.utils.measurement.synthetic_functions.Aug_Branin[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Augmented Branin function (3-dimensional with infinitely many global minima).

class ax.utils.measurement.synthetic_functions.Aug_Hartmann6[source]

Bases: ax.utils.measurement.synthetic_functions.Hartmann6

Augmented Hartmann6 function (7-dimensional with 1 global minimum).

class ax.utils.measurement.synthetic_functions.Branin[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Branin function (2-dimensional with 3 global minima).

class ax.utils.measurement.synthetic_functions.FromBotorch(botorch_synthetic_function)[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

property name
Return type

str

class ax.utils.measurement.synthetic_functions.Hartmann6[source]

Bases: ax.utils.measurement.synthetic_functions.SyntheticFunction

Hartmann6 function (6-dimensional with 1 global minimum).

class ax.utils.measurement.synthetic_functions.SyntheticFunction[source]

Bases: abc.ABC

property domain
Return type

Any

f(X)[source]

Synthetic function implementation.

Parameters

X (numpy.ndarray) – an n by d array, where n represents the number of observations and d is the dimensionality of the inputs.

Returns

an n-dimensional array.

Return type

numpy.ndarray

property fmax
Return type

Any

property fmin
Return type

Any

property maximums
Return type

Any

property minimums
Return type

Any

property name
Return type

Any

property required_dimensionality
Return type

Any

ax.utils.measurement.synthetic_functions.from_botorch(botorch_synthetic_function)[source]

Utility to generate Ax synthetic functions from BoTorch synthetic functions.

Return type

SyntheticFunction

ax.utils.measurement.synthetic_functions.informative_failure_on_none(func)[source]
Return type

Any

Notebook

Plotting

ax.utils.notebook.plotting.init_notebook_plotting(offline=False)[source]

Initialize plotting in notebooks, either in online or offline mode.

ax.utils.notebook.plotting.render(plot_config, inject_helpers=False)[source]

Render plot config.

Return type

None

Report

Render

ax.utils.report.render.h2_html(text)[source]

Embed text in subheading tag.

Return type

str

ax.utils.report.render.h3_html(text)[source]

Embed text in subsubheading tag.

Return type

str

Embed text and reference address into link tag.

Return type

str

ax.utils.report.render.list_item_html(text)[source]

Embed text in list element tag.

Return type

str

ax.utils.report.render.p_html(text)[source]

Embed text in paragraph tag.

Return type

str

ax.utils.report.render.render_report_elements(experiment_name, html_elements, header=True, offline=False, notebook_env=False)[source]

Generate Ax HTML report for a given experiment from HTML elements.

Uses Jinja2 for template. Injects Plotly JS for graph rendering.

Example:

html_elements = [
    h2_html("Subsection with plot"),
    p_html("This is an example paragraph."),
    plot_html(plot_fitted(gp_model, 'perf_metric')),
    h2_html("Subsection with table"),
    pandas_html(data.df),
]
html = render_report_elements('My experiment', html_elements)
Parameters
  • experiment_name (str) – the name of the experiment to use for title.

  • html_elements (List[str]) – list of HTML strings to render in report body.

  • header (bool) – if True, render experiment title as a header. Meant to be used for standalone reports (e.g. via email), as opposed to served on the front-end.

  • offline (bool) – if True, entire Plotly library is bundled with report.

  • notebook_env (bool) – if True, caps the report width to 700px for viewing in a notebook environment.

Returns

HTML string.

Return type

str

ax.utils.report.render.table_cell_html(text, width=None)[source]

Embed text or an HTML element into table cell tag.

Return type

str

ax.utils.report.render.table_heading_cell_html(text)[source]

Embed text or an HTML element into table heading cell tag.

Return type

str

ax.utils.report.render.table_html(table_rows)[source]

Embed list of HTML elements into table tag.

Return type

str

ax.utils.report.render.table_row_html(table_cells)[source]

Embed list of HTML elements into table row tag.

Return type

str

ax.utils.report.render.unordered_list_html(list_items)[source]

Embed list of html elements into an unordered list tag.

Return type

str

Stats

Statstools

ax.utils.stats.statstools.agresti_coull_sem(n_numer, n_denom, prior_successes=2, prior_failures=2)[source]

Compute the Agresti-Coull style standard error for a binomial proportion.

Reference: Agresti, Alan, and Brent A. Coull. Approximate Is Better than ‘Exact’ for Interval Estimation of Binomial Proportions.” The American Statistician, vol. 52, no. 2, 1998, pp. 119-126. JSTOR, www.jstor.org/stable/2685469.

Return type

Union[ndarray, float]

ax.utils.stats.statstools.inverse_variance_weight(means, variances, conflicting_noiseless='warn')[source]

Perform inverse variance weighting.

Parameters
  • means (ndarray) – The means of the observations.

  • variances (ndarray) – The variances of the observations.

  • conflicting_noiseless (str) – How to handle the case of multiple observations with zero variance but different means. Options are “warn” (default), “ignore” or “raise”.

Return type

Tuple[float, float]

ax.utils.stats.statstools.marginal_effects(df)[source]

This method calculates the relative (in %) change in the outcome achieved by using any individual factor level versus randomizing across all factor levels. It does this by estimating a baseline under the experiment by marginalizing over all factors/levels. For each factor level, then, it conditions on that level for the individual factor and then marginalizes over all levels for all other factors.

Parameters

df (DataFrame) – Dataframe containing columns named mean and sem. All other columns are assumed to be factors for which to calculate marginal effects.

Return type

DataFrame

Returns

A dataframe containing columns “Name”, “Level”, “Beta” and “SE”

corresponding to the factor, level, effect and standard error. Results are relativized as percentage changes.

ax.utils.stats.statstools.positive_part_james_stein(means, sems)[source]

Estimation method for Positive-part James-Stein estimator.

This method takes a vector of K means (y_i) and standard errors (sigma_i) and calculates the positive-part James Stein estimator.

Resulting estimates are the shrunk means and standard errors. The positive part James-Stein estimator shrinks each constituent average to the grand average:

y_i - phi_i * y_i + phi_i * ybar

The variable phi_i determines the amount of shrinkage. For phi_i = 1, mu_hat is equal to ybar (the mean of all y_i), while for phi_i = 0, mu_hat is equal to y_i. It can be shown that restricting phi_i <= 1 dominates the unrestricted estimator, so this method restricts phi_i in this manner. The amount of shrinkage, phi_i, is determined by:

(K - 3) * sigma2_i / s2

That is, less shrinkage is applied when individual means are estimated with greater precision, and more shrinkage is applied when individual means are very tightly clustered together. We also restrict phi_i to never be larger than 1.

The variance of the mean estimator is:

(1 - phi_i) * sigma2_i + phi * sigma2_i / K + 2 * phi_i ** 2 * (y_i - ybar)^2 / (K - 3)

The first term is the variance component from y_i, the second term is the contribution from the mean of all y_i, and the third term is the contribution from the uncertainty in the sum of squared deviations of y_i from the mean of all y_i.

For more information, see https://fburl.com/empirical_bayes.

Parameters
Returns

Empirical Bayes estimate of each arm’s mean sem_i: Empirical Bayes estimate of each arm’s sem

Return type

mu_hat_i

ax.utils.stats.statstools.relativize(means_t, sems_t, mean_c, sem_c, bias_correction=True, cov_means=0.0, as_percent=False)[source]

Ratio estimator based on the delta method.

This uses the delta method (i.e. a Taylor series approximation) to estimate the mean and standard deviation of the sampling distribution of the ratio between test and control – that is, the sampling distribution of an estimator of the true population value under the assumption that the means in test and control have a known covariance:

(mu_t / mu_c) - 1.

Under a second-order Taylor expansion, the sampling distribution of the relative change in empirical means, which is m_t / m_c - 1, is approximately normally distributed with mean

[(mu_t - mu_c) / mu_c] - [(sigma_c)^2 * mu_t] / (mu_c)^3

and variance

(sigma_t / mu_c)^2 - 2 * mu_t _ sigma_tc / mu_c^3 + [(sigma_c * mu_t)^2 / (mu_c)^4]

as the higher terms are assumed to be close to zero in the full Taylor series. To estimate these parameters, we plug in the empirical means and standard errors. This gives us the estimators:

[(m_t - m_c) / m_c] - [(s_c)^2 * m_t] / (m_c)^3

and

(s_t / m_c)^2 - 2 * m_t * s_tc / m_c^3 + [(s_c * m_t)^2 / (m_c)^4]

Note that the delta method does NOT take as input the empirical standard deviation of a metric, but rather the standard error of the mean of that metric – that is, the standard deviation of the metric after division by the square root of the total number of observations.

Parameters
  • means_t (Union[ndarray, List[float], float]) – Sample means (test)

  • sems_t (Union[ndarray, List[float], float]) – Sample standard errors of the means (test)

  • mean_c (float) – Sample mean (control)

  • sem_c (float) – Sample standard error of the mean (control)

  • cov_means (Union[ndarray, List[float], float]) – Sample covariance between test and control

  • as_percent (bool) – If true, return results in percent (* 100)

Returns

Inferred means of the sampling distribution of

the relative change (mean_t / mean_c) - 1

sem_hat: Inferred standard deviation of the sampling

distribution of rel_hat – i.e. the standard error.

Return type

rel_hat

ax.utils.stats.statstools.total_variance(means, variances, sample_sizes)[source]

Compute total variance.

Return type

float

Testing

Core Stubs

ax.utils.testing.core_stubs.get_abandoned_arm()[source]
Return type

AbandonedArm

ax.utils.testing.core_stubs.get_arm()[source]
Return type

Arm

ax.utils.testing.core_stubs.get_arm_weights1()[source]
Return type

MutableMapping[Arm, float]

ax.utils.testing.core_stubs.get_arm_weights2()[source]
Return type

MutableMapping[Arm, float]

ax.utils.testing.core_stubs.get_arms()[source]
Return type

List[Arm]

ax.utils.testing.core_stubs.get_arms_from_dict(arm_weights_dict)[source]
Return type

List[Arm]

ax.utils.testing.core_stubs.get_batch_trial(abandon_arm=True)[source]
Return type

BatchTrial

ax.utils.testing.core_stubs.get_batch_trial_with_repeated_arms(num_repeated_arms)[source]

Create a batch that contains both new arms and N arms from the last existed trial in the experiment. Where N is equal to the input argument ‘num_repeated_arms’.

Return type

BatchTrial

ax.utils.testing.core_stubs.get_branin_arms(n, seed)[source]
Return type

List[Arm]

ax.utils.testing.core_stubs.get_branin_data()[source]
Return type

Data

ax.utils.testing.core_stubs.get_branin_experiment(has_optimization_config=True, with_batch=False, with_status_quo=False)[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_branin_metric(name='branin')[source]
Return type

BraninMetric

ax.utils.testing.core_stubs.get_branin_objective()[source]
Return type

Objective

ax.utils.testing.core_stubs.get_branin_optimization_config()[source]
Return type

OptimizationConfig

ax.utils.testing.core_stubs.get_branin_outcome_constraint()[source]
Return type

OutcomeConstraint

ax.utils.testing.core_stubs.get_branin_search_space()[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_choice_parameter()[source]
Return type

ChoiceParameter

ax.utils.testing.core_stubs.get_data(trial_index=0)[source]
Return type

Data

ax.utils.testing.core_stubs.get_discrete_search_space()[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_experiment()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_batch_and_single_trial()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_batch_trial()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_data()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_multi_objective()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_repeated_arms(num_repeated_arms)[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_experiment_with_scalarized_objective()[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_factorial_experiment(has_optimization_config=True, with_batch=False, with_status_quo=False)[source]
Return type

Experiment

ax.utils.testing.core_stubs.get_factorial_metric(name='success_metric')[source]
Return type

FactorialMetric

ax.utils.testing.core_stubs.get_factorial_search_space()[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_fixed_parameter()[source]
Return type

FixedParameter

ax.utils.testing.core_stubs.get_generator_run()[source]
Return type

GeneratorRun

ax.utils.testing.core_stubs.get_generator_run2()[source]
Return type

GeneratorRun

ax.utils.testing.core_stubs.get_hartmann_metric()[source]
Return type

Hartmann6Metric

ax.utils.testing.core_stubs.get_metric()[source]
Return type

Metric

ax.utils.testing.core_stubs.get_model_covariance()[source]
Return type

Dict[str, Dict[str, List[float]]]

ax.utils.testing.core_stubs.get_model_mean()[source]
Return type

Dict[str, List[float]]

ax.utils.testing.core_stubs.get_model_predictions()[source]
Return type

Tuple[Dict[str, List[float]], Dict[str, Dict[str, List[float]]]]

ax.utils.testing.core_stubs.get_model_predictions_per_arm()[source]
Return type

Dict[str, Tuple[Dict[str, float], Optional[Dict[str, Dict[str, float]]]]]

ax.utils.testing.core_stubs.get_multi_objective()[source]
Return type

Objective

ax.utils.testing.core_stubs.get_multi_type_experiment(add_trial_type=True, add_trials=False)[source]
Return type

MultiTypeExperiment

ax.utils.testing.core_stubs.get_objective()[source]
Return type

Objective

ax.utils.testing.core_stubs.get_optimization_config()[source]
Return type

OptimizationConfig

ax.utils.testing.core_stubs.get_optimization_config_no_constraints()[source]
Return type

OptimizationConfig

ax.utils.testing.core_stubs.get_order_constraint()[source]
Return type

OrderConstraint

ax.utils.testing.core_stubs.get_outcome_constraint()[source]
Return type

OutcomeConstraint

ax.utils.testing.core_stubs.get_parameter_constraint()[source]
Return type

ParameterConstraint

ax.utils.testing.core_stubs.get_range_parameter()[source]
Return type

RangeParameter

ax.utils.testing.core_stubs.get_range_parameter2()[source]
Return type

RangeParameter

ax.utils.testing.core_stubs.get_scalarized_objective()[source]
Return type

Objective

ax.utils.testing.core_stubs.get_search_space()[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_search_space_for_range_value(min=3.0, max=6.0)[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_search_space_for_range_values(min=3.0, max=6.0)[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_search_space_for_value(val=3.0)[source]
Return type

SearchSpace

ax.utils.testing.core_stubs.get_simple_experiment()[source]
Return type

SimpleExperiment

ax.utils.testing.core_stubs.get_simple_experiment_with_batch_trial()[source]
Return type

SimpleExperiment

ax.utils.testing.core_stubs.get_status_quo()[source]
Return type

Arm

ax.utils.testing.core_stubs.get_sum_constraint1()[source]
Return type

SumConstraint

ax.utils.testing.core_stubs.get_sum_constraint2()[source]
Return type

SumConstraint

ax.utils.testing.core_stubs.get_synthetic_runner()[source]
Return type

SyntheticRunner

ax.utils.testing.core_stubs.get_trial()[source]
Return type

Trial

ax.utils.testing.core_stubs.get_weights()[source]
Return type

List[float]

ax.utils.testing.core_stubs.get_weights_from_dict(arm_weights_dict)[source]
Return type

List[float]

Modeling Stubs

ax.utils.testing.modeling_stubs.get_experiment_for_value()[source]
Return type

Experiment

ax.utils.testing.modeling_stubs.get_generation_strategy(with_experiment=False)[source]
Return type

GenerationStrategy

ax.utils.testing.modeling_stubs.get_observation()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation1()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation1trans()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation2()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation2trans()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation_features()[source]
Return type

ObservationFeatures

ax.utils.testing.modeling_stubs.get_observation_status_quo0()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_observation_status_quo1()[source]
Return type

Observation

ax.utils.testing.modeling_stubs.get_transform_type()[source]
Return type

Type[Transform]

class ax.utils.testing.modeling_stubs.transform_1(search_space, observation_features, observation_data, config=None)[source]

Bases: ax.modelbridge.transforms.base.Transform

config = None
transform_observation_data(observation_data, observation_features)[source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.

Parameters

Returns: transformed observation data

Return type

List[ObservationData]

transform_observation_features(observation_features)[source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features (List[ObservationFeatures]) – Observation features

Returns: transformed observation features

Return type

List[ObservationFeatures]

transform_optimization_config(optimization_config, modelbridge, fixed_features)[source]

Transform optimization config.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

optimization_config (OptimizationConfig) – The optimization config

Returns: transformed optimization config.

Return type

OptimizationConfig

transform_search_space(search_space)[source]

Transform search space.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

search_space (SearchSpace) – The search space

Returns: transformed search space.

Return type

SearchSpace

untransform_observation_data(observation_data, observation_features)[source]

Untransform observation data.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

Returns: observation data in original space.

Return type

List[ObservationData]

untransform_observation_features(observation_features)[source]

Untransform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features (List[ObservationFeatures]) – Observation features in the transformed space

Returns: observation features in the original space

Return type

List[ObservationFeatures]

class ax.utils.testing.modeling_stubs.transform_2(search_space, observation_features, observation_data, config=None)[source]

Bases: ax.modelbridge.transforms.base.Transform

config = None
transform_observation_data(observation_data, observation_features)[source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.

Parameters

Returns: transformed observation data

Return type

List[ObservationData]

transform_observation_features(observation_features)[source]

Transform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features (List[ObservationFeatures]) – Observation features

Returns: transformed observation features

Return type

List[ObservationFeatures]

transform_optimization_config(optimization_config, modelbridge, fixed_features)[source]

Transform optimization config.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

optimization_config (OptimizationConfig) – The optimization config

Returns: transformed optimization config.

Return type

OptimizationConfig

transform_search_space(search_space)[source]

Transform search space.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

search_space (SearchSpace) – The search space

Returns: transformed search space.

Return type

SearchSpace

untransform_observation_data(observation_data, observation_features)[source]

Untransform observation data.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

Returns: observation data in original space.

Return type

List[ObservationData]

untransform_observation_features(observation_features)[source]

Untransform observation features.

This is typically done in-place. This class implements the identity transform (does nothing).

Parameters

observation_features (List[ObservationFeatures]) – Observation features in the transformed space

Returns: observation features in the original space

Return type

List[ObservationFeatures]

Tutorials

Neural Net

class ax.utils.tutorials.cnn_utils.CNN[source]

Bases: torch.nn.modules.module.Module

Convolutional Neural Network.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ax.utils.tutorials.cnn_utils.evaluate(net, data_loader, dtype, device)[source]

Compute classification accuracy on provided dataset.

Parameters
  • net (Module) – trained model

  • data_loader (DataLoader) – DataLoader containing the evaluation set

  • dtype (dtype) – torch dtype

  • device (device) – torch device

Returns

classification accuracy

Return type

float

ax.utils.tutorials.cnn_utils.get_partition_data_loaders(train_valid_set, test_set, downsample_pct=0.5, train_pct=0.8, batch_size=128, num_workers=0, deterministic_partitions=False, downsample_pct_test=None)[source]
Helper function for partitioning training data into training and validation sets,

downsampling data, and initializing DataLoaders for each partition.

Parameters
  • train_valid_set (Dataset) – torch.dataset

  • downsample_pct (float) – the proportion of the dataset to use for training, and validation

  • train_pct (float) – the proportion of the downsampled data to use for training

  • batch_size (int) – how many samples per batch to load

  • num_workers (int) – number of workers (subprocesses) for loading data

  • deterministic_partitions (bool) – whether to partition data in a deterministic fashion

  • downsample_pct_test (Optional[float]) – the proportion of the dataset to use for test, default to be equal to downsample_pct

Returns

training data DataLoader: validation data DataLoader: test data

Return type

DataLoader

ax.utils.tutorials.cnn_utils.load_mnist(downsample_pct=0.5, train_pct=0.8, data_path='./data', batch_size=128, num_workers=0, deterministic_partitions=False, downsample_pct_test=None)[source]
Load MNIST dataset (download if necessary) and split data into training,

validation, and test sets.

Parameters
  • downsample_pct (float) – the proportion of the dataset to use for training, validation, and test

  • train_pct (float) – the proportion of the downsampled data to use for training

  • data_path (str) – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.

  • batch_size (int) – how many samples per batch to load

  • num_workers (int) – number of workers (subprocesses) for loading data

  • deterministic_partitions (bool) – whether to partition data in a deterministic fashion

  • downsample_pct_test (Optional[float]) – the proportion of the dataset to use for test, default to be equal to downsample_pct

Returns

training data DataLoader: validation data DataLoader: test data

Return type

DataLoader

ax.utils.tutorials.cnn_utils.split_dataset(dataset, lengths, deterministic_partitions=False)[source]

Split a dataset either randomly or deterministically.

Parameters
  • dataset (Dataset) – the dataset to split

  • lengths (List[int]) – the lengths of each partition

  • deterministic_partitions (bool) – deterministic_partitions: whether to partition data in a deterministic fashion

Returns

split datasets

Return type

List[Dataset]

ax.utils.tutorials.cnn_utils.train(net, train_loader, parameters, dtype, device)[source]

Train CNN on provided data set.

Parameters
  • net (Module) – initialized neural network

  • train_loader (DataLoader) – DataLoader containing training set

  • parameters (Dict[str, float]) – dictionary containing parameters to be passed to the optimizer. - lr: default (0.001) - momentum: default (0.0) - weight_decay: default (0.0) - num_epochs: default (1)

  • dtype (dtype) – torch dtype

  • device (device) – torch device

Returns

trained CNN.

Return type

nn.Module