lace.plot module

Plotting utilities.

lace.plot.diagnostics(engine: Engine, name: str = 'score', log_x: bool = False) Figure

Plot state diagnostics.

Parameters:
  • engine (lace.Engine) – The engine whose diagnostics to plot

  • name (str, optional) – The name of the diagnostic to plot

  • log_x (bool, optional) – If True, plot on a log x-axis

Returns:

The figure handle

Return type:

plotly.graph_objects.Figure

Examples

Plot the score over iterations for satellites data

>>> from lace.examples import Satellites
>>> from lace.plot import diagnostics
>>> diagnostics(Satellites(), log_x=True).show()  
{...}
lace.plot.prediction_explanation(engine: Engine, target: int | str, given: dict[Union[str, int], Any], *, method: str | None = None, cmap=None)

Plot prediction explanations.

Parameters:
  • engine (lace.Engine) – The source engine

  • target (str, int) – The target variable – the variable to predict

  • given (Dict[index, value], optional) – A dictionary mapping column indices/name to values, which specifies conditions on the observations.

  • method (str, optional (default: 'ablative-err')) –

    The method to use for explanation: * ‘ablative-err’ (default): computes the different between p(y|X) and

    p(x|X - xᵢ) for each predictor xᵢ in the given, X.

    • ’ablative-dist’: computed the error between the predictions (argmax) of p(y|X) and p(x|X - xᵢ) for each predictor xᵢ in the given, X. Note that this method does not support categorical targets.

  • cmap (plotly color_continuous_scale (default: 'picnic')) – Argument forwarded to the color_continuous_scale argument of plotly.bar.

Returns:

  • data (pandas.Series) – The column importances

  • fig (plotly.Figure) – The figure

Examples

>>> import polars as pl
>>> from lace.examples import Satellites
>>> from lace.plot import prediction_explanation
>>> engine = Satellites()

Define a target

>>> target = 'Period_minutes'

We’ll use a row from the data

>>> row = engine[5, :].to_dicts()[0]
>>> ix = row.pop('index')
>>> _ = row.pop(target)
>>> given = { k: v for k, v in row.items() if v is not None }

Plot the explanation using the ‘ablative-dist’ method

>>> data, fig = prediction_explanation(
...     engine,
...     target,
...     given,
...     method='ablative-dist'
... )
>>> fig.show()  
{...}
lace.plot.prediction_uncertainty(engine: Engine, target: str | int, given: Dict[str | int, object] | None = None, xs: Series | Series | None = None, n_points: int = 1000, mass: float = 0.99)

Visualize prediction uncertainty.

Parameters:
  • engine (Engine) – The Engine from which to predict

  • target (column index) – The column to predict

  • given (Dict[column index, value], optional) – Column -> Value dictionary describing observations. Note that columns can either be indices (int) or names (str)

  • xs (polars.Series or pandas.Series, optional) – The values over which to visualize uncertainty. If None (default), values will be computed manually. For categorical columns, the value map will be used; for continuous and count columns the values +/- range_stds standard deviations from the mean will be used.

Examples

Visualize uncertainty for a continuous target

>>> from lace.examples import Satellites
>>> from lace.plot import prediction_uncertainty
>>> satellites = Satellites()
>>> fig = prediction_uncertainty(
...     satellites,
...     "Period_minutes",
...     given={"Class_of_Orbit": "GEO"},
... )
>>> fig.show()  
{...}

Narrow down the range for visualization

>>> import numpy as np
>>> import polars as pl
>>> fig = prediction_uncertainty(
...     satellites,
...     "Period_minutes",
...     given={"Class_of_Orbit": "GEO"},
...     xs=pl.Series("Period_minutes", np.linspace(1350, 1500, 500)),
... )
>>> fig.show()  
{...}

Visualize uncertainty for a categorical target

>>> fig = prediction_uncertainty(
...     satellites,
...     "Class_of_Orbit",
...     given={"Period_minutes": 1326.0},
... )
>>> fig.show()  
{...}
lace.plot.state(engine: Engine, state_ix: int, *, cmap: str | None = None, missing_color=None, cat_gap: float | int = 0.1, view_gap: float | int = 0.2, show_index: bool = True, show_columns: bool = True, min_height: int = 0, min_width: int = 0, aspect=None, ax=None)

Plot a Lace state.

View are sorted from largest (most columns) to smallest. Within views, columns are sorted from highest (left) to lowest total likelihood. Categories are sorted from largest (most rows) to smallest. Within categories, rows are sorted from highest (top) to lowest log likelihood.

Parameters:
  • engine (Engine) – The engine containing the states to plot

  • state_ix (int) – The index of the state to plot

  • cmap (str, optional, default: gray_r) – The color map to use for present data

  • missing_color (optional, default: red) – The RGBA array representation ([float, float, float, float]) of the color to use to represent missing data

  • cat_gap (int or float, default: 0.1) – The vertical spacing (in cells if int or fraction of table height if float) between categories

  • view_gap (int or float, default: 0.2) – The horizontal spacing (in cells if int or fraction of table width if float) between views

  • show_index (bool, default: True) – If True (default), will show row names next to rows in each view

  • show_columns (bool, default: True) – If True (default), will show columns names above each column

  • min_height (int, default: 0) – The minimum height in cells of the state render. Padding will be added to the lower part of the image.

  • min_width (int (default: 0)) – The minimum width in cells of the state render. Padding will be added to the right of the image.

  • aspect ({'equal', 'auto'} or float or None, default: None) – matplotlib imshow aspect

  • ax (matplotlib.Axis, optional) – The axis on which to plot

Examples

Render an animals state

>>> import matplotlib.pyplot as plt
>>> from lace.examples import Animals
>>> from lace import plot
>>> engine = Animals()
>>> fig = plt.figure(tight_layout=True, facecolor="#00000000")
>>> ax = plt.gca()
>>> plot.state(engine, 7, ax=ax)
>>> _ = plt.axis("off")
>>> plt.show()

Render a satellites State, which has continuous, categorial and missing data

>>> from lace.examples import Satellites
>>> engine = Satellites()
>>> fig = plt.figure(tight_layout=True, facecolor="#00000000")
>>> ax = plt.gca()
>>> plot.state(
...    engine,
...    1,
...    view_gap=2,
...    cat_gap=100,
...    show_index=False,
...    show_columns=False,
...    ax=ax,
...    cmap="YlGnBu",
...    aspect="auto"
... )
>>> _ = plt.axis("off")
>>> plt.show()