Source code for ax.analysis.plotly.surface.contour

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import math
from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.surface.utils import (
    get_parameter_values,
    is_axis_log_scale,
    select_fixed_value,
)
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go
from pyre_extensions import none_throws


[docs] class ContourPlot(PlotlyAnalysis): """ Plot a 2D surface of the surrogate model's predicted outcomes for a given pair of parameters, where all other parameters are held fixed at their status-quo value or mean if no status quo is available. The DataFrame computed will contain the following columns: - PARAMETER_NAME: The value of the x parameter specified - PARAMETER_NAME: The value of the y parameter specified - METRIC_NAME: The predected mean of the metric specified """ def __init__( self, x_parameter_name: str, y_parameter_name: str, metric_name: str | None = None, ) -> None: """ Args: y_parameter_name: The name of the parameter to plot on the x-axis. y_parameter_name: The name of the parameter to plot on the y-axis. metric_name: The name of the metric to plot """ # TODO: Add a flag to specify whether or not to plot markers at the (x, y) # coordinates of arms (with hover text). This is fine to exlude for now because # this Analysis is only used in the InteractionPlot, but when we want to use it # a-la-carte we should add this feature. self.x_parameter_name = x_parameter_name self.y_parameter_name = y_parameter_name self.metric_name = metric_name
[docs] def compute( self, experiment: Optional[Experiment] = None, generation_strategy: Optional[GenerationStrategyInterface] = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ContourPlot requires an Experiment") if not isinstance(generation_strategy, GenerationStrategy): raise UserInputError("ContourPlot requires a GenerationStrategy") if generation_strategy.model is None: generation_strategy._fit_current_model(None) metric_name = self.metric_name or select_metric(experiment=experiment) df = _prepare_data( experiment=experiment, model=none_throws(generation_strategy.model), x_parameter_name=self.x_parameter_name, y_parameter_name=self.y_parameter_name, metric_name=metric_name, ) fig = _prepare_plot( df=df, x_parameter_name=self.x_parameter_name, y_parameter_name=self.y_parameter_name, metric_name=metric_name, log_x=is_axis_log_scale( parameter=experiment.search_space.parameters[self.x_parameter_name] ), log_y=is_axis_log_scale( parameter=experiment.search_space.parameters[self.y_parameter_name] ), ) return self._create_plotly_analysis_card( title=( f"{self.x_parameter_name}, {self.y_parameter_name} vs. {metric_name}" ), subtitle=( "2D contour of the surrogate model's predicted outcomes for " f"{metric_name}" ), level=AnalysisCardLevel.LOW, df=df, fig=fig, )
def _prepare_data( experiment: Experiment, model: ModelBridge, x_parameter_name: str, y_parameter_name: str, metric_name: str, ) -> pd.DataFrame: # Choose which parameter values to predict points for. xs = get_parameter_values( parameter=experiment.search_space.parameters[x_parameter_name], density=10 ) ys = get_parameter_values( parameter=experiment.search_space.parameters[y_parameter_name], density=10 ) # Construct observation features for each parameter value previously chosen by # fixing all other parameters to their status-quo value or mean. features = [ ObservationFeatures( parameters={ x_parameter_name: x, y_parameter_name: y, **{ parameter.name: select_fixed_value(parameter=parameter) for parameter in experiment.search_space.parameters.values() if not ( parameter.name == x_parameter_name or parameter.name == y_parameter_name ) }, } ) for x in xs for y in ys ] predictions = model.predict(observation_features=features) return pd.DataFrame.from_records( [ { x_parameter_name: features[i].parameters[x_parameter_name], y_parameter_name: features[i].parameters[y_parameter_name], f"{metric_name}_mean": predictions[0][metric_name][i], } for i in range(len(features)) ] ) def _prepare_plot( df: pd.DataFrame, x_parameter_name: str, y_parameter_name: str, metric_name: str, log_x: bool, log_y: bool, ) -> go.Figure: z_grid = df.pivot( index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean" ) fig = go.Figure( data=go.Contour( z=z_grid.values, x=z_grid.columns.values, y=z_grid.index.values, contours_coloring="heatmap", showscale=False, ), layout=go.Layout( xaxis_title=x_parameter_name, yaxis_title=y_parameter_name, ), ) # Set the x-axis scale to log if relevant if log_x: fig.update_xaxes( type="log", range=[ math.log10(df[x_parameter_name].min()), math.log10(df[x_parameter_name].max()), ], ) else: fig.update_xaxes(range=[df[x_parameter_name].min(), df[x_parameter_name].max()]) if log_y: fig.update_yaxes( type="log", range=[ math.log10(df[y_parameter_name].min()), math.log10(df[y_parameter_name].max()), ], ) else: fig.update_yaxes(range=[df[y_parameter_name].min(), df[y_parameter_name].max()]) return fig