Multi scale wavelet thresholding

Wavelet thresholding detector
class byotrack.implementation.detector.wavelet.B3SplineUWT(level: int = 3)

Bases: Module

Undecimated Wavelet Transform with B3Spline for 2D images

Also called A trous Wavelet Transform.

Let J be the level of the UWD. Let c_0 be the original image. It computes successively dilated convolution of it.

c_j = h_j * c_{j-1}, 1 le j le J w_j = c_{j-1} - c_{j}, 1 le j le J

Where h_j is the filter at scale j which is the original filter h_1 dilated with 2^{j - 1} 0 between each coefficient and h_1 = [1/16, 1/4, 3/8, 1/4, 1/16].

The parameters returned are [w_1, w_2, …, w_J, c_J]. The original image is easily reconstructed from the parameters with c_0 = c_J + sum_j w_j

Boundaries issues are handled using mirror padding.

forward(x: Tensor) Tensor

Compute the parameters of the UWD

Parameters:

x (torch.Tensor) – Input images without channel dimension Shape: (B, H, W)

Returns:

Parameters [w_1, …, w_J, c_J]

Shape: (B, J + 1, H, W)

Return type:

torch.Tensor

class byotrack.implementation.detector.wavelet.B3SplineUWTApprox(level=3)

Bases: Module

Approximation of Undecimated Wavelet Transform with B3Spline for 2D images

Split the 2D convolution in two 1D convolution first alongside the rows then the columns.

After some analysis this implementation is slower than the first one… even though it should be faster.

forward(x: Tensor) Tensor

Compute the parameters of the UWD

XXX: Batchsize != 1 is not supported yet

Parameters:

x (torch.Tensor) – Input images without channel dimension Shape: B x H x W

Returns:

Parameters [w_1, …, w_J, c_J]

Shape: B x J + 1 x H x W

Return type:

torch.Tensor

class byotrack.implementation.detector.wavelet.WaveletDetector(scale=2, k=3.0, min_area=10.0, device: device | None = None, **kwargs)

Bases: BatchDetector

Detection of bright spots using B3SplineUWT

Following paper from Olivo-Marin, J.C. Extraction of spots in biological images using multiscale products. Pattern Recognit. 35, 1989-1996

The algorithm is in 4 steps:

  1. UWT decomposition

  2. Scale selection

  3. Noise filtering

  4. Connected components extraction

The multi scales behavior (choosing multiple scales) was implemented but we decided to drop it. It adds complexity without real advantages from our experience.

The same algorithm is implemented in Icy Software (SpotDetector). The main differences are:

  • 2d wavelets (rather than 2 times one dimensional wavelets). It was designed to improve computations, but with torch no gain in time is observed. (Can be switch with FOLLOW_ICY)

  • Thresholding -> We follow the original paper.

scale

Scale of the wavelet coefficients used. With small scales, the detector focus on smaller objects.

Type:

int

k

Noise threshold. Following the paper, the wavelet coefficients are filtered if coef le k sigma. (The higher the less spots you retrieve)

Type:

float

min_area

Filter resulting spots that are too small (less than min_area pixels)

Type:

float

device

Device on which run the B3SplineUWT Default to cpu

Type:

torch.device

b3swt

Undecimated wavelet transform

Type:

B3SplineUWT

\*\*kwargs

Additional arguments for BatchDetector (batch_size, add_true_frames)

Warning: The connected components used (opencv) yields segfault with too many components…

detect(batch: ndarray) List[Detections]

Apply the detection on a batch of frames

By default, the frame ids are set from 0 to n-1 with n the size of the batch. The aggregattion of batches and frame ids correction is automatically handled when called the run method.

Parameters:

batch (np.ndarray) – Batch of video frames Shape: (B, H, W, C)

Returns:

Detections for each given frame

Return type:

Sequence[byotrack.Detections]

compute_threshold(coefficients: Tensor) Tensor

Compute threshold for the UWT coefficients

Note: One could use MAD approx of sigma rather than std (or Icy implem) but it’s almost equivalent (And k is a parameter to tune so it truly is)

Parameters:

coefficients (torch.Tensor) – Coefficients of the UWT Shape: (…, H, W)

Returns:

Threshold for each scale

Shape: (…, 1, 1)

Return type:

torch.Tensor