Skip to content

Scatter Plot

This module provides functionality for creating scatter plots and bubble charts from pandas DataFrames.

It is designed to visualize relationships between variables, highlight distributions, and compare different categories using scatter points with optional variable sizing for bubble chart functionality.

Core Features

  • Flexible X-Axis Handling: Uses an index or a specified x-axis column (x_col) for plotting.
  • Multiple Scatter Groups: Supports plotting multiple columns (value_col) or groups (group_col).
  • Bubble Chart Support: Variable point sizes via size_col and size_scale parameters.
  • Point Labels: Text labels with automatic positioning to avoid overlaps.
  • Dynamic Color Mapping: Automatically selects a colormap based on the number of groups.

Use Cases

  • Category-Based Scatter Plots: Compare different categories using scatter points.
  • Bubble Charts: Visualize three dimensions of data with x, y positions and point sizes.
  • Labeled Scatter Plots: Identify specific data points with text labels (e.g., product names, store IDs).

Limitations and Warnings

  • Pre-Aggregated Data Required: Data should be pre-aggregated before being passed to the function.
  • Label Limitations: Point labels are not supported when value_col is a list (raises ValueError).
  • Size Column Requirements: size_col must contain numeric, non-negative values.

plot(df, value_col, x_label=None, y_label=None, title=None, x_col=None, group_col=None, size_col=None, size_scale=1.0, ax=None, source_text=None, legend_title=None, move_legend_outside=False, label_col=None, label_kwargs=None, **kwargs)

Plots a scatter chart for the given value_col over x_col or index, with optional grouping by group_col.

Parameters:

Name Type Description Default
df DataFrame or Series

The dataframe or series to plot.

required
value_col str or list of str

The column(s) to plot.

required
x_label str

The x-axis label.

None
y_label str

The y-axis label.

None
title str

The title of the plot.

None
x_col str

The column to be used as the x-axis. If None, the index is used.

None
group_col str

The column used to define different scatter groups.

None
size_col str

The column name containing values to determine point sizes. If None, all points have uniform size. Creates bubble charts when specified. When used with multiple value_col columns, the same size values apply to all series.

None
size_scale float

Scaling factor for point sizes. Default: 1.0. Actual size = size_col_value * size_scale.

1.0
ax Axes

Matplotlib axes object to plot on.

None
source_text str

The source text to add to the plot.

None
legend_title str

The title of the legend.

None
move_legend_outside bool

Move the legend outside the plot.

False
label_col str

Column name containing text labels for each point. Not supported when value_col is a list. Defaults to None.

None
label_kwargs dict

Keyword arguments passed to textalloc.allocate(). Common options: textsize, nbr_candidates, min_distance, max_distance, draw_lines. By default, draw_lines=False to avoid lines connecting labels to points. Defaults to None.

None
**kwargs Any

Additional keyword arguments for matplotlib scatter function.

{}

Returns:

Name Type Description
SubplotBase SubplotBase

The matplotlib axes object.

Raises:

Type Description
ValueError

If value_col is a list and group_col is provided (which causes ambiguity in plotting).

ValueError

If label_col is provided when value_col is a list.

KeyError

If label_col doesn't exist in DataFrame.

KeyError

If size_col doesn't exist in DataFrame.

ValueError

If size_col contains non-numeric or negative values.

ValueError

If size_scale is not positive when size_col is specified.

