Source code for towbintools.plotting.curves

import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np

from .utils_plotting import build_legend
from .utils_plotting import get_colors
from .utils_plotting import set_scale
from towbintools.data_analysis import rescale_and_aggregate
from towbintools.data_analysis.time_series import (
    smooth_series_classified,
)
from towbintools.foundation.utils import find_best_string_match


[docs] def plot_aggregated_series( conditions_struct: list, series_column: str | list[str], conditions_to_plot: list[int], x: str = "time", experiment_time: bool = True, aggregation: str = "mean", n_points: int = 100, time_step: int = 10, log_scale: bool = True, colors: list | dict | None = None, legend: dict | None = None, x_axis_label: str | None = None, y_axis_label: str | None = None, xlim: tuple[float, float] | None = None, ) -> matplotlib.figure.Figure: """ Plot the time-rescaled aggregated series with 95% confidence intervals. Series are rescaled to a common time axis via ``rescale_and_aggregate``, then plotted as a solid line (mean or median) with a shaded 95% CI band. If ``series_column`` is a list, all columns are overlaid on the same axes. Parameters: conditions_struct (list) : List of condition dicts. series_column (str or list[str]) : Key(s) of the measurement series to plot. conditions_to_plot (list[int]) : Indices of conditions to include. x (str) : X-axis variable. ``"time"`` uses rescaled hours; ``"percentage"`` uses development completion (0–100 %). Defaults to ``"time"``. experiment_time (bool) : If ``True``, use absolute experiment time (hours); otherwise use time-step index scaled by ``time_step``. Defaults to ``True``. aggregation (str) : Aggregation function; ``"mean"`` or ``"median"``. Defaults to ``"mean"``. n_points (int) : Number of resampled points per larval stage. Defaults to ``100``. time_step (int) : Minutes per frame, used to convert time-step indices to hours when ``experiment_time=False``. Defaults to ``10``. log_scale (bool) : If ``True``, set the y-axis to log scale. Defaults to ``True``. colors (list or dict or None) : Color spec passed to ``get_colors``. Defaults to ``None``. legend (dict or None) : Legend spec passed to ``build_legend``. Defaults to ``None``. x_axis_label (str or None) : X-axis label; auto-generated when ``None``. Defaults to ``None``. y_axis_label (str or None) : Y-axis label; falls back to ``series_column`` when ``None``. Defaults to ``None``. xlim (tuple[float, float] or None) : X-axis limits ``(xmin, xmax)`` used to crop the plotted range. Defaults to ``None``. Returns: matplotlib.figure.Figure : The generated figure. Raises: ValueError : If ``x`` is not ``"time"`` or ``"percentage"``. """ color_palette = get_colors(conditions_to_plot, colors) def plot_single_series(column: str): for i, condition_id in enumerate(conditions_to_plot): condition_dict = conditions_struct[condition_id] if experiment_time: time = condition_dict["experiment_time_hours"] larval_stage_durations = condition_dict[ "larval_stage_durations_experiment_time_hours" ] else: time = condition_dict["time"] larval_stage_durations = condition_dict[ "larval_stage_durations_time_step" ] qc_keys = [key for key in condition_dict.keys() if "qc" in key] if len(qc_keys) == 1: qc_key = qc_keys[0] else: qc_key = find_best_string_match(column, qc_keys) rescaled_time, aggregated_series, _, ste_series = rescale_and_aggregate( condition_dict[column], time, condition_dict["ecdysis_index"], larval_stage_durations, condition_dict[qc_key], aggregation=aggregation, n_points=n_points, ) ci_lower = aggregated_series - 1.96 * ste_series ci_upper = aggregated_series + 1.96 * ste_series if not experiment_time: rescaled_time = rescaled_time * time_step / 60 label = build_legend(condition_dict, legend) if x == "time": x_values = rescaled_time elif x == "percentage": x_values = np.linspace(0, 100, len(rescaled_time)) else: raise ValueError( f"Invalid x value: {x}. Must be 'time' or 'percentage'." ) if xlim is not None: x_values_not_in_xlim = (x_values < xlim[0]) | (x_values > xlim[1]) x_values = x_values[~x_values_not_in_xlim] aggregated_series = aggregated_series[~x_values_not_in_xlim] ci_lower = ci_lower[~x_values_not_in_xlim] ci_upper = ci_upper[~x_values_not_in_xlim] plt.plot(x_values, aggregated_series, color=color_palette[i], label=label) plt.fill_between( x_values, ci_lower, ci_upper, color=color_palette[i], alpha=0.2 ) if isinstance(series_column, list): for column in series_column: plot_single_series(column) else: plot_single_series(series_column) # remove duplicate labels handles, labels = plt.gca().get_legend_handles_labels() by_label = dict(zip(labels, handles)) plt.legend(by_label.values(), by_label.keys()) plt.yscale("log" if log_scale else "linear") if y_axis_label is not None: plt.ylabel(y_axis_label) else: plt.ylabel(series_column) if x_axis_label is not None: plt.xlabel(x_axis_label) else: plt.xlabel("time (h)" if x == "time" else "development completion (%)") fig = plt.gcf() plt.show() return fig
[docs] def plot_growth_curves_individuals( conditions_struct: list, column: str, conditions_to_plot: list[int], share_y_axis: bool, log_scale: bool | tuple | list = True, figsize: tuple[float, float] | None = None, legend: dict | None = None, y_axis_label: str | None = None, cut_after: float | None = None, ) -> matplotlib.figure.Figure: """ Plot smoothed individual-worm growth curves with one subplot per condition. Each worm's series is smoothed via ``smooth_series_classified`` and plotted from hatch time. Worms without a detected hatch event are skipped. Parameters: conditions_struct (list) : List of condition dicts. column (str) : Key of the raw measurement series. conditions_to_plot (list[int]) : Indices of conditions to include. share_y_axis (bool) : If ``True``, all subplots share the same y-axis range. log_scale (bool or tuple or list) : Scale spec passed to ``set_scale``. Defaults to ``True`` (log y-axis only). figsize (tuple[float, float] or None) : Figure size ``(width, height)`` in inches. Defaults to ``(n_conditions * 8, 10)``. legend (dict or None) : Legend spec used to generate subplot titles. Defaults to ``None``. y_axis_label (str or None) : Y-axis label; falls back to ``column`` when ``None``. Defaults to ``None``. cut_after (float or None) : Truncate worm traces at this experiment time (hours after hatch). ``None`` keeps full traces. Defaults to ``None``. Returns: matplotlib.figure.Figure : The generated figure. """ if figsize is None: figsize = (len(conditions_to_plot) * 8, 10) fig, ax = plt.subplots( 1, len(conditions_to_plot), figsize=figsize, sharey=share_y_axis ) for i, condition_id in enumerate(conditions_to_plot): condition_dict = conditions_struct[condition_id] qc_keys = [key for key in condition_dict.keys() if "qc" in key] if len(qc_keys) == 1: qc_key = qc_keys[0] else: qc_key = find_best_string_match(column, qc_keys) for j in range(len(condition_dict[column])): time = condition_dict["experiment_time"][j] / 3600 data = condition_dict[column][j] qc = condition_dict[qc_key][j] hatch = condition_dict["ecdysis_time_step"][j][0] hatch_experiment_time = ( condition_dict["ecdysis_experiment_time"][j][0] / 3600 ) if not np.isnan(hatch): hatch = int(hatch) if cut_after is not None: indexes_to_cut = np.where(time > cut_after)[0] if len(indexes_to_cut) > 0: data = data[: indexes_to_cut[0] + 1] time = time[: indexes_to_cut[0] + 1] qc = qc[: indexes_to_cut[0] + 1] time = time[hatch:] time = time - hatch_experiment_time data = data[hatch:] qc = qc[hatch:] filtered_data = smooth_series_classified( data, time, qc, ) label = build_legend(condition_dict, legend) try: ax[i].plot(time, filtered_data) set_scale(ax[i], log_scale) except TypeError: ax.plot(time, filtered_data) set_scale(ax, log_scale) try: ax[i].title.set_text(label) except TypeError: ax.title.set_text(label) # Set labels if y_axis_label is not None: try: ax[0].set_ylabel(y_axis_label) ax[0].set_xlabel("Time (h)") except TypeError: ax.set_ylabel(y_axis_label) ax.set_xlabel("Time (h)") else: try: ax[0].set_ylabel(column) ax[0].set_xlabel("Time (h)") except TypeError: ax.set_ylabel(column) ax.set_xlabel("Time (h)") fig = plt.gcf() plt.show() return fig