Skip to content

Commit

Permalink
split out widgets for ease of use
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Oct 13, 2023
1 parent 7f4ab19 commit e6bcf74
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 273 deletions.
23 changes: 3 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ dependencies = [
"cryotypes>=0.2.0",
"einops",
"morphosamplers>=0.0.6",
"pydantic<2", # migration will take a while for napari
]

# extras
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
[project.optional-dependencies]
all = [
"napari[all]>=0.4.18",
"napari[all]>=0.4.19",
"napari-properties-plotter",
"napari-properties-viewer",
"napari-label-interpolator",
Expand All @@ -51,7 +52,7 @@ test = [
"pytest>=6.0",
"pytest-cov",
"pytest-qt",
"napari[all]>=0.4.18",
"napari[all]>=0.4.19",
]
dev = [
"blik[test]",
Expand Down Expand Up @@ -139,12 +140,6 @@ disallow_subclassing_any = false
show_error_codes = true
pretty = true

# # module specific overrides
# [[tool.mypy.overrides]]
# module = ["numpy.*",]
# ignore_errors = true


# https://coverage.readthedocs.io/en/6.4/config.html
[tool.coverage.report]
exclude_lines = [
Expand All @@ -166,15 +161,3 @@ ignore = [
".ruff_cache/**/*",
"tests/**/*",
]

# # for things that require compilation
# # https://cibuildwheel.readthedocs.io/en/stable/options/
# [tool.cibuildwheel]
# # Skip 32-bit builds & PyPy wheels on all platforms
# skip = ["*-manylinux_i686", "*-musllinux_i686", "*-win32", "pp*"]
# test-extras = ["test"]
# test-command = "pytest {project}/tests -v"
# test-skip = "*-musllinux*"

# [tool.cibuildwheel.environment]
# HATCH_BUILD_HOOKS_ENABLE = "1"
10 changes: 10 additions & 0 deletions src/blik/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ contributions:
- id: blik.main_widget
python_name: blik.widgets.main_widget:MainBlikWidget
title: Create blik widget
- id: blik.filament_picking
python_name: blik.widgets.picking:FilamentWidget
title: Create filament picking widget
- id: blik.surface_picking
python_name: blik.widgets.picking:SurfaceWidget
title: Create surface picking widget
- id: blik.file_reader
python_name: blik.widgets.file_reader:file_reader
title: Create reader widget
Expand Down Expand Up @@ -88,6 +94,10 @@ contributions:
widgets:
- command: blik.main_widget
display_name: main widget
- command: blik.filament_picking
display_name: filament widget
- command: blik.surface_picking
display_name: surface widget
- command: blik.file_reader
display_name: file reader
- command: blik.bandpass_filter
Expand Down
3 changes: 3 additions & 0 deletions src/blik/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _construct_positions_layer(
"shading": "spherical",
"antialiasing": 0,
"metadata": {"experiment_id": exp_id, "p_id": p_id, "source": source},
"projection_mode": "all",
"out_of_slice_display": True,
**pt_kwargs,
},
Expand All @@ -62,6 +63,7 @@ def _construct_orientations_layer(coords, features, scale, exp_id, p_id, source)
"length": 150 / np.array(scale),
"scale": [scale] * 3,
"metadata": {"experiment_id": exp_id, "p_id": p_id, "source": source},
"projection_mode": "all",
"out_of_slice_display": True,
},
"vectors",
Expand Down Expand Up @@ -149,6 +151,7 @@ def read_image(image):
"depiction": "plane",
"blending": "translucent",
"plane": {"thickness": 5},
"projection_mode": "mean",
},
"image",
)
Expand Down
254 changes: 8 additions & 246 deletions src/blik/widgets/main_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from scipy.spatial.transform import Rotation

from ..reader import construct_particle_layer_tuples
from ..utils import generate_vectors, invert_xyz, layer_tuples_to_layers
from .picking import (
_generate_filaments_from_points_layer,
_generate_surface_grids_from_shapes_layer,
_resample_filament,
_resample_surfaces,
)
from ..utils import generate_vectors, layer_tuples_to_layers


