"""This module contains image manipulation tools for the specsanalyzer package"""
from __future__ import annotations
from typing import Sequence
import numpy as np
import xarray as xr
[docs]
def gauss2d(
x: float | np.ndarray,
y: float | np.ndarray,
mx: float,
my: float,
sx: float,
sy: float,
) -> float | np.ndarray:
"""Function to calculate a 2-dimensional Gaussian peak function without correlation, and
amplitude 1.
Args:
x (float | np.ndarray): independent x-variable
y (float | np.ndarray): independent y-variable
mx (float): x-center of the 2D Gaussian
my (float): y-center of the 2D Gaussian
sx (float): Sigma in y direction
sy (float): Sigma in x direction
Returns:
float | np.ndarray: peak intensity at the given (x, y) coordinates.
"""
return np.exp(
-((x - mx) ** 2.0 / (2.0 * sx**2.0) + (y - my) ** 2.0 / (2.0 * sy**2.0)),
)
[docs]
def fourier_filter_2d(
image: np.ndarray,
peaks: Sequence[dict],
ret: str = "filtered",
) -> np.ndarray:
"""Function to Fourier filter an image for removal of regular pattern artifacts,
e.g. grid lines.
Args:
image (np.ndarray): the input image
peaks (Sequence[dict]): list of dicts containing the following information about a "peak"
in the Fourier image:
'pos_x', 'pos_y', sigma_x', sigma_y', 'amplitude'.
Define one entry for each feature you want to suppress in the Fourier image, where
amplitude 1 corresponds to full suppression.
ret (str, optional): flag to indicate which data to return. Possible values are:
'filtered', 'fft', 'mask', 'filtered_fft'. Defaults to "filtered"
Returns:
np.ndarray: The chosen image data. Default is the filtered real image.
"""
# Do Fourier Transform of the (real-valued) image
image_fft = np.fft.rfft2(image)
# shift fft axis to have 0 in the center
image_fft = np.fft.fftshift(image_fft, axes=0)
mask = np.ones(image_fft.shape)
xgrid, ygrid = np.meshgrid(
range(image_fft.shape[0]),
range(image_fft.shape[1]),
indexing="ij",
sparse=True,
)
for peak in peaks:
try:
mask -= peak["amplitude"] * gauss2d(
xgrid,
ygrid,
image_fft.shape[0] / 2 + peak["pos_x"],
peak["pos_y"],
peak["sigma_x"],
peak["sigma_y"],
)
except KeyError as exc:
raise KeyError(
f"The peaks input is supposed to be a list of dicts with the "
"following structure: pos_x, pos_y, sigma_x, sigma_y, amplitude.",
) from exc
# apply mask to the FFT, and transform back
filtered = np.fft.irfft2(np.fft.ifftshift(image_fft * mask, axes=0))
# strip negative values
filtered = filtered.clip(min=0)
if ret == "filtered":
return filtered
if ret == "fft":
return image_fft
if ret == "mask":
return mask
if ret == "filtered_fft":
return image_fft * mask
return filtered # default return
[docs]
def crop_xarray(
data_array: xr.DataArray,
x_min: float,
x_max: float,
y_min: float,
y_max: float,
) -> xr.DataArray:
"""Crops an xarray according to the provided coordinate boundaries.
Args:
data_array (xr.DataArray): the input xarray DataArray
x_min (float): the minimum position along the first element in the x-array dims list.
x_max (float): the maximum position along the first element in the x-array dims list.
y_min (float): the minimum position along the second element in the x-array dims list.
y_max (float): the maximum position along the second element in the x-array dims list.
Returns:
xr.DataArray: The cropped xarray DataArray.
"""
x_axis = data_array.coords[data_array.dims[0]]
y_axis = data_array.coords[data_array.dims[1]]
x_mask = (x_axis >= x_min) & (x_axis <= x_max)
y_mask = (y_axis >= y_min) & (y_axis <= y_max)
data_array_cropped = data_array.where(x_mask & y_mask, drop=True)
return data_array_cropped