from __future__ import print_function, division
from typing import Optional, Union, List
import numpy as np
from scipy import ndimage
from skimage.measure import label
from skimage.transform import resize
from skimage.morphology import dilation
from skimage.segmentation import watershed
from skimage.morphology import remove_small_objects
from connectomics.data.utils import getSegType
from .misc import bbox_ND, crop_ND
__all__ = ['binary_connected',
'binary_watershed',
'bc_connected',
'bc_watershed',
'bcd_watershed',
'polarity2instance']
# Post-processing functions of mitochondria instance segmentation model outputs
# as described in "MitoEM Dataset: Large-scale 3D Mitochondria Instance Segmentation
# from EM Images (MICCAI 2020, https://donglaiw.github.io/page/mitoEM/index.html)".
[docs]def binary_connected(volume, thres=0.8, thres_small=128, scale_factors=(1.0, 1.0, 1.0),
remove_small_mode='background'):
r"""Convert binary foreground probability maps to instance masks via
connected-component labeling.
Args:
volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`.
thres (float): threshold of foreground. Default: 0.8
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
"""
semantic = volume[0]
foreground = (semantic > int(255*thres))
segm = label(foreground)
segm = remove_small_instances(segm, thres_small, remove_small_mode)
if not all(x==1.0 for x in scale_factors):
target_size = (int(semantic.shape[0]*scale_factors[0]),
int(semantic.shape[1]*scale_factors[1]),
int(semantic.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
[docs]def binary_watershed(volume, thres1=0.98, thres2=0.85, thres_small=128, scale_factors=(1.0, 1.0, 1.0),
remove_small_mode='background', seed_thres=32):
r"""Convert binary foreground probability maps to instance masks via
watershed segmentation algorithm.
Note:
This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_
function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays.
Args:
volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`.
thres1 (float): threshold of seeds. Default: 0.98
thres2 (float): threshold of foreground. Default: 0.85
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
"""
semantic = volume[0]
seed_map = semantic > int(255*thres1)
foreground = semantic > int(255*thres2)
seed = label(seed_map)
seed = remove_small_objects(seed, seed_thres)
segm = watershed(-semantic.astype(np.float64), seed, mask=foreground)
segm = remove_small_instances(segm, thres_small, remove_small_mode)
if not all(x==1.0 for x in scale_factors):
target_size = (int(semantic.shape[0]*scale_factors[0]),
int(semantic.shape[1]*scale_factors[1]),
int(semantic.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
[docs]def bc_connected(volume, thres1=0.8, thres2=0.5, thres_small=128, scale_factors=(1.0, 1.0, 1.0),
dilation_struct=(1,5,5), remove_small_mode='background'):
r"""Convert binary foreground probability maps and instance contours to
instance masks via connected-component labeling.
Note:
The instance contour provides additional supervision to distinguish closely touching
objects. However, the decoding algorithm only keep the intersection of foreground and
non-contour regions, which will systematically result in imcomplete instance masks.
Therefore we apply morphological dilation (check :attr:`dilation_struct`) to enlarge
the object masks.
Args:
volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`.
thres1 (float): threshold of foreground. Default: 0.8
thres2 (float): threshold of instance contours. Default: 0.5
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
dilation_struct (tuple): the shape of the structure for morphological dilation. Default: (1, 5, 5)
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
"""
semantic = volume[0]
boundary = volume[1]
foreground = (semantic > int(255*thres1)) * (boundary < int(255*thres2))
segm = label(foreground)
struct = np.ones(dilation_struct)
segm = dilation(segm, struct)
segm = remove_small_instances(segm, thres_small, remove_small_mode)
if not all(x==1.0 for x in scale_factors):
target_size = (int(semantic.shape[0]*scale_factors[0]),
int(semantic.shape[1]*scale_factors[1]),
int(semantic.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
[docs]def bc_watershed(volume, thres1=0.9, thres2=0.8, thres3=0.85, thres_small=128, scale_factors=(1.0, 1.0, 1.0),
remove_small_mode='background', seed_thres=32):
r"""Convert binary foreground probability maps and instance contours to
instance masks via watershed segmentation algorithm.
Note:
This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_
function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays.
Args:
volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`.
thres1 (float): threshold of seeds. Default: 0.9
thres2 (float): threshold of instance contours. Default: 0.8
thres3 (float): threshold of foreground. Default: 0.85
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
"""
assert volume.shape[0] == 2
semantic = volume[0]
boundary = volume[1]
seed_map = (semantic > int(255*thres1)) * (boundary < int(255*thres2))
foreground = (semantic > int(255*thres3))
seed = label(seed_map)
seed = remove_small_objects(seed, seed_thres)
segm = watershed(-semantic.astype(np.float64), seed, mask=foreground)
segm = remove_small_instances(segm, thres_small, remove_small_mode)
if not all(x==1.0 for x in scale_factors):
target_size = (int(semantic.shape[0]*scale_factors[0]),
int(semantic.shape[1]*scale_factors[1]),
int(semantic.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
[docs]def bcd_watershed(volume, thres1=0.9, thres2=0.8, thres3=0.85, thres4=0.5, thres5=0.0, thres_small=128,
scale_factors=(1.0, 1.0, 1.0), remove_small_mode='background', seed_thres=32):
r"""Convert binary foreground probability maps, instance contours and signed distance
transform to instance masks via watershed segmentation algorithm.
Note:
This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_
function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays.
Args:
volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`.
thres1 (float): threshold of seeds. Default: 0.9
thres2 (float): threshold of instance contours. Default: 0.8
thres3 (float): threshold of foreground. Default: 0.85
thres4 (float): threshold of signed distance for locating seeds. Default: 0.5
thres5 (float): threshold of signed distance for foreground. Default: 0.0
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'``
"""
assert volume.shape[0] == 3
semantic, boundary, distance = volume[0], volume[1], volume[2]
distance = (distance / 255.0) * 2.0 - 1.0
seed_map = (semantic > int(255*thres1)) * (boundary < int(255*thres2)) * (distance > thres4)
foreground = (semantic > int(255*thres3)) * (distance > thres5)
seed = label(seed_map)
seed = remove_small_objects(seed, seed_thres)
segm = watershed(-semantic.astype(np.float64), seed, mask=foreground)
segm = remove_small_instances(segm, thres_small, remove_small_mode)
if not all(x==1.0 for x in scale_factors):
target_size = (int(semantic.shape[0]*scale_factors[0]),
int(semantic.shape[1]*scale_factors[1]),
int(semantic.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
# Post-processing functions for synaptic polarity model outputs as described
# in "Two-Stream Active Query Suggestion for Active Learning in Connectomics
# (ECCV 2020, https://zudi-lin.github.io/projects/#two_stream_active)".
[docs]def polarity2instance(volume, thres=0.5, thres_small=128,
scale_factors=(1.0, 1.0, 1.0), semantic=False):
r"""From synaptic polarity prediction to instance masks via connected-component
labeling. The input volume should be a 3-channel probability map of shape :math:`(C, Z, Y, X)`
where :math:`C=3`, representing pre-synaptic region, post-synaptic region and their
union, respectively.
Note:
For each pair of pre- and post-synaptic segmentation, the decoding function will
annotate pre-synaptic region as :math:`2n-1` and post-synaptic region as :math:`2n`,
for :math:`n>0`. If :attr:`semantic=True`, all pre-synaptic pixels are labeled with
while all post-synaptic pixels are labeled with 2. Both kinds of annotation are compatible
with the ``TARGET_OPT: ['1']`` configuration in training.
Note:
The number of pre- and post-synaptic segments will be reported when setting :attr:`semantic=False`.
Note that the numbers can be different due to either incomplete syanpses touching the volume borders,
or errors in the prediction. We thus make a conservative estimate of the total number of synapses
by using the relatively small number among the two.
Args:
volume (numpy.ndarray): 3-channel probability map of shape :math:`(3, Z, Y, X)`.
thres (float): probability threshold of foreground. Default: 0.5
thres_small (int): size threshold of small objects to remove. Default: 128
scale_factors (tuple): scale factors for resizing the output volume in :math:`(Z, Y, X)` order. Default: :math:`(1.0, 1.0, 1.0)`
semantic (bool): return only the semantic mask of pre- and post-synaptic regions. Default: False
Examples::
>>> from connectomics.data.utils import readvol, savevol
>>> from connectomics.utils.processing import polarity2instance
>>> volume = readvol(input_name)
>>> instances = polarity2instance(volume)
>>> savevol(output_name, instances)
"""
thres = int(255.0 * thres)
temp = (volume > thres).astype(np.uint8)
syn_pre = temp[0] * temp[2]
syn_pre = remove_small_objects(syn_pre,
min_size=thres_small, connectivity=1)
syn_post = temp[1] * temp[2]
syn_post = remove_small_objects(syn_post,
min_size=thres_small, connectivity=1)
if semantic:
# Generate only the semantic mask. The pre-synaptic region is labeled
# with 1, while the post-synaptic region is labeled with 2.
segm = np.maximum(syn_pre.astype(np.uint8),
syn_post.astype(np.uint8) * 2)
else:
# Generate the instance mask.
foreground = dilation(temp[2].copy(), np.ones((1,5,5)))
foreground = label(foreground)
# Since non-zero pixels in seg_pos and seg_neg are subsets of temp[2],
# they are naturally subsets of non-zero pixels in foreground.
seg_pre = (foreground*2 - 1) * syn_pre.astype(foreground.dtype)
seg_post = (foreground*2) * syn_post.astype(foreground.dtype)
segm = np.maximum(seg_pre, seg_post)
# Report the number of synapses
num_syn_pre = len(np.unique(seg_pre))-1
num_syn_post = len(np.unique(seg_post))-1
num_syn = min(num_syn_pre, num_syn_post) # a conservative estimate
print("Stats: found %d pre- and %d post-" % (num_syn_pre, num_syn_post) +
"synaptic segments in the volume")
print("There are %d synapses under a conservative estimate." % num_syn)
# resize the segmentation based on specified scale factors
if not all(x==1.0 for x in scale_factors):
target_size = (int(segm.shape[0]*scale_factors[0]),
int(segm.shape[1]*scale_factors[1]),
int(segm.shape[2]*scale_factors[2]))
segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True)
return cast2dtype(segm)
# utils for post-processing
def binarize_and_median(pred, size=(7,7,7), thres=0.8):
"""First binarize the prediction with a given threshold, and
then conduct median filtering to reduce noise.
pred (numpy.ndarray): predicted foreground probability within (0,1).
size (tuple): kernal size of filtering. Default: (7,7,7)
thres (float): threshold for binarizing the prediction. Default: 0.8
"""
pred = (pred > thres).astype(np.uint8)
pred = ndimage.median_filter(pred, size=size)
return pred
def remove_small_instances(segm: np.ndarray,
thres_small: int = 128,
mode: str = 'background'):
"""Remove small spurious instances.
"""
assert mode in ['none',
'background',
'background_2d',
'neighbor',
'neighbor_2d']
if mode == 'none':
return segm
if mode == 'background':
return remove_small_objects(segm, thres_small)
elif mode == 'background_2d':
temp = [remove_small_objects(segm[i], thres_small)
for i in range(segm.shape[0])]
return np.stack(temp, axis=0)
if mode == 'neighbor':
return merge_small_objects(segm, thres_small, do_3d=True)
elif mode == 'neighbor_2d':
temp = [merge_small_objects(segm[i], thres_small)
for i in range(segm.shape[0])]
return np.stack(temp, axis=0)
def merge_small_objects(segm, thres_small, do_3d=False):
struct = np.ones((1,3,3)) if do_3d else np.ones((3,3))
indices, counts = np.unique(segm, return_counts=True)
for i in range(len(indices)):
idx = indices[i]
if counts[i] < thres_small:
temp = (segm == idx).astype(np.uint8)
coord = bbox_ND(temp, relax=2)
cropped = crop_ND(temp, coord)
diff = dilation(cropped, struct) - cropped
diff_segm = crop_ND(segm, coord)
diff_segm[np.where(diff==0)]=0
u, ct = np.unique(diff_segm, return_counts=True)
if len(u) > 1 and u[0] == 0:
u, ct = u[1:], ct[1:]
segm[np.where(segm==idx)] = u[np.argmax(ct)]
return segm
def remove_large_instances(segm: np.ndarray,
max_size: int = 2000):
"""Remove large instances given a maximum size threshold.
"""
out = np.copy(segm)
component_sizes = np.bincount(segm.ravel())
too_large = component_sizes > max_size
too_large_mask = too_large[segm]
out[too_large_mask] = 0
return out
def cast2dtype(segm):
"""Cast the segmentation mask to the best dtype to save storage.
"""
max_id = np.amax(np.unique(segm))
m_type = getSegType(int(max_id))
return segm.astype(m_type)