Multi scale wavelet thresholding
- class byotrack.implementation.detector.wavelet.B3SplineUWT(level=3, dim=2, return_all=True)
Bases:
ModuleUndecimated Wavelet Transform with B3Spline for nD images.
Also called A trous Wavelet Transform.
Let J be the level of the UWD. Let c_0 be the original image. It computes successively dilated convolution of it.
c_j = h_j * c_{j-1}, 1 le j le J w_j = c_{j-1} - c_{j}, 1 le j le J
Where h_j is the filter at scale j which is the original filter h_1 dilated with 2^{j - 1} 0 between each coefficient and h_1 = [1/16, 1/4, 3/8, 1/4, 1/16].
The parameters returned are [w_1, w_2, …, w_J, c_J], or only [w_J] if return_all is False The original image is easily reconstructed from the parameters with c_0 = c_J + sum_j w_j
Boundaries issues are handled using mirror padding.
- forward(x: Tensor) Tensor
Compute the parameters of the UWD
- Parameters:
x (torch.Tensor) – Input multidimensional images without channel dimension Shape: (B, [D, ]H, W)
- Returns:
- Parameters [w_1, …, w_J, c_J] or only w_J
Shape: (B, J + 1, [D, ]H, W) or (B, [D, ]H, W) if return_all is False
- Return type:
torch.Tensor
- class byotrack.implementation.detector.wavelet.B3SplineUWTApprox1(level=3, dim=2, return_all=True)
Bases:
ModuleApproximation of Undecimated Wavelet Transform with B3Spline for nD images.
Split the nD convolution in n 1D convolution along each axis (as done in the original paper). Though it reduces the FLOPs it can lead to slower runtime (because of some torch optims).
Depending on the gpu, the cudnn/cuda kernels, the pytorch version (and so on), it can sometimes be much faster than the nD convolution counterpart (especially in 3D).
In this implementation, we rely on torch.nn.Conv1D, and is able to handle any number of dimension n.
- forward(x: Tensor) Tensor
Compute the parameters of the UWD
- Parameters:
x (torch.Tensor) – Input images without channel dimension Shape: (B, [D, ]H, W)
- Returns:
- Parameters [w_1, …, w_J, c_J] or only [w_J]
Shape: (B, J + 1, [D, ]H, W) or (B, [D, ]H, W) if return_all is False
- Return type:
torch.Tensor
- class byotrack.implementation.detector.wavelet.B3SplineUWTApprox2(level=3, dim=2, return_all=True)
Bases:
ModuleApproximation of Undecimated Wavelet Transform with B3Spline for nD images.
Split the nD convolution in n 1D convolution along each axis (as done in the original paper). Though it reduces the FLOPs it can lead to slower runtime (because of some torch optims).
Depending on the gpu, the cudnn/cuda kernels, the pytorch version (and so on), it can sometimes be much faster than the nD convolution counterpart (especially in 3D).
In this implementation, we rely on torch.nn.ConvnD with a reduced kernel.
- forward(x: Tensor) Tensor
Compute the parameters of the UWD
- Parameters:
x (torch.Tensor) – Input images without channel dimension Shape: (B, [D, ]H, W)
- Returns:
- Parameters [w_1, …, w_J, c_J] or only [w_J]
Shape: (B, J + 1, [D, ]H, W) or (B, [D, ]H, W) if return_all is False
- Return type:
torch.Tensor
- byotrack.implementation.detector.wavelet.filter_small_objects(segmentation: ndarray, min_area: float) None
Filter small instances from the segmentation in place
- Parameters:
segmentation (np.ndarray) – Segmentation mask to filtered inplace Shape ([D, ]H, W), dtype: integer
min_area (float) – Minimum number of pixels to be kept in the segmentation.
- class byotrack.implementation.detector.wavelet.WaveletDetector(scale=2, k=3.0, min_area=10.0, device: device | None = None, **kwargs)
Bases:
BatchDetectorDetection of bright spots using B3SplineUWT
Following paper from Olivo-Marin, J.C. Extraction of spots in biological images using multiscale products. Pattern Recognit. 35, 1989-1996
It supports 2D and 3D videos.
The algorithm is in 4 steps:
UWT decomposition
Scale selection
Noise filtering
Connected components extraction
The multi scales behavior (choosing multiple scales) was implemented but we decided to drop it. It adds complexity without real advantages from our experience.
The same algorithm is implemented in Icy Software (SpotDetector). The main differences are:
nd wavelets (rather than n times one dimensional wavelets). It was designed to improve computations, but with torch no gain in time is observed in 2D. This can be switched either by calling optimize that will try to find the fastest option for your case, or manually by modifying the b3swt parameter.
Thresholding -> We follow the original paper using k times the std
- scale
Scale of the wavelet coefficients used. With small scales, the detector focus on smaller objects.
- Type:
- k
Noise threshold. Following the paper, the wavelet coefficients are filtered if coef le k sigma. (The higher the less spots you retrieve)
- Type:
- device
Device on which run the B3SplineUWT Default to cpu
- Type:
torch.device
- b3swt
Undecimated wavelet transform
- Type:
- \*\*kwargs
Additional arguments for BatchDetector (batch_size, add_true_frames)
- detect(batch: ndarray) List[Detections]
Apply the detection on a batch of frames
By default, the frame ids are set from 0 to n-1 with n the size of the batch. The aggregattion of batches and frame ids correction is automatically handled when called the run method.
- Parameters:
batch (np.ndarray) – Batch of video frames Shape: (B, [D, ]H, W, C)
- Returns:
Detections for each given frame
- Return type:
Sequence[byotrack.Detections]
- compute_threshold(coefficients: Tensor) Tensor
Compute threshold for the UWT coefficients
Note: One could use MAD approx of sigma rather than std (or Icy implem) but it’s almost equivalent (And k is a parameter to tune so it truly is)
- Parameters:
coefficients (torch.Tensor) – Coefficients of the UWT Shape: (B, [D, ]H, W)
- Returns:
- Threshold for each scale
Shape: (B, 1, …, 1)
- Return type:
torch.Tensor
- optimize(frames: ~numpy.ndarray, repeat=5, warm_up=True, splines=(<class 'byotrack.implementation.detector.wavelet.B3SplineUWT'>, <class 'byotrack.implementation.detector.wavelet.B3SplineUWTApprox1'>, <class 'byotrack.implementation.detector.wavelet.B3SplineUWTApprox2'>)) WaveletDetector
Find the fastest configuration for the model on the given frames
This is mainly designed for 3D videos, where 3D convolutions are heavy and poorly optimized with a dilation > 1 (for scales > 0). In particular, we observed cudnn kernels running 10 times faster on some gpus when determistic was set to False for 3D conv.
epending on your hardware, kernels and pytorch version, one solution may be better than the other. This allows you to test most of the configuration and use the fastest one.
Warning
With large 3D images, B3SplineUWTApprox1 may not work (it converts the spatial axis into batch axis, but conv1D do not support very large batch size. If you are in this case, you may disable it by changing the splines argument to (B3SplineUWT, B3SplineUWTApprox2).
- Parameters:
frames (np.ndarray) – Frames of the video on which to test. Shape: (B, [D, ]H, W, C), dtype: float
repeat (int) – Number of time to repeat the computation to measure the timings. Default: 5
warm_up (bool) – Warm up each model before measuring time. Default: True
splines (tuple) – Implementation of B3SplineUWT to test. Reduce this list if one of the implementation is too long to run for instance. Default: (B3SplineUWT, B3SplineUWTApprox1, B3SplineUWTApprox2)
- Returns:
self, with the best found b3swt. It may also modify pytorch backend for convolution.
- Return type: