Skip to main content
Version: Next

BONSAI + MAP-SAAS Tutorial:

This tutorial demonstrates how to use BONSAI (Bayesian Optimization with Natural Simplicity and Interpretability) with MAP-SAAS to optimize the high-dimensional Hartmann50 benchmark problem and simplify proposals in order to only make necessary chagnes from the default (status quo) parameter values.

Overview

  • Hartmann50: A 50-dimensional synthetic benchmark where only 6 dimensions are relevant (the true Hartmann function), and 44 dimensions are "dummy" irrelevant dimensions.
  • BONSAI: A Bayesian optimization method that removes irrelevant parameter changes to simplify proposals from Ax. This simplifies the proposals so that they change fewer parameters, making the proposals more interpretable and more likely to avoid regressions in metrics not captured in the optimization objective.
  • MAP-SAAS: A fast Gaussian process model that has a SAAS (sparsity) prior.

This combination is particularly powerful for high-dimensional problems with low effective dimensionality.

These methods were proposed in Daulton, et al. BONSAI: Bayesian Optimization with Natural Simplicity and Interpretability, ArXiv, 2026.

1. Imports

import numpy as np
import torch

from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.api.utils.generation_strategy_dispatch import choose_generation_strategy
from ax.api.utils.structs import GenerationStrategyDispatchStruct

# Model configuration
from ax.generators.torch.botorch_modular.surrogate import ModelConfig

# BoTorch model (the key component for BONSAI)
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP

print(f"Using torch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
Output:
Using torch version: 2.10.0+cu128
CUDA available: False

2. Understanding the Components

2.1 Hartmann50 Problem

The Hartmann50 problem is a 50-dimensional optimization problem where:

  • The first 6 dimensions contain the actual Hartmann function (which has 6 local minima)
  • The remaining 44 dimensions are "dummy" and do not affect the objective value
  • This makes it an ideal test case for algorithms that can identify and focus on relevant dimensions
  • The global minimum is approximately -3.32237

2.2 A MAP-SAAS model (EnsembleMapSaasSingleTaskGP)

This is a Gaussian Process model that uses an ensemble of independent GPs with different samples of the global sparsity level (integrating over the global sparsity level). It uses Maximum A Posteriori (MAP) estimation for fitting each member in the ensemble, which is significantly faster than using MCMC as in SAASBO. It levers the same sparsity prior as SAASBO, but is significantly faster.

2.3 BONSAI

BONSAI (Bayesian Optimization with Natural Simplicity and Interpretability) is a technique for post-processing candidates generated by BO to prune irrelevant parameter changes from the default (status quo or target) values. It is compatabile with any acquisition function and is easily enabled by specifying simplify_parameter_changes=True in Ax.

3. Set Up the Hartmann50 Optimization Problem

We'll create a Client and configure the experiment with 50 parameters.

# Create a client
client = Client()

# Define 50 parameters (x0 through x49) in the unit hypercube [0, 1]
parameters = [
RangeParameterConfig(
name=f"x{i}",
parameter_type="float",
bounds=(0.0, 1.0),
)
for i in range(50)
]

# Configure the experiment
client.configure_experiment(parameters=parameters)

# Define the center of the search space as the pruning target
# Parameters that are "pruned" will be set to these default values
pruning_target = {f"x{i}": 0.5 for i in range(50)}

# Configure optimization to minimize the objective
metric_name = "hartmann"
objective = f"-{metric_name}" # Negative sign indicates minimization
client.configure_optimization(
objective=objective,
pruning_target_parameterization=pruning_target,
)

print(f"Experiment configured with {len(parameters)} parameters")
print(f"Objective: minimize {metric_name}")
print(f"Pruning target: center of search space (0.5 for all parameters)")
Output:
Experiment configured with 50 parameters
Objective: minimize hartmann
Pruning target: center of search space (0.5 for all parameters)

4. Define the Hartmann50 Objective Function

The Hartmann50 function uses the 6D Hartmann function on the first 6 dimensions, with 44 dummy dimensions that don't affect the output.

from botorch.test_functions import Hartmann

# Create the 6D Hartmann function
hartmann_6d: Hartmann = Hartmann(dim=6, negate=False)


def hartmann50(**parameters) -> float:
"""Evaluate the Hartmann50 function.

Only the first 6 dimensions (x0-x5) affect the output.
The remaining 44 dimensions (x6-x49) are ignored.

Args:
**parameters: Dict of parameter values (x0 through x49)

Returns:
The Hartmann function value (to be minimized).
"""
# Extract the first 6 parameters that actually matter
x = torch.tensor([[parameters[f"x{i}"] for i in range(6)]], dtype=torch.double)
return hartmann_6d(x).item()


# Test the function
test_params = {f"x{i}": 0.5 for i in range(50)}
print(f"Test evaluation at center: {hartmann50(**test_params):.4f}")
print("Global optimum is approximately: -3.32237")
Output:
Test evaluation at center: -0.5053
Global optimum is approximately: -3.32237

5. Configure the Generation Strategy with BONSAI and MAP-SAAS

We use choose_generation_strategy with GenerationStrategyDispatchStruct(method="custom", simplify_parameter_changes=True) to specify that we want to use BONSAI and we specify to use `EnsembleMapSaasSingleTaskGP`` to leverage MAP-SAAS.

# Configuration parameters
NUM_SOBOL_TRIALS = 10 # Number of initial quasi-random trials

# Configure the model for BONSAI with MAP-SAAS
model_config = ModelConfig(
botorch_model_class=EnsembleMapSaasSingleTaskGP,
name="BONSAI",
)

# Create the BONSAI generation strategy using choose_generation_strategy
generation_strategy = choose_generation_strategy(
struct=GenerationStrategyDispatchStruct(
method="custom",
initialization_budget=NUM_SOBOL_TRIALS,
initialize_with_center=True,
simplify_parameter_changes=True,
),
model_config=model_config,
)

# Set the generation strategy on the client
client.set_generation_strategy(generation_strategy=generation_strategy)

print(f"Generation strategy configured: {generation_strategy.name}")
print(" - 1 Center trial")
print(f" - {NUM_SOBOL_TRIALS - 1} Sobol trials")
print(" - BONSAI with MAP-SAAS")
print(" - simplify_parameter_changes=True (pruning irrelevant dimensions)")
Output:
Generation strategy configured: Center+Sobol+MBM:BONSAI
- 1 Center trial
- 9 Sobol trials
- BONSAI with MAP-SAAS
- simplify_parameter_changes=True (pruning irrelevant dimensions)

6. Run the Optimization Loop

import logging

# Set the Ax logger to show only warnings and errors.
logging.getLogger("ax.api.client").setLevel(logging.WARNING)

# Total number of trials
TOTAL_TRIALS = 50

# Track best values for visualization
best_values = []
all_values = []
current_best = float("inf")

print(f"Starting optimization with {TOTAL_TRIALS} trials...")
print("-" * 60)

for trial_idx in range(TOTAL_TRIALS):
# Get the next trial(s) from the generation strategy
trials = client.get_next_trials(max_trials=1)

for index, parameters in trials.items():
# Evaluate the objective function
result = hartmann50(**parameters)
all_values.append(result)

# Update best value (we're minimizing)
if result < current_best:
current_best = result
improvement_marker = " *NEW BEST*"
else:
improvement_marker = ""

best_values.append(current_best)

# Report the result back to Ax
client.complete_trial(
trial_index=index,
raw_data={metric_name: result},
)

# Determine which phase we're in
if trial_idx == 0:
phase = "Center"
elif trial_idx < NUM_SOBOL_TRIALS:
phase = "Sobol"
else:
phase = "BONSAI"

# Print progress (every 5 trials or when there's improvement)
if trial_idx % 5 == 0 or improvement_marker:
print(
f"Trial {trial_idx + 1:3d}/{TOTAL_TRIALS} [{phase:6s}]: "
f"value = {result:8.4f}, best = {current_best:8.4f}{improvement_marker}"
)

print("-" * 60)
print("Optimization complete!")
print(f"Best value found: {current_best:.4f}")
print("Global optimum: -3.32237")
print(f"Gap to optimum: {current_best - (-3.32237):.4f}")
Output:
Starting optimization with 50 trials...
------------------------------------------------------------
Trial 1/50 [Center]: value = -0.5053, best = -0.5053 *NEW BEST*
Trial 5/50 [Sobol ]: value = -1.8422, best = -1.8422 *NEW BEST*
Trial 6/50 [Sobol ]: value = -0.1485, best = -1.8422
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 11/50 [BONSAI]: value = -0.4900, best = -1.8422
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 16/50 [BONSAI]: value = -0.8209, best = -1.8422
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 21/50 [BONSAI]: value = -0.5053, best = -1.8422
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 26/50 [BONSAI]: value = -1.2257, best = -1.8422
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 28/50 [BONSAI]: value = -2.6131, best = -2.6131 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 30/50 [BONSAI]: value = -2.7232, best = -2.7232 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 31/50 [BONSAI]: value = -2.7779, best = -2.7779 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 32/50 [BONSAI]: value = -3.0603, best = -3.0603 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 35/50 [BONSAI]: value = -3.1692, best = -3.1692 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 36/50 [BONSAI]: value = -3.2638, best = -3.2638 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 39/50 [BONSAI]: value = -3.2915, best = -3.2915 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 41/50 [BONSAI]: value = -0.6084, best = -3.2915
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 43/50 [BONSAI]: value = -3.3088, best = -3.3088 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 44/50 [BONSAI]: value = -3.3154, best = -3.3154 *NEW BEST*
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
Trial 46/50 [BONSAI]: value = -3.3126, best = -3.3154
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
/opt/hostedtoolcache/Python/3.14.2/x64/lib/python3.14/site-packages/linear_operator/utils/cholesky.py:41: NumericalWarning:
A not p.d., added jitter of 1.0e-08 to the diagonal
------------------------------------------------------------
Optimization complete!
Best value found: -3.3154
Global optimum: -3.32237
Gap to optimum: 0.0069

7. Visualize Optimization Performance

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 6))

trials_range = range(1, len(best_values) + 1)

# Plot best values over trials (convergence plot)
ax.plot(trials_range, best_values, 'b-', linewidth=2, label='Best value found')
ax.axvline(x=1, color='purple', linestyle='--', alpha=0.5, label='Center')
ax.axvline(x=NUM_SOBOL_TRIALS, color='r', linestyle='--', alpha=0.7, label='Sobol → BONSAI')
ax.axhline(y=-3.32237, color='g', linestyle=':', alpha=0.7, label='Global optimum (-3.32)')

# Scatter plot of all trial values
colors = ['purple'] + ['orange'] * (NUM_SOBOL_TRIALS - 1) + ['blue'] * (TOTAL_TRIALS - NUM_SOBOL_TRIALS)
ax.scatter(trials_range, all_values, c=colors, alpha=0.6, s=50)

ax.set_xlabel('Trial', fontsize=12)
ax.set_ylabel('Objective Value', fontsize=12)
ax.set_title('BONSAI Optimization Progress on Hartmann50', fontsize=14)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
ax.set_xlim([1, TOTAL_TRIALS])

plt.tight_layout()
plt.show()

8. Analyze Objective vs Simplicity trade-offs

One of the key benefits of BONSAI is that it can prune irrelevant parameters (set them to default values). Let's analyze how the number of active parameters relates to the best objective value found.

def count_active_parameters(
parameters: dict[str, float], default_value: float = 0.5, tol: float = 1e-6
) -> int:
"""Count the number of parameters that differ from the default value."""
return sum(1 for v in parameters.values() if abs(v - default_value) > tol)


# Get all trials and their parameters
experiment = client._experiment
trials_data = []

for trial_index, trial in experiment.trials.items():
arm = trial.arm
if arm is not None:
params = arm.parameters
num_active = count_active_parameters(params)
# Get the objective value for this trial
trial_data = trial.lookup_data()
if not trial_data.df.empty:
obj_value = trial_data.df[trial_data.df["metric_name"] == metric_name][
"mean"
].values[0]
trials_data.append(
{
"trial_index": trial_index,
"num_active_params": num_active,
"objective_value": obj_value,
}
)

# Convert to arrays for plotting
num_active_params = [d["num_active_params"] for d in trials_data]

print(f"Collected data for {len(trials_data)} trials")
print("\nNumber of active parameters per trial:")
print(f" Min: {min(num_active_params)}")
print(f" Max: {max(num_active_params)}")
print(f" Mean: {np.mean(num_active_params):.1f}")

# Compute best observed objective value for each number of active parameters
from collections import defaultdict

# Group trials by number of active parameters
params_to_best_value = defaultdict(lambda: float('inf'))
for d in trials_data:
n_active = d["num_active_params"]
obj_val = d["objective_value"]
if obj_val < params_to_best_value[n_active]:
params_to_best_value[n_active] = obj_val

# Sort by number of active parameters
sorted_n_active = sorted(params_to_best_value.keys())

fig, ax = plt.subplots(figsize=(8, 5))

# Modify best_values_by_n_active to represent the best objective for any point with <=k parameters active
cumulative_best_values = []
current_best = float('inf')
for n in sorted_n_active:
current_best = min(current_best, params_to_best_value[n])
cumulative_best_values.append(current_best)

# Line plot: Best observed value versus number of active parameters based on cumulative best values
ax.plot(sorted_n_active, cumulative_best_values, color='steelblue', marker='o', linestyle='-', linewidth=2)
ax.axhline(y=-3.32237, color='g', linestyle=':', linewidth=2, label='Global optimum (-3.32)')
ax.axvline(x=6, color='r', linestyle='--', alpha=0.7, label='True relevant dims (6)')
ax.set_xlabel('Number of Active Parameters (<=k)', fontsize=12)
ax.set_ylabel('Best Objective Value', fontsize=12)
ax.set_title('Best Observed Value by Number of Active Parameters', fontsize=12)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Output:
Collected data for 50 trials
Number of active parameters per trial:
Min: 0
Max: 50
Mean: 14.3

10. Key Takeaways

Why BONSAI + MAP-SAAS?

  1. MAP-SAAS: MAP-SAAS is a variant of the Sparse Axis Aligned Subspace prior (Eriksson & Jankowiak. High-dimensional Bayesian optimization with sparse axis-aligned subspace, UAI, 2021), and places a half-Cauchy prior on the GP lengthscales. As a result, SAAS models encourage model sparsity, where less relevant inputs are driven toward long lengthscales. This improves performance on high-dimensional tasks, and is synergistic with BONSAI. Standard SAAS models use a time consuming, Bayesian (MCMC) inference procedure; MAP-SAAS provides many of the benefits of the fully Bayesian MAP SAAS by ensembling over just a few models estimated via MAP with significantly lower computational costs.

  2. BONSAI: BONSAI prunes irrelevant dimensions via simplify_parameter_changes=True and sets them to the pruning_target_parameterization (the status quo/default/production values or a target point of interest). This simplifies the proposals so that they change fewer parameters, making the proposals more interpretable and more likely to avoid regressions in metrics not captured in the optimization objective.

When to Use This Approach

  • Real-world optimization where simple, interpretable changes are desired.