Skip to content

Commit

Permalink
Merge branch 'omezarr' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
kttakasaki committed Mar 27, 2023
2 parents 3d5ad21 + 750e5b3 commit aace92c
Show file tree
Hide file tree
Showing 4 changed files with 570 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
import json
import pathlib
import shutil
from packaging import version

import argschema

from acpreprocessing.stitching_modules.convert_to_n5.tiff_to_n5 import (
tiffdir_to_n5_group,
N5GenerationParameters
)
from acpreprocessing.stitching_modules.convert_to_n5.tiff_to_ngff import tiffdir_to_ngff_group, NGFFGenerationParameters


def yield_position_paths_from_rootdir(
Expand All @@ -27,17 +24,40 @@ def yield_position_paths_from_rootdir(
yield root_path / stripdir_bn


def get_position_names_from_rootdir(
root_dir, stripjson_bn="hh.log",
stripjson_key="stripdirs"):
root_path = pathlib.Path(root_dir)
stripjson_path = root_path / stripjson_bn
with stripjson_path.open() as f:
stripjson_md = json.load(f)
return stripjson_md[stripjson_key]


def get_pixel_resolution_from_rootdir(
root_dir, md_bn="acqinfo_metadata.json"):
root_path = pathlib.Path(root_dir)
md_path = root_path / md_bn
with md_path.open() as f:
md = json.load(f)
xy = md["settings"]["pixel_spacing_um"]
z = md["positions"][1]["x_step_um"]
z = md["positions"][0]["x_step_um"]
return [xy, xy, z]


def get_strip_positions_from_rootdir(
root_dir, md_bn="acqinfo_metadata.json"):
root_path = pathlib.Path(root_dir)
md_path = root_path / md_bn
with md_path.open() as f:
md = json.load(f)

if version.parse(md["version"]) >= version.parse("0.0.3"):
return [(p["z_start_um"], p["y_start_um"], p["x_start_um"]) for p in md["positions"]]
else:
return [(0, p["y_start_um"], p["x_start_um"]) for p in md["positions"]]


def get_number_interleaved_channels_from_rootdir(
root_dir, md_bn="acqinfo_metadata.json"):
root_path = pathlib.Path(root_dir)
Expand All @@ -49,29 +69,34 @@ def get_number_interleaved_channels_from_rootdir(
return interleaved_channels


def acquisition_to_n5(acquisition_dir, out_dir, concurrency=5,
n5_generation_kwargs=None, copy_top_level_files=True):
def acquisition_to_ngff(acquisition_dir, output, out_dir, concurrency=5,
ngff_generation_kwargs=None, copy_top_level_files=True):
"""
"""
n5_generation_kwargs = (
{} if n5_generation_kwargs is None
else n5_generation_kwargs)
ngff_generation_kwargs = (
{} if ngff_generation_kwargs is None
else ngff_generation_kwargs)

acquisition_path = pathlib.Path(acquisition_dir)
out_path = pathlib.Path(out_dir)
out_n5_dir = str(out_path / f"{out_path.name}.n5")
if output == 'zarr':
output_dir = str(out_path / f"{out_path.name}.zarr")
else:
output_dir = str(out_path / f"{out_path.name}.n5")

interleaved_channels = get_number_interleaved_channels_from_rootdir(
acquisition_path)
positionList = get_strip_positions_from_rootdir(acquisition_path)

try:
setup_group_attributes = {
setup_group_attributes = [{
"pixelResolution": {
"dimensions": get_pixel_resolution_from_rootdir(
acquisition_path),
"unit": "um"
}
}
},
"position": p
} for p in positionList]
except (KeyError, FileNotFoundError):
setup_group_attributes = {}

Expand All @@ -85,16 +110,22 @@ def acquisition_to_n5(acquisition_dir, out_dir, concurrency=5,
# pos_group = pospath.name
# below is more like legacy structure
# out_n5_dir = str(out_path / f"{pos_group}.n5")
if output == 'zarr':
group_names = [pospath.name]
group_attributes = [setup_group_attributes[i]]
else:
group_names = [
f"channel{channel_idx}", f"setup{i}", "timepoint0"]
group_attributes = [channel_group_attributes,
setup_group_attributes[i]]

futs.append(e.submit(
tiffdir_to_n5_group,
str(pospath), out_n5_dir, [
f"channel{channel_idx}", f"setup{i}", "timepoint0"],
group_attributes=[
channel_group_attributes,
setup_group_attributes],
tiffdir_to_ngff_group,
str(pospath), output, output_dir, group_names,
group_attributes=group_attributes,
interleaved_channels=interleaved_channels,
channel=channel_idx,
**n5_generation_kwargs
**ngff_generation_kwargs
))

for fut in concurrent.futures.as_completed(futs):
Expand All @@ -109,33 +140,35 @@ def acquisition_to_n5(acquisition_dir, out_dir, concurrency=5,
shutil.copy(str(tlf_path), str(out_tlf_path))


class AcquisitionDirToN5DirParameters(
argschema.ArgSchema, N5GenerationParameters):
class AcquisitionDirToNGFFParameters(
argschema.ArgSchema, NGFFGenerationParameters):
input_dir = argschema.fields.Str(required=True)
output_dir = argschema.fields.Str(required=True)
output_format = argschema.fields.Str(required=True)
# output_dir = argschema.fields.Str(required=True)
copy_top_level_files = argschema.fields.Bool(required=False, default=True)
position_concurrency = argschema.fields.Int(required=False, default=5)


class AcquisitionDirToN5Dir(argschema.ArgSchemaParser):
default_schema = AcquisitionDirToN5DirParameters
class AcquisitionDirToNGFF(argschema.ArgSchemaParser):
default_schema = AcquisitionDirToNGFFParameters

def _get_n5_kwargs(self):
n5_keys = {
def _get_ngff_kwargs(self):
ngff_keys = {
"max_mip", "concurrency", "compression",
"lvl_to_mip_kwargs", "chunk_size", "mip_dsfactor"}
return {k: self.args[k] for k in (n5_keys & self.args.keys())}
"lvl_to_mip_kwargs", "chunk_size", "mip_dsfactor",
"deskew_options"}
return {k: self.args[k] for k in (ngff_keys & self.args.keys())}

def run(self):
n5_kwargs = self._get_n5_kwargs()
acquisition_to_n5(
self.args["input_dir"], self.args["output_dir"],
ngff_kwargs = self._get_ngff_kwargs()
acquisition_to_ngff(
self.args["input_dir"], self.args["output_format"], self.args["output_file"],
concurrency=self.args["position_concurrency"],
n5_generation_kwargs=n5_kwargs,
ngff_generation_kwargs=ngff_kwargs,
copy_top_level_files=self.args["copy_top_level_files"]
)
)


if __name__ == "__main__":
mod = AcquisitionDirToN5Dir()
mod.run()
mod = AcquisitionDirToNGFF()
mod.run()
154 changes: 154 additions & 0 deletions acpreprocessing/stitching_modules/convert_to_n5/psdeskew.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Pixel shift deskew
implements (chunked) pixel shifting deskew
skew_dims_zyx = dimensions of skewed (input) tiff data (xy are camera coordinates, z is tiff chunk #size, xz define skewed plane and y is non-skewed axis)
stride = number of camera (x) pixels to shift onto a sample (z') plane (sample z dim = camera x #dim/stride)
deskewFlip = flip volume (reflection, parity inversion)
dtype = datatype of input data
NOTE: must be run sequentially as each tiff chunk contains data for the next deskewed block #retained in self.slice1d except for the final chunk which should form the rhomboid edge
"""

import numpy as np


def psdeskew_kwargs(skew_dims_zyx, deskew_stride=1, deskew_flip=False, deskew_crop=1, dtype='uint16', **kwargs):
"""get keyword arguments for deskew_block
Parameters
----------
skew_dims_zyx : tuple of int
dimensions of raw data array block to be deskewed
stride : int
number of camera pixels per deskewed sampling plane (divides z resolution)
deskewFlip : bool
flip data blocks before deskewing
dtype : str
datatype for deskew output
crop_factor : float
reduce y dimension according to ydim*crop_factor < ydim
Returns
----------
kwargs : dict
parameters representing pixel deskew operation for deskew_block
"""
sdims = skew_dims_zyx
crop_factor = deskew_crop
stride = deskew_stride
ydim = int(sdims[1]*crop_factor)
blockdims = (int(sdims[2]/stride), ydim, stride*sdims[0])
subblocks = int(np.ceil((sdims[2]+stride*sdims[0])/(stride*sdims[0])))
# print(subblocks)
blockx = sdims[0]
dsi = []
si = []
for i_block in range(subblocks):
sxv = []
szv = []
for sz in range(blockx):
sxstart = i_block*stride*blockx-stride*sz
sxend = (i_block+1)*stride*blockx-stride*sz
if sxstart < 0:
sxstart = 0
if sxend > sdims[2]:
sxend = sdims[2]
sx = np.arange(sxstart, sxend)
sxv.append(sx)
szv.append(sz*np.ones(sx.shape, dtype=sx.dtype))
sxv = np.concatenate(sxv)
szv = np.concatenate(szv)
dsx = sxv + stride*szv - i_block*stride*blockx
dsz = np.floor(sxv/stride).astype(int)
dsi.append(np.ravel_multi_index(
(dsz, dsx), (blockdims[0], blockdims[2])))
si.append(np.ravel_multi_index((szv, sxv), (sdims[0], sdims[2])))
kwargs = {'dsi': dsi,
'si': si,
'slice1d': np.zeros((subblocks, blockdims[1], blockdims[2]*blockdims[0]), dtype=dtype),
'blockdims': blockdims,
'subblocks': subblocks,
'flip': deskew_flip,
'dtype': dtype,
'chunklength': blockx
}
return kwargs


def deskew_block(blockData, n, dsi, si, slice1d, blockdims, subblocks, flip, dtype, chunklength, *args, **kwargs):
"""deskew a data chunk in sequence with prior chunks
Parameters
----------
blockData : numpy.ndarray
block of raw (nondeskewed) data to be deskewed
n : int
current iteration in block sequence (must be run sequentially)
dsi : numpy.ndarray
deskewed indices for reslicing flattened data
si : numpy.ndarray
skewed indices for sampling flattened raw data
slice1d : numpy.ndarray
flattened data from previous iteration containing data for next deskewed block
blockdims : tuple of int
dimensions of output block
subblocks : int
number of partitions of input block for processing - likely not necessary
flip : bool
deskew flip
dtype : str
datatype
chunklength : int
number of slices expected for raw data block (for zero filling)
Returns
----------
block3d : numpy.ndarray
pixel shifted deskewed data ordered (z,y,x) by sample axes
"""
subb = subblocks
block3d = np.zeros(blockdims, dtype=dtype)
zdim = block3d.shape[0]
ydim = block3d.shape[1]
xdim = block3d.shape[2]
# crop blockData if needed
if blockData.shape[1] > ydim:
y0 = int(np.floor((blockData.shape[1]-ydim)/2))
y1 = int(np.floor((blockData.shape[1]+ydim)/2))
blockData = blockData[:, y0:y1, :]
#print('deskewing block ' + str(n) + ' with shape ' + str(blockData.shape))
if blockData.shape[0] < chunklength:
#print('block is short, filling with zeros')
blockData = np.concatenate((blockData, np.zeros(
(int(chunklength-blockData.shape[0]), blockData.shape[1], blockData.shape[2]))))
order = (np.arange(subb)+n) % subb
for y in range(ydim):
for i, o in enumerate(order):
# flip stack axis 2 for ispim2
s = -1 if flip else 1
slice1d[o, y, :][dsi[i]] = blockData[:, y, ::s].ravel()[si[i]]
block3d[:, y, :] = slice1d[n % subb, y, :].reshape((zdim, xdim))
slice1d[n % subb, y, :] = 0
return block3d


def reshape_joined_shapes(joined_shapes, stride, blockdims, *args, **kwargs):
"""get dimensions of deskewed joined shapes from skewed joined shapes
Parameters
----------
joined_shapes : tuple of int
shape of 3D array represented by concatenating mimg_fns
stride : int
number of camera pixels per deskewed sampling plane (divides z resolution)
blockdims : tuple of int
dimensions of output block
Returns
----------
deskewed_shape : tuple of int
shape of deskewed 3D array represented by joined_shapes
"""
deskewed_shape = (int(np.ceil(joined_shapes[0]/(blockdims[2]/stride))*blockdims[2]),
blockdims[1],
blockdims[0])
return deskewed_shape
Loading

0 comments on commit aace92c

Please sign in to comment.