Making Line Plots Delightful with Optimized Direct Labeling
Improve your Matplotlib line plots by replacing the default legend
The friction between data and understanding every time your reader glances back and forth between a line plot and its legend is uncomfortable. By using direct labeling, you can eliminate this cognitive overhead and make the reception of your data visualization a better, more efficient and more delightful experience.
When you have two or more lines in the same plot, it is considered good practice to include a legend explaining what each line represents (as in Figure 1). However, even better is to use direct labeling, placing the label for a line right next to it (as in Figure 2). This means less redundant ink and, more importantly, less effort for the consumer of the plot:
- Less redundant ink: direct labeling saves you from drawing the small line sections that are part of a legend
- Less effort for the consumer trying to match the visualized data with their legend:
- No back and forth between lines and labels, as they are directly juxtaposed.
- No cognitive friction due to different vertical ordering in the legend and the lines (notice how the topmost line in the example plot of Figure 1 ends up at the bottom of the default legend).
How much more effort will it take you to create a plot with direct labeling, as compared to one with the default legend? From now on, the effort will be minimal, because I am hereby offering you the code that will allow you to transform the legend of any Matplotlib line plot into direct labels with this single line of Python:
do_direct_line_labeling_with_optimized_positions(plt.gca())
When to use this
Use this direct labeling scheme for line plots with between 2 and 8 lines.
- If your line plot only has one line, you probably do not need a legend.
- If your line plot has more than 8 lines, pause and consider if you really need all these lines or if you cannot find an alternative to what one could probably describe as a “spaghetti plot”.
The proposed direct labeling might not work well for rare configurations of lines that all converge at the end.
Very long labels may be a challenge, but this is also the case for legends.
The code
"""Functions for direct labeling of line plots"""
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import minimize
from scipy.spatial import distance_matrix
def do_direct_line_labeling_with_optimized_positions(
ax, height_space_factor=1.25, dx_factor=0.05
):
"""Do direct line labeling with optimized vertical positions
Args:
ax (matplotlib.axes.Axes): axes to plot on
height_space_factor (float): factor to control the minimum space between labels
1.0 means a minimum space between labels equal to the height of the text
dx_factor (float): controlling the horizontal distance between line and label
"""
lines = ax.get_lines()
line_data_list = []
# Collect data from existing lines
for line in lines:
x_data, y_data = line.get_data()
if isinstance(y_data, np.ma.MaskedArray):
y_data = y_data.data
y_last = y_data[~np.isnan(y_data)][-1]
line_data_list.append(
{
"label": line.get_label(),
"x_last": x_data[-1],
"y_last": y_last,
"color": line.get_color(),
}
)
# Optimize label positions
ref_positions = np.array([line_data["y_last"] for line_data in line_data_list])
height_data_units = get_text_height(ax)
opt_label_positions = optimize_1d_label_positions_simple(
ref_positions, min_interlabel_distance=height_data_units * height_space_factor
)
# Plot labels
dx_txt = x_data[-1] - x_data[-int(dx_factor * len(x_data))]
for i_line, line_data in enumerate(line_data_list):
line_color = line_data["color"]
x_last = line_data["x_last"]
y_last = line_data["y_last"]
txt_color = line_color
ax.text(
x_last + dx_txt,
opt_label_positions[i_line],
line_data["label"],
ha="left",
va="center",
color=txt_color,
)
x_link = [x_last, x_last + dx_txt]
y_link = [y_last, opt_label_positions[i_line]]
x_link, y_link = cubic_spline(x_link, y_link, n_segments=20)
ax.plot(
x_link,
y_link,
color=line_color,
ls="--",
lw=0.5,
)
ax.get_legend().remove()
ax.set_xlim([ax.get_xlim()[0], x_last + dx_txt])
def do_simple_direct_line_labeling(ax):
"""Do direct line labeling in the simplest way possible
Writing the label on the last point of the corresponding line.
This may result in overlapping labels if multiple lines end near each other.
"""
lines = ax.get_lines()
for line in lines:
x_data, y_data = line.get_data()
if isinstance(y_data, np.ma.MaskedArray):
y_data = y_data.data
y_last = y_data[~np.isnan(y_data)][-1]
line_color = line.get_color()
txt_color = line_color
ax.text(
x_data[-1],
y_last,
" " + line.get_label(),
ha="left",
va="center",
color=txt_color,
)
legend = ax.get_legend()
if legend is not None:
legend.remove()
def optimize_1d_label_positions_simple(
point_positions, min_interlabel_distance, label_dist_weight=30.0
):
"""Optimize label positions along a line
Args:
point_positions (np.ndarray): array of point positions with shape (n_points,)
min_interlabel_distance (float): minimum interlabel distance
a strong penalty is applied to interlabel distances lower than this
ideally this should be slightly higher than the height of labels
(see get_text_height function)
label_dist_weight (float): weight of the penalty for distance between labels and points
tune this parameter to balance the two penalties if needed
Returns:
opt_label_positions: optimized label positions, a numpy ndarray of shape (n_points, 2)
"""
there_are_overlaps = len(set(point_positions)) < len(point_positions)
if there_are_overlaps:
print(
"Warning: some points are exactly on top of each other",
"it is recommended to avoid that, for instance by adding some small noise",
)
# work with normalized positions
coord_range = point_positions.max(axis=0) - point_positions.min(axis=0)
normed_label_positions = point_positions / coord_range
rel_min_interlabel_distance = min_interlabel_distance / coord_range
n_pts = normed_label_positions.shape[0]
def high_when_below(min_value, value):
"""Return a high value when below a certain threshold, otherwise a low value"""
return np.exp(5 * (min_value - value))
def fun_to_minimize(x_1d):
"""Objective function to minimize
Args:
x_1d (np.ndarray): array of label positions with shape (n_points,)
Returns:
float: value of the objective, the lower the better
"""
# a) distances between points and the respective labels
label_distance = label_dist_weight * np.sum(
(x_1d - normed_label_positions) ** 2
)
# b) pairwise distances between labels
interlabel_dist_matrix = distance_matrix(
x_1d.reshape(-1, 1), x_1d.reshape(-1, 1)
)
interlabel_dist_array = interlabel_dist_matrix[np.triu_indices(n_pts, 1)]
interlabel_dist_loss = np.sum(
high_when_below(rel_min_interlabel_distance, interlabel_dist_array)
)
return label_distance + interlabel_dist_loss
x_init = normed_label_positions.reshape(-1)
res = minimize(fun_to_minimize, x_init, method="L-BFGS-B")
opt_label_positions = res.x * coord_range
return opt_label_positions
def get_text_height(ax, **text_kwargs):
"""Get the height of a text object in data units"""
fig = plt.gcf()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
text_obj = ax.text(
0.5 * (xlim[0] + xlim[1]),
0.5 * (ylim[0] + ylim[1]),
"This is a dummy text",
**text_kwargs
)
# Draw the canvas to make sure the text object is rendered
fig.canvas.draw()
# Get the bounding box of the text in display coordinates (pixels)
bbox_display = text_obj.get_window_extent(renderer=fig.canvas.get_renderer())
# Convert the bbox from display coordinates to data coordinates
inv = ax.transData.inverted()
bbox_data = mtransforms.Bbox(inv.transform(bbox_display))
height_data_units = bbox_data.height
text_obj.remove()
return height_data_units
def cubic_spline(
x_from_to,
y_from_to,
n_segments: int = 20,
):
"""Get the coordinates of a (horizontally oriented) cubic spline
from a point to another point
Args:
x_from_to: a numpy array with two x coordinates
y_from_to: a numpy array with two y coordinates
n_segments: number of straight segments with which to draw the spline
Returns:
two numpy arrays
"""
x_t = np.linspace(0, 1, n_segments + 1)
y_coord = y_from_to[0] + (y_from_to[1] - y_from_to[0]) * (x_t**2 * (3 - 2 * x_t))
x_coord = x_from_to[0] + (x_from_to[1] - x_from_to[0]) * x_t
return x_coord, y_coord
How it works
The provided code enhances Matplotlib line plots by replacing traditional legends with direct labels positioned optimally next to the lines. For those of you that like to understand the code they run, here is a short summary of the approach, function by function.
- Direct labeling the naive way: The
do_simple_direct_line_labeling
function shows the naive way to do direct labeling, placing labels at the end of each line (as illustrated in Figure 3). This can result in overlaps between labels when the endpoints of multiple lines are too close. - Vertical position optimization: The
optimize_1d_label_positions_simple
function optimizes the labels’ vertical positions to avoid overlap and maintain readability by minimizing the distance between labels and the corresponding lines. - Text height calculation: The
get_text_height function
calculates the height of text objects in data units to ensure proper spacing and avoid vertical overlap. - Cubic spline: The
cubic_spline
function creates smooth curves for connecting lines to their labels, enhancing visual appeal yet a tiny bit.
More on the optimization magic
A key aspect of the proposed approach is to treat label positioning as an optimization problem.
We want labels close to their corresponding line endpoints, but we also need sufficient spacing between labels to avoid visual clutter.
We solve this by minimizing a cost function with the following two competing objectives:
Fidelity penalty: penalizing labels straying too far from their “natural” position (the line endpoint)
Overlap penalty: penalizing labels getting too close to each other
For the overlap penalty, we use an exponential penalty function that becomes increasingly harsh as labels approach each other
As a result of combining the two objectives, the labels repel each other while being pulled toward their target positions.
For less than 10 lines, the time for scipy.optimize.minimize
to converge should not be significant.
Further reading
An alternative to the direct labeling proposed here is “inline labeling” directly on the lines, as implemented in the matplotlib-label-lines package.
I am quite a believer in the potential of mathematical optimization for enhancing data visualization. If the present post got your interest, you might also want to read about optimizing the label positions of scatter plots.