def _get_choices(wdg, condition=None):
Expand Down Expand Up @@ -209,7 +203,7 @@ def new(
if isinstance(lay, Image) and lay.metadata["experiment_id"] == exp_id:
pts = Points(
name=f"{exp_id} - filament picks",
size=100 / lay.scale[0],
size=20 / lay.scale[0],
scale=lay.scale,
metadata={"experiment_id": exp_id},
face_color_cycle=np.random.rand(30, 3),
Expand All @@ -223,238 +217,12 @@ def new(


@magicgui(
labels=True,
call_button="Generate",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
inside_points={"nullable": True},
)
def surface(
surface_shapes: napari.layers.Shapes,
inside_points: napari.layers.Points,
spacing_A=100,
closed=False,
) -> napari.types.LayerDataTuple:
"""create a new surface representation from picked surface points."""
surface_grids, colors = _generate_surface_grids_from_shapes_layer(
surface_shapes, spacing_A
)

meshes = []
exp_id = surface_shapes.metadata["experiment_id"]

for surf in surface_grids:
meshes.append(surf.mesh())

offset = 0
vert = []
faces = []
ids = []
for surf_id, (v, f) in enumerate(meshes):
f += offset
offset += len(v)
vert.append(v)
faces.append(f)
ids.append(np.full(len(v), surf_id))
vert = np.concatenate(vert)
faces = np.concatenate(faces)
uniq_colors, idx = np.unique(colors, axis=0, return_index=True)
colormap = uniq_colors[np.argsort(idx)]
values = np.concatenate(ids) / len(colormap)
# special case for colormap with 1 color because blacks get autoadded at index 0
if colormap.shape[0] == 1:
values += 1

surface_layer_tuple = (
(invert_xyz(vert), faces, values),
{
"name": f"{exp_id} - surface",
"metadata": {
"experiment_id": exp_id,
"surface_grids": surface_grids,
"surface_colors": colors,
},
"scale": surface_shapes.scale,
"shading": "smooth",
"colormap": colormap,
},
"surface",
)
return [surface_layer_tuple]


@magicgui(
labels=True,
call_button="Generate",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
)
def surface_particles(
surface: napari.layers.Surface,
spacing_A=50,
) -> napari.types.LayerDataTuple:
surface_grids = surface.metadata.get("surface_grids", None)
if surface_grids is None:
raise ValueError("This surface layer contains no surface grid object.")
colors = surface.metadata.get("surface_colors")

exp_id = surface.metadata["experiment_id"]
spacing = spacing_A / surface.scale[0]

pos = []
ori = []
for surf in surface_grids:
if not np.isclose(surf.separation, spacing):
surf.separation = spacing
pos.append(surf.sample())
ori.append(surf.sample_orientations())

pos = np.concatenate(pos)
features = pd.DataFrame({"orientation": np.asarray(Rotation.concatenate(ori))})

vec_layer_tuple, pos_layer_tuple = construct_particle_layer_tuples(
coords=invert_xyz(pos),
features=features,
scale=surface.scale[0],
exp_id=exp_id,
face_color_cycle=colors,
)
return [vec_layer_tuple, pos_layer_tuple]


@magicgui(
labels=True,
call_button="Resample",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
thickness_A={"widget_type": "Slider", "min": 1, "max": 500},
)
def resample_surface(
surface: napari.layers.Surface,
volume: napari.layers.Image,
spacing_A=5,
thickness_A=200,
masked=False,
) -> napari.types.LayerDataTuple:
surface_grids = surface.metadata.get("surface_grids", None)
if surface_grids is None:
raise ValueError("This surface layer contains no surface grid object.")

exp_id = surface.metadata["experiment_id"]
spacing = spacing_A / surface.scale[0]
thickness = int(np.round(thickness_A / surface.scale[0]))
for surf in surface_grids:
if not np.isclose(surf.separation, spacing):
surf.separation = spacing

vols = _resample_surfaces(volume, surface_grids, spacing, thickness, masked)

v = napari.Viewer()
for i, vol in enumerate(vols):
v.add_image(
vol,
name=f"{exp_id} - surface_{i} resampled",
metadata={"experiment_id": exp_id},
scale=surface.scale,
)


@magicgui(
labels=True,
call_button="Generate",
)
def filament(
points: napari.layers.Points,
) -> napari.types.LayerDataTuple:
filament = _generate_filaments_from_points_layer(points)

exp_id = points.metadata["experiment_id"]

path = filament.sample(n_samples=len(points.data) * 50)
shapes_layer_tuple = (
[invert_xyz(path)],
{
"name": f"{exp_id} - filament",
"metadata": {
"experiment_id": exp_id,
"helical_filament": filament,
},
"scale": points.scale,
"shape_type": "path",
},
"shapes",
)
return [shapes_layer_tuple]


@magicgui(
labels=True,
call_button="Generate",
rise_A={"widget_type": "Slider", "min": 1, "max": 500},
radius_A={"widget_type": "Slider", "min": 0, "max": 500},
twist_deg={"widget_type": "Slider", "min": 0, "max": 360},
twist_offset={"widget_type": "Slider", "min": 0, "max": 360},
)
def filament_particles(
filament: napari.layers.Shapes,
rise_A=50,
twist_deg=0,
twist_offset=0,
radius_A=0,
cyclic_symmetry_order=1,
) -> napari.types.LayerDataTuple:
helical_filament = filament.metadata.get("helical_filament", None)
if helical_filament is None:
raise ValueError("This shapes layer contains no helical filament object.")

exp_id = filament.metadata["experiment_id"]

pos, ori = helical_filament.sample_helical(
rise=rise_A / filament.scale[0],
twist=twist_deg,
radial_offset=radius_A / filament.scale[0],
cyclic_symmetry_order=cyclic_symmetry_order,
twist_offset=twist_offset,
degrees=True,
)

features = pd.DataFrame({"orientation": np.asarray(Rotation.concatenate(ori))})

vec_layer_tuple, pos_layer_tuple = construct_particle_layer_tuples(
coords=invert_xyz(pos),
features=features,
scale=filament.scale[0],
exp_id=exp_id,
)
return [vec_layer_tuple, pos_layer_tuple]


@magicgui(
labels=True,
call_button="Resample",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
thickness_A={"widget_type": "Slider", "min": 1, "max": 500},
labels=False,
auto_call=True,
thickness_A={"widget_type": "FloatSlider", "min": 0, "max": 500},
)
def resample_filament(
filament: napari.layers.Shapes,
volume: napari.layers.Image,
spacing_A=5,
thickness_A=200,
) -> napari.types.LayerDataTuple:
helical_filament = filament.metadata.get("helical_filament", None)
if helical_filament is None:
raise ValueError("This shapes layer contains no helical filament object.")

exp_id = filament.metadata["experiment_id"]
spacing = spacing_A / filament.scale[0]
thickness = int(np.round(thickness_A / filament.scale[0]))

vol = _resample_filament(volume, helical_filament, spacing, thickness)

v = napari.Viewer()
v.add_image(
vol,
name=f"{exp_id} - filament resampled",
metadata={"experiment_id": exp_id},
scale=filament.scale,
)
def slice_thickness_A(viewer: napari.Viewer, thickness_A=0):
viewer.dims.thickness = (thickness_A,) * viewer.dims.ndim


class MainBlikWidget(Container):
Expand All @@ -474,13 +242,7 @@ def __init__(self, *args, **kwargs):
self.append(exp)
self.append(new)
self.append(add_to_exp)
self.append(surface)
self.append(surface_particles)
self.append(resample_surface)
self.append(filament)
self.append(filament_particles)
self.append(resample_filament)
self.scrollable = True
self.append(slice_thickness_A)

def append(self, item):
super().append(item)
Expand Down
Loading

0 comments on commit e6bcf74

Please sign in to comment.