"""Main beamforming functions for ultrasound imaging."""
import keras
import numpy as np
from keras import ops
from zea.beamform.lens_correction import compute_lens_corrected_travel_times
from zea.func.tensor import vmap
[docs]
def fnum_window_fn_rect(normalized_angle):
"""Rectangular window function for f-number masking."""
return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
[docs]
def fnum_window_fn_hann(normalized_angle):
"""Hann window function for f-number masking."""
# Use a Hann window function to smoothly transition the mask
return ops.where(
normalized_angle <= 1.0,
0.5 * (1 + ops.cos(np.pi * normalized_angle)),
0.0,
)
[docs]
def fnum_window_fn_tukey(normalized_angle, alpha=0.5):
"""Tukey window function for f-number masking.
Args:
normalized_angle (ops.Tensor): Normalized angle values in the range [0, 1].
alpha (float, optional): The alpha parameter for the Tukey window. 0.0 corresponds to a
rectangular window, 1.0 corresponds to a Hann window. Defaults to 0.5.
"""
# Use a Tukey window function to smoothly transition the mask
normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0)
beta = 1.0 - alpha
return ops.where(
normalized_angle < beta,
1.0,
ops.where(
normalized_angle < 1.0,
0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))),
0.0,
),
)
[docs]
def tof_correction(
data,
flatgrid,
t0_delays,
tx_apodizations,
sound_speed,
probe_geometry,
initial_times,
sampling_frequency,
demodulation_frequency,
f_number,
polar_angles,
focus_distances,
t_peak,
tx_waveform_indices,
transmit_origins,
apply_lens_correction=False,
lens_thickness=1e-3,
lens_sound_speed=1000,
fnum_window_fn=fnum_window_fn_rect,
):
"""Time-of-flight correction for a flat grid.
Args:
data (ops.Tensor): Input RF/IQ data of shape `(n_tx, n_ax, n_el, n_ch)`.
flatgrid (ops.Tensor): Pixel locations x, y, z of shape `(n_pix, 3)`
t0_delays (ops.Tensor): Times at which the elements fire shifted such
that the first element fires at t=0 of shape `(n_tx, n_el)`
tx_apodizations (ops.Tensor): Transmit apodizations of shape `(n_tx, n_el)`
sound_speed (float): Speed-of-sound.
probe_geometry (ops.Tensor): Element positions x, y, z of shape (n_el, 3)
initial_times (Tensor): The probe transmit time offsets of shape `(n_tx,)`.
sampling_frequency (float): Sampling frequency.
demodulation_frequency (float): Demodulation frequency.
f_number (float): Focus number (ratio of focal depth to aperture size).
polar_angles (ops.Tensor): The angles of the waves in radians of shape `(n_tx,)`
focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
t_peak (ops.Tensor): Time of the peak of the pulse in seconds.
Shape `(n_waveforms,)`.
tx_waveform_indices (ops.Tensor): The indices of the waveform used for each
transmit of shape `(n_tx,)`.
transmit_origins (ops.Tensor): Transmit origins of shape (n_tx, 3).
apply_lens_correction (bool, optional): Whether to apply lens correction to
time-of-flights. This makes it slower, but more accurate in the near-field.
Defaults to False.
lens_thickness (float, optional): Thickness of the lens in meters. Used for
lens correction. Defaults to 1e-3.
lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
for lens correction Defaults to 1000.
fnum_window_fn (callable, optional): F-number function to define the transition from
straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
(fn(1.0)). The function should be zero for fn(x>1.0).
Returns:
(ops.Tensor): time-of-flight corrected data
with shape: `(n_tx, n_pix, n_el, n_ch)`.
"""
assert len(data.shape) == 4, (
"The input data should have 4 dimensions, "
f"namely n_tx, n_ax, n_el, n_ch, got {len(data.shape)} dimensions: {data.shape}"
)
n_tx, n_ax, n_el, _ = ops.shape(data)
# Calculate delays
# --------------------------------------------------------------------
# txdel: The delay from t=0 to the wavefront reaching the pixel
# txdel has shape (n_tx, n_pix)
#
# rxdel: The delay from the wavefront reaching the pixel to the scattered wave
# reaching the transducer element.
# rxdel has shape (n_el, n_pix)
# --------------------------------------------------------------------
txdel, rxdel = calculate_delays(
flatgrid,
t0_delays,
tx_apodizations,
probe_geometry,
initial_times,
sampling_frequency,
sound_speed,
n_tx,
n_el,
focus_distances,
polar_angles,
t_peak,
tx_waveform_indices,
transmit_origins,
apply_lens_correction,
lens_thickness,
lens_sound_speed,
)
n_pix = ops.shape(flatgrid)[0]
mask = ops.cond(
f_number == 0,
lambda: ops.ones((n_pix, n_el, 1)),
lambda: fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn=fnum_window_fn),
)
def _apply_delays(data_tx, txdel):
"""Applies the delays to TOF correct a single transmit.
Args:
data_tx (ops.Tensor): The RF/IQ data for a single transmit of shape
`(n_ax, n_el, n_ch)`.
txdel (ops.Tensor): The transmit delays for a single transmit in samples
(not in seconds) of shape `(n_pix, 1)`.
Returns:
ops.Tensor: The time-of-flight corrected data of shape
`(n_pix, n_el, n_ch)`.
"""
# data_tx is of shape (num_elements, num_samples, 1 or 2)
# Take receive delays and add the transmit delays for this transmit
# The txdel tensor has one fewer dimensions because the transmit
# delays are the same for all dimensions
# delays is of shape (n_pix, n_el)
delays = rxdel + txdel
# Compute the time-of-flight corrected samples for each element
# from each pixel of shape (n_pix, n_el, n_ch)
tof_tx = apply_delays(data_tx, delays, clip_min=0, clip_max=n_ax - 1)
# Apply the mask
tof_tx = tof_tx * mask
# Apply phase rotation if using IQ data
# This is needed because interpolating the IQ data without phase rotation
# is not equivalent to interpolating the RF data and then IQ demodulating
# See the docstring from complex_rotate for more details
apply_phase_rotation = data_tx.shape[-1] == 2
if apply_phase_rotation:
total_delay_seconds = delays[:, :] / sampling_frequency
theta = 2 * np.pi * demodulation_frequency * total_delay_seconds
tof_tx = complex_rotate(tof_tx, theta)
return tof_tx
# Reshape to (n_tx, n_pix, 1)
txdel = ops.moveaxis(txdel, 1, 0)
txdel = txdel[..., None]
return vmap(_apply_delays)(data, txdel)
[docs]
def calculate_delays(
grid,
t0_delays,
tx_apodizations,
probe_geometry,
initial_times,
sampling_frequency,
sound_speed,
n_tx,
n_el,
focus_distances,
polar_angles,
t_peak,
tx_waveform_indices,
transmit_origins,
apply_lens_correction=False,
lens_thickness=None,
lens_sound_speed=None,
n_iter=2,
):
"""Calculates the delays in samples to every pixel in the grid.
The delay consists of two components: The transmit delay and the
receive delay.
The transmit delay is the delay between transmission and the
wavefront reaching the pixel.
The receive delay is the delay between the
wavefront reaching a pixel and the reflections returning to a specific
element.
Args:
grid (Tensor): The pixel coordinates to beamform to of shape `(n_pix, 3)`.
t0_delays (Tensor): The transmit delays in seconds of shape
`(n_tx, n_el)`, shifted such that the smallest delay is 0. Defaults to None.
tx_apodizations (Tensor): The transmit apodizations of shape
`(n_tx, n_el)`.
probe_geometry (Tensor): The positions of the transducer elements of shape
`(n_el, 3)`.
initial_times (Tensor): The probe transmit time offsets of shape
`(n_tx,)`.
sampling_frequency (float): The sampling frequency of the probe in Hz.
sound_speed (float): The assumed speed of sound in m/s.
focus_distances (Tensor): The focus distances of shape `(n_tx,)`.
If the focus distance is set to infinity, the beamformer will
assume plane wave transmission.
polar_angles (Tensor): The polar angles of the plane waves in radians
of shape `(n_tx,)`.
t_peak (Tensor): Time of the peak of the pulse in seconds of shape
`(n_waveforms,)`.
tx_waveform_indices (Tensor): The indices of the waveform used for each
transmit of shape `(n_tx,)`.
transmit_origins (Tensor): Transmit origins of shape (n_tx, 3).
apply_lens_correction (bool, optional): Whether to apply lens correction to
time-of-flights. This makes it slower, but more accurate in the near-field.
Defaults to False.
lens_thickness (float, optional): Thickness of the lens in meters. Used for
lens correction.
lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
for lens correction.
n_iter (int, optional): Number of iterations for the Newton-Raphson method
used in lens correction. Defaults to 2.
Returns:
transmit_delays (Tensor): The tensor of transmit delays to every pixel
in samples (not in seconds), of shape `(n_pix, n_tx)`.
receive_delays (Tensor): The tensor of receive delays from every pixel
back to the transducer element in samples (not in seconds), of shape
`(n_pix, n_el)`.
"""
# Validate input shapes
for arr in [t0_delays, grid, tx_apodizations, probe_geometry]:
assert arr.ndim == 2
assert probe_geometry.shape[0] == n_el
assert t0_delays.shape[0] == n_tx
if not apply_lens_correction:
# Compute receive distances in meters of shape (n_pix, n_el)
rx_distances = distance_Rx(grid, probe_geometry)
# Convert distances to delays in seconds
rx_delays = rx_distances / sound_speed
else:
# Compute lens-corrected travel times from each element to each pixel
assert lens_thickness is not None, "lens_thickness must be provided for lens correction."
assert lens_sound_speed is not None, (
"lens_sound_speed must be provided for lens correction."
)
rx_delays = compute_lens_corrected_travel_times(
probe_geometry,
grid,
lens_thickness,
lens_sound_speed,
sound_speed,
n_iter=n_iter,
)
# Compute transmit delays
tx_delays = vmap(transmit_delays, in_axes=(None, 0, 0, None, 0, 0, 0, None, 0), out_axes=1)(
grid,
t0_delays,
tx_apodizations,
rx_delays,
focus_distances,
polar_angles,
initial_times,
None,
transmit_origins,
)
# Add the offset to the transmit peak time
tx_delays += ops.take(t_peak, tx_waveform_indices)[None]
# TODO: nan to num needed?
# tx_delays = ops.nan_to_num(tx_delays, nan=0.0, posinf=0.0, neginf=0.0)
# rx_delays = ops.nan_to_num(rx_delays, nan=0.0, posinf=0.0, neginf=0.0)
# Convert from seconds to samples
tx_delays *= sampling_frequency
rx_delays *= sampling_frequency
return tx_delays, rx_delays
[docs]
def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1):
"""Applies time delays for a single transmit using linear interpolation.
Most delays in d will not be by an integer number of samples, which means
we have no measurement for that time instant. This function solves this by
finding the sample before and after and interpolating the data to the
desired delays in d using linear interpolation.
Args:
data (ops.Tensor): The RF or IQ data of shape `(n_ax, n_el, n_ch)`. This is
the data we are drawing samples from to for each element for each pixel.
delays (ops.Tensor): The delays in samples of shape `(n_pix, n_el)`. Contains
one delay value for every pixel in the image for every transducer element.
clip_min (int, optional): The minimum delay value to use. If set to -1 no
clipping is applied. Defaults to -1.
clip_max (int, optional): The maximum delay value to use. If set to -1 no
clipping is applied. Defaults to -1.
Returns:
ops.Tensor: The samples received by each transducer element corresponding to the
reflections of each pixel in the image of shape `(n_pix, n_el, n_ch)`.
"""
# Add a dummy channel dimension to the delays tensor to ensure it has the
# same number of dimensions as the data. The new shape is (n_pix, n_el, 1)
delays = delays[..., None]
# Get the integer values above and below the exact delay values
# Floor to get the integers below
# (num_elements, num_pixels, 1)
d0 = ops.floor(delays)
# Cast to integer to be able to use as indices
d0 = ops.cast(d0, "int32")
# Add 1 to find the integers above the exact delay values
d1 = d0 + 1
# Apply clipping of delays clipping to ensure correct behavior on cpu
if clip_min != -1 and clip_max != -1:
clip_min = ops.cast(clip_min, d0.dtype)
clip_max = ops.cast(clip_max, d0.dtype)
d0 = ops.clip(d0, clip_min, clip_max)
d1 = ops.clip(d1, clip_min, clip_max)
if data.shape[-1] == 2:
d0 = ops.concatenate([d0, d0], axis=-1)
d1 = ops.concatenate([d1, d1], axis=-1)
# Gather pixel values
# Here we extract for each transducer element the sample containing the
# reflection from each pixel. These are of shape `(n_pix, n_el, n_ch)`.
data0 = ops.take_along_axis(data, d0, 0)
data1 = ops.take_along_axis(data, d1, 0)
# Compute interpolated pixel value
d0 = ops.cast(d0, delays.dtype) # Cast to float
d1 = ops.cast(d1, delays.dtype) # Cast to float
data0 = ops.cast(data0, delays.dtype) # Cast to float
data1 = ops.cast(data1, delays.dtype) # Cast to float
reflection_samples = (d1 - delays) * data0 + (delays - d0) * data1
return reflection_samples
[docs]
def complex_rotate(iq, theta):
"""Performs a simple phase rotation of I and Q component.
Args:
iq (ops.Tensor): The iq data of shape `(..., 2)`.
theta (float): The complex angle to rotate by.
Returns:
Tensor: The rotated tensor of shape `(..., 2)`.
.. dropdown:: Explanation
The IQ data is related to the RF data as follows:
.. math::
x(t) &= I(t)\\cos(\\omega_c t) + Q(t)\\cos(\\omega_c t + \\pi/2)\\\\
&= I(t)\\cos(\\omega_c t) - Q(t)\\sin(\\omega_c t)
If we want to delay the RF data `x(t)` by `Δt` we can substitute in
:math:`t=t+\\Delta t`. We also define :math:`I'(t) = I(t + \\Delta t)`,
:math:`Q'(t) = Q(t + \\Delta t)`, and :math:`\\theta=\\omega_c\\Delta t`.
This gives us:
.. math::
x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
- Q'(t) \\sin(\\omega_c (t + \\Delta t))\\\\
&= \\overbrace{(I'(t)\\cos(\\theta)
- Q'(t)\\sin(\\theta) )}^{I_\\Delta(t)} \\cos(\\omega_c t)\\\\
&- \\overbrace{(Q'(t)\\cos(\\theta)
+ I'(t)\\sin(\\theta))}^{Q_\\Delta(t)} \\sin(\\omega_c t)
This means that to correctly interpolate the IQ data to the new components
:math:`I_\\Delta(t)` and :math:`Q_\\Delta(t)`, it is not sufficient to just
interpolate the I- and Q-channels independently. We also need to rotate the
I- and Q-channels by the angle :math:`\\theta`. This function performs this
rotation.
"""
assert iq.shape[-1] == 2, (
"The last dimension of the input tensor should be 2, "
f"got {iq.shape[-1]} dimensions and shape {iq.shape}."
)
# Select i and q channels
i = iq[..., 0]
q = iq[..., 1]
# Compute rotated components
ir = i * ops.cos(theta) - q * ops.sin(theta)
qr = q * ops.cos(theta) + i * ops.sin(theta)
# Reintroduce channel dimension
ir = ir[..., None]
qr = qr[..., None]
return ops.concatenate([ir, qr], -1)
[docs]
def distance_Rx(grid, probe_geometry):
"""Computes distance to user-defined pixels from elements.
Expects all inputs to be arrays specified in SI units.
Args:
grid (ops.Tensor): Pixel positions in x,y,z of shape `(n_pix, 3)`.
probe_geometry (ops.Tensor): Element positions in x,y,z of shape `(n_el, 3)`.
Returns:
dist (ops.Tensor): Distance from each pixel to each element of shape
`(n_pix, n_el)`.
"""
# Get norm of distance vector between elements and pixels via broadcasting
dist = ops.linalg.norm(grid[:, None, :] - probe_geometry[None, :, :], axis=-1)
return dist
[docs]
def transmit_delays(
grid,
t0_delays,
tx_apodization,
rx_delays,
focus_distance,
polar_angle,
initial_time,
azimuth_angle=None,
transmit_origin=None,
):
"""
Computes the transmit delay from transmission to each pixel in the grid.
Uses the first-arrival time for pixels before the focus (or virtual source)
and the last-arrival time for pixels beyond the focus.
The receive delays can be precomputed since they do not depend on the
transmit parameters.
Args:
grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape `(n_pix, 3)`
t0_delays (Tensor): The transmit delays in seconds of shape (n_el,).
tx_apodization (Tensor): The transmit apodization of shape (n_el,).
rx_delays (Tensor): The travel times in seconds from elements to pixels
of shape (n_pix, n_el).
focus_distance (float): The focus distance in meters.
polar_angle (float): The polar angle in radians.
initial_time (float): The initial time for this transmit in seconds.
azimuth_angle (float, optional): The azimuth angle in radians. Defaults to 0.0.
transmit_origin (ops.Tensor, optional): The origin of the transmit beam of shape (3,).
If None, defaults to (0, 0, 0). Defaults to None.
Returns:
Tensor: The transmit delays of shape `(n_pix,)`.
"""
# Add a large offset for elements that are not used in the transmit to
# disqualify them from being the closest element
offset = ops.where(tx_apodization == 0, np.inf, 0.0)
# Compute total travel time from t=0 to each pixel via each element
# rx_delays has shape (n_pix, n_el)
# t0_delays has shape (n_el,)
total_times = rx_delays + t0_delays[None, :]
if azimuth_angle is None:
azimuth_angle = ops.zeros_like(polar_angle)
# Set origin to (0, 0, 0) if not provided
if transmit_origin is None:
transmit_origin = ops.zeros(3, dtype=grid.dtype)
# Compute the 3D position of the focal point
# The beam direction vector
beam_direction = ops.stack(
[
ops.sin(polar_angle) * ops.cos(azimuth_angle),
ops.sin(polar_angle) * ops.sin(azimuth_angle),
ops.cos(polar_angle),
]
)
# Handle plane wave case where focus_distance is set to zero
# We use np.inf to consider the first wavefront arrival for all pixels
focus_distance = ops.where(focus_distance == 0.0, np.inf, focus_distance)
# Compute focal point position: origin + focus_distance * beam_direction
# For negative focus_distance (diverging/virtual source), this is behind the origin
focal_point = transmit_origin + focus_distance * beam_direction # shape (3,)
# Deal with plane wave case where focus_distance is infinite and beam_direction is zero
# (np.inf * 0.0 -> nan) so we convert nan to zero
focal_point = ops.where(ops.isnan(focal_point), 0.0, focal_point)
# Compute the position of each pixel relative to the focal point
pixel_relative_to_focus = grid - focal_point[None, :] # shape (n_pix, 3)
# Project onto the beam direction to determine if pixel is before or after focus
# Positive projection means pixel is in the direction of beam propagation (beyond focus)
# Negative projection means pixel is behind the focus (before focus)
projection_along_beam = ops.sum(
pixel_relative_to_focus * beam_direction[None, :], axis=-1
) # shape (n_pix,)
# For focused waves (positive focus_distance):
# - Use min time for pixels before focus (projection < 0)
# - Use max time for pixels beyond focus (projection > 0)
# For diverging waves (negative focus_distance, virtual source):
# - The sign of focus_distance flips the logic
# - Use min time for pixels between transducer and virtual source
# - Use max time for pixels beyond transducer
is_before_focus = ops.cast(ops.sign(focus_distance), "float32") * projection_along_beam < 0.0
# Compute the effective time of the pixels to the wavefront by computing the
# smallest time over all elements (first wavefront arrival) for pixels before
# the focus, and the largest time (last wavefront contribution) for pixels
# beyond the focus.
tx_delay = ops.where(
is_before_focus,
ops.min(total_times + offset[None, :], axis=-1),
ops.max(total_times - offset[None, :], axis=-1),
)
# Subtract the initial time offset for this transmit
tx_delay = tx_delay - initial_time
return tx_delay
[docs]
def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
"""Apodization mask for the receive beamformer.
Computes a mask to disregard pixels outside of the vision cone of a
transducer element. Transducer elements can only accurately measure
signals within some range of incidence angles. Waves coming in from the
side do not register correctly leading to a worse image.
Args:
flatgrid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
probe_geometry (ops.Tensor): The transducer element positions of shape
`(n_el, 3)`.
f_number (int): The receive f-number. Set to zero to not use masking and
return 1. (The f-number is the ratio between distance from the transducer
and the size of the aperture below which transducer elements contribute to
the signal for a pixel.).
fnum_window_fn (callable): F-number function to define the transition from
straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
(fn(1.0)). The function should be zero for fn(x>1.0).
Returns:
Tensor: Mask of shape `(n_pix, n_el, 1)`
"""
grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None]
grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1)
grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6)
alpha = ops.arccos(grid_relative_to_probe_z)
# The f-number is f_number = z/aperture = 1/(2 * tan(alpha))
# Rearranging gives us alpha = arctan(1/(2 * f_number))
# We can use this to compute the maximum angle alpha that is allowed
max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
normalized_angle = alpha / max_alpha
mask = fnum_window_fn(normalized_angle)
# Add dummy channel dimension
mask = mask[..., None]
return mask