Source code for towbintools.plotting.plotting_structure

import os
from collections import defaultdict

import numpy as np
import polars as pl
import yaml

from towbintools.data_analysis import compute_series_at_time_classified
from towbintools.foundation.file_handling import read_filemap
from towbintools.foundation.utils import find_best_string_match
from towbintools.foundation.worm_features import get_features_to_compute_at_molt

FEATURES_TO_COMPUTE_AT_MOLT = get_features_to_compute_at_molt()

# THIS PART HANDLES THE PROCESSING OF THE EXPERIMENT FILEMAP AND THE CREATION OF THE PLOTTING STRUCTURE


[docs] def build_conditions(conditions_yaml: str | dict) -> list: """ Process a conditions YAML file structured in a factorized way. Refer to the documentation for the expected format. The function will return a list of dictionaries, each representing a condition with its parameters. Parameters: conditions_yaml (str or dict): Path to the YAML file or a dictionary containing the conditions. Returns: list: A list of dictionaries, each representing a condition with its parameters. Raises: ValueError: If the conditions YAML file is not structured correctly. """ if isinstance(conditions_yaml, str): with open(conditions_yaml) as file: conditions_dict = yaml.safe_load(file) file.close() elif isinstance(conditions_yaml, dict): conditions_dict = conditions_yaml conditions = [] condition_id = 0 for condition in conditions_dict["conditions"]: condition = { key: [val] if not isinstance(val, list) else val for key, val in condition.items() } lengths = {len(val) for val in condition.values()} if len(lengths) > 2 or (len(lengths) == 2 and 1 not in lengths): raise ValueError( "All lists in the condition must have the same length or be of length 1." ) max_length = max(lengths) for i in range(max_length): condition_dict = { key: val[0] if len(val) == 1 else val[i] for key, val in condition.items() } condition_dict["condition_id"] = condition_id conditions.append(condition_dict) condition_id += 1 return conditions
def _add_conditions_to_filemap( experiment_filemap: pl.DataFrame, conditions: list[dict], ) -> pl.DataFrame: """ Add condition metadata columns to the experiment filemap. Rows are matched to conditions via ``"point_range"`` (inclusive integer range) or ``"pad"`` (exact pad string) keys in each condition dict. Unmatched rows receive ``None`` for the added columns. Parameters: experiment_filemap (polars.DataFrame) : Filemap with at least ``"Point"`` and optionally ``"Pad"`` columns. conditions (list[dict]) : List of condition dicts as returned by ``build_conditions``; each must contain ``"point_range"`` or ``"pad"``. Returns: polars.DataFrame : Filemap with one new column per condition metadata key. """ if "Pad" in experiment_filemap.columns: new_columns = experiment_filemap.select(pl.col("Point"), pl.col("Pad")) else: new_columns = experiment_filemap.select(pl.col("Point")) for condition in conditions: # Get condition rows mask if "point_range" in condition: point_range = condition["point_range"] # handle both single range and list of ranges if isinstance(point_range[0], list): # multiple point ranges - build expressions list mask_exprs = [ (pl.col("Point") >= pr[0]) & (pl.col("Point") <= pr[1]) for pr in point_range ] # Combine with or operator condition_mask = mask_exprs[0] for expr in mask_exprs[1:]: condition_mask = condition_mask | expr else: # single point range condition_mask = (pl.col("Point") >= point_range[0]) & ( pl.col("Point") <= point_range[1] ) filter_key = "point_range" elif "pad" in condition: pad = condition["pad"] condition_mask = pl.col("Pad") == pad filter_key = "pad" else: print( "Condition does not contain 'point_range' or 'pad' key, impossible to add condition to filemap, skipping." ) continue # Extract the condition attributes to add (excluding the filter key) conditions_to_add = {k: v for k, v in condition.items() if k != filter_key} column_ops = [] for key, val in conditions_to_add.items(): # For new columns, add with null values first if needed if key not in new_columns.columns: new_columns = new_columns.with_columns(pl.lit(None).alias(key)) # Build the column operation but don't apply it yet column_ops.append( pl.when(condition_mask) .then(pl.lit(val)) .otherwise(pl.col(key)) .alias(key) ) # Add the new columns to the DataFrame if column_ops: new_columns = new_columns.with_columns(column_ops) if "Pad" in new_columns.columns: new_columns = new_columns.drop(["Point", "Pad"]) else: new_columns = new_columns.drop(["Point"]) # Apply the new columns to the experiment filemap if not new_columns.is_empty(): experiment_filemap = experiment_filemap.with_columns(new_columns) return experiment_filemap def _get_custom_columns(filemap: pl.DataFrame) -> list[str]: """ Identify non-standard columns in the filemap for pass-through to the plotting structure. Standard columns (time, ecdysis events, QC, analysis, and known worm features) are excluded; anything else is considered a custom column. Parameters: filemap (polars.DataFrame) : Experiment filemap. Returns: list[str] : Names of custom columns not covered by the standard schema. """ usual_columns = [ "Time", "ExperimentTime", "Point", "raw", "HatchTime", "M1", "M2", "M3", "M4", "Arrest", "Ignore", "Death", "Dead", ] usual_columns.extend([column for column in filemap.columns if "qc" in column]) for feature in FEATURES_TO_COMPUTE_AT_MOLT: usual_columns.extend( [column for column in filemap.columns if feature in column] ) usual_columns.extend([column for column in filemap.columns if "analysis" in column]) feature_columns = [] for feature in FEATURES_TO_COMPUTE_AT_MOLT: feature_columns.extend( [ column for column in filemap.columns if feature in column and "_at_" not in column ] ) custom_columns = [ column for column in filemap.columns if column not in usual_columns ] return custom_columns # TODO: Instead of doing it for each condition, do it for all conditions at once, then split def _process_condition_id_plotting_structure( experiment_dir: str, experiment_filemap: pl.DataFrame, filemap_path: str, organ_channels: dict[str, str], conditions_keys: list[str], condition_id: int, custom_columns: list[str] | None = None, recompute_values_at_molt: bool = False, rescale_n_points: int = 100, ) -> dict: """ Build the condition dict for a single condition ID. Extracts time series, ecdysis events, QC classifications, worm feature arrays, and optional custom columns for all points belonging to ``condition_id``. Parameters: experiment_dir (str) : Path to the experiment directory. experiment_filemap (polars.DataFrame) : Full filemap with condition metadata already merged in. filemap_path (str) : Path to the CSV filemap file (stored in the output dict). organ_channels (dict[str, str]) : Mapping of organ name → filemap channel prefix (e.g. ``{"body": "ch2", "pharynx": "ch1"}``). conditions_keys (list[str]) : Metadata keys from the conditions YAML to copy into the condition dict. condition_id (int) : Condition identifier to process. custom_columns (list[str] or None) : Additional filemap columns to include. Defaults to ``None``. recompute_values_at_molt (bool) : If ``True``, recompute at-molt values even when they are already present in the filemap. Defaults to ``False``. rescale_n_points (int) : Unused; reserved for future use. Defaults to ``100``. Returns: dict : Condition dict containing time arrays, ecdysis indices, QC, feature series, and all metadata keys. """ condition_df = experiment_filemap.filter(pl.col("condition_id") == condition_id) condition_dict = {} for key in conditions_keys: condition_dict[key] = condition_df.select(pl.col(key))[0].item() ( time, experiment_time, ecdysis_index, ecdysis_time_step, ecdysis_experiment_time, larval_stage_durations_time_step, larval_stage_durations_experiment_time, ) = _get_time_ecdysis_and_durations(condition_df) death, arrest = _get_death_and_arrest(condition_df) n_points = condition_df.select(pl.col("Point")).n_unique() condition_dict.update( { "condition_id": int(condition_dict["condition_id"]), "ecdysis_index": ecdysis_index, "ecdysis_time_step": ecdysis_time_step, "larval_stage_durations_time_step": larval_stage_durations_time_step, "ecdysis_experiment_time": ecdysis_experiment_time, "ecdysis_experiment_time_hours": ecdysis_experiment_time / 3600, "larval_stage_durations_experiment_time": larval_stage_durations_experiment_time, "larval_stage_durations_experiment_time_hours": larval_stage_durations_experiment_time / 3600, "experiment": np.full((n_points, 1), experiment_dir), "filemap_path": np.full((n_points, 1), filemap_path), "point": condition_df.select( pl.col("Point").unique(maintain_order=True) ).to_numpy(), "time": time, "experiment_time": experiment_time, "experiment_time_hours": experiment_time / 3600, "death": death, "arrest": arrest, } ) qc_columns = [ col for col in condition_df.columns if "qc" in col or "worm_type" in col ] for organ in organ_channels.keys(): organ_channel = organ_channels[organ] organ_columns = [ col for col in condition_df.columns if col.startswith(organ_channel) ] # remove any column with _at_ in it organ_columns = [col for col in organ_columns if "_at_" not in col] organ_qc_columns = [ col for col in organ_columns if "qc" in col or "worm_type" in col ] if len(organ_qc_columns) == 0: organ_qc_columns.append(qc_columns[0]) renamed_organ_qc_columns = [ col.replace(organ_channel, organ) for col in organ_qc_columns ] renamed_organ_qc_columns = [ col.replace("worm_type", "qc") for col in renamed_organ_qc_columns ] for column, renamed_column in zip(organ_qc_columns, renamed_organ_qc_columns): qc_values = separate_column_by_point(condition_df, column) condition_dict[renamed_column] = qc_values if column in organ_columns: organ_columns.remove(column) # get the columns that contain the interesting features organ_feature_columns = [] for feature in FEATURES_TO_COMPUTE_AT_MOLT: organ_feature_columns.extend( [col for col in organ_columns if feature in col] ) for organ_feature_column in organ_feature_columns: # rename the column to remove the organ channel renamed_feature_organ_column = organ_feature_column.replace( organ_channel, organ ) qc_key = find_best_string_match( renamed_feature_organ_column, renamed_organ_qc_columns ) qc = condition_dict[qc_key] condition_dict[renamed_feature_organ_column] = separate_column_by_point( condition_df, organ_feature_column ) condition_dict[ f"{renamed_feature_organ_column}_at_ecdysis" ] = _get_values_at_molt( condition_df, organ_feature_column, ecdysis_time_step ) condition_dict = _compute_values_at_molt( condition_dict, renamed_feature_organ_column, qc, recompute_values_at_molt=recompute_values_at_molt, ) # Add custom columns if they exist if custom_columns is not None: for custom_column in custom_columns: if custom_column in condition_df.columns: condition_dict[custom_column] = separate_column_by_point( condition_df, custom_column ) return condition_dict
[docs] def build_plotting_struct( experiment_dir: str, filemap_path: str, conditions_yaml_path: str | dict, organ_channels: dict[str, str] = {"body": "ch2", "pharynx": "ch1"}, recompute_values_at_molt: bool = False, rescale_n_points: int = 100, ) -> tuple[list[dict], list[dict]]: """ Build the plotting structure for a single experiment. Reads the filemap CSV and the conditions YAML, merges condition metadata into the filemap, and produces a list of condition dicts ready for plotting. Parameters: experiment_dir (str) : Path to the experiment directory. filemap_path (str) : Path to the CSV filemap file. conditions_yaml_path (str or dict) : Path to the conditions YAML file or an already-parsed conditions dict. organ_channels (dict[str, str]) : Mapping of organ name → filemap channel prefix. Defaults to ``{"body": "ch2", "pharynx": "ch1"}``. recompute_values_at_molt (bool) : If ``True``, recompute at-molt values even when they are already stored in the filemap. Defaults to ``False``. rescale_n_points (int) : Unused; reserved for future use. Defaults to ``100``. Returns: tuple[list[dict], list[dict]] : ``(conditions_struct, conditions_info)`` where ``conditions_struct`` is the full list of condition dicts sorted by ``condition_id``, and ``conditions_info`` contains only the metadata keys from the YAML, also sorted by ``condition_id``. """ experiment_filemap = read_filemap(filemap_path) custom_columns = _get_custom_columns(experiment_filemap) if not custom_columns: custom_columns = None conditions = build_conditions(conditions_yaml_path) conditions_keys = list(conditions[0].keys()) # remove 'point_range' and 'pad' from the conditions keys if they are present if "point_range" in conditions_keys: conditions_keys.remove("point_range") if "pad" in conditions_keys: conditions_keys.remove("pad") experiment_filemap = _add_conditions_to_filemap( experiment_filemap, conditions, ) experiment_filemap.write_csv("test.csv") # if ExperimentTime is not present in the filemap, add it if "ExperimentTime" not in experiment_filemap.columns: experiment_filemap = experiment_filemap.with_columns( pl.lit(np.nan).alias("ExperimentTime") ) # remove rows where condition_id is null experiment_filemap = experiment_filemap.filter(~pl.col("condition_id").is_null()) # set molts that should be ignored to NaN if "Ignore" in experiment_filemap.columns: experiment_filemap = remove_ignored_molts(experiment_filemap) # remove rows where Ignore is True if "Ignore" in experiment_filemap.columns: experiment_filemap = experiment_filemap.filter(~pl.col("Ignore")) conditions_struct = [] for condition_id in ( experiment_filemap.select(pl.col("condition_id")) .unique(maintain_order=True) .to_numpy() .squeeze() ): condition_dict = _process_condition_id_plotting_structure( experiment_dir, experiment_filemap, filemap_path, organ_channels, conditions_keys, condition_id, custom_columns=custom_columns, recompute_values_at_molt=recompute_values_at_molt, rescale_n_points=rescale_n_points, ) conditions_struct.append(condition_dict) conditions_info = [ {key: condition[key] for key in conditions_keys} for condition in conditions_struct ] # sort the conditions and conditions_info by condition_id conditions_struct = sorted(conditions_struct, key=lambda x: x["condition_id"]) conditions_info = sorted(conditions_info, key=lambda x: x["condition_id"]) return conditions_struct, conditions_info
def _compute_larval_stage_duration(ecdysis_array: np.ndarray) -> np.ndarray: """ Compute the duration of each larval stage from consecutive ecdysis times. Parameters: ecdysis_array (array-like) : Ecdysis time values for one worm (length 5: Hatch, M1, M2, M3, M4). Returns: np.ndarray : Durations of shape ``(4,)`` (L1–L4); NaN where either boundary event is NaN. """ durations = np.full(len(ecdysis_array) - 1, np.nan) for i, (start, end) in enumerate(zip(ecdysis_array[:-1], ecdysis_array[1:])): # check if start or end is NaN if np.isnan(start) or np.isnan(end): durations[i] = np.nan else: durations[i] = end - start return durations def _get_time_ecdysis_and_durations(filemap: pl.DataFrame) -> tuple: """ Extract per-point time arrays, ecdysis indices, and larval stage durations. Parameters: filemap (polars.DataFrame) : Condition filemap containing ``"Point"``, ``"Time"``, ``"ExperimentTime"``, and the five ecdysis columns (``"HatchTime"``, ``"M1"``–``"M4"``). Returns: tuple : - time (np.ndarray) : Shape ``(n_points, n_frames)`` time-step values. - experiment_time (np.ndarray) : Shape ``(n_points, n_frames)`` experiment time in seconds. - ecdysis_index (np.ndarray) : Shape ``(n_points, 5)`` frame indices of ecdysis events. - ecdysis_time_step (np.ndarray) : Shape ``(n_points, 5)`` ecdysis times in raw time steps. - ecdysis_experiment_time (np.ndarray) : Shape ``(n_points, 5)`` ecdysis times in seconds. - larval_stage_durations_time_step (np.ndarray) : Shape ``(n_points, 4)`` stage durations in raw time steps. - larval_stage_durations_experiment_time (np.ndarray) : Shape ``(n_points, 4)`` stage durations in seconds. """ all_ecdysis_time_step = [] all_ecdysis_index = [] all_durations_time_step = [] all_ecdysis_experiment_time = [] all_durations_experiment_time = [] ecdysis_columns = ["HatchTime", "M1", "M2", "M3", "M4"] column_list = ["Point", "Time", "ExperimentTime"] + ecdysis_columns filemap = filemap.select(pl.col(column_list)) ecdysis_values = ( filemap.group_by("Point", maintain_order=True) .agg(pl.col(ecdysis_columns).first()) .drop("Point") .cast(pl.Float64) ) time = separate_column_by_point(filemap, "Time").astype(float) experiment_time = separate_column_by_point(filemap, "ExperimentTime").astype(float) for i in range(len(ecdysis_values)): ecdysis = ecdysis_values[i].to_numpy().squeeze() time_of_point = time[i] experiment_time_of_point = experiment_time[i] ecdysis_index = [ float(np.where(time_of_point == ecdysis)[0][0]) if ecdysis in time_of_point else np.nan for ecdysis in ecdysis ] ecdysis_experiment_time = [ experiment_time_of_point[int(index)] if not np.isnan(index) else np.nan for index in ecdysis_index ] all_ecdysis_time_step.append(ecdysis) all_ecdysis_index.append(ecdysis_index) all_ecdysis_experiment_time.append(ecdysis_experiment_time) larval_stage_durations = _compute_larval_stage_duration(ecdysis) larval_stage_durations_experiment_time = _compute_larval_stage_duration( ecdysis_experiment_time ) all_durations_time_step.append(larval_stage_durations) all_durations_experiment_time.append(larval_stage_durations_experiment_time) return ( time, experiment_time, np.array(all_ecdysis_index), np.array(all_ecdysis_time_step), np.array(all_ecdysis_experiment_time), np.array(all_durations_time_step), np.array(all_durations_experiment_time), ) def _get_values_at_molt( filemap: pl.DataFrame, column: str, ecdysis_time_step: np.ndarray, ) -> np.ndarray: """ Retrieve precomputed at-ecdysis values for a column from the filemap. Reads the ``{column}_at_{event}`` columns for each of the five ecdysis events. Entries corresponding to NaN ecdysis times are set to NaN. Parameters: filemap (polars.DataFrame) : Condition filemap with the at-ecdysis columns. column (str) : Base column name (without the ``_at_*`` suffix). ecdysis_time_step (np.ndarray) : Ecdysis time array of shape ``(n_points, 5)`` used to mask NaN events. Returns: np.ndarray : At-ecdysis values of shape ``(n_points, 5)``; NaN where the corresponding ecdysis event is NaN. """ ecdysis = ["HatchTime", "M1", "M2", "M3", "M4"] columns_at_ecdysis = [f"{column}_at_{e}" for e in ecdysis] column_list = ["Point"] + columns_at_ecdysis # if the column_at_ecdysis does not exist, create it for col in columns_at_ecdysis: if col not in filemap.columns: filemap = filemap.with_columns( pl.lit(np.nan).alias(col), ) filemap = filemap.select(pl.col(column_list)) values_at_ecdysis = ( ( filemap.group_by("Point", maintain_order=True) .agg(pl.col(columns_at_ecdysis).first()) .drop("Point") .cast(pl.Float64) ) .to_numpy() .squeeze() ) # handle a edge case where there is only one point in the filemap if values_at_ecdysis.ndim == 1: values_at_ecdysis = values_at_ecdysis[np.newaxis, :] # Set all values at molt at the same index as nan ecdysis to nan nan_mask = np.isnan(ecdysis_time_step) values_at_ecdysis[nan_mask] = np.nan return values_at_ecdysis def _get_death_and_arrest(filemap: pl.DataFrame) -> tuple[np.ndarray, np.ndarray]: """ Extract per-point death time and arrest flag from the filemap. If ``"Death"`` or ``"Arrest"`` columns are absent, they are filled with NaN and False respectively. Parameters: filemap (polars.DataFrame) : Condition filemap. Returns: tuple[np.ndarray, np.ndarray] : - death (np.ndarray) : Shape ``(n_points, 1)`` death times (float, NaN if alive). - arrest (np.ndarray) : Shape ``(n_points, 1)`` arrest flags (bool). """ column_list = ["Point", "Death", "Arrest"] if "Death" not in filemap.columns: filemap = filemap.with_columns( pl.lit(np.nan).alias("Death"), ) if "Arrest" not in filemap.columns: filemap = filemap.with_columns( pl.lit(False).alias("Arrest"), ) filemap = filemap.select(pl.col(column_list)) death_and_arrest = ( ( filemap.group_by("Point", maintain_order=True) .agg(pl.col(["Death", "Arrest"]).first()) .drop("Point") ) .to_numpy() .squeeze() ) # handle a edge case where there is only one point in the filemap if death_and_arrest.ndim == 1: death_and_arrest = death_and_arrest[np.newaxis, :] death = death_and_arrest[:, 0].astype(float) arrest = death_and_arrest[:, 1].astype(bool) return death[:, np.newaxis], arrest[:, np.newaxis] def _compute_values_at_molt( condition_dict: dict, column: str, worm_types: np.ndarray, recompute_values_at_molt: bool = False, ) -> dict: """ Fill missing (or recompute all) at-molt values for a column using the time series. Missing values are identified as NaN entries in ``{column}_at_ecdysis`` where the corresponding ecdysis time is not NaN. ``compute_series_at_time_classified`` is called to interpolate the value from the raw time series. Parameters: condition_dict (dict) : Condition dict containing the column, its at-ecdysis counterpart, time arrays, and ecdysis times. column (str) : Base column name. worm_types (np.ndarray) : QC classification array of shape ``(n_points, n_frames)`` used by the interpolation. recompute_values_at_molt (bool) : If ``True``, recompute all at-molt values regardless of existing values. Defaults to ``False``. Returns: dict : The updated ``condition_dict`` with ``{column}_at_ecdysis`` filled in. """ column_at_molt = f"{column}_at_ecdysis" values_at_molt = condition_dict[column_at_molt] updated_values_at_molt = values_at_molt.copy() nan_indexes_values_mask = np.isnan(values_at_molt) experiment_time = condition_dict["experiment_time_hours"] if (~np.isnan(experiment_time)).any(): time = condition_dict["experiment_time_hours"] ecdysis = condition_dict["ecdysis_experiment_time_hours"] else: time = condition_dict["time"] ecdysis = condition_dict["ecdysis_index"] non_nan_indexes_ecdysis_mask = np.invert(np.isnan(ecdysis)) if recompute_values_at_molt: values_to_recompute_mask = non_nan_indexes_ecdysis_mask else: values_to_recompute_mask = ( nan_indexes_values_mask & non_nan_indexes_ecdysis_mask ) for i in range(len(values_to_recompute_mask)): mask = values_to_recompute_mask[i] idx_values_to_recompute = np.where(mask)[0] if len(idx_values_to_recompute) == 0: continue ecdys = ecdysis[i][idx_values_to_recompute] recomputed_values = compute_series_at_time_classified( condition_dict[column][i], ecdys, time[i], worm_types[i], ) updated_values_at_molt[i][idx_values_to_recompute] = recomputed_values condition_dict[column_at_molt] = updated_values_at_molt return condition_dict
[docs] def separate_column_by_point(filemap: pl.DataFrame, column: str) -> np.ndarray: """ Pivot a long-format filemap column into a 2-D array indexed by point. Each row of the output corresponds to one imaging point; columns correspond to time frames. Shorter points are right-padded with NaN (numeric) or ``"error"`` (string) to the length of the longest point. Parameters: filemap (polars.DataFrame) : Filemap with at least ``"Point"`` and ``column`` columns. column (str) : Column to pivot. Returns: np.ndarray : Array of shape ``(n_points, max_n_frames)`` sorted by point. """ points = ( filemap.select(pl.col("Point").unique(maintain_order=True).sort()) .to_numpy() .flatten() ) filemap_points = filemap.select(pl.col("Point"), pl.col(column)) point_dataframes = filemap_points.partition_by("Point", maintain_order=True) sample = point_dataframes[0].select(pl.col(column)).head(1).item() is_string = isinstance(sample, str) or ( hasattr(sample, "dtype") and np.issubdtype(sample.dtype, np.str_) ) max_height = max(point_df.height for point_df in point_dataframes) if is_string: result = np.full((len(points), max_height), "error", dtype=object) else: result = np.full((len(points), max_height), np.nan) for i, point_df in enumerate(point_dataframes): point_column = point_df.select(pl.col(column)).to_numpy().squeeze() result[i, : len(point_column)] = point_column return result
[docs] def remove_ignored_molts(filemap: pl.DataFrame) -> pl.DataFrame: """ Set molt-time columns to null for time points flagged with ``Ignore=True``. For each ecdysis column, any molt time that coincides with an ignored ``(Point, Time)`` pair is replaced with ``None``. Parameters: filemap (polars.DataFrame) : Filemap with ecdysis columns and an optional ``"Ignore"`` column. Returns: polars.DataFrame : Updated filemap; unchanged if ``"Ignore"`` is absent. """ molt_columns = ["HatchTime", "M1", "M2", "M3", "M4"] # Only process if "Ignore" column exists if "Ignore" not in filemap.columns: return filemap # Get all rows where Ignore is True ignored_rows = filemap.filter(pl.col("Ignore")) # Build a set of (Point, Time) pairs to ignore ignored_points = ignored_rows.select(pl.col("Point")).to_numpy().flatten() ignored_times = ignored_rows.select(pl.col("Time")).to_numpy().flatten() ignored_pairs = set(zip(ignored_points, ignored_times)) # For each molt column, set to None where (Point, molt_time) is in ignored_pairs for col in molt_columns: molt_times = filemap.select(pl.col(col)).to_numpy().flatten() points = filemap.select(pl.col("Point")).to_numpy().flatten() mask = np.array( [ (p, mt) in ignored_pairs if mt is not None and not (isinstance(mt, float) and np.isnan(mt)) else False for p, mt in zip(points, molt_times) ] ) if mask.any(): filemap = filemap.with_columns( pl.when(pl.Series(mask)).then(None).otherwise(pl.col(col)).alias(col) ) return filemap
[docs] def remove_unwanted_info(conditions_info: list[dict]) -> list[dict]: """ Remove ``"description"`` and ``"condition_id"`` from each entry in conditions_info. Parameters: conditions_info (list[dict]) : List of condition metadata dicts. Returns: list[dict] : The modified list with the two keys removed in place. """ for condition in conditions_info: if "description" in condition.keys(): condition.pop("description") if "condition_id" in condition.keys(): condition.pop("condition_id") return conditions_info
[docs] def combine_experiments( filemap_paths: list[str], config_paths: list[str], experiment_dirs: list[str] | None = None, organ_channels: list[dict] | dict = [{"body": "ch2", "pharynx": "ch1"}], recompute_values_at_molt: bool = False, rescale_n_points: int = 100, ) -> list[dict]: """ Build and merge the plotting structure from multiple experiments. Each experiment is processed independently via ``build_plotting_struct``. Conditions with identical metadata (after removing ``description`` and ``condition_id``) are merged by concatenating their numpy arrays along axis 0. Condition IDs in the merged structure are reassigned sequentially. Parameters: filemap_paths (list[str]) : Paths to the CSV filemap files, one per experiment. config_paths (list[str]) : Paths to the conditions YAML files, one per experiment. experiment_dirs (list[str] or None) : Experiment directories; defaults to the parent directory of each filemap when ``None``. organ_channels (list[dict] or dict) : Organ-to-channel mapping(s). A single dict is broadcast to all experiments; otherwise one dict per experiment is expected. Defaults to ``[{"body": "ch2", "pharynx": "ch1"}]``. recompute_values_at_molt (bool) : Passed to ``build_plotting_struct``. Defaults to ``False``. rescale_n_points (int) : Passed to ``build_plotting_struct``. Defaults to ``100``. Returns: list[dict] : Merged conditions_struct with sequential condition IDs. Raises: ValueError : If the length of ``organ_channels`` is neither 1 nor equal to the number of experiments. """ all_conditions_struct = [] condition_info_merge_list = [] conditions_info_keys = set() condition_id_counter = 0 if isinstance(organ_channels, dict): organ_channels = [organ_channels] if len(organ_channels) == 1: organ_channels = organ_channels * len(filemap_paths) elif len(organ_channels) != len(filemap_paths): raise ValueError( "Number of organ channels must be equal to the number of experiments." ) # Process each experiment for i, (filemap_path, config_path, organ_channel) in enumerate( zip(filemap_paths, config_paths, organ_channels) ): experiment_dir = ( experiment_dirs[i] if experiment_dirs else os.path.dirname(filemap_path) ) conditions_struct, conditions_info = build_plotting_struct( experiment_dir, filemap_path, config_path, organ_channels=organ_channel, recompute_values_at_molt=recompute_values_at_molt, rescale_n_points=rescale_n_points, ) # Process conditions for this experiment for condition in conditions_struct: condition["condition_id"] = condition_id_counter condition_id_counter += 1 all_conditions_struct.append(condition) # Process condition info experiment_conditions_info = remove_unwanted_info(conditions_info) condition_info_merge_list.extend(experiment_conditions_info) conditions_info_keys.update( *[condition.keys() for condition in experiment_conditions_info] ) # Merge conditions based on their info condition_dict = defaultdict(list) for i, condition_info in enumerate(condition_info_merge_list): key = frozenset(condition_info.items()) condition_dict[key].append(i) merged_conditions_struct = [] for indices in condition_dict.values(): base_condition = all_conditions_struct[indices[0]] for idx in indices[1:]: for key, value in all_conditions_struct[idx].items(): if key not in conditions_info_keys: if isinstance(value, np.ndarray): if key not in base_condition: base_condition[key] = value continue if value.shape[1] > base_condition[key].shape[1]: base_condition[key] = np.pad( base_condition[key], ( (0, 0), (0, value.shape[1] - base_condition[key].shape[1]), ), mode="constant", constant_values=np.nan, ) elif value.shape[1] < base_condition[key].shape[1]: value = np.pad( value, ( (0, 0), (0, base_condition[key].shape[1] - value.shape[1]), ), mode="constant", constant_values=np.nan, ) try: base_condition[key] = np.concatenate( (base_condition[key], value), axis=0 ) except ValueError as e: print(f"Could not concatenate {key}: {e}") merged_conditions_struct.append(base_condition) # # Sort and reassign condition IDs # merged_conditions_struct.sort(key=lambda x: x['condition_id']) for i, condition in enumerate(merged_conditions_struct): condition["condition_id"] = i return merged_conditions_struct