Source code in openretailscience/plots/scatter.py
def plot(
    df: pd.DataFrame | pd.Series,
    value_col: str | list[str],
    x_label: str | None = None,
    y_label: str | None = None,
    title: str | None = None,
    x_col: str | None = None,
    group_col: str | None = None,
    size_col: str | None = None,
    size_scale: float = 1.0,
    ax: Axes | None = None,
    source_text: str | None = None,
    legend_title: str | None = None,
    move_legend_outside: bool = False,
    label_col: str | None = None,
    label_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,  # noqa: ANN401
) -> SubplotBase:
    """Plots a scatter chart for the given `value_col` over `x_col` or index, with optional grouping by `group_col`.

    Args:
        df (pd.DataFrame or pd.Series): The dataframe or series to plot.
        value_col (str or list of str): The column(s) to plot.
        x_label (str, optional): The x-axis label.
        y_label (str, optional): The y-axis label.
        title (str, optional): The title of the plot.
        x_col (str, optional): The column to be used as the x-axis. If None, the index is used.
        group_col (str, optional): The column used to define different scatter groups.
        size_col (str, optional): The column name containing values to determine point sizes.
            If None, all points have uniform size. Creates bubble charts when specified.
            When used with multiple value_col columns, the same size values apply to all series.
        size_scale (float, optional): Scaling factor for point sizes. Default: 1.0.
            Actual size = size_col_value * size_scale.
        ax (Axes, optional): Matplotlib axes object to plot on.
        source_text (str, optional): The source text to add to the plot.
        legend_title (str, optional): The title of the legend.
        move_legend_outside (bool, optional): Move the legend outside the plot.
        label_col (str, optional): Column name containing text labels for each point.
            Not supported when value_col is a list. Defaults to None.
        label_kwargs (dict, optional): Keyword arguments passed to textalloc.allocate().
            Common options: textsize, nbr_candidates, min_distance, max_distance, draw_lines.
            By default, draw_lines=False to avoid lines connecting labels to points.
            Defaults to None.
        **kwargs: Additional keyword arguments for matplotlib scatter function.

    Returns:
        SubplotBase: The matplotlib axes object.

    Raises:
        ValueError: If `value_col` is a list and `group_col` is provided (which causes ambiguity in plotting).
        ValueError: If `label_col` is provided when `value_col` is a list.
        KeyError: If `label_col` doesn't exist in DataFrame.
        KeyError: If `size_col` doesn't exist in DataFrame.
        ValueError: If `size_col` contains non-numeric or negative values.
        ValueError: If `size_scale` is not positive when `size_col` is specified.
    """
    if isinstance(df, pd.Series):
        df = df.to_frame()

    if isinstance(value_col, list) and group_col:
        raise ValueError("Cannot use both a list for `value_col` and a `group_col`. Choose one.")

    if label_col is not None:
        if isinstance(value_col, list):
            raise ValueError(
                "label_col is not supported when value_col is a list. "
                "Please use a single value_col or create separate plots.",
            )

        if label_col not in df.columns:
            msg = f"label_col '{label_col}' not found in DataFrame"
            raise KeyError(msg)

    # Validate size parameters and resolve conflicting kwargs
    kwargs = _handle_size_params(df, size_col, size_scale, kwargs)

    if group_col is None:
        pivot_df = df.set_index(x_col if x_col is not None else df.index)[
            [value_col] if isinstance(value_col, str) else value_col
        ]
    else:
        pivot_df = (
            df.pivot(columns=group_col, values=value_col)
            if x_col is None
            else df.pivot(index=x_col, columns=group_col, values=value_col)
        )

    is_multi_scatter = (group_col is not None) or (isinstance(value_col, list) and len(value_col) > 1)

    num_colors = len(pivot_df.columns) if is_multi_scatter else 1
    default_colors = get_plot_colors(num_colors)

    # Handle color parameter - can be single color or list of colors
    color = kwargs.pop("color", default_colors)
    colors = [color] * num_colors if not isinstance(color, list) else color

    # Process size data if size_col is specified
    size_data = _process_size_data(df, size_col, size_scale, x_col, group_col)

    ax = ax or plt.gca()
    alpha = kwargs.pop("alpha", 0.7)

    _create_scatter_plot(ax, pivot_df, colors, size_data, group_col, is_multi_scatter, alpha, **kwargs)

    # Add labels if requested
    if label_col is not None:
        _add_point_labels(
            ax=ax,
            df=df,
            value_col=value_col,
            label_col=label_col,
            x_col=x_col,
            group_col=group_col,
            label_kwargs=label_kwargs,
        )

    ax = gu.standard_graph_styles(
        ax=ax,
        title=title,
        x_label=x_label,
        y_label=y_label,
        legend_title=legend_title,
        move_legend_outside=move_legend_outside,
    )

    if source_text is not None:
        gu.add_source_text(ax=ax, source_text=source_text)

    return gu.standard_tick_styles(ax)