Skip to content

Commit

Permalink
Bump bioimageio dependency (#243)
Browse files Browse the repository at this point in the history
Update bioimageio dependencies and modelzoo export scripts
  • Loading branch information
constantinpape authored Apr 29, 2024
1 parent cf6f674 commit 588f2fc
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 316 deletions.
4 changes: 2 additions & 2 deletions environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ name:
torch-em-cpu
dependencies:
- affogato
- bioimageio.spec <0.5.0
- bioimageio.core >=0.5.0
- bioimageio.spec >=0.5.0
- bioimageio.core >=0.6.0
- cpuonly
- imagecodecs
- python-elf
Expand Down
4 changes: 2 additions & 2 deletions environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ name:
torch-em
dependencies:
- affogato
- bioimageio.spec <0.5.0
- bioimageio.core >=0.5.0
- bioimageio.spec >=0.5.0
- bioimageio.core >=0.6.0
- imagecodecs
- python-elf
- pytorch >=2.0
Expand Down
33 changes: 19 additions & 14 deletions test/util/test_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ def __call__(self, labels):


class TestModelzoo(unittest.TestCase):
data_path = "./data.h5"
checkpoint_folder = "./checkpoints"
log_folder = "./logs"
save_folder = "./zoo_export"
name = "test"
name = "test-export"
data_path = "./zoo_export/data.h5"

def setUp(self):
os.makedirs(self.save_folder, exist_ok=True)
shape = (8, 128, 128)
chunks = (1, 128, 128)
with h5py.File(self.data_path, "w") as f:
Expand All @@ -42,10 +44,9 @@ def setUp(self):
chunks=chunks)

def tearDown(self):
if os.path.exists(self.data_path):
os.remove(self.data_path)
rmtree(self.checkpoint_folder, ignore_errors=True)
rmtree(self.save_folder, ignore_errors=True)
rmtree(self.log_folder, ignore_errors=True)

def _create_checkpoint(self, n_channels):
if n_channels > 1:
Expand Down Expand Up @@ -73,38 +74,42 @@ def _create_checkpoint(self, n_channels):

def _test_export(self, n_channels):
from torch_em.util.modelzoo import export_bioimageio_model
self._create_checkpoint(n_channels)

self._create_checkpoint(n_channels)
output_path = os.path.join(self.save_folder, "exported.zip")
success = export_bioimageio_model(
os.path.join(self.checkpoint_folder, self.name),
self.save_folder,
output_path,
input_data=np.random.rand(128, 128).astype("float32"),
maintainers=[{"github_user": "constantinpape"}],
input_optional_parameters=False

)
self.assertTrue(success)
self.assertTrue(os.path.exists(self.save_folder))
self.assertTrue(os.path.exists(os.path.join(self.save_folder, "rdf.yaml")))
self.assertTrue(os.path.exists(output_path))

return output_path

def test_export_single_channel(self):
self._test_export(1)

def test_export_multi_channel(self):
self._test_export(4)

@unittest.expectedFailure
def test_add_weights_torchscript(self):
from torch_em.util.modelzoo import add_weight_formats
self._test_export(1)
additional_formats = ["torchscript"]
add_weight_formats(self.save_folder, additional_formats)

model_path = self._test_export(1)
add_weight_formats(model_path, ["torchscript"])
self.assertTrue(os.path.join(self.save_folder, "weigths-torchscript.pt"))

@unittest.expectedFailure
@unittest.skipIf(onnx is None, "Needs onnx")
def test_add_weights_onnx(self):
from torch_em.util.modelzoo import add_weight_formats

self._test_export(1)
additional_formats = ["onnx"]
add_weight_formats(self.save_folder, additional_formats)
add_weight_formats(self.save_folder, ["onnx"])
self.assertTrue(os.path.join(self.save_folder, "weigths.onnx"))


Expand Down
Loading

0 comments on commit 588f2fc

Please sign in to comment.