from __future__ import annotations
import logging
import pickle
from typing import Literal
from typing import TYPE_CHECKING
import cv2
import numpy as np
import skimage
from numpy.typing import NDArray
from scipy.interpolate import BSpline
from scipy.interpolate import splprep
from scipy.sparse import csgraph
from scipy.spatial import distance
from skimage.feature import hessian_matrix
from skimage.feature import hessian_matrix_eigvals
from skimage.util import img_as_ubyte
if TYPE_CHECKING:
from os import PathLike
from scipy import sparse
# spacing as seen in Ti2 image metadata
DEFAULT_SPACING = np.asarray((2.0, 0.325, 0.325))
RIDGE_DETECTOR: cv2.ximgproc.RidgeDetectionFilter = (
cv2.ximgproc.RidgeDetectionFilter.create()
)
[docs]
class Warper:
"""Class for generating, storing, and applying splines fit to midlines of oblong objects for use in image/coordinate warping
e.g. on worms, worm pharynxes
individual splines are fit to each plane independently to account for movement during stack acquisition
"""
def __init__(self, length: float, width: float, splines: list[BSpline | None]):
"""
Store the warp geometry and per-plane splines.
Parameters:
length (float): Length of the straightened object (in pixels along the
midline axis).
width (float): Width of the straightened object (in pixels across the
midline axis).
splines (list[BSpline or None]): Fitted spline for each z-plane.
``None`` entries correspond to empty or unprocessable planes.
"""
self.length = length
self.width = width
self.splines = splines
[docs]
@classmethod
def from_img(cls, img: NDArray, mask: NDArray[np.bool_]) -> Warper:
"""
Construct a Warper by extracting midlines from a mask and fitting splines.
Extracts the midline for each plane of ``mask``, ensures consistent
head/tail orientation across planes, fits splines, and aligns splines
using image-based registration.
Parameters:
img (NDArray): Grayscale image (2D or 3D stack) used for spline
alignment across planes.
mask (NDArray[bool]): Binary mask with the same shape as ``img``.
Returns:
Warper: A Warper instance ready for image warping.
Raises:
ValueError: If ``img`` and ``mask`` shapes do not match, or if the
mask is invalid (see :func:`validate_mask`).
"""
if img.shape != mask.shape:
raise ValueError(
f"Image and Mask shapes don't match: img.shape = {img.shape}, mask.shape = {mask.shape}"
)
if len(img.shape) == 2:
img = img[np.newaxis, ...]
mask = mask[np.newaxis, ...]
validate_mask(mask)
mask = mask.astype(bool)
# midlines
mls = []
# distance transforms
dts = []
for mask_plane in mask:
if not mask_plane.any():
mls.append(None)
dts.append(None)
continue
ml, dt = extract_midline(mask_plane, return_dt=True)
mls.append(ml)
dts.append(dt)
mls = _handle_flips(mls)
splines = []
spline_lengths = []
# spline parameter "u" a la "uv-mapping"
spline_us = []
for ml in mls:
# need at least 4 points for spline fitting
if (ml is None) or (len(ml) < 4):
splines.append(None)
spline_lengths.append(None)
spline_us.append(None)
continue
# raw spline only smooths and does not account for inter-pixel distances
s = len(ml) - np.sqrt(2 * len(ml))
raw_spline = _fit_spline(ml, parametrisation=None, smoothing_factor=s)
# refined spline accounts for inter-pixel distances
refined_spline, spline_length, refined_us = _refine_spline(raw_spline)
splines.append(refined_spline)
spline_lengths.append(spline_length)
spline_us.append(refined_us)
# dt.max() is a radius, want diameter
# TODO: is mean sufficient? ideally it's max, but occasionally poor masks REALLY inflate max of dts
# I sometimes see clipping of elements wider than the mean (especially pharynx bulbs), so I multiply it by 1.2
worm_width = (np.mean([dt.max() for dt in dts if dt is not None]) * 1.2) * 2
worm_length = max(length for length in spline_lengths if length is not None)
# if provided image is 2D, no alignment is necessary
if len(splines) == 1:
return cls(worm_length, worm_width, splines)
# if 3D, proceed to spline alignment
# handle the missing masks/splines
not_missing = [item is not None for item in splines]
good_img = img[not_missing]
aligned_splines, worm_length = _align_splines(
good_img,
worm_width,
worm_length,
[spline for spline in splines if spline is not None],
[us for us in spline_us if us is not None],
)
splines = []
warnings = []
for i, is_ok in enumerate(not_missing):
if is_ok:
splines.append(aligned_splines.pop(0))
else:
splines.append(None)
warnings.append(i)
if len(warnings) > 0:
logging.warning(
f"Masks for planes {warnings} were empty and no splines could be fit. Warped images will be blank"
)
return cls(worm_length, worm_width, splines)
[docs]
def to_pickle(self, file: PathLike):
"""save object instace to a pickle file"""
with open(file, "wb") as f:
pickle.dump(self, f)
[docs]
@classmethod
def from_pickle(cls, file: PathLike) -> Warper:
"""load a Warper object from a pickle file
WARNING do not load untrusted pickle files as they can execute arbitrary code
"""
with open(file, "rb") as f:
wt = pickle.load(f)
return wt
[docs]
def warp_2D_img(
self,
img2D: NDArray,
spline_i: int,
scale_factor: float | tuple[float, float] = 1,
mirror: bool = False,
interpolation_order: Literal[0, 1, 2, 3, 4, 5] = 1,
preserve_range: bool = True,
preserve_dtype: bool = True,
) -> NDArray:
"""Warp a single 2D plane using self.splines[i]
self.splines[i] is the spline fit to mask[i] in the provided mask during object instantiation,
so usually spline_i should match z of chosen 2D plane from zstack"""
# scale_factor: how many pixels in straightened image correspond to one pixel in raw
# mirror: whether to flip the axis perpendicular to spline length.original MATLAB code flipped
# practically speaking, if mirror=True, the left-right axis of a worm as seen in image will be flipped
# interpolation_order:
# 0: Nearest-neighbor (e.g. use for masks/labels)
# 1: Bi-linear (default)
# 2: Bi-quadratic
# 3: Bi-cubic
# 4: Bi-quartic
# 5: Bi-quintic
# preserve_range: internally, for interpolation, images are always converted to float. if true, value range is the same as input. if false, value range converted to [0,1] according to image conversion conventions
# preserve_dtype: if False, output is float as generated during interpolation. if True, output will be converted to input's dtype
spline = self.splines[spline_i]
return _warp_2D_img(
img2D,
spline,
self.width,
self.length,
scale_factor,
mirror,
interpolation_order,
preserve_range,
preserve_dtype,
)
[docs]
def warp_3D_img(
self,
img3D: NDArray,
scale_factor: float | tuple[float, float] = 1,
mirror: bool = False,
interpolation_order: Literal[0, 1, 2, 3, 4, 5] = 1,
preserve_range: bool = True,
preserve_dtype: bool = True,
) -> NDArray:
"""Warp full 3D image using generated all generated splines"""
if len(self.splines) != len(img3D):
raise ValueError(
f"Incompatible number of planes in 3D image: {len(img3D)} planes and {len(self.splines)} stored splines"
)
return _warp_3D_img(
img3D,
self.splines,
self.width,
self.length,
scale_factor,
mirror,
interpolation_order,
preserve_range,
preserve_dtype,
)
[docs]
def rescaled_3D_img(
self,
img3D: NDArray,
scale_factor: float = 1,
spacing: tuple[float, float, float] = DEFAULT_SPACING,
normalise_spacing: bool = True,
mirror: bool = False,
interpolation_order: Literal[0, 1, 2, 3, 4, 5] = 1,
preserve_range: bool = True,
preserve_dtype: bool = True,
return_final_spacing: bool = False,
) -> NDArray | tuple(NDArray, tuple[float, float, float]):
"""Warp and rescale raw 3D image according to spacing such that resultant spacing is the same in each axis
spacing is set by microscope acquisition settings and can be found in (raw) image metadata post-experiment
default spacing is (2.0, 0.325, 0.325), but note, this is specific for each experiment type
if normalise_spacing is True, spacing is normalised by 'spacing/spacing.min()', keeping pixel length in the respective spacing.min() dimension the same size as provided image. Resultant final_spacing will be spacing.min() / scale_factor in each dimension
if normalise_spacing is False, pixel lengths in resultant image will be 1/scale_factor units of the underlying physical unit (e.g. micrometers), and so resultant final_spacing will be 1/scale_factor in each dimension
"""
# scale_factor: how many pixels in straightened image correspond to one pixel in raw
# mirror: whether to flip left/right relative to raw. original MATLAB code flipped,
# so what is "left" on the raw image became "right" on the straightened
# spacing: metadata from microscope acquisition mode - how moving by one pixel in each dimension relates to physical measurements
# interpolation_order:
# 0: Nearest-neighbor (e.g. use for masks/labels)
# 1: Bi-linear (default)
# 2: Bi-quadratic
# 3: Bi-cubic
# 4: Bi-quartic
# 5: Bi-quintic
# don't preserve dtype to keep float as image will be rescaled in next step, and excessive dtype conversions will introduce rounding errors
spacing = np.array(spacing)
final_spacing = spacing.copy()
final_spacing[1:] /= scale_factor
warped_img = self.warp_3D_img(
img3D,
scale_factor,
mirror,
interpolation_order,
preserve_range,
preserve_dtype=False,
)
if normalise_spacing:
scale = final_spacing / final_spacing.min()
else:
scale = final_spacing
rescaled = skimage.transform.rescale(
warped_img,
scale,
interpolation_order,
preserve_range=preserve_range,
)
if preserve_dtype:
rescaled = rescaled.astype(img3D.dtype)
if return_final_spacing:
final_spacing = final_spacing / scale
return rescaled, final_spacing
else:
return rescaled
[docs]
def validate_mask(mask: NDArray) -> None:
"""
provided mask should be:
- 2D or 3D
- not empty
- contain a single object (connected component) on each 2D plane.
"""
if mask.ndim == 1:
raise ValueError("Provided mask is 1D, but should either be 2D or 3D")
elif mask.ndim == 2:
mask = mask[np.newaxis, ...]
elif mask.ndim > 3:
raise ValueError(
f"Provided mask has {mask.ndim} dimensions, but should either be 2D or 3D"
)
if np.sum(mask) == 0:
raise ValueError("Mask is empty so cannot be computationally straightened")
for i, plane in enumerate(mask):
num_labels, _ = cv2.connectedComponents(img_as_ubyte(plane), connectivity=4)
if num_labels > 2:
raise ValueError(
f"2D Plane {i} of provided mask contains more than 1 connected component"
)
####################
# midline extraction
####################
def _detect_ridges(gray: NDArray, sigma: float = 1) -> tuple[NDArray, NDArray]:
"""Finds ridge points in a grayscale image. Copied from here : https://stackoverflow.com/questions/48727914/how-to-use-ridge-detection-filter-in-opencv"""
H_elems = hessian_matrix(
gray,
sigma=sigma,
order="rc",
use_gaussian_derivatives=False,
)
maxima_ridges, minima_ridges = hessian_matrix_eigvals(H_elems)
return maxima_ridges, minima_ridges
def _medial_axis_transform(
mask2D: NDArray[np.bool_],
percentile: float = 90,
ridge_detector: str = "scikit",
return_distance_transform: bool = False,
) -> NDArray[np.bool_] | tuple[NDArray[np.bool_], NDArray[np.float_]]:
"""medial axis transform of a mask
MAT calculated by finding pixels where the gradient of the mask's distance transform is discontinuous and thinning the resultant mask to 1 pixel width
percentile (0<percentile<100) controls sensitivity for discontinuity detection by defining a percentile threshold, above which values are kept
e.g. 100-90=10 -> top 10% left over
algorithm from https://doi.org/10.1016/j.apm.2011.05.001
only works for single component 2D masks"""
if mask2D.ndim != 2:
raise ValueError(
f"Mask must be a 2D array. Provided mask dimensionality is {mask2D.ndim}"
)
mask2D = mask2D.astype(bool)
mask2D = img_as_ubyte(mask2D)
distance_transform = cv2.distanceTransform(
mask2D, cv2.DIST_L2, cv2.DIST_MASK_PRECISE
)
if ridge_detector == "scikit":
ridge_response, _ = _detect_ridges(distance_transform * -1)
elif ridge_detector == "opencv":
ridge_response = RIDGE_DETECTOR.getRidgeFilteredImage(distance_transform * -1)
# get top_percentile of non-zero ridges
thresh_ridges = ridge_response > np.percentile(
ridge_response[ridge_response > 0], percentile
)
# thresh_ridges has a lot of noise, so extract the largest component to filter away noise
main_ridge = _largest_mask_component(thresh_ridges)
# thin to one pixel thickness
thinned = cv2.ximgproc.thinning(
img_as_ubyte(main_ridge),
thinningType=cv2.ximgproc.THINNING_GUOHALL,
)
if return_distance_transform:
return thinned, distance_transform
else:
return thinned
def _largest_mask_component(
mask: NDArray[np.bool_], connectivity: Literal[4, 8] = 8
) -> NDArray[np.bool_]:
"""return a new mask of the largest component in mask
connectivity 4: cross-shaped connectivity, i.e. no diagonals
connectivity 8: square-shaped connectivity, i.e. diagonals included"""
num_labels, labels = cv2.connectedComponents(
img_as_ubyte(mask), connectivity=connectivity
)
unique_labels, counts = np.unique(labels, return_counts=True)
largest_label = unique_labels[np.argmax(counts[1:]) + 1]
return labels == largest_label
def _extract_midline(
midline_image: NDArray[np.bool_],
distance_transform: NDArray[np.float_],
combined_coord_array: bool = False,
) -> list[NDArray[np.int_]] | list[tuple[NDArray[np.int_], NDArray[np.int_]]]:
"""
returns a topologically sorted pixel path of the midline
a midline is the longest pixel path of midline image the extended to the boundary
conceptually, it can be the longest axis of an oblong object
combined_coord_array:
False - np.nonzero style indices, a tuple of 2 (M,) arrays is returned for each midline; intended for indexing into skeleton array
True - np.argwhere style coords, a single (M, 2) array is returned for each midline
"""
mask = img_as_ubyte(distance_transform > 0)
if not np.any(mask):
raise ValueError("Skeleton is empty")
# start from 1 as not interested in background
radius = distance_transform.max()
midline = _main_midline_path(
midline_image, connectivity=2, pixel_trim=int(radius * 0.5)
)
midline = _extend_midline_to_boundary(midline, distance_transform)
if combined_coord_array:
midline = np.c_[midline]
return midline
def _main_midline_path(
skeleton_img: NDArray[np.bool_],
connectivity: Literal[1, 2] = 2,
pixel_trim: int = 0,
) -> NDArray[np.int_]:
"""topologically sorted main path points, trimmed on each end by pixel_trim
connectivity 1: + shaped neighbour connectivity; i.e. diagonals excluded
connectivity 2: square shaped neighbour connectivity; i.e. diagonals included
"""
graph, nodes = skimage.graph.pixel_graph(
skeleton_img.astype(bool), connectivity=connectivity
)
path = _longest_path(graph)
if pixel_trim > 0:
pixel_trim = int(pixel_trim)
if 2 * pixel_trim < len(path):
# want to avoid trimming the whole path
path = path[pixel_trim:-pixel_trim]
flat_indices = nodes[path]
shaped_indices = np.unravel_index(flat_indices, shape=skeleton_img.shape)
return shaped_indices
def _longest_path(graph: sparse.csr_matrix) -> NDArray[np.int_]:
"""returns node order of the longest bfs path in graph"""
if graph.getnnz() == 0:
return np.array([])
bfs_visit_order = csgraph.breadth_first_order(
graph, i_start=0, directed=False, return_predecessors=False
)
furthest_node1 = bfs_visit_order[-1]
bfs_visit_order, predecessors = csgraph.breadth_first_order(
graph, i_start=furthest_node1, directed=False, return_predecessors=True
)
furthest_node2 = bfs_visit_order[-1]
path = _bfs_path(predecessors, furthest_node2)
return path
def _bfs_path(predecessors: NDArray[np.int_], node_j: int) -> NDArray[np.int_]:
"""given bfs predecessors and node_j, find path connecting bfs tree's root->node_j"""
path = []
current_node = node_j
while current_node != -9999:
path.append(current_node)
current_node = predecessors[current_node]
# reverse as path is node_j -> root, but bfs would have been root -> node_j
path.reverse()
return np.array(path)
###############
# tip extension
###############
def _extend_midline_to_boundary(
midline: tuple(NDArray[np.int_], NDArray[np.int_]),
distance_transform: NDArray[np.float_],
) -> NDArray[np.int_]:
"""
adds two points to midline on the distance_transform background boundary
independently for each tip, achieved by projecting a diametric ray, finding where the ray intersects the boundary,
and finding midpoint of the path along boundary connecting the two intersections
the midpoints are then added to respective ends of midline
"""
border_radius = distance_transform.max()
# border consists of pixels that have distance 1 or sqrt(2) to the background
# border then effectively has connectivity 2 (or cv2 8)
border = np.logical_or(
np.isclose(distance_transform, 1), np.isclose(distance_transform, np.sqrt(2))
)
# cv2 connectivity 4 is same as skimage connectivity 1; connectivity of 1 pixel distance i.e cross pattern
# extract the largest component because border may have 'lonely' bits
border = _largest_mask_component(border, connectivity=4)
graph, nodes = skimage.graph.pixel_graph(border.astype(bool), connectivity=1)
# function with preapplied arguments
def tip_handler(tip, samples):
return _handle_tip(border, border_radius, graph, nodes, tip, samples)
midline = np.c_[midline]
# if calculate finite derivative from just two points at each tip, would get either vertical, horizontal, or 45deg derivative
# so calculate finite derivative from the last n points at each midline tip
# cap num_samples to midline size in case midline is too short
num_samples = np.min([border_radius, len(midline)])
num_samples = np.round(num_samples).astype(int)
# get deriv for numsamples from each end of midline
sample1, sample2 = midline[1:num_samples], midline[-num_samples:-1]
tip1, tip2 = midline[[0, -1]]
new_tip1, new_tip2 = tip_handler(tip1, sample1), tip_handler(tip2, sample2)
new_midline = np.vstack(
[points for points in (new_tip1, midline, new_tip2) if points is not None]
)
return new_midline
def _handle_tip(
border: NDArray[np.bool_],
border_radius: float,
border_graph: sparse.csr_matrix,
graph_nodes: list[int],
origin: tuple[int, int],
point_samples: NDArray[np.int_],
) -> tuple[int, int] | None:
"""calculates mean vector point_samples -> origin
projects a perpendicular diameter at origin
finds where diameter intersects border
graph seach for path connecting the two intersections
and returns midpoint of path"""
parallel_vector = _mean_unit_vector(origin, point_samples)
normal_vector = np.array([-1, 1]) * parallel_vector[::-1]
# extend radius so ensure it crosses the border. otherwise radius might barely miss the border
# left and right radii together form a diameter centered at origin
left_radius_points = _generate_ray_points(
origin, normal_vector, border_radius * 1.25, border.shape
)
right_radius_points = _generate_ray_points(
origin, -normal_vector, border_radius * 1.25, border.shape
)
# get the two points where the radii cross the border
# argmax because most of radius will evaluate 0, and when radius intersects border it will evaluate 1
# index into radius_points
point_index = np.argmax(border[left_radius_points], axis=0)
# row,col point
start_point = np.r_[
left_radius_points[0][point_index], left_radius_points[1][point_index]
]
# index into border.flatten(), which is what graph_nodes contains
start_point_flat_index = np.sum(start_point * border.strides // border.itemsize)
# there may not be any point on the boundary, e.g. if mask goes up to the edge of the image
try:
# index of point within graph_nodes gives index into graph itself
start_node = np.flatnonzero(graph_nodes == start_point_flat_index)[0]
except IndexError:
return None
# repeat above for second point
point_index = np.argmax(border[right_radius_points], axis=0)
end_point = np.r_[
right_radius_points[0][point_index],
right_radius_points[1][point_index],
]
end_point_flat_index = np.sum(end_point * border.strides // border.itemsize)
try:
end_node = np.flatnonzero(graph_nodes == end_point_flat_index)[0]
except IndexError:
return None
visit_order, predecessors = csgraph.breadth_first_order(
border_graph, start_node, directed=False, return_predecessors=True
)
path = _bfs_path(predecessors, end_node)
# if path traverses more than half the perimeter of border, new tip_point is not going to be good
if len(path) > len(np.flatnonzero(border)) / 2:
return None
tip_node = path[len(path) // 2]
tip_point = np.unravel_index(graph_nodes[tip_node], shape=border.shape)
return tip_point
def _generate_ray_points(
origin: tuple[int, int],
unit_vector: tuple[float, float],
ray_length: float,
clipping_shape: tuple[int, int] | None = None,
) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
"""return integer rows,cols crossed by ray starting at origin, in direction of unit_vector, of length ray_length
clipping_shape allows clipping of points that would be outside of grid with that shape
"""
start = np.round(origin).astype(int).tolist()
unit_vector = np.array(unit_vector)
end = origin + unit_vector * ray_length
end = np.round(end).astype(int).tolist()
points = skimage.draw.line(*start, *end)
if clipping_shape is not None:
rows, cols = points
n_rows, n_cols = clipping_shape
allowed_rows = (0 <= rows) & (rows < n_rows)
allowed_cols = (0 <= cols) & (cols < n_cols)
allowed_both = allowed_rows & allowed_cols
points = (rows[allowed_both], cols[allowed_both])
return points
def _mean_unit_vector(
main_point: tuple[int, int], point_samples: NDArray[np.int_]
) -> tuple[float, float]:
"""calculates mean unit vector of point_samples -> main_point"""
finite_derivs = point_samples - main_point
# normalise finite_derivs so avoid bias towards points far away from main_point
finite_derivs = finite_derivs / np.linalg.norm(finite_derivs, axis=1)[:, np.newaxis]
mean_deriv = np.mean(finite_derivs, axis=0)
# mean not necessarily unit length, so normalise again
mean_deriv = mean_deriv / np.linalg.norm(mean_deriv)
return mean_deriv
###############################
# spline and coord manipulation
###############################
# There are three sets of coordinates:
# raw_img xy-coordinates
# straightened_img x'y'-coordinates
# and intermediate uv spline coordinates
# spline uv coordinates are just x'y' coordinates where origin is is shifted down to be vertically centered, instead of being in the top-left corner
# basic idea of the transform is to convert uv coordinates to xy coordinated in raw img_as_ubyte
# u - how far to walk along spline (always +ve)
# v - how far to walk perpendicularly to spline. (+ve and -ve values correspond to which orthogonal to follow)
# so xy = spline(u) + v * orthog
def _from_spline_coords_to_raw(
spline: BSpline, spline_coords: NDArray[np.float_], mirror: bool = False
) -> NDArray[np.float_]:
"""convert from coordinates in spline domain to coordinates in spline range
spline_coords col0 is axis perpendicular to spline, col1 is axis parallel to spline
"""
# left or right orthogonal
if mirror:
# right orthogonal
orthog = np.array([-1, 1])
else:
# left orthogonal
orthog = np.array([1, -1])
# tuple e.g. if np.nonzero() is used on straightened 2D image
if isinstance(spline_coords, tuple):
widths, lengths = spline_coords
# otherwise will be numpy (N, 2) numpy array, e.g. if np.argwhere() is used on straightened 2D image
else:
widths, lengths = spline_coords.T
origins = spline(lengths)
# parallel vectors
derivs = spline.derivative(1)(lengths)
# orthogonal vectors
normals = np.roll(derivs, shift=1, axis=1) * orthog
# normalised orthogonal vectors
normals = normals / np.linalg.norm(normals, axis=1)[..., np.newaxis]
img_yxs = origins + normals * widths[..., np.newaxis]
return img_yxs
def _from_grid_coords_to_spline(
grid_coords: NDArray[np.int_],
length: float,
width: float,
scale_factor: float | tuple[float, float] = np.array([1, 1]),
) -> NDArray[np.float_]:
"""coordinates are scaled according to scale_factor, and translated such that spline's length axis is place in middle of grid axis0
scale_factor: how many units on grid equal one unit in spline domain"""
spline_coords = grid_coords / scale_factor
# spline widths are centred around 0, grid widths are centred around worm_width / 2
offset = np.array([width / 2, 0])
spline_coords = spline_coords - offset
return spline_coords
def _fit_spline(
points: NDArray[np.float_],
parametrisation: list[float],
smoothing_factor: float,
) -> BSpline:
"""fit a spline to the points and return a scipy BSpline object"""
# number of points needs to be greater than the degree of spline, which is 3
if len(points) < 4:
return None
ys, xs = points.T
try:
tck, u = splprep([ys, xs], u=parametrisation, s=smoothing_factor)
except Exception as e:
print(
"Something went wrong in spline fitting, returning None and resuming execution."
)
import traceback
print("Error traceback:")
traceback.print_exception(type(e), e, e.__traceback__)
return None
t, c, k = tck
# splprep returns "c" in wrong shape, unlike splrep
c = np.asarray(c).T
# splprep values guaranteed to be correct, so construct fast without checks
spline = BSpline.construct_fast(t, c, k)
return spline
def _warped_coords(
width: float,
length: float,
spline: BSpline,
scale_factor: float | tuple[float, float] = 1,
mirror: bool = False,
) -> NDArray[np.float_]:
"""given dimensions and a spline, creates a [2,M,N] array, where M,N are warped image dimensions, and axis0 YX float coordinates into original image
coordinates are float because values will be interpolated between integer coordinates in original image
"""
def coord_map(grid_indices: NDArray[np.int_]) -> NDArray[np.float_]:
# current implementation of skimage.transform.warp_coords has a bug where row,col are swapped relative to what documentation says
# so have to reverse axis=1 so that the internal code calls it's own dependencies correctly
# but have to reverse axis=0 afterwards because the internal code assumes row,col when reshaping, when actually it called dependencies with col,row
spline_coords = _from_grid_coords_to_spline(
grid_indices[:, ::-1], length, width, scale_factor
)
raw_img_coords = _from_spline_coords_to_raw(
spline, spline_coords, mirror=mirror
)
return raw_img_coords
warped_shape = _warped_shape(width, length, scale_factor)
# TODO: remove axis reversal when bug is fixed
warped_coords = skimage.transform.warp_coords(
coord_map=coord_map,
shape=warped_shape,
)
warped_coords = warped_coords[::-1, ...]
return warped_coords
def _refine_spline(
smooth_spline: BSpline,
extrapolation_range: float = 0.01,
) -> tuple[BSpline, float, list[float]]:
"""given a smooth spline, refine it such that 1unit in spline domain equals 1unit in spline range.
slightly extrapolates beyond ends of spline by extrapolation_range (as fraction of total length) to be sure to cover full real length.
does not apply further smoothing"""
if smooth_spline is None:
return None
# refines spline by fitting on parametrisation that respects inter-pixel distance
points = smooth_spline(
np.linspace(-extrapolation_range, 1 + extrapolation_range, 1000)
)
bb_dists = np.linalg.norm(points[:-1] - points[1:], axis=1)
parametrisation = np.r_[0, np.cumsum(bb_dists)]
spline_length = parametrisation[-1]
# don't smooth because points already smooth
refined_spline = _fit_spline(points, parametrisation, smoothing_factor=0)
return refined_spline, spline_length, parametrisation
def _align_splines(
img3D: NDArray,
worm_width: float,
worm_length: float,
splines: list[BSpline],
spline_parametrisations: list[list[float]],
) -> tuple[list[BSpline], float]:
"""uses splines to warp a 3D image to use for aligning splines along spline length.
i.e. all of unaligned spline domains all start at 0. each of aligned splines' domains will start wherever necessary so that resultant warped image planes are aligned with each other
Alignment along spline width is implied by accurate midline generation in original spline fitting
"""
# fits new splines with offsets on parametrisation such that they are aligned together
warped_img = _warp_3D_img(img3D, splines, worm_width, worm_length)
feature_img = np.stack(
[skimage.filters.farid_v(plane) for plane in warped_img], axis=0
)
feature_img = feature_img.mean(axis=1)
alignment_offsets = [
skimage.registration.phase_cross_correlation(
ref_img, mov_img, upsample_factor=20, normalization=None
)[0]
for ref_img, mov_img in zip(feature_img[:-1], feature_img[1:])
]
# alignments relative to 0th plane, which has offset of 0 to itself
alignment_offsets = np.concatenate([[0], *alignment_offsets])
cuml_aln_offsets = np.cumsum(alignment_offsets)
aligned_parametrisations = [
prm + offset for prm, offset in zip(spline_parametrisations, cuml_aln_offsets)
]
# change alignment so that no splines have negative length parametrisations
min_parametrisation = min(prm[0] for prm in aligned_parametrisations)
aligned_parametrisations = [
prm - min_parametrisation for prm in aligned_parametrisations
]
coords = [spline(prm) for spline, prm in zip(splines, spline_parametrisations)]
aligned_splines = [
_fit_spline(bb, aln_prm, smoothing_factor=0)
for bb, aln_prm in zip(coords, aligned_parametrisations)
]
# before, length was the length of the longest spline domain, but now want the length of the union of all spline domains
# since the union of domains starts at 0, the length of the union is simply the largest number reached by any spline
length = max(prm[-1] for prm in aligned_parametrisations)
return aligned_splines, length
def _handle_flips(midlines: list[NDArray | None]) -> list[NDArray | None]:
"""
Ensure consistent head/tail orientation across a list of midlines.
Iterates over the midlines in order, comparing each midline's endpoints to
the previous one. If the endpoints suggest a head/tail swap (i.e. the
flipped assignment has a smaller total Euclidean distance), the midline is
reversed before being added to the output.
Parameters:
midlines (list[np.ndarray or None]): Ordered list of midline coordinate
arrays (each of shape ``(N, 2)``). ``None`` entries for empty planes
are passed through unchanged.
Returns:
list[np.ndarray or None]: Midlines with consistent orientation.
"""
# initialise
for ml in midlines:
if ml is not None:
current_tips = ml[[0, -1]]
break
final_midlines = []
for ml in midlines:
if ml is None:
final_midlines.append(None)
continue
next_tips = ml[[0, -1]]
if _are_tips_flipped(current_tips, next_tips):
new_ml = np.flip(ml, axis=0)
else:
new_ml = ml
final_midlines.append(new_ml)
current_tips = new_ml[[0, -1]]
return final_midlines
def _are_tips_flipped(tip_pair1: NDArray[np.int_], tip_pair2: NDArray[np.int_]) -> bool:
"""Should tip assignment be flipped based on euclidian distance penalty"""
# pair_wise_dists[i, j] distance between tip_pair1[i] and tip_pair2[j]
pair_wise_dists = distance.cdist(tip_pair1, tip_pair2, metric="euclidean")
unflipped_dist = np.sum(pair_wise_dists[[0, 1], [0, 1]])
flipped_dist = np.sum(pair_wise_dists[[0, 1], [1, 0]])
return flipped_dist < unflipped_dist
###############
# image warping
###############
def _warp_2D_img(
img2D: NDArray,
spline: BSpline,
width: float,
length: float,
scale_factor: float | tuple[float, float],
mirror: bool = False,
interpolation_order: Literal[0, 1, 2, 3, 4, 5] = 1,
preserve_range: bool = False,
preserve_dtype: bool = False,
) -> NDArray:
"""
Warp a single 2D image plane using spline-based coordinate mapping.
Computes the inverse coordinate map via :func:`_warped_coords` and applies
``skimage.transform.warp``. Returns a zero-filled array when ``spline`` is
``None``.
Parameters:
img2D (NDArray): 2D input image plane.
spline (BSpline or None): Fitted spline defining the midline mapping.
width (float): Output image width (pixels across the midline).
length (float): Output image length (pixels along the midline).
scale_factor (float or tuple[float, float]): Scale factor(s) applied to
the output grid.
mirror (bool, optional): If ``True``, flip the midline direction.
(default: False)
interpolation_order (int, optional): Spline interpolation order for
``skimage.transform.warp`` (0–5). (default: 1)
preserve_range (bool, optional): Passed to ``skimage.transform.warp``.
(default: False)
preserve_dtype (bool, optional): If ``True``, cast the output back to
``img2D.dtype``. (default: False)
Returns:
NDArray: Warped 2D image of shape ``(width * scale, length * scale)``.
"""
if spline is None:
shape = _warped_shape(width, length, scale_factor)
return np.zeros(shape)
warped_coords = _warped_coords(width, length, spline, scale_factor, mirror)
warped_img = skimage.transform.warp(
img2D,
warped_coords,
order=interpolation_order,
preserve_range=preserve_range,
)
if preserve_dtype:
warped_img = warped_img.astype(img2D.dtype)
return warped_img
def _warp_3D_img(
img3D: NDArray,
splines: BSpline,
width: float,
length: float,
scale_factor: float | tuple[float, float] = 1,
mirror: bool = False,
interpolation_order: Literal[0, 1, 2, 3, 4, 5] = 1,
preserve_range: bool = False,
preserve_dtype: bool = False,
) -> NDArray:
"""
Apply :func:`_warp_2D_img` to each plane of a 3D image stack.
Parameters:
img3D (NDArray): 3D input image stack of shape ``(N, H, W)``.
splines (list[BSpline or None]): One spline per plane; must have the same
length as ``img3D``.
width (float): Output image width (pixels across the midline).
length (float): Output image length (pixels along the midline).
scale_factor (float or tuple[float, float], optional): Scale factor(s)
applied to the output grid. (default: 1)
mirror (bool, optional): If ``True``, flip the midline direction.
(default: False)
interpolation_order (int, optional): Spline interpolation order (0–5).
(default: 1)
preserve_range (bool, optional): Passed to ``skimage.transform.warp``.
(default: False)
preserve_dtype (bool, optional): If ``True``, cast each output plane back
to the input plane's dtype. (default: False)
Returns:
NDArray: Warped stack of shape ``(N, width * scale, length * scale)``.
"""
warped_img = [
_warp_2D_img(
plane,
spline,
width,
length,
scale_factor,
mirror,
interpolation_order,
preserve_range,
preserve_dtype,
)
for plane, spline in zip(img3D, splines)
]
warped_img = np.stack(warped_img, axis=0)
return warped_img
def _warped_shape(
width: float,
length: float,
scale_factor: float | tuple[float, float] = 1,
) -> tuple[int, int]:
"""2D grid shape that will contain the full object, scaled by scale_factor
currently, no way to add margins has been implemented"""
shape = np.array([width, length]) * scale_factor
shape = np.ceil(shape).astype(int)
return shape