Making Line Plots Delightful with Optimized Direct Labeling

Making Line Plots Delightful with Optimized Direct Labeling

June 15, 2025

Improve your Matplotlib line plots by replacing the default legend

The friction between data and understanding every time your reader glances back and forth between a line plot and its legend is uncomfortable. By using direct labeling, you can eliminate this cognitive overhead and make the reception of your data visualization a better, more efficient and more delightful experience.

Example of line plot with a legend (Matplotlib default)
Figure 1: Example of line plot with a legend (Matplotlib default, imaginary data and image by author).

When you have two or more lines in the same plot, it is considered good practice to include a legend explaining what each line represents (as in Figure 1). However, even better is to use direct labeling, placing the label for a line right next to it (as in Figure 2). This means less redundant ink and, more importantly, less effort for the consumer of the plot:

  • Less redundant ink: direct labeling saves you from drawing the small line sections that are part of a legend
  • Less effort for the consumer trying to match the visualized data with their legend:
    • No back and forth between lines and labels, as they are directly juxtaposed.
    • No cognitive friction due to different vertical ordering in the legend and the lines (notice how the topmost line in the example plot of Figure 1 ends up at the bottom of the default legend).

Example of line plot with optimized direct labeling as proposed by the author
Figure 2: Example of line plot with direct labeling and optimized label positions, as proposed by the author.

How much more effort will it take you to create a plot with direct labeling, as compared to one with the default legend? From now on, the effort will be minimal, because I am hereby offering you the code that will allow you to transform the legend of any Matplotlib line plot into direct labels with this single line of Python:

do_direct_line_labeling_with_optimized_positions(plt.gca())

When to use this

  • Use this direct labeling scheme for line plots with between 2 and 8 lines.

    • If your line plot only has one line, you probably do not need a legend.
    • If your line plot has more than 8 lines, pause and consider if you really need all these lines or if you cannot find an alternative to what one could probably describe as a “spaghetti plot”.
  • The proposed direct labeling might not work well for rare configurations of lines that all converge at the end.

  • Very long labels may be a challenge, but this is also the case for legends.

The code

"""Functions for direct labeling of line plots"""

import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import minimize
from scipy.spatial import distance_matrix

def do_direct_line_labeling_with_optimized_positions(
    ax, height_space_factor=1.25, dx_factor=0.05
):
    """Do direct line labeling with optimized vertical positions

    Args:
        ax (matplotlib.axes.Axes): axes to plot on
        height_space_factor (float): factor to control the minimum space between labels
            1.0 means a minimum space between labels equal to the height of the text
        dx_factor (float): controlling the horizontal distance between line and label
    """
    lines = ax.get_lines()
    line_data_list = []
    # Collect data from existing lines
    for line in lines:
        x_data, y_data = line.get_data()
        if isinstance(y_data, np.ma.MaskedArray):
            y_data = y_data.data
        y_last = y_data[~np.isnan(y_data)][-1]
        line_data_list.append(
            {
                "label": line.get_label(),
                "x_last": x_data[-1],
                "y_last": y_last,
                "color": line.get_color(),
            }
        )

    # Optimize label positions
    ref_positions = np.array([line_data["y_last"] for line_data in line_data_list])
    height_data_units = get_text_height(ax)
    opt_label_positions = optimize_1d_label_positions_simple(
        ref_positions, min_interlabel_distance=height_data_units * height_space_factor
    )

    # Plot labels
    dx_txt = x_data[-1] - x_data[-int(dx_factor * len(x_data))]

    for i_line, line_data in enumerate(line_data_list):
        line_color = line_data["color"]
        x_last = line_data["x_last"]
        y_last = line_data["y_last"]
        txt_color = line_color

        ax.text(
            x_last + dx_txt,
            opt_label_positions[i_line],
            line_data["label"],
            ha="left",
            va="center",
            color=txt_color,
        )
        x_link = [x_last, x_last + dx_txt]
        y_link = [y_last, opt_label_positions[i_line]]
        x_link, y_link = cubic_spline(x_link, y_link, n_segments=20)
        ax.plot(
            x_link,
            y_link,
            color=line_color,
            ls="--",
            lw=0.5,
        )
    ax.get_legend().remove()
    ax.set_xlim([ax.get_xlim()[0], x_last + dx_txt])

def do_simple_direct_line_labeling(ax):
    """Do direct line labeling in the simplest way possible

    Writing the label on the last point of the corresponding line.
    This may result in overlapping labels if multiple lines end near each other.
    """
    lines = ax.get_lines()
    for line in lines:
        x_data, y_data = line.get_data()
        if isinstance(y_data, np.ma.MaskedArray):
            y_data = y_data.data
        y_last = y_data[~np.isnan(y_data)][-1]
        line_color = line.get_color()
        txt_color = line_color
        ax.text(
            x_data[-1],
            y_last,
            " " + line.get_label(),
            ha="left",
            va="center",
            color=txt_color,
        )
    legend = ax.get_legend()
    if legend is not None:
        legend.remove()


def optimize_1d_label_positions_simple(
    point_positions, min_interlabel_distance, label_dist_weight=30.0
):
    """Optimize label positions along a line

    Args:
        point_positions (np.ndarray): array of point positions with shape (n_points,)
        min_interlabel_distance (float): minimum interlabel distance
            a strong penalty is applied to interlabel distances lower than this
            ideally this should be slightly higher than the height of labels
            (see get_text_height function)
        label_dist_weight (float): weight of the penalty for distance between labels and points
            tune this parameter to balance the two penalties if needed

    Returns:
        opt_label_positions: optimized label positions, a numpy ndarray of shape (n_points, 2)
    """

    there_are_overlaps = len(set(point_positions)) < len(point_positions)
    if there_are_overlaps:
        print(
            "Warning: some points are exactly on top of each other",
            "it is recommended to avoid that, for instance by adding some small noise",
        )

    # work with normalized positions
    coord_range = point_positions.max(axis=0) - point_positions.min(axis=0)
    normed_label_positions = point_positions / coord_range
    rel_min_interlabel_distance = min_interlabel_distance / coord_range
    n_pts = normed_label_positions.shape[0]

    def high_when_below(min_value, value):
        """Return a high value when below a certain threshold, otherwise a low value"""
        return np.exp(5 * (min_value - value))

    def fun_to_minimize(x_1d):
        """Objective function to minimize

        Args:
            x_1d (np.ndarray): array of label positions with shape (n_points,)

        Returns:
            float: value of the objective, the lower the better
        """

        # a) distances between points and the respective labels
        label_distance = label_dist_weight * np.sum(
            (x_1d - normed_label_positions) ** 2
        )
        # b) pairwise distances between labels
        interlabel_dist_matrix = distance_matrix(
            x_1d.reshape(-1, 1), x_1d.reshape(-1, 1)
        )
        interlabel_dist_array = interlabel_dist_matrix[np.triu_indices(n_pts, 1)]
        interlabel_dist_loss = np.sum(
            high_when_below(rel_min_interlabel_distance, interlabel_dist_array)
        )

        return label_distance + interlabel_dist_loss

    x_init = normed_label_positions.reshape(-1)
    res = minimize(fun_to_minimize, x_init, method="L-BFGS-B")
    opt_label_positions = res.x * coord_range
    return opt_label_positions


def get_text_height(ax, **text_kwargs):
    """Get the height of a text object in data units"""
    fig = plt.gcf()
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    text_obj = ax.text(
        0.5 * (xlim[0] + xlim[1]),
        0.5 * (ylim[0] + ylim[1]),
        "This is a dummy text",
        **text_kwargs
    )

    # Draw the canvas to make sure the text object is rendered
    fig.canvas.draw()

    # Get the bounding box of the text in display coordinates (pixels)
    bbox_display = text_obj.get_window_extent(renderer=fig.canvas.get_renderer())

    # Convert the bbox from display coordinates to data coordinates
    inv = ax.transData.inverted()
    bbox_data = mtransforms.Bbox(inv.transform(bbox_display))
    height_data_units = bbox_data.height

    text_obj.remove()
    return height_data_units


def cubic_spline(
    x_from_to,
    y_from_to,
    n_segments: int = 20,
):
    """Get the coordinates of a (horizontally oriented) cubic spline
        from a point to another point

    Args:
        x_from_to: a numpy array with two x coordinates
        y_from_to: a numpy array with two y coordinates
        n_segments: number of straight segments with which to draw the spline

    Returns:
        two numpy arrays
    """
    x_t = np.linspace(0, 1, n_segments + 1)
    y_coord = y_from_to[0] + (y_from_to[1] - y_from_to[0]) * (x_t**2 * (3 - 2 * x_t))
    x_coord = x_from_to[0] + (x_from_to[1] - x_from_to[0]) * x_t

    return x_coord, y_coord

How it works

The provided code enhances Matplotlib line plots by replacing traditional legends with direct labels positioned optimally next to the lines. For those of you that like to understand the code they run, here is a short summary of the approach, function by function.

  • Direct labeling the naive way: The do_simple_direct_line_labeling function shows the naive way to do direct labeling, placing labels at the end of each line (as illustrated in Figure 3). This can result in overlaps between labels when the endpoints of multiple lines are too close.
  • Vertical position optimization: The optimize_1d_label_positions_simple function optimizes the labels’ vertical positions to avoid overlap and maintain readability by minimizing the distance between labels and the corresponding lines.
  • Text height calculation: The get_text_height function calculates the height of text objects in data units to ensure proper spacing and avoid vertical overlap.
  • Cubic spline: The cubic_spline function creates smooth curves for connecting lines to their labels, enhancing visual appeal yet a tiny bit.

Example of line plot with naive direct labeling
Figure 3: Example of line plot with direct labeling done the naive way, i.e. without label position optimization, which can result in overlaps. Notice how dangerously close the pink and blue labels are.

More on the optimization magic

A key aspect of the proposed approach is to treat label positioning as an optimization problem.

We want labels close to their corresponding line endpoints, but we also need sufficient spacing between labels to avoid visual clutter.

We solve this by minimizing a cost function with the following two competing objectives:

  • Fidelity penalty: penalizing labels straying too far from their “natural” position (the line endpoint)

  • Overlap penalty: penalizing labels getting too close to each other

For the overlap penalty, we use an exponential penalty function that becomes increasingly harsh as labels approach each other As a result of combining the two objectives, the labels repel each other while being pulled toward their target positions. For less than 10 lines, the time for scipy.optimize.minimize to converge should not be significant.

Further reading

  • An alternative to the direct labeling proposed here is “inline labeling” directly on the lines, as implemented in the matplotlib-label-lines package.

  • I am quite a believer in the potential of mathematical optimization for enhancing data visualization. If the present post got your interest, you might also want to read about optimizing the label positions of scatter plots.