Skip to content

Commit

Permalink
Merge pull request #416 from juaml/fix/correct-interpolation-for-mask
Browse files Browse the repository at this point in the history
[ENH]: Add helper function for getting correct interpolator for masks
  • Loading branch information
synchon authored Dec 9, 2024
2 parents 1102920 + ea9f435 commit 972e208
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 21 deletions.
23 changes: 22 additions & 1 deletion junifer/data/masks/_ants_mask_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, Optional

import nibabel as nib
import numpy as np

from ...pipeline import WorkDirManager
from ...utils import logger, raise_error, run_ext_cmd
Expand All @@ -20,6 +21,26 @@
__all__ = ["ANTsMaskWarper"]


def _get_interpolation_method(img: "Nifti1Image") -> str:
"""Get correct interpolation method for `img`.
Parameters
----------
img : nibabel.nifti1.Nifti1Image
The image.
Returns
-------
str
The interpolation method.
"""
if np.array_equal(np.unique(img.get_fdata()), [0, 1]):
return "'GenericLabel[NearestNeighbor]'"
else:
return "LanczosWindowedSinc"


class ANTsMaskWarper:
"""Class for mask space warping via ANTs.
Expand Down Expand Up @@ -143,7 +164,7 @@ def warp(
"antsApplyTransforms",
"-d 3",
"-e 3",
"-n 'GenericLabel[NearestNeighbor]'",
f"-n {_get_interpolation_method(mask_img)}",
f"-i {prewarp_mask_path.resolve()}",
f"-r {template_space_img_path.resolve()}",
f"-t {xfm_file_path.resolve()}",
Expand Down
23 changes: 22 additions & 1 deletion junifer/data/masks/_fsl_mask_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any

import nibabel as nib
import numpy as np

from ...pipeline import WorkDirManager
from ...utils import logger, run_ext_cmd
Expand All @@ -19,6 +20,26 @@
__all__ = ["FSLMaskWarper"]


def _get_interpolation_method(img: "Nifti1Image") -> str:
"""Get correct interpolation method for `img`.
Parameters
----------
img : nibabel.nifti1.Nifti1Image
The image.
Returns
-------
str
The interpolation method.
"""
if np.array_equal(np.unique(img.get_fdata()), [0, 1]):
return "nn"
else:
return "spline"


class FSLMaskWarper:
"""Class for mask space warping via FSL FLIRT.
Expand Down Expand Up @@ -71,7 +92,7 @@ def warp(
# Set applywarp command
applywarp_cmd = [
"applywarp",
"--interp=nn",
f"--interp={_get_interpolation_method(mask_img)}",
f"-i {prewarp_mask_path.resolve()}",
# use resampled reference
f"-r {target_data['reference']['path'].resolve()}",
Expand Down
59 changes: 40 additions & 19 deletions junifer/data/masks/_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,45 +165,44 @@ def compute_brain_mask(
)

mask_name = f"template_{target_std_space}_for_compute_brain_mask"

# Warp template to correct space (MNI to MNI)
if template_space != "native" and template_space != target_std_space:
logger.debug(
f"Warping template to {target_std_space} space using ANTs."
)
template = ANTsMaskWarper().warp(
mask_name=mask_name,
mask_img=template,
src=template_space,
dst=target_std_space,
target_data=target_data,
warp_data=None,
)

# Resample and warp template if target space is native
if target_data["space"] == "native" and template_space != "native":
if warp_data["warper"] == "fsl":
resampled_template = FSLMaskWarper().warp(
mask_name=f"template_{target_std_space}_for_compute_brain_mask",
mask_name=mask_name,
mask_img=template,
target_data=target_data,
warp_data=warp_data,
)
elif warp_data["warper"] == "ants":
resampled_template = ANTsMaskWarper().warp(
mask_name=f"template_{target_std_space}_for_compute_brain_mask",
mask_name=mask_name,
# use template here
mask_img=template,
src=target_std_space,
dst="native",
target_data=target_data,
warp_data=warp_data,
)
# Resample template to target image
else:
# Warp template to correct space
if template_space != target_std_space:
logger.debug(
f"Warping template to {target_std_space} space using ANTs."
)
template = ANTsMaskWarper().warp(
mask_name=mask_name,
mask_img=template,
src=template_space,
dst=target_std_space,
target_data=target_data,
warp_data=None,
)
# Resample template to target image
resampled_template = nimg.resample_to_img(
source_img=template, target_img=target_data["data"]
source_img=template,
target_img=target_data["data"],
interpolation=_get_interpolation_method(template),
)

# Threshold resampled template and get mask
Expand Down Expand Up @@ -561,6 +560,7 @@ def get( # noqa: C901
mask_img = nimg.resample_to_img(
source_img=mask_img,
target_img=target_data["data"],
interpolation=_get_interpolation_method(mask_img),
)
# Starting with new mask
else:
Expand Down Expand Up @@ -632,6 +632,7 @@ def get( # noqa: C901
mask_img = nimg.resample_to_img(
source_img=mask_img,
target_img=target_img,
interpolation=_get_interpolation_method(mask_img),
)
else:
# Warp mask if target space is native as
Expand Down Expand Up @@ -761,3 +762,23 @@ def _load_ukb_mask(name: str) -> Path:
mask_fname = _masks_path / "ukb" / mask_fname

return mask_fname


def _get_interpolation_method(img: "Nifti1Image") -> str:
"""Get correct interpolation method for `img`.
Parameters
----------
img : nibabel.nifti1.Nifti1Image
The image.
Returns
-------
str
The interpolation method.
"""
if np.array_equal(np.unique(img.get_fdata()), [0, 1]):
return "nearest"
else:
return "continuous"

0 comments on commit 972e208

Please sign in to comment.