lace.plot module
Plotting utilities.
- lace.plot.diagnostics(engine: Engine, name: str = 'score', log_x: bool = False) Figure
Plot state diagnostics.
- Parameters:
engine (lace.Engine) – The engine whose diagnostics to plot
name (str, optional) – The name of the diagnostic to plot
log_x (bool, optional) – If True, plot on a log x-axis
- Returns:
The figure handle
- Return type:
plotly.graph_objects.Figure
Examples
Plot the score over iterations for satellites data
>>> from lace.examples import Satellites >>> from lace.plot import diagnostics >>> diagnostics(Satellites(), log_x=True).show() {...}
- lace.plot.prediction_explanation(engine: Engine, target: int | str, given: dict[Union[str, int], Any], *, method: str | None = None, cmap=None)
Plot prediction explanations.
- Parameters:
engine (lace.Engine) – The source engine
target (str, int) – The target variable – the variable to predict
given (Dict[index, value], optional) – A dictionary mapping column indices/name to values, which specifies conditions on the observations.
method (str, optional (default: 'ablative-err')) –
The method to use for explanation: * ‘ablative-err’ (default): computes the different between p(y|X) and
p(x|X - xᵢ) for each predictor xᵢ in the given, X.
’ablative-dist’: computed the error between the predictions (argmax) of p(y|X) and p(x|X - xᵢ) for each predictor xᵢ in the given, X. Note that this method does not support categorical targets.
cmap (plotly color_continuous_scale (default: 'picnic')) – Argument forwarded to the color_continuous_scale argument of plotly.bar.
- Returns:
data (pandas.Series) – The column importances
fig (plotly.Figure) – The figure
Examples
>>> import polars as pl >>> from lace.examples import Satellites >>> from lace.plot import prediction_explanation >>> engine = Satellites()
Define a target
>>> target = 'Period_minutes'
We’ll use a row from the data
>>> row = engine[5, :].to_dicts()[0] >>> ix = row.pop('index') >>> _ = row.pop(target) >>> given = { k: v for k, v in row.items() if v is not None }
Plot the explanation using the ‘ablative-dist’ method
>>> data, fig = prediction_explanation( ... engine, ... target, ... given, ... method='ablative-dist' ... ) >>> fig.show() {...}
- lace.plot.prediction_uncertainty(engine: Engine, target: str | int, given: Dict[str | int, object] | None = None, xs: Series | Series | None = None, n_points: int = 1000, mass: float = 0.99)
Visualize prediction uncertainty.
- Parameters:
engine (Engine) – The Engine from which to predict
target (column index) – The column to predict
given (Dict[column index, value], optional) – Column -> Value dictionary describing observations. Note that columns can either be indices (int) or names (str)
xs (polars.Series or pandas.Series, optional) – The values over which to visualize uncertainty. If None (default), values will be computed manually. For categorical columns, the value map will be used; for continuous and count columns the values +/-
range_stds
standard deviations from the mean will be used.
Examples
Visualize uncertainty for a continuous target
>>> from lace.examples import Satellites >>> from lace.plot import prediction_uncertainty >>> satellites = Satellites() >>> fig = prediction_uncertainty( ... satellites, ... "Period_minutes", ... given={"Class_of_Orbit": "GEO"}, ... ) >>> fig.show() {...}
Narrow down the range for visualization
>>> import numpy as np >>> import polars as pl >>> fig = prediction_uncertainty( ... satellites, ... "Period_minutes", ... given={"Class_of_Orbit": "GEO"}, ... xs=pl.Series("Period_minutes", np.linspace(1350, 1500, 500)), ... ) >>> fig.show() {...}
Visualize uncertainty for a categorical target
>>> fig = prediction_uncertainty( ... satellites, ... "Class_of_Orbit", ... given={"Period_minutes": 1326.0}, ... ) >>> fig.show() {...}
- lace.plot.state(engine: Engine, state_ix: int, *, cmap: str | None = None, missing_color=None, cat_gap: float | int = 0.1, view_gap: float | int = 0.2, show_index: bool = True, show_columns: bool = True, min_height: int = 0, min_width: int = 0, aspect=None, ax=None)
Plot a Lace state.
View are sorted from largest (most columns) to smallest. Within views, columns are sorted from highest (left) to lowest total likelihood. Categories are sorted from largest (most rows) to smallest. Within categories, rows are sorted from highest (top) to lowest log likelihood.
- Parameters:
engine (Engine) – The engine containing the states to plot
state_ix (int) – The index of the state to plot
cmap (str, optional, default: gray_r) – The color map to use for present data
missing_color (optional, default: red) – The RGBA array representation ([float, float, float, float]) of the color to use to represent missing data
cat_gap (int or float, default: 0.1) – The vertical spacing (in cells if int or fraction of table height if float) between categories
view_gap (int or float, default: 0.2) – The horizontal spacing (in cells if int or fraction of table width if float) between views
show_index (bool, default: True) – If True (default), will show row names next to rows in each view
show_columns (bool, default: True) – If True (default), will show columns names above each column
min_height (int, default: 0) – The minimum height in cells of the state render. Padding will be added to the lower part of the image.
min_width (int (default: 0)) – The minimum width in cells of the state render. Padding will be added to the right of the image.
aspect ({'equal', 'auto'} or float or None, default: None) – matplotlib imshow aspect
ax (matplotlib.Axis, optional) – The axis on which to plot
Examples
Render an animals state
>>> import matplotlib.pyplot as plt >>> from lace.examples import Animals >>> from lace import plot >>> engine = Animals() >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") >>> ax = plt.gca() >>> plot.state(engine, 7, ax=ax) >>> _ = plt.axis("off") >>> plt.show()
Render a satellites State, which has continuous, categorial and missing data
>>> from lace.examples import Satellites >>> engine = Satellites() >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") >>> ax = plt.gca() >>> plot.state( ... engine, ... 1, ... view_gap=2, ... cat_gap=100, ... show_index=False, ... show_columns=False, ... ax=ax, ... cmap="YlGnBu", ... aspect="auto" ... ) >>> _ = plt.axis("off") >>> plt.show()