"""Drawing functions that accept only Layer 1 objects."""
from typing import TYPE_CHECKING
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from matplotlib.axes import Axes
from plotsmith.objects.views import (
BandView,
BarView,
BoxView,
DumbbellView,
FigureSpec,
HeatmapView,
HistogramView,
LollipopView,
MetricView,
RangeView,
ScatterView,
SeriesView,
SlopeView,
ViolinView,
WaffleView,
WaterfallView,
)
[docs]
def minimal_axes(ax: "Axes") -> None:
"""Apply minimalist axes styling.
Removes top and right spines, keeps left and bottom spines.
Uses serif font.
Args:
ax: Matplotlib axes to style.
"""
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(True)
ax.spines["bottom"].set_visible(True)
# Set serif font
for item in [ax.title, ax.xaxis.label, ax.yaxis.label]:
item.set_fontfamily("serif")
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontfamily("serif")
[docs]
def draw_series(ax: "Axes", view: SeriesView) -> None:
"""Draw a time series on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: SeriesView containing the data to plot.
"""
kwargs: dict[str, str | float] = {}
if view.marker is not None:
kwargs["marker"] = view.marker
if view.linewidth is not None:
kwargs["linewidth"] = view.linewidth
if view.alpha is not None:
kwargs["alpha"] = view.alpha
ax.plot(view.x, view.y, label=view.label, **kwargs) # type: ignore[arg-type]
[docs]
def draw_band(ax: "Axes", view: BandView) -> None:
"""Draw a confidence band or shaded region on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: BandView containing the data to plot.
"""
kwargs: dict[str, float] = {}
if view.alpha is not None:
kwargs["alpha"] = view.alpha
ax.fill_between(view.x, view.y_lower, view.y_upper, label=view.label, **kwargs) # type: ignore[arg-type]
[docs]
def draw_scatter(ax: "Axes", view: ScatterView) -> None:
"""Draw a scatter plot on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: ScatterView containing the data to plot.
"""
kwargs: dict[str, str | float | np.ndarray] = {}
if view.marker is not None:
kwargs["marker"] = view.marker
if view.alpha is not None:
kwargs["alpha"] = view.alpha
if view.s is not None:
kwargs["s"] = view.s
if view.c is not None:
kwargs["c"] = view.c
ax.scatter(view.x, view.y, label=view.label, **kwargs) # type: ignore[arg-type]
[docs]
def draw_histogram(ax: "Axes", view: HistogramView) -> None:
"""Draw a histogram on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: HistogramView containing the data to plot.
"""
kwargs: dict[str, str | float | int | np.ndarray] = {}
if view.bins is not None:
kwargs["bins"] = view.bins
if view.color is not None:
kwargs["color"] = view.color
if view.edgecolor is not None:
kwargs["edgecolor"] = view.edgecolor
if view.alpha is not None:
kwargs["alpha"] = view.alpha
if view.label is not None:
kwargs["label"] = view.label
ax.hist(view.values, **kwargs) # type: ignore[arg-type]
[docs]
def draw_bar(ax: "Axes", view: BarView) -> None:
"""Draw a bar chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: BarView containing the data to plot.
"""
kwargs: dict[str, str | float] = {}
if view.color is not None:
kwargs["color"] = view.color
if view.edgecolor is not None:
kwargs["edgecolor"] = view.edgecolor
if view.alpha is not None:
kwargs["alpha"] = view.alpha
if view.label is not None:
kwargs["label"] = view.label
if view.horizontal:
ax.barh(view.x, view.height, **kwargs) # type: ignore[arg-type]
else:
ax.bar(view.x, view.height, **kwargs) # type: ignore[arg-type]
[docs]
def draw_heatmap(ax: "Axes", view: HeatmapView) -> None:
"""Draw a heatmap on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: HeatmapView containing the data to plot.
"""
kwargs: dict[str, str | float] = {}
if view.cmap is not None:
kwargs["cmap"] = view.cmap
if view.vmin is not None:
kwargs["vmin"] = view.vmin
if view.vmax is not None:
kwargs["vmax"] = view.vmax
im = ax.imshow(view.data, aspect="auto", **kwargs) # type: ignore[arg-type]
# Set labels if provided
if view.x_labels is not None:
ax.set_xticks(range(len(view.x_labels)))
ax.set_xticklabels(view.x_labels)
if view.y_labels is not None:
ax.set_yticks(range(len(view.y_labels)))
ax.set_yticklabels(view.y_labels)
# Add annotations if requested - vectorized iteration
if view.annotate:
fmt = view.fmt or ".2f"
rows, cols = view.data.shape
for i, j in np.ndindex(rows, cols):
ax.text(
j,
i,
f"{view.data[i, j]:{fmt}}",
ha="center",
va="center",
color="black",
fontweight="bold",
)
# Add colorbar
plt.colorbar(im, ax=ax)
[docs]
def apply_axes_style(ax: "Axes", spec: FigureSpec) -> None:
"""Apply figure specification to axes.
Args:
ax: Matplotlib axes to style.
spec: FigureSpec containing styling information.
"""
if spec.title is not None:
ax.set_title(spec.title)
if spec.xlabel is not None:
ax.set_xlabel(spec.xlabel)
if spec.ylabel is not None:
ax.set_ylabel(spec.ylabel)
if spec.xlim is not None:
ax.set_xlim(spec.xlim)
if spec.ylim is not None:
ax.set_ylim(spec.ylim)
[docs]
def tidy_axes(ax: "Axes") -> None:
"""Clean up axes spines and ticks with minimalist styling.
Removes top and right spines, colors remaining spines gray,
and adjusts tick parameters.
Args:
ax: Matplotlib axes to style.
"""
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_color("0.35")
ax.spines["bottom"].set_color("0.35")
ax.spines["left"].set_linewidth(0.8)
ax.spines["bottom"].set_linewidth(0.8)
ax.tick_params(colors="0.25", width=0.8, length=3)
[docs]
def force_bar_zero(ax: "Axes") -> None:
"""Force bar chart y-axis to start at zero, preserving honest scale.
Args:
ax: Matplotlib axes to adjust.
"""
lo, hi = ax.get_ylim()
if hi <= 0:
hi = 1.0
ax.set_ylim(bottom=0, top=hi)
[docs]
def add_range_frame(ax: "Axes", x_data=None, y_data=None) -> None:
"""Shorten axes spines so they span only the data range.
Args:
ax: Matplotlib axes to adjust.
x_data: Optional iterable of x data values.
y_data: Optional iterable of y data values.
"""
if x_data is not None:
x_arr = np.asarray(x_data)
x_min, x_max = float(x_arr.min()), float(x_arr.max())
ax.spines["bottom"].set_bounds(x_min, x_max)
if y_data is not None:
y_arr = np.asarray(y_data)
y_min, y_max = float(y_arr.min()), float(y_arr.max())
ax.spines["left"].set_bounds(y_min, y_max)
[docs]
def style_line_plot(ax: "Axes", *, emphasize_last: bool = False) -> None:
"""Apply a restrained hierarchy of line styles to an axes.
Args:
ax: Matplotlib axes to style.
emphasize_last: If True, emphasize the last line with thicker width.
"""
tidy_axes(ax)
lines = ax.get_lines()
line_colors = ["#000000", "#404040", "#808080"] # Black, dark gray, gray
line_styles = ["-", "--", ":"]
for i, line in enumerate(lines):
color = line_colors[i % len(line_colors)]
style = line_styles[i % len(line_styles)]
line.set_color(color)
line.set_linestyle(style)
line.set_linewidth(1.2)
if emphasize_last and i == len(lines) - 1:
line.set_linewidth(2.0)
line.set_color("#000000")
[docs]
def style_scatter_plot(ax: "Axes") -> None:
"""Style scatter collections with neutral points.
Args:
ax: Matplotlib axes to style.
"""
tidy_axes(ax)
for collection in ax.collections:
collection.set_edgecolors("#000000") # type: ignore[attr-defined]
collection.set_facecolors("#FFFFFF") # type: ignore[attr-defined]
collection.set_linewidths(1.0) # type: ignore[attr-defined]
[docs]
def style_bar_plot(ax: "Axes", *, horizontal: bool = False) -> None:
"""Style bar charts with alternating light fills and clean edges.
Args:
ax: Matplotlib axes to style.
horizontal: If True, style for horizontal bars.
"""
tidy_axes(ax)
for i, patch in enumerate(ax.patches):
patch.set_facecolor("#FFFFFF" if i % 2 == 0 else "#E8E8E8")
patch.set_edgecolor("#000000")
patch.set_linewidth(1.0)
if horizontal:
ax.spines["left"].set_visible(False)
ax.tick_params(left=False)
# Accent color for emphasis
ACCENT = "#d62728" # muted red, readable in print
[docs]
def direct_label(
ax: "Axes",
x: float,
y: float,
text: str,
*,
dx: float = 0.0,
dy: float = 0.0,
use_accent: bool = False,
ha: str = "left",
va: str = "center",
**kwargs,
) -> None:
"""Place a direct label near a data point.
Typical use is to label the last point of a line instead of using a legend.
Args:
ax: Matplotlib axes to label.
x: X coordinate of the point.
y: Y coordinate of the point.
text: Label text.
dx: X offset for label position.
dy: Y offset for label position.
use_accent: If True, use accent color.
ha: Horizontal alignment.
va: Vertical alignment.
**kwargs: Additional arguments passed to ax.text().
"""
color = ACCENT if use_accent else "0.1"
ax.text(
x + dx,
y + dy,
text,
ha=ha,
va=va,
color=color,
fontsize=plt.rcParams["font.size"],
**kwargs,
)
[docs]
def note(ax: "Axes", x: float, y: float, text: str, **kwargs) -> None:
"""Attach a short note with a simple arrow, avoiding legends.
Args:
ax: Matplotlib axes to annotate.
x: X coordinate of the point.
y: Y coordinate of the point.
text: Note text.
**kwargs: Additional arguments passed to ax.annotate().
"""
ax.annotate(
text,
xy=(x, y),
xytext=(10, 10),
textcoords="offset points",
arrowprops={"arrowstyle": "-", "color": "0.35", "linewidth": 0.8},
color="0.15",
**kwargs,
)
[docs]
def emphasize_last(
ax: "Axes", x: float, y: float, *, size: float = 30.0, **kwargs
) -> None:
"""Emphasize a final point in a series using the accent color.
Args:
ax: Matplotlib axes to modify.
x: X coordinate of the point.
y: Y coordinate of the point.
size: Marker size.
**kwargs: Additional arguments passed to ax.scatter().
"""
ax.scatter([x], [y], s=size, color=ACCENT, zorder=3, **kwargs)
[docs]
def accent_point(
ax: "Axes",
x: float,
y: float,
*,
label: str | None = None,
color: str | None = None,
size: float = 30.0,
zorder: int = 3,
**kwargs,
) -> None:
"""Highlight a single point with the accent color.
Optionally add a short text label offset slightly from the point.
Args:
ax: Matplotlib axes to modify.
x: X coordinate of the point.
y: Y coordinate of the point.
label: Optional label text.
color: Optional color override (defaults to accent color).
size: Marker size.
zorder: Z-order for the marker.
**kwargs: Additional arguments passed to ax.scatter().
"""
c = color or ACCENT
ax.scatter([x], [y], s=size, color=c, zorder=zorder, **kwargs)
if label:
ax.text(
x,
y,
label,
ha="left",
va="bottom",
color=c,
fontsize=plt.rcParams["font.size"],
)
[docs]
def event_line(
ax: "Axes",
x: float,
*,
text: str | None = None,
y_text: float = 0.9,
color: str | None = None,
linewidth: float = 1.0,
linestyle: str = "--",
**kwargs,
) -> None:
"""Draw a vertical event marker with optional label.
Args:
ax: Matplotlib axes to modify.
x: X coordinate for the vertical line.
text: Optional label text.
y_text: Y position for text (0-1 is fraction of y-span, otherwise data coords).
color: Optional color override (defaults to accent color).
linewidth: Line width.
linestyle: Line style.
**kwargs: Additional arguments passed to ax.axvline().
"""
c = color or ACCENT
ax.axvline(x, color=c, linewidth=linewidth, linestyle=linestyle, **kwargs)
if text is not None:
y0, y1 = ax.get_ylim()
if 0.0 <= y_text <= 1.0:
y = y0 + y_text * (y1 - y0)
else:
y = y_text
ax.text(
x,
y,
text,
ha="left",
va="center",
color=c,
fontsize=plt.rcParams["font.size"],
)
def draw_waterfall(ax: "Axes", view: WaterfallView) -> None:
"""Draw a waterfall chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: WaterfallView containing the data to plot.
"""
categories = np.asarray(view.categories)
values = np.asarray(view.values)
n = len(categories)
# Calculate cumulative positions and bar positions/heights
x_pos = np.arange(n)
bottoms = np.zeros(n)
heights = np.zeros(n)
cumulative = 0.0
for i in range(n):
measure_type = view.measures[i] if view.measures is not None else None
if measure_type == "absolute" or (measure_type is None and i == 0):
# Absolute value: bar goes from 0 to value
bottoms[i] = 0
heights[i] = values[i]
cumulative = values[i]
elif measure_type == "total":
# Total: bar goes from 0 to cumulative + value (which is the total)
# For totals, the value in the array might be 0, so we calculate from previous cumulative
bottoms[i] = 0
heights[i] = cumulative + values[i] if values[i] != 0 else cumulative
cumulative = heights[i]
else: # relative (default for middle values)
# Relative: bar starts at previous cumulative, adds the value
bottoms[i] = cumulative
heights[i] = values[i]
cumulative = cumulative + values[i]
# Determine colors - vectorized
if view.color:
colors_list = [view.color] * n
elif view.colors:
colors_list = list(view.colors)
else:
# Default colors: green for positive, red for negative, blue for totals/absolute
measures_arr = np.asarray(view.measures) if view.measures is not None else None
if measures_arr is not None:
is_total = measures_arr == "total"
is_absolute = measures_arr == "absolute"
else:
is_total = np.zeros(n, dtype=bool)
is_absolute = np.arange(n) == 0
is_positive = values >= 0
colors_list = np.where(
is_total | is_absolute,
"#1f77b4", # blue
np.where(is_positive, "#2ca02c", "#d62728"), # green or red
).tolist()
# Draw bars - vectorized
ax.bar(
x_pos,
heights,
bottom=bottoms,
color=colors_list,
edgecolor="black",
linewidth=0.5,
)
# Set category labels
ax.set_xticks(x_pos)
ax.set_xticklabels(categories, rotation=45, ha="right")
def draw_waffle(ax: "Axes", view: WaffleView) -> None:
"""Draw a waffle chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: WaffleView containing the data to plot.
"""
categories = np.asarray(view.categories)
values = np.asarray(view.values)
total = np.sum(values)
# Calculate grid size
if view.rows is None and view.columns is None:
# Default: try to make it roughly square
grid_size = int(np.ceil(np.sqrt(total)))
rows = grid_size
columns = grid_size
elif view.rows is not None:
rows = view.rows
columns = int(np.ceil(total / rows))
elif view.columns is not None:
columns = view.columns
rows = int(np.ceil(total / columns))
else:
# Default: 10x10 grid
rows = 10
columns = 10
# Create grid
grid = np.zeros((rows, columns), dtype=int)
category_map = {}
# Fill grid with category indices
idx = 0
for cat_idx, val in enumerate(values):
category_map[cat_idx] = categories[cat_idx]
for _ in range(int(val)):
if idx < rows * columns:
row = idx // columns
col = idx % columns
grid[row, col] = cat_idx
idx += 1
# Determine colors
if view.colors is not None:
colors_list = list(view.colors)
else:
# Default blue gradient - vectorized
n_cats = len(categories)
if n_cats > 1:
color_indices = np.linspace(0.3, 1.0, n_cats)
else:
color_indices = np.array([0.65])
cmap = cm.get_cmap("Blues") # type: ignore[attr-defined]
colors_list = [cmap(idx) for idx in color_indices]
# Draw grid
for row in range(rows):
for col in range(columns):
cat_idx = grid[row, col]
color = colors_list[cat_idx] if cat_idx < len(colors_list) else "#cccccc"
rect = plt.Rectangle( # type: ignore[attr-defined]
(col, rows - row - 1),
1,
1,
facecolor=color,
edgecolor="white",
linewidth=0.5,
)
ax.add_patch(rect)
ax.set_xlim(0, columns)
ax.set_ylim(0, rows)
ax.set_aspect("equal")
ax.axis("off")
def _draw_dumbbell_range(
ax: "Axes",
categories: np.ndarray,
values1: np.ndarray,
values2: np.ndarray,
colors: list[str] | np.ndarray | None,
color: str | None,
orientation: str,
label1: str | None,
linewidth: float,
) -> None:
"""Helper function to draw dumbbell/range charts (shared logic).
Args:
ax: Matplotlib axes to draw on.
categories: Category labels.
values1: First set of values.
values2: Second set of values.
colors: Optional list of colors.
color: Optional single color.
orientation: 'h' for horizontal or 'v' for vertical.
label1: Optional label.
linewidth: Line width.
"""
n = len(categories)
# Determine colors - vectorized
if colors is not None:
color_arr = np.asarray(colors)
if len(color_arr) < n:
color_arr = np.pad(
color_arr,
(0, n - len(color_arr)),
constant_values=color_arr[-1] if len(color_arr) > 0 else "#1f77b4",
)
colors_list = color_arr[:n].tolist()
elif color:
colors_list = [color] * n
else:
colors_list = ["#1f77b4"] * n
if orientation == "v":
# Vertical orientation
pos = np.arange(n)
# Vectorized plotting
for i, (v1, v2, pos_val, col) in enumerate(
zip(values1, values2, pos, colors_list)
):
ax.plot(
[v1, v2], [pos_val, pos_val], color=col, linewidth=linewidth, zorder=1
)
ax.scatter(
[v1, v2],
[pos_val, pos_val],
s=100,
color=col,
edgecolor="white",
linewidth=1.5,
zorder=2,
)
ax.set_yticks(pos)
ax.set_yticklabels(categories)
ax.set_xlabel(label1 or "Value")
else:
# Horizontal orientation (default)
pos = np.arange(n)
# Vectorized plotting
for i, (v1, v2, pos_val, col) in enumerate(
zip(values1, values2, pos, colors_list)
):
ax.plot(
[v1, v2], [pos_val, pos_val], color=col, linewidth=linewidth, zorder=1
)
ax.scatter(
[v1, v2],
[pos_val, pos_val],
s=100,
color=col,
edgecolor="white",
linewidth=1.5,
zorder=2,
)
ax.set_yticks(pos)
ax.set_yticklabels(categories)
ax.set_xlabel(label1 or "Value")
def draw_dumbbell(ax: "Axes", view: DumbbellView) -> None:
"""Draw a dumbbell chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: DumbbellView containing the data to plot.
"""
categories = np.asarray(view.categories)
values1 = np.asarray(view.values1)
values2 = np.asarray(view.values2)
_draw_dumbbell_range(
ax=ax,
categories=categories,
values1=values1,
values2=values2,
colors=view.colors,
color=view.color,
orientation=view.orientation,
label1=view.label1,
linewidth=2.0,
)
def draw_range(ax: "Axes", view: RangeView) -> None:
"""Draw a range chart on the given axes (similar to dumbbell but different styling).
Args:
ax: Matplotlib axes to draw on.
view: RangeView containing the data to plot.
"""
categories = np.asarray(view.categories)
values1 = np.asarray(view.values1)
values2 = np.asarray(view.values2)
_draw_dumbbell_range(
ax=ax,
categories=categories,
values1=values1,
values2=values2,
colors=None, # Range uses single color
color=view.color or "#1f77b4",
orientation=view.orientation,
label1=view.label1,
linewidth=3.0, # Thicker line for range
)
def draw_lollipop(ax: "Axes", view: LollipopView) -> None:
"""Draw a lollipop chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: LollipopView containing the data to plot.
"""
categories = np.asarray(view.categories)
values = np.asarray(view.values)
n = len(categories)
color = view.color or "#1f77b4"
marker = view.marker or "o"
linewidth = view.linewidth or 1.5
pos = np.arange(n)
if view.horizontal:
# Horizontal lollipops - vectorized
# Draw lines and markers vectorized
for pos_val, val in zip(pos, values):
ax.plot(
[0, val], [pos_val, pos_val], color=color, linewidth=linewidth, zorder=1
)
ax.scatter(
values,
pos,
s=100,
color=color,
marker=marker,
edgecolor="white",
linewidth=1.5,
zorder=2,
)
ax.set_yticks(pos)
ax.set_yticklabels(categories)
ax.axvline(x=0, color="black", linewidth=0.5, linestyle="-")
else:
# Vertical lollipops (default) - vectorized
# Draw lines and markers vectorized
for pos_val, val in zip(pos, values):
ax.plot(
[pos_val, pos_val], [0, val], color=color, linewidth=linewidth, zorder=1
)
ax.scatter(
pos,
values,
s=100,
color=color,
marker=marker,
edgecolor="white",
linewidth=1.5,
zorder=2,
)
ax.set_xticks(pos)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.axhline(y=0, color="black", linewidth=0.5, linestyle="-")
def draw_slope(ax: "Axes", view: SlopeView) -> None:
"""Draw a slope chart on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: SlopeView containing the data to plot.
"""
x = np.asarray(view.x)
if len(x) != 2:
raise ValueError("Slope chart requires exactly 2 x values")
# Default colors
default_colors = np.array(
[
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
)
for i, (group_name, y_values) in enumerate(view.groups.items()):
y = np.asarray(y_values)
if len(y) != 2:
raise ValueError(f"Group '{group_name}' must have exactly 2 y values")
# Determine color - use dict.get for cleaner code
color = (view.colors or {}).get(
group_name, default_colors[i % len(default_colors)]
)
# Draw line
ax.plot(
x,
y,
color=color,
linewidth=2,
marker="o",
markersize=8,
label=group_name,
zorder=2,
)
# Calculate xlim with vectorized operations
x_range = x[1] - x[0]
ax.set_xlim(x[0] - 0.1 * x_range, x[1] + 0.1 * x_range)
def draw_metric(ax: "Axes", view: MetricView) -> None:
"""Draw a metric display on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: MetricView containing the data to display.
"""
# Clear axes and remove spines
ax.axis("off")
# Format value
value_str = (
f"{view.value:,.0f}"
if isinstance(view.value, (int, float))
else str(view.value)
)
if view.prefix:
value_str = view.prefix + value_str
if view.suffix:
value_str = value_str + view.suffix
# Format delta
delta_str = ""
if view.delta is not None:
delta_val = abs(view.delta)
delta_str = (
f"{delta_val:+,.0f}"
if isinstance(delta_val, (int, float))
else str(delta_val)
)
if view.suffix:
delta_str = delta_str + view.suffix
# Determine colors
value_color = view.value_color if view.value_color else "black"
if view.delta is not None:
if view.delta_color:
delta_color = view.delta_color
else:
delta_color = "#2ca02c" if view.delta >= 0 else "#d62728"
else:
delta_color = "black"
# Display title
ax.text(
0.5,
0.8,
view.title,
ha="center",
va="top",
fontsize=14,
color="gray",
transform=ax.transAxes,
)
# Display value
ax.text(
0.5,
0.5,
value_str,
ha="center",
va="center",
fontsize=32,
fontweight="bold",
color=value_color,
transform=ax.transAxes,
)
# Display delta
if delta_str:
ax.text(
0.5,
0.2,
delta_str,
ha="center",
va="bottom",
fontsize=18,
color=delta_color,
transform=ax.transAxes,
)
def draw_box(ax: "Axes", view: BoxView) -> None:
"""Draw a box plot on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: BoxView containing the data to plot.
"""
bp = ax.boxplot(
view.data,
labels=view.labels, # type: ignore[call-arg]
positions=view.positions,
patch_artist=True,
showmeans=view.show_means,
showfliers=view.show_outliers,
)
# Set colors - vectorized
boxes = bp["boxes"]
if view.color:
for patch in boxes:
patch.set_facecolor(view.color)
elif view.colors:
colors_arr = np.asarray(view.colors)
for i, patch in enumerate(boxes):
if i < len(colors_arr):
patch.set_facecolor(colors_arr[i])
# Style boxes - vectorized operations
for patch in boxes:
patch.set_edgecolor("black")
patch.set_linewidth(1.0)
patch.set_alpha(0.7)
# Style other elements - more Pythonic
style_elements = ["whiskers", "fliers", "means", "medians", "caps"]
for element in style_elements:
if element in bp:
for item in bp[element]:
item.set_color("black")
item.set_linewidth(1.0)
def draw_violin(ax: "Axes", view: ViolinView) -> None:
"""Draw a violin plot on the given axes.
Args:
ax: Matplotlib axes to draw on.
view: ViolinView containing the data to plot.
"""
parts = ax.violinplot(
view.data,
positions=view.positions,
showmeans=view.show_means,
showmedians=view.show_medians,
)
# Set colors - vectorized
bodies = parts["bodies"]
if view.color:
for pc in bodies: # type: ignore[attr-defined]
pc.set_facecolor(view.color) # type: ignore[attr-defined]
pc.set_alpha(0.7) # type: ignore[attr-defined]
elif view.colors:
colors_arr = np.asarray(view.colors)
for i, pc in enumerate(bodies): # type: ignore[attr-defined,arg-type]
if i < len(colors_arr):
pc.set_facecolor(colors_arr[i]) # type: ignore[attr-defined]
pc.set_alpha(0.7) # type: ignore[attr-defined]
# Style violins
for pc in bodies: # type: ignore[attr-defined]
pc.set_edgecolor("black") # type: ignore[attr-defined]
pc.set_linewidth(1.0) # type: ignore[attr-defined]
# Style other elements - more Pythonic
style_elements = ["cmeans", "cmedians", "cbars", "cmins", "cmaxes"]
for element in style_elements:
if element in parts:
parts[element].set_color("black")
parts[element].set_linewidth(1.0)
# Set labels if provided
if view.labels is not None:
positions = (
view.positions
if view.positions is not None
else np.arange(1, len(view.data) + 1)
)
ax.set_xticks(positions)
ax.set_xticklabels(view.labels)