# $$_ Lines starting with # $$_* autogenerated by jup_mini. Do not modify these
# $$_markdown
# # Interactive Plot
# $$_end_markdown
# $$_markdown
# # Description
#
# Allows interactive plotting of multidimensional data
# $$_end_markdown
# $$_code
# $$_ %%checkall
from __future__ import annotations
import os
import sys
import math
import colorsys
from dataclasses import dataclass
import unittest
import doctest
import pandas as pd
import numpy as np
from IPython.display import display, clear_output
from ipywidgets import widgets
import plotly
import plotly.callbacks
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Callable, Any
from collections.abc import Sequence
import traitlets
from pyqstrat.pq_utils import bootstrap_ci, get_child_logger
ROOT_DIR = os.path.join(sys.path[0])
sys.path.insert(1, ROOT_DIR)
_logger = get_child_logger(__name__)
DEFAULT_PLOTLY_COLORS = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)',
'rgb(44, 160, 44)', 'rgb(214, 39, 40)',
'rgb(148, 103, 189)', 'rgb(140, 86, 75)',
'rgb(227, 119, 194)', 'rgb(127, 127, 127)',
'rgb(188, 189, 34)', 'rgb(23, 190, 207)']
LineDataType = tuple[str, pd.DataFrame, dict[Any, pd.DataFrame]]
DimensionFilterType = Callable[[
pd.DataFrame,
str,
list[tuple[str, Any]]],
np.ndarray]
DataFilterType = Callable[[
pd.DataFrame,
list[tuple[str, Any]]],
pd.DataFrame]
StatFuncType = Callable[[pd.DataFrame, str, str, str], list[LineDataType]]
DetailDisplayType = Callable[[
widgets.Widget,
pd.DataFrame,
bool],
None]
PlotFuncType = Callable[[str, str, list[LineDataType]], list[widgets.Widget]]
DataFrameTransformFuncType = Callable[[pd.DataFrame], pd.DataFrame]
SeriesTransformFuncType = Callable[[pd.Series], pd.Series]
DisplayFormFuncType = Callable[[Sequence[widgets.Widget], bool], None]
UpdateFormFuncType = Callable[[int], None]
CreateSelectionWidgetsFunctype = Callable[[dict[str, str], dict[str, str], UpdateFormFuncType], dict[str, Any]]
[docs]
def percentile_buckets(a: np.ndarray, n=10) -> np.ndarray:
'''
>>> np.random.seed(0)
>>> a = np.random.uniform(size=10000)
>>> assert np.allclose(np.unique(percentile_buckets(a)), np.arange(0.05, 1, 0.1), atol=0.01)
'''
if not len(a): return np.empty(0)
pctiles = np.arange(0, 100, int(round(100 / n)))
buckets = np.nanpercentile(a, pctiles)
conditions: list[Any] = []
for i, bucket in enumerate(buckets[:-1]):
if buckets[i] == buckets[i + 1]:
conditions.append(a == buckets[i])
else:
conditions.append((a >= buckets[i]) & (a < buckets[i + 1]))
conditions.append((a >= buckets[-1]))
b = [np.mean(a[cond]) for cond in conditions]
ret = np.select(conditions, b)
return ret
[docs]
def simple_dimension_filter(data: pd.DataFrame, dim_name: str, selected_values: list[tuple[str, Any]]) -> np.ndarray:
'''
Produces a list to put into a dropdown for selecting a dimension value
'''
mask = np.full(len(data), True)
for name, value in selected_values:
if value == 'All': continue
mask &= (data[name] == value)
values = np.unique(data[mask][dim_name].values) # will sort values before returning them
return ['All'] + values.tolist()
[docs]
def simple_data_filter(data: pd.DataFrame, selected_values: list[tuple[str, Any]]) -> pd.DataFrame:
'''
Filters a dataframe based on the selected values
'''
mask = np.full(len(data), True)
for name, value in selected_values:
if value == 'All': continue
mask &= (data[name] == value)
return data[mask]
[docs]
class MeanWithCI:
'''
Computes mean (or median) and optionally confidence intervals for plotting
'''
[docs]
def __init__(self, mean_func: Callable[[np.ndarray], np.ndarray] = np.nanmean, ci_level: int = 0) -> None: # type: ignore
'''
Args:
mean: The function to compute ci for
ci_level: Set to 0 for no confidence intervals, or the level you want.
For example, set to 95 to compute a 95% confidence interval. Default 0
'''
self.mean_func = mean_func
self.ci_level = ci_level
[docs]
def __call__(self, filtered_data: pd.DataFrame, xcol: str, ycol: str, zcol: str) -> list[LineDataType]:
'''
For each unique value of x and z, compute mean (and optionally ci) of y.
Return:
x, y data for plotting lines of the mean of y versus x for each z and the data used to compute the mean
'''
zvals = np.unique(filtered_data[zcol])
cols = [col for col in filtered_data.columns if col not in [xcol, ycol, zcol]]
df = filtered_data[[xcol, ycol, zcol] + cols]
ret = []
columns = [xcol, ycol] if not self.ci_level else [xcol, ycol, f'ci_d_{self.ci_level}', f'ci_u_{self.ci_level}']
for zvalue in zvals:
df = filtered_data[filtered_data[zcol] == zvalue]
plt_data: list[Any] = []
for x, yseries in df.groupby(xcol)[ycol]:
y = yseries.values
if not len(y): raise Exception(y)
mean = self.mean_func(y)
if self.ci_level:
ci_up, ci_down = bootstrap_ci(y, ci_level=self.ci_level / 100)
plt_data.append((x, mean, ci_down, ci_up))
else:
plt_data.append((x, mean))
line = pd.DataFrame.from_records(plt_data, columns=columns)
ret.append((zvalue, line, df))
return ret
[docs]
class SimpleDetailTable:
'''
Displays a pandas DataFrame under a plot that contains the data used to compute a statistic of y for each x, y pair
'''
[docs]
def __init__(self,
colnames: list[str] | None = None,
float_format: str = '{:.4g}',
min_rows: int = 100,
copy_to_clipboard: bool = True) -> None:
'''
Args:
colnames: list of column names to display. If None we display all columns. Default None
float_format: Format for each floating point column. Default {:.4g}
min_rows: Do not truncate the display of the table before this many rows. Default 100
copy_to_clipboard: If set, we copy the dataframe to the clipboard. On linux, you must install xclip for this to work
'''
self.colnames = colnames
self.float_format = float_format
self.min_rows = min_rows
self.copy_to_clipboard = True
[docs]
def __call__(self, detail_widget: widgets.Widget, data: pd.DataFrame, debug=False) -> None:
'''
Args:
detail_widget: The widget to display the data in
data: The dataframe to display
'''
if self.float_format:
orig_float_format = pd.options.display.float_format
pd.options.display.float_format = (self.float_format).format
if self.min_rows:
orig_min_rows = pd.options.display.min_rows
pd.options.display.min_rows = self.min_rows
with detail_widget:
if not debug: clear_output()
if self.colnames: data = data[self.colnames]
data = data.reset_index(drop=True)
display(data)
if self.copy_to_clipboard: data.to_clipboard(index=False)
if self.float_format: pd.options.display.float_format = orig_float_format
if self.min_rows: pd.options.display.min_rows = orig_min_rows
[docs]
def create_selection_dropdowns(dims: dict[str, str], labels: dict[str, str], update_form_func: UpdateFormFuncType) -> dict[str, Any]:
'''
Create a list of selection widgets
'''
selection_widgets: dict[str, widgets.Widget] = {}
for name in dims.keys():
label = labels[name] if name in labels else name
widget = widgets.Dropdown(description=label, style={'description_width': 'initial'})
selection_widgets[name] = widget
for widget in selection_widgets.values():
widget.observe(lambda x: on_widgets_updated(x, update_form_func, selection_widgets), names='value')
return selection_widgets
[docs]
@dataclass
class LineConfig:
color: str | None = None
thickness: float = math.nan
secondary_y: bool = False
marker_mode: str = 'lines+markers'
show_detail: bool = True
def _plotly_color_to_rgb(plotly_color: str) -> tuple[int, int, int]:
'''
Convert plotly color which is a string into r, g, b values
>>> assert _plotly_color_to_rgb('rgb(31, 119, 180)') == (31, 119, 180)
'''
plotly_color = plotly_color.replace('rgb(', '').replace(')', '')
s = plotly_color.split(',')
r, g, b = int(s[0]), int(s[1]), int(s[2])
return r, g, b
def _lighten_color(r: int, g: int, b: int) -> tuple[int, int, int]:
'''
Lighten color so we can show confidence intervals in a lighter shade than the line itself
We convert to hls and increase lightness and decrease saturation
>>> assert _lighten_color(31, 119, 180) == (102, 168, 214)
'''
hls = colorsys.rgb_to_hls(r, g, b)
light_hls = (hls[0], hls[1] * 1.5, hls[2] * 0.5)
rgb = colorsys.hls_to_rgb(*light_hls)
rgb = (int(round(rgb[0])), int(round(rgb[1])), int(round(rgb[2])))
return rgb
[docs]
def foo(name, old, new):
import datetime
print(f'hello: {datetime.datetime.now()} {name} {old} {new}')
[docs]
class LineGraphWithDetailDisplay:
'''
Draws line graphs and also includes a detail pane.
When you click on a point on the line graph, the detail pane shows the data used to compute that point.
'''
[docs]
def __init__(self,
display_detail_func: DetailDisplayType = SimpleDetailTable(),
line_configs: dict[str, LineConfig] = {},
title: str | None = None,
hovertemplate: str | None = None,
debug=False) -> None:
'''
Args:
display_detail_func: A function that displays the data on the detail pane. Default SimpleDetailTable
line_configs: Configuration of each line. The key in this dict is the zvalue for that line. Default {}
title: Title of the graph. Default None
hovertemplate: What to display when we hover over a point on the graph. See plotly hovertemplate
'''
self.display_detail_func = display_detail_func
self.line_configs = line_configs
self.title = title
self.hovertemplate = hovertemplate
self.debug = debug
self.default_line_config = LineConfig()
self.detail_data: dict[Any, pd.DataFrame] = {}
self.xcol = ''
self.zvalues: dict[int, Any] = {} # trace index by zvalue
[docs]
def __call__(self, xaxis_title: str, yaxis_title: str, line_data: list[LineDataType]) -> list[widgets.Widget]:
'''
Draw the plot and also set it up so if you click on a point, we display the data used to compute that point.
Args:
line_data: The zvalue, plot data, and detail data for each line to draw. The plot data must have
x as the first column and y as the second column
Return:
A list of widgets to draw. In this case, a figure widget and a output widget which contains the detail display
'''
if not len(line_data): return []
self.detail_data.clear()
secondary_y = any([lc.secondary_y for lc in self.line_configs.values()])
fig_widget = go.FigureWidget(make_subplots(specs=[[{"secondary_y": secondary_y}]]))
# fig_widget.on_trait_change(foo, '_js2py_restyle')
# fig_widget.on_trait_change()
detail_widget = widgets.Output()
trace_num = 0
for line_num, (zvalue, line_df, _detail_data) in enumerate(line_data):
x = line_df.iloc[:, 0].values
self.xcol = line_df.columns[0]
y = line_df.iloc[:, 1].values
self.detail_data[zvalue] = _detail_data
line_config = self.line_configs[zvalue] if zvalue in self.line_configs else self.default_line_config
marker_mode = line_config.marker_mode
color = line_config.color if line_config.color else DEFAULT_PLOTLY_COLORS[line_num]
hovertemplate = self.hovertemplate
customdata = None
if hovertemplate is None:
unique, counts = np.unique(_detail_data[self.xcol].values, return_counts=True)
customdata = counts[np.searchsorted(unique, x)]
hovertemplate = 'N: %{customdata}' # number of entries used to compute each x
hovertemplate += f' Series: {zvalue} {xaxis_title}: ' + '%{x:.4g} ' + f'{yaxis_title}: ' + '%{y:.4g}'
trace = go.Scatter(
x=x,
y=y,
customdata=customdata,
mode=marker_mode,
name=str(zvalue),
line=dict(color=color),
hovertemplate=hovertemplate
)
self.zvalues[trace_num] = zvalue
fig_widget.add_trace(trace, secondary_y=line_config.secondary_y)
if line_config.show_detail:
fig_widget.data[trace_num].on_click(self._on_graph_click, append=True)
trace_num += 1
if len(line_df.columns) > 2: # x, y, ci up and ci down
fill_color = _plotly_color_to_rgb(color)
fill_color = _lighten_color(*fill_color)
# we set transparency to 0.5 so we can see lines under the ci fill
fill_color_str = f'rgba({fill_color[0]},{fill_color[1]},{fill_color[2]},0.5)'
ci_down = line_df.iloc[:, 2].values
ci_up = line_df.iloc[:, 3].values
ci_trace = go.Scatter(
x=np.concatenate([x, x[::-1]]), # x, then x reversed
y=np.concatenate([ci_up, ci_down[::-1]]), # upper, then lower reversed
fill='toself',
fillcolor=fill_color_str,
line=dict(color='rgba(255,255,255,0)'),
hoverinfo="skip",
showlegend=False)
fig_widget.add_trace(ci_trace, secondary_y=line_config.secondary_y)
trace_num += 1
fig_widget.update_layout(title=self.title, xaxis_title=xaxis_title)
fig_widget.update_layout(yaxis_title=yaxis_title)
if secondary_y:
fig_widget.update_yaxes(title_text=yaxis_title, secondary_y=True)
self.fig_widget = fig_widget
self.detail_widget = detail_widget
self.line_data = line_data
return [self.fig_widget, self.detail_widget]
def _on_graph_click(self,
trace: go.Trace,
points: plotly.callbacks.Points,
selector: plotly.callbacks.InputDeviceState) -> None:
'''
Callback called by plotly when you click on a point on the graph.
When you click on a point, we display the dataframe with the data we used to compute that point.
'''
if not len(points.xs): return
trace_idx = points.trace_index
zvalue = self.zvalues[trace_idx]
_detail_data = self.detail_data[zvalue]
_detail_data = _detail_data[_detail_data[self.xcol].values == points.xs[0]]
self.display_detail_func(self.detail_widget, _detail_data, self.debug)
[docs]
class InteractivePlot:
'''
Creates a multidimensional interactive plot off a dataframe.
'''
[docs]
def __init__(self,
data: pd.DataFrame,
labels: dict[str, str] | None = None,
transform_func: DataFrameTransformFuncType = SimpleTransform(),
create_selection_widgets_func: CreateSelectionWidgetsFunctype = create_selection_dropdowns,
dim_filter_func: DimensionFilterType = simple_dimension_filter,
data_filter_func: DataFilterType = simple_data_filter,
stat_func: StatFuncType = MeanWithCI(),
plot_func: PlotFuncType = LineGraphWithDetailDisplay(),
display_form_func: DisplayFormFuncType = display_form,
debug=False) -> None:
'''
Args:
data: The pandas dataframe to use for plotting
labels: A dict where column names from the dataframe are mapped to user friendly labels. For any column names
not found as keys in this dict, we use the column name as the label. Default None
dim_filter_func: A function that generates the values of a dimension based on other dimensions. For example, if
the user chooses "Put Option" in a put/call dropdown, the valid strikes could change in a Strike
dropdown that follows. Default simple_dimension_filter
data_filter_func: A function that filters the data to plot. For example, if the user chooses "Put Option" in a put/call dropdown,
we could filter the dataframe to only include put options. Default simple_data_filter
stat_func: Once we have filtered the data, we may need to plot some statistics, such as mean and confidence intervals.
In this function, we compute these statistics. Default MeanWithCI()
plot_func: A function that plots the data. This could also display detail data used to compute the statistics associated
with each data point.
display_form_func: A function that displays the form given a list of plotly widgets (including the graph widget)
debug: Dont clear forms if this is true so we can see print output
'''
self.data = data
self.transform_func = transform_func
self.create_selection_widgets_func = create_selection_widgets_func
if labels is None: labels = {}
self.labels = labels
self.dim_filter_func = dim_filter_func
self.data_filter_func = data_filter_func
self.stat_func = stat_func
self.plot_func = plot_func
self.display_form_func = display_form_func
self.selection_widgets: dict[str, Any] = {}
self.debug = debug
[docs]
def create_pivot(self, xcol: str, ycol: str, zcol: str, dimensions: dict[str, Any]) -> None:
'''
Create the initial pivot
Args:
xcol: Column name to use as the x axis in the DataFrame
ycol: Column name to use as the y axis in the DataFrame
zcol: Column name to use for z-values. Each zvalue can be used for a different trace within this plot. For example, a column
called "option_type" could contain the values "American", "European", "Bermudan" and we could plot the data for each type
in a separate trace
dimensions: The column names used for filter dimensions. For example, we may want to filter by days to expiration and put/call
The key the column name and the value is the initial value for that column. For example, in a
dropdown for Put/Call we may want "Put" to be the initial value set in the dropdown. Set to None if you
don't care what initial value is chosen.
'''
self.xlabel = xcol if xcol not in self.labels else self.labels[xcol]
self.ylabel = ycol if ycol not in self.labels else self.labels[ycol]
self.zcol = zcol
self.xcol = xcol
self.ycol = ycol
self.selection_widgets = self.create_selection_widgets_func(dimensions, self.labels, self.update)
self.update()
[docs]
def update(self, owner_idx: int = -1) -> None:
'''
Redraw the form using the values of all widgets above and including the one with index owner_idx.
If owner_idx is -1, we redraw everything.
'''
select_conditions = [(name, widget.value) for name, widget in self.selection_widgets.items()]
if owner_idx == -1:
dim_select_conditions = []
else:
dim_select_conditions = select_conditions[:owner_idx + 1] # for selecting lower widget options, use value of widgets above
for name in list(self.selection_widgets.keys())[owner_idx + 1:]:
widget = self.selection_widgets[name]
widget_options = self.dim_filter_func(self.data, name, dim_select_conditions)
_logger.debug(f'setting values: {widget_options} on widget: {name}')
widget.options = self.dim_filter_func(self.data, name, dim_select_conditions)
if owner_idx == -1: return
filtered_data = self.data_filter_func(self.data, select_conditions)
transformed_data = self.transform_func(filtered_data)
lines = self.stat_func(transformed_data, self.xcol, self.ycol, self.zcol)
plot_widgets = self.plot_func(self.xlabel, self.ylabel, lines)
self.display_form_func(list(self.selection_widgets.values()) + plot_widgets, self.debug)
# unit tests
[docs]
class TestInteractivePlot(unittest.TestCase):
[docs]
def test_interactive_plot(self):
np.random.seed(0)
size = 1000
dte = np.random.randint(5, 10, size)
put_call = np.random.choice(['put', 'call'], size)
year = np.random.choice([2018, 2019, 2020, 2021], size)
delta = np.random.uniform(0, 0.5, size)
delta = np.where(put_call == 'call', delta, -delta)
premium = np.abs(delta * 10) * dte + np.random.normal(size=size) * dte / 10
data = pd.DataFrame({'dte': dte, 'put_call': put_call, 'year': year, 'delta': delta, 'premium': premium})
labels = {'premium': 'Premium $', 'year': 'Year', 'dte': 'Days to Expiry', 'delta_rnd': 'Delta'}
secy_line_config = LineConfig(secondary_y=True)
ip = InteractivePlot(data,
labels,
transform_func=self.transform,
stat_func=MeanWithCI(ci_level=95),
plot_func=LineGraphWithDetailDisplay(line_configs={'put': secy_line_config}), debug=True)
ip.create_pivot('delta_rnd', 'premium', 'put_call', dimensions={'year': 2018, 'dte': None})
if __name__ == '__main__':
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS)
unittest.main(argv=['first-arg-is-ignored'], exit=False)
print('done')
# $$_end_code