ax.utils¶
Common¶
Docutils¶
Support functions for sphinx et. al
-
ax.utils.common.docutils.
copy_doc
(src: Callable[[...], Any]) → Callable[[_T], _T][source]¶ A decorator that copies the docstring of another object
Since
sphinx
actually loads the python modules to grab the docstrings this works with bothsphinx
and thehelp
function.class Cat(Mamal): @property @copy_doc(Mamal.is_feline) def is_feline(self) -> true: ...
Equality¶
-
ax.utils.common.equality.
dataframe_equals
(df1: pandas.DataFrame, df2: pandas.DataFrame) → bool[source]¶ Compare equality of two pandas dataframes.
-
ax.utils.common.equality.
datetime_equals
(dt1: Optional[datetime.datetime], dt2: Optional[datetime.datetime]) → bool[source]¶ Compare equality of two datetimes, ignoring microseconds.
-
ax.utils.common.equality.
equality_typechecker
(eq_func: Callable) → Callable[source]¶ A decorator to wrap all __eq__ methods to ensure that the inputs are of the right type.
-
ax.utils.common.equality.
object_attribute_dicts_equal
(one_dict: Dict[str, Any], other_dict: Dict[str, Any]) → bool[source]¶ Utility to check if all items in attribute dicts of two Ax objects are the same.
NOTE: Special-cases some Ax object attributes, like “_experiment” or “_model”, where full equality is hard to check.
-
ax.utils.common.equality.
object_attribute_dicts_find_unequal_fields
(one_dict: Dict[str, Any], other_dict: Dict[str, Any], fast_return: bool = True) → Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Tuple[Any, Any]]][source]¶ Utility for finding out what attributes of two objects’ attribute dicts are unequal.
- Parameters
one_dict – First object’s attribute dict (obj.__dict__).
other_dict – Second object’s attribute dict (obj.__dict__).
fast_return – Boolean representing whether to return as soon as a single unequal attribute was found or to iterate over all attributes and collect all unequal ones.
- Returns
attribute name to attribute values of unequal type (as a tuple),
attribute name to attribute values of unequal value (as a tuple).
- Return type
Two dictionaries
-
ax.utils.common.equality.
same_elements
(list1: List[Any], list2: List[Any]) → bool[source]¶ Compare equality of two lists of core Ax objects.
- Assumptions:
– The contents of each list are types that implement __eq__ – The lists do not contain duplicates
Checking equality is then the same as checking that the lists are the same length, and that one is a subset of the other.
Kwargs¶
-
ax.utils.common.kwargs.
consolidate_kwargs
(kwargs_iterable: Iterable[Optional[Dict[str, Any]]], keywords: Iterable[str]) → Dict[str, Any][source]¶ Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the ones that appear earlier and only kwargs referenced in keywords will be used. This allows to combine somewhat redundant sets of kwargs, where a user-set kwarg, for instance, needs to override a default kwarg.
>>> consolidate_kwargs( ... kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}], ... keywords=['a', 'b', 'd'] ... ) {'a': 1, 'b': 3, 'd': 5}
-
ax.utils.common.kwargs.
get_function_argument_names
(function: Callable, omit: Optional[List[str]] = None) → List[str][source]¶ Extract parameter names from function signature.
-
ax.utils.common.kwargs.
get_function_default_arguments
(function: Callable) → Dict[str, Any][source]¶ Extract default arguments from function signature.
-
ax.utils.common.kwargs.
validate_kwarg_typing
(typed_callables: List[Callable], **kwargs: Any) → None[source]¶ Check if keywords in kwargs exist in any of the typed_callables and if the type of each keyword value matches the type of corresponding arg in one of the callables
Note: this function expects the typed callables to have unique keywords for the arguments and will raise an error if repeat keywords are found.
Logger¶
-
class
ax.utils.common.logger.
AxOutputNameFilter
(name='')[source]¶ Bases:
logging.Filter
This is a filter which sets the record’s output_name, if not configured
-
ax.utils.common.logger.
build_file_handler
(filepath: str, level: int = 20) → logging.StreamHandler[source]¶ Build a file handle that logs entries to the given file, using the same formatting as the stream handler.
- Parameters
filepath – Location of the file to log output to. If the file exists, output will be appended. If it does not exist, a new file will be created.
level – The log level. By default, sets level to INFO
- Returns
A logging.FileHandler instance
-
ax.utils.common.logger.
build_stream_handler
(level: int = 20) → logging.StreamHandler[source]¶ Build the default stream handler used for most Ax logging. Sets default level to INFO, instead of WARNING.
- Parameters
level – The log level. By default, sets level to INFO
- Returns
A logging.StreamHandler instance
-
ax.utils.common.logger.
get_logger
(name: str) → logging.Logger[source]¶ Get an Axlogger.
To set a human-readable “output_name” that appears in logger outputs, add {“output_name”: “[MY_OUTPUT_NAME]”} to the logger’s contextual information. By default, we use the logger’s name
- Parameters
name – The name of the logger.
- Returns
The logging.Logger object.
Serialization¶
-
ax.utils.common.serialization.
callable_from_reference
(path: str) → Callable[source]¶ Retrieves a callable by its path.
-
ax.utils.common.serialization.
callable_to_reference
(callable: Callable) → str[source]¶ Obtains path to the callable of form <module>.<name>.
-
ax.utils.common.serialization.
extract_init_args
(args: Dict[str, Any], class_: Type) → Dict[str, Any][source]¶ Given a dictionary, extract the arguments required for the given class’s constructor.
Testutils¶
Support functions for tests
-
class
ax.utils.common.testutils.
TestCase
(methodName: str = 'runTest')[source]¶ Bases:
unittest.case.TestCase
The base Ax test case, contains various helper functions to write unittests.
-
assertEqual
(first: Any, second: Any, msg: Optional[str] = None) → None[source]¶ Fail if the two objects are unequal as determined by the ‘==’ operator.
-
Timeutils¶
-
ax.utils.common.timeutils.
current_timestamp_in_millis
() → int[source]¶ Grab current timestamp in milliseconds as an int.
-
ax.utils.common.timeutils.
timestamps_in_range
(start: datetime.datetime, end: datetime.datetime, delta: datetime.timedelta) → Generator[datetime.datetime, None, None][source]¶ Generator of timestamps in range [start, end], at intervals delta.
Typeutils¶
-
ax.utils.common.typeutils.
checked_cast
(typ: Type[T], val: V) → T[source]¶ Cast a value to a type (with a runtime safety check).
Returns the value unchanged and checks its type at runtime. This signals to the typechecker that the value has the designated type.
Like typing.cast
check_cast
performs no runtime conversion on its argument, but, unliketyping.cast
,checked_cast
will throw an error if the value is not of the expected type. The type passed as an argument should be a python class.- Parameters
typ – the type to cast to
val – the value that we are casting
- Returns
the
val
argument, unchanged
-
ax.utils.common.typeutils.
checked_cast_dict
(key_typ: Type[K], value_typ: Type[V], d: Dict[X, Y]) → Dict[K, V][source]¶ Calls checked_cast on all keys and values in the dictionary.
-
ax.utils.common.typeutils.
checked_cast_list
(typ: Type[T], old_l: List[V]) → List[T][source]¶ Calls checked_cast on all items in a list.
-
ax.utils.common.typeutils.
checked_cast_optional
(typ: Type[T], val: Optional[V]) → Optional[T][source]¶ Calls checked_cast only if value is not None.
-
ax.utils.common.typeutils.
checked_cast_to_tuple
(typ: Tuple[Type[V], ...], val: V) → T[source]¶ Cast a value to a union of multiple types (with a runtime safety check). This function is similar to checked_cast, but allows for the type to be defined as a tuple of types, in which case the value is cast as a union of the types in the tuple.
- Parameters
typ – the tuple of types to cast to
val – the value that we are casting
- Returns
the
val
argument, unchanged
-
ax.utils.common.typeutils.
not_none
(val: Optional[T]) → T[source]¶ Unbox an optional type.
- Parameters
val – the value to cast to a non
None
type- Returns
val
whenval
is notNone
- Return type
V
- Throws:
ValueError if
val
isNone
-
ax.utils.common.typeutils.
numpy_type_to_python_type
(value: Any) → Any[source]¶ If value is a Numpy int or float, coerce to a Python int or float. This is necessary because some of our transforms return Numpy values.
Measurement¶
Synthetic Functions¶
-
class
ax.utils.measurement.synthetic_functions.
Aug_Branin
[source]¶ Bases:
ax.utils.measurement.synthetic_functions.SyntheticFunction
Augmented Branin function (3-dimensional with infinitely many global minima).
-
class
ax.utils.measurement.synthetic_functions.
Aug_Hartmann6
[source]¶ Bases:
ax.utils.measurement.synthetic_functions.Hartmann6
Augmented Hartmann6 function (7-dimensional with 1 global minimum).
-
class
ax.utils.measurement.synthetic_functions.
Branin
[source]¶ Bases:
ax.utils.measurement.synthetic_functions.SyntheticFunction
Branin function (2-dimensional with 3 global minima).
-
class
ax.utils.measurement.synthetic_functions.
FromBotorch
(botorch_synthetic_function: botorch.test_functions.synthetic.SyntheticTestFunction)[source]¶ Bases:
ax.utils.measurement.synthetic_functions.SyntheticFunction
-
property
name
¶
-
property
-
class
ax.utils.measurement.synthetic_functions.
Hartmann6
[source]¶ Bases:
ax.utils.measurement.synthetic_functions.SyntheticFunction
Hartmann6 function (6-dimensional with 1 global minimum).
-
class
ax.utils.measurement.synthetic_functions.
SyntheticFunction
[source]¶ Bases:
abc.ABC
-
property
domain
¶
-
f
(X: numpy.ndarray) → Union[float, numpy.ndarray][source]¶ Synthetic function implementation.
- Parameters
X (numpy.ndarray) – an n by d array, where n represents the number of observations and d is the dimensionality of the inputs.
- Returns
an n-dimensional array.
- Return type
numpy.ndarray
-
property
fmax
¶
-
property
fmin
¶
-
property
maximums
¶
-
property
minimums
¶
-
property
name
¶
-
property
required_dimensionality
¶
-
property
Notebook¶
Report¶
Render¶
-
ax.utils.report.render.
link_html
(text: str, href: str) → str[source]¶ Embed text and reference address into link tag.
-
ax.utils.report.render.
render_report_elements
(experiment_name: str, html_elements: List[str], header: bool = True, offline: bool = False, notebook_env: bool = False) → str[source]¶ Generate Ax HTML report for a given experiment from HTML elements.
Uses Jinja2 for template. Injects Plotly JS for graph rendering.
Example:
html_elements = [ h2_html("Subsection with plot"), p_html("This is an example paragraph."), plot_html(plot_fitted(gp_model, 'perf_metric')), h2_html("Subsection with table"), pandas_html(data.df), ] html = render_report_elements('My experiment', html_elements)
- Parameters
experiment_name – the name of the experiment to use for title.
html_elements – list of HTML strings to render in report body.
header – if True, render experiment title as a header. Meant to be used for standalone reports (e.g. via email), as opposed to served on the front-end.
offline – if True, entire Plotly library is bundled with report.
notebook_env – if True, caps the report width to 700px for viewing in a notebook environment.
- Returns
HTML string.
- Return type
-
ax.utils.report.render.
table_cell_html
(text: str, width: Optional[str] = None) → str[source]¶ Embed text or an HTML element into table cell tag.
-
ax.utils.report.render.
table_heading_cell_html
(text: str) → str[source]¶ Embed text or an HTML element into table heading cell tag.
-
ax.utils.report.render.
table_html
(table_rows: List[str]) → str[source]¶ Embed list of HTML elements into table tag.
Stats¶
Statstools¶
-
ax.utils.stats.statstools.
agresti_coull_sem
(n_numer: Union[pandas.Series, numpy.ndarray, int], n_denom: Union[pandas.Series, numpy.ndarray, int], prior_successes: int = 2, prior_failures: int = 2) → Union[numpy.ndarray, float][source]¶ Compute the Agresti-Coull style standard error for a binomial proportion.
Reference: Agresti, Alan, and Brent A. Coull. Approximate Is Better than ‘Exact’ for Interval Estimation of Binomial Proportions.” The American Statistician, vol. 52, no. 2, 1998, pp. 119-126. JSTOR, www.jstor.org/stable/2685469.
-
ax.utils.stats.statstools.
inverse_variance_weight
(means: numpy.ndarray, variances: numpy.ndarray, conflicting_noiseless: str = 'warn') → Tuple[float, float][source]¶ Perform inverse variance weighting.
- Parameters
means – The means of the observations.
variances – The variances of the observations.
conflicting_noiseless – How to handle the case of multiple observations with zero variance but different means. Options are “warn” (default), “ignore” or “raise”.
-
ax.utils.stats.statstools.
marginal_effects
(df: pandas.DataFrame) → pandas.DataFrame[source]¶ This method calculates the relative (in %) change in the outcome achieved by using any individual factor level versus randomizing across all factor levels. It does this by estimating a baseline under the experiment by marginalizing over all factors/levels. For each factor level, then, it conditions on that level for the individual factor and then marginalizes over all levels for all other factors.
- Parameters
df – Dataframe containing columns named mean and sem. All other columns are assumed to be factors for which to calculate marginal effects.
- Returns
- A dataframe containing columns “Name”, “Level”, “Beta” and “SE”
corresponding to the factor, level, effect and standard error. Results are relativized as percentage changes.
-
ax.utils.stats.statstools.
positive_part_james_stein
(means: Union[numpy.ndarray, List[float]], sems: Union[numpy.ndarray, List[float]]) → Tuple[numpy.ndarray, numpy.ndarray][source]¶ Estimation method for Positive-part James-Stein estimator.
This method takes a vector of K means (y_i) and standard errors (sigma_i) and calculates the positive-part James Stein estimator.
Resulting estimates are the shrunk means and standard errors. The positive part James-Stein estimator shrinks each constituent average to the grand average:
y_i - phi_i * y_i + phi_i * ybar
The variable phi_i determines the amount of shrinkage. For phi_i = 1, mu_hat is equal to ybar (the mean of all y_i), while for phi_i = 0, mu_hat is equal to y_i. It can be shown that restricting phi_i <= 1 dominates the unrestricted estimator, so this method restricts phi_i in this manner. The amount of shrinkage, phi_i, is determined by:
(K - 3) * sigma2_i / s2
That is, less shrinkage is applied when individual means are estimated with greater precision, and more shrinkage is applied when individual means are very tightly clustered together. We also restrict phi_i to never be larger than 1.
The variance of the mean estimator is:
(1 - phi_i) * sigma2_i + phi * sigma2_i / K + 2 * phi_i ** 2 * (y_i - ybar)^2 / (K - 3)
The first term is the variance component from y_i, the second term is the contribution from the mean of all y_i, and the third term is the contribution from the uncertainty in the sum of squared deviations of y_i from the mean of all y_i.
For more information, see https://fburl.com/empirical_bayes.
- Parameters
means – Means of each arm
sems – Standard errors of each arm
- Returns
Empirical Bayes estimate of each arm’s mean sem_i: Empirical Bayes estimate of each arm’s sem
- Return type
mu_hat_i
-
ax.utils.stats.statstools.
relativize
(means_t: Union[numpy.ndarray, List[float], float], sems_t: Union[numpy.ndarray, List[float], float], mean_c: float, sem_c: float, bias_correction: bool = True, cov_means: Union[numpy.ndarray, List[float], float] = 0.0, as_percent: bool = False) → Tuple[numpy.ndarray, numpy.ndarray][source]¶ Ratio estimator based on the delta method.
This uses the delta method (i.e. a Taylor series approximation) to estimate the mean and standard deviation of the sampling distribution of the ratio between test and control – that is, the sampling distribution of an estimator of the true population value under the assumption that the means in test and control have a known covariance:
(mu_t / mu_c) - 1.
Under a second-order Taylor expansion, the sampling distribution of the relative change in empirical means, which is m_t / m_c - 1, is approximately normally distributed with mean
[(mu_t - mu_c) / mu_c] - [(sigma_c)^2 * mu_t] / (mu_c)^3
and variance
(sigma_t / mu_c)^2 - 2 * mu_t _ sigma_tc / mu_c^3 + [(sigma_c * mu_t)^2 / (mu_c)^4]
as the higher terms are assumed to be close to zero in the full Taylor series. To estimate these parameters, we plug in the empirical means and standard errors. This gives us the estimators:
[(m_t - m_c) / m_c] - [(s_c)^2 * m_t] / (m_c)^3
and
(s_t / m_c)^2 - 2 * m_t * s_tc / m_c^3 + [(s_c * m_t)^2 / (m_c)^4]
Note that the delta method does NOT take as input the empirical standard deviation of a metric, but rather the standard error of the mean of that metric – that is, the standard deviation of the metric after division by the square root of the total number of observations.
- Parameters
means_t – Sample means (test)
sems_t – Sample standard errors of the means (test)
mean_c – Sample mean (control)
sem_c – Sample standard error of the mean (control)
cov_means – Sample covariance between test and control
as_percent – If true, return results in percent (* 100)
- Returns
- Inferred means of the sampling distribution of
the relative change (mean_t / mean_c) - 1
- sem_hat: Inferred standard deviation of the sampling
distribution of rel_hat – i.e. the standard error.
- Return type
rel_hat
-
ax.utils.stats.statstools.
relativize_data
(data: ax.core.data.Data, status_quo_name: str = 'status_quo', as_percent: bool = False, include_sq: bool = False) → ax.core.data.Data[source]¶ Relativize a data object w.r.t. a status_quo arm.
- Parameters
data – The data object to be relativized.
status_quo_name – The name of the status_quo arm.
as_percent – If True, return results as percentage change.
include_sq – Include status quo in final df.
- Returns
- The new data object with the relativized metrics (excluding the
status_quo arm)
Testing¶
Core Stubs¶
-
ax.utils.testing.core_stubs.
get_acquisition_function_type
() → Type[botorch.acquisition.acquisition.AcquisitionFunction][source]¶
-
ax.utils.testing.core_stubs.
get_acquisition_type
() → Type[ax.models.torch.botorch_modular.acquisition.Acquisition][source]¶
-
ax.utils.testing.core_stubs.
get_arms_from_dict
(arm_weights_dict: MutableMapping[ax.core.arm.Arm, float]) → List[ax.core.arm.Arm][source]¶
-
ax.utils.testing.core_stubs.
get_augmented_branin_metric
(name='aug_branin') → ax.metrics.branin.AugmentedBraninMetric[source]¶
-
ax.utils.testing.core_stubs.
get_augmented_branin_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_augmented_hartmann_metric
(name='aug_hartmann') → ax.metrics.hartmann6.AugmentedHartmann6Metric[source]¶
-
ax.utils.testing.core_stubs.
get_augmented_hartmann_objective
() → ax.core.objective.Objective[source]¶
-
ax.utils.testing.core_stubs.
get_augmented_hartmann_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_batch_trial
(abandon_arm: bool = True) → ax.core.batch_trial.BatchTrial[source]¶
-
ax.utils.testing.core_stubs.
get_batch_trial_with_repeated_arms
(num_repeated_arms: int) → ax.core.batch_trial.BatchTrial[source]¶ Create a batch that contains both new arms and N arms from the last existed trial in the experiment. Where N is equal to the input argument ‘num_repeated_arms’.
-
ax.utils.testing.core_stubs.
get_botorch_model
() → ax.models.torch.botorch_modular.model.BoTorchModel[source]¶
-
ax.utils.testing.core_stubs.
get_botorch_model_with_default_acquisition_class
() → ax.models.torch.botorch_modular.model.BoTorchModel[source]¶
-
ax.utils.testing.core_stubs.
get_branin_data
(trial_indices: Optional[Iterable[int]] = None) → ax.core.data.Data[source]¶
-
ax.utils.testing.core_stubs.
get_branin_data_batch
(batch: ax.core.batch_trial.BatchTrial) → ax.core.data.Data[source]¶
-
ax.utils.testing.core_stubs.
get_branin_data_multi_objective
(trial_indices: Optional[Iterable[int]] = None) → ax.core.data.Data[source]¶
-
ax.utils.testing.core_stubs.
get_branin_experiment
(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False, with_choice_parameter: bool = False, search_space: Optional[ax.core.search_space.SearchSpace] = None) → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_branin_experiment_with_multi_objective
(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False, with_fidelity_parameter: bool = False) → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_branin_metric
(name='branin') → ax.metrics.branin.BraninMetric[source]¶
-
ax.utils.testing.core_stubs.
get_branin_multi_objective_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_branin_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_branin_outcome_constraint
() → ax.core.outcome_constraint.OutcomeConstraint[source]¶
-
ax.utils.testing.core_stubs.
get_branin_search_space
(with_fidelity_parameter: bool = False, with_choice_parameter: bool = False) → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_batch_and_single_trial
() → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_batch_trial
() → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_multi_objective
() → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_repeated_arms
(num_repeated_arms: int) → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_scalarized_objective
() → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_experiment_with_trial_with_ttl
() → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_factorial_experiment
(has_optimization_config: bool = True, with_batch: bool = False, with_status_quo: bool = False) → ax.core.experiment.Experiment[source]¶
-
ax.utils.testing.core_stubs.
get_factorial_metric
(name: str = 'success_metric') → ax.metrics.factorial.FactorialMetric[source]¶
-
ax.utils.testing.core_stubs.
get_factorial_search_space
() → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_hartmann_metric
(name='hartmann') → ax.metrics.hartmann6.Hartmann6Metric[source]¶
-
ax.utils.testing.core_stubs.
get_hartmann_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_hartmann_search_space
(with_fidelity_parameter: bool = False) → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_list_surrogate
() → ax.models.torch.botorch_modular.surrogate.Surrogate[source]¶
-
ax.utils.testing.core_stubs.
get_mll_type
() → Type[gpytorch.mlls.marginal_log_likelihood.MarginalLogLikelihood][source]¶
-
ax.utils.testing.core_stubs.
get_model_predictions
() → Tuple[Dict[str, List[float]], Dict[str, Dict[str, List[float]]]][source]¶
-
ax.utils.testing.core_stubs.
get_model_predictions_per_arm
() → Dict[str, Tuple[Dict[str, float], Optional[Dict[str, Dict[str, float]]]]][source]¶
-
ax.utils.testing.core_stubs.
get_multi_objective_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_multi_type_experiment
(add_trial_type: bool = True, add_trials: bool = False) → ax.core.multi_type_experiment.MultiTypeExperiment[source]¶
-
ax.utils.testing.core_stubs.
get_objective_threshold
() → ax.core.outcome_constraint.ObjectiveThreshold[source]¶
-
ax.utils.testing.core_stubs.
get_optimization_config
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_optimization_config_no_constraints
() → ax.core.optimization_config.OptimizationConfig[source]¶
-
ax.utils.testing.core_stubs.
get_order_constraint
() → ax.core.parameter_constraint.OrderConstraint[source]¶
-
ax.utils.testing.core_stubs.
get_outcome_constraint
() → ax.core.outcome_constraint.OutcomeConstraint[source]¶
-
ax.utils.testing.core_stubs.
get_parameter_constraint
() → ax.core.parameter_constraint.ParameterConstraint[source]¶
-
ax.utils.testing.core_stubs.
get_search_space_for_range_value
(min: float = 3.0, max: float = 6.0) → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_search_space_for_range_values
(min: float = 3.0, max: float = 6.0) → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_search_space_for_value
(val: float = 3.0) → ax.core.search_space.SearchSpace[source]¶
-
ax.utils.testing.core_stubs.
get_simple_experiment
() → ax.core.simple_experiment.SimpleExperiment[source]¶
-
ax.utils.testing.core_stubs.
get_simple_experiment_with_batch_trial
() → ax.core.simple_experiment.SimpleExperiment[source]¶
-
ax.utils.testing.core_stubs.
get_sum_constraint1
() → ax.core.parameter_constraint.SumConstraint[source]¶
-
ax.utils.testing.core_stubs.
get_sum_constraint2
() → ax.core.parameter_constraint.SumConstraint[source]¶
Modeling Stubs¶
-
ax.utils.testing.modeling_stubs.
get_generation_strategy
(with_experiment: bool = False, with_callable_model_kwarg: bool = True) → ax.modelbridge.generation_strategy.GenerationStrategy[source]¶
-
ax.utils.testing.modeling_stubs.
get_observation_features
() → ax.core.observation.ObservationFeatures[source]¶
-
ax.utils.testing.modeling_stubs.
get_observation_status_quo0
() → ax.core.observation.Observation[source]¶
-
ax.utils.testing.modeling_stubs.
get_observation_status_quo1
() → ax.core.observation.Observation[source]¶
-
ax.utils.testing.modeling_stubs.
get_transform_type
() → Type[ax.modelbridge.transforms.base.Transform][source]¶
-
class
ax.utils.testing.modeling_stubs.
transform_1
(search_space: Optional[ax.core.search_space.SearchSpace], observation_features: List[ax.core.observation.ObservationFeatures], observation_data: List[ax.core.observation.ObservationData], config: Optional[Dict[str, Union[int, float, str, botorch.acquisition.acquisition.AcquisitionFunction, Dict[str, Any]]]] = None)[source]¶ Bases:
ax.modelbridge.transforms.base.Transform
-
config
= None¶
-
transform_observation_data
(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationData][source]¶ Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.
- Parameters
observation_data – Observation data
observation_features – Corresponding observation features
Returns: transformed observation data
-
transform_observation_features
(observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationFeatures][source]¶ Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_features – Observation features
Returns: transformed observation features
-
transform_optimization_config
(optimization_config: ax.core.optimization_config.OptimizationConfig, modelbridge: Optional[ax.modelbridge.base.ModelBridge], fixed_features: ax.core.observation.ObservationFeatures) → ax.core.optimization_config.OptimizationConfig[source]¶ Transform optimization config.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
optimization_config – The optimization config
Returns: transformed optimization config.
-
transform_search_space
(search_space: ax.core.search_space.SearchSpace) → ax.core.search_space.SearchSpace[source]¶ Transform search space.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
search_space – The search space
Returns: transformed search space.
-
untransform_observation_data
(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationData][source]¶ Untransform observation data.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_data – Observation data, in transformed space
observation_features – Corresponding observation features, in same space.
Returns: observation data in original space.
-
untransform_observation_features
(observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationFeatures][source]¶ Untransform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_features – Observation features in the transformed space
Returns: observation features in the original space
-
-
class
ax.utils.testing.modeling_stubs.
transform_2
(search_space: Optional[ax.core.search_space.SearchSpace], observation_features: List[ax.core.observation.ObservationFeatures], observation_data: List[ax.core.observation.ObservationData], config: Optional[Dict[str, Union[int, float, str, botorch.acquisition.acquisition.AcquisitionFunction, Dict[str, Any]]]] = None)[source]¶ Bases:
ax.modelbridge.transforms.base.Transform
-
config
= None¶
-
transform_observation_data
(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationData][source]¶ Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
This takes in observation_features, so that data transforms can be conditional on features, but observation_features are notmutated.
- Parameters
observation_data – Observation data
observation_features – Corresponding observation features
Returns: transformed observation data
-
transform_observation_features
(observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationFeatures][source]¶ Transform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_features – Observation features
Returns: transformed observation features
-
transform_optimization_config
(optimization_config: ax.core.optimization_config.OptimizationConfig, modelbridge: Optional[ax.modelbridge.base.ModelBridge], fixed_features: ax.core.observation.ObservationFeatures) → ax.core.optimization_config.OptimizationConfig[source]¶ Transform optimization config.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
optimization_config – The optimization config
Returns: transformed optimization config.
-
transform_search_space
(search_space: ax.core.search_space.SearchSpace) → ax.core.search_space.SearchSpace[source]¶ Transform search space.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
search_space – The search space
Returns: transformed search space.
-
untransform_observation_data
(observation_data: List[ax.core.observation.ObservationData], observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationData][source]¶ Untransform observation data.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_data – Observation data, in transformed space
observation_features – Corresponding observation features, in same space.
Returns: observation data in original space.
-
untransform_observation_features
(observation_features: List[ax.core.observation.ObservationFeatures]) → List[ax.core.observation.ObservationFeatures][source]¶ Untransform observation features.
This is typically done in-place. This class implements the identity transform (does nothing).
- Parameters
observation_features – Observation features in the transformed space
Returns: observation features in the original space
-
Tutorials¶
Neural Net¶
-
class
ax.utils.tutorials.cnn_utils.
CNN
[source]¶ Bases:
torch.nn.modules.module.Module
Convolutional Neural Network.
-
forward
(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training
= None¶
-
-
ax.utils.tutorials.cnn_utils.
evaluate
(net: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader, dtype: torch.dtype, device: torch.device) → float[source]¶ Compute classification accuracy on provided dataset.
- Parameters
net – trained model
data_loader – DataLoader containing the evaluation set
dtype – torch dtype
device – torch device
- Returns
classification accuracy
- Return type
-
ax.utils.tutorials.cnn_utils.
get_partition_data_loaders
(train_valid_set: torch.utils.data.dataset.Dataset, test_set: torch.utils.data.dataset.Dataset, downsample_pct: float = 0.5, train_pct: float = 0.8, batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None) → Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader][source]¶ - Helper function for partitioning training data into training and validation sets,
downsampling data, and initializing DataLoaders for each partition.
- Parameters
train_valid_set – torch.dataset
downsample_pct – the proportion of the dataset to use for training, and validation
train_pct – the proportion of the downsampled data to use for training
batch_size – how many samples per batch to load
num_workers – number of workers (subprocesses) for loading data
deterministic_partitions – whether to partition data in a deterministic fashion
downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct
- Returns
training data DataLoader: validation data DataLoader: test data
- Return type
DataLoader
-
ax.utils.tutorials.cnn_utils.
load_mnist
(downsample_pct: float = 0.5, train_pct: float = 0.8, data_path: str = './data', batch_size: int = 128, num_workers: int = 0, deterministic_partitions: bool = False, downsample_pct_test: Optional[float] = None) → Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader][source]¶ - Load MNIST dataset (download if necessary) and split data into training,
validation, and test sets.
- Parameters
downsample_pct – the proportion of the dataset to use for training, validation, and test
train_pct – the proportion of the downsampled data to use for training
data_path – Root directory of dataset where MNIST/processed/training.pt and MNIST/processed/test.pt exist.
batch_size – how many samples per batch to load
num_workers – number of workers (subprocesses) for loading data
deterministic_partitions – whether to partition data in a deterministic fashion
downsample_pct_test – the proportion of the dataset to use for test, default to be equal to downsample_pct
- Returns
training data DataLoader: validation data DataLoader: test data
- Return type
DataLoader
-
ax.utils.tutorials.cnn_utils.
split_dataset
(dataset: torch.utils.data.dataset.Dataset, lengths: List[int], deterministic_partitions: bool = False) → List[torch.utils.data.dataset.Dataset][source]¶ Split a dataset either randomly or deterministically.
- Parameters
dataset – the dataset to split
lengths – the lengths of each partition
deterministic_partitions – deterministic_partitions: whether to partition data in a deterministic fashion
- Returns
split datasets
- Return type
List[Dataset]
-
ax.utils.tutorials.cnn_utils.
train
(net: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, parameters: Dict[str, float], dtype: torch.dtype, device: torch.device) → torch.nn.modules.module.Module[source]¶ Train CNN on provided data set.
- Parameters
net – initialized neural network
train_loader – DataLoader containing training set
parameters – dictionary containing parameters to be passed to the optimizer. - lr: default (0.001) - momentum: default (0.0) - weight_decay: default (0.0) - num_epochs: default (1)
dtype – torch dtype
device – torch device
- Returns
trained CNN.
- Return type
nn.Module