Skip to content

Commit

Permalink
typing changes; less verbose creating torch Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
afrendeiro committed Nov 21, 2023
1 parent e630def commit e9248ef
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 78 deletions.
164 changes: 90 additions & 74 deletions wsi_core/WholeSlideImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from xml.dom import minidom
import typing as tp
from pathlib import Path as _Path

import cv2
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -37,7 +38,7 @@
class WholeSlideImage(object):
def __init__(
self,
path: Path | str,
path: Path | _Path | str,
attributes: tp.Optional[dict[str, tp.Any]] = None,
mask_file: Path | None = None,
hdf5_file: Path | None = None,
Expand All @@ -47,7 +48,7 @@ def __init__(
path (str): fullpath to WSI file
attributes
"""
if isinstance(path, str):
if not isinstance(path, Path):
path = Path(path)
self.path = path
self.attributes = attributes
Expand Down Expand Up @@ -424,7 +425,9 @@ def as_data_loader(self, batch_size: int = 128, with_coords: bool = False, **kwa
collate = partial(collate_features, with_coords=with_coords)

dataset = self.as_tile_bag()
loader = DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate)
loader = DataLoader(
dataset=dataset, batch_size=batch_size, collate_fn=collate, **kwargs
)
return loader

def _getPatchGenerator(
Expand Down Expand Up @@ -745,94 +748,107 @@ def segment_tissue_manual(self, level: int | None = None, color_space: str = "RG
for cont in holes_tissue
]

# TODO: Important! Pair holes and contours by checking which holes are in which tissue pieces
# shape of holes_tissue must match contours_tissue, even if there are no holes

self.contours_tissue = [x[:, np.newaxis, :] for x in contours_tissue]
self.holes_tissue = [x[:, np.newaxis, :] for x in holes_tissue]

assert len(self.contours_tissue) > 0, "Segmentation could not find tissue!"
self.saveSegmentation()

def segment(
self,
params: tp.Optional[dict[str, tp.Any]] = None,
method: str = "CLAM",
) -> None:
assert method in ["manual", "CLAM"], f"Unknown segmentation method: {method}"
if method == "manual":
self.segment_tissue_manual(**(params or {}))
return

assert method == "CLAM", f"Unknown segmentation method: {method}"
# import pandas as pd
if params is None:
# url = "https://raw.githubusercontent.com/mahmoodlab/CLAM/master/presets/bwh_biopsy.csv"
# params = pd.read_csv(url).squeeze().to_dict()
params = {
"sthresh": 15,
"mthresh": 11,
"close": 2,
"use_otsu": False,
"a_t": 1,
"a_h": 1,
"max_n_holes": 2,
"vis_level": -1,
"line_thickness": 50,
"white_thresh": 5,
"black_thresh": 50,
"use_padding": True,
"contour_fn": "four_pt",
"keep_ids": "none",
"exclude_ids": "none",
}

if "seg_level" not in params:
g = np.absolute(
(np.asarray(self.wsi.level_dimensions) - np.asarray([1000, 1000]))
).sum(1)
params["seg_level"] = np.argmin(g)
else:
# import pandas as pd
if params is None:
# url = "https://raw.githubusercontent.com/mahmoodlab/CLAM/master/presets/bwh_biopsy.csv"
# params = pd.read_csv(url).squeeze().to_dict()
params = {
"sthresh": 15,
"mthresh": 11,
"close": 2,
"use_otsu": False,
"a_t": 1,
"a_h": 1,
"max_n_holes": 2,
"vis_level": -1,
"line_thickness": 50,
"white_thresh": 5,
"black_thresh": 50,
"use_padding": True,
"contour_fn": "four_pt",
"keep_ids": "none",
"exclude_ids": "none",
}

kwargs = filter_kwargs_by_callable(params, self.segmentTissue)
fkwargs = {k: v for k, v in params.items() if k not in kwargs}
self.segmentTissue(**kwargs, filter_params=fkwargs)
self.saveSegmentation()
if "seg_level" not in params:
g = np.absolute(
(np.asarray(self.wsi.level_dimensions) - np.asarray([1000, 1000]))
).sum(1)
params["seg_level"] = np.argmin(g)

kwargs = filter_kwargs_by_callable(params, self.segmentTissue)
fkwargs = {k: v for k, v in params.items() if k not in kwargs}
self.segmentTissue(**kwargs, filter_params=fkwargs)
assert len(self.contours_tissue) > 0, "Segmentation could not find tissue!"
self.saveSegmentation()
self.plot_segmentation()

# def plot_segmentation(self, output_file: tp.Optional[Path] = None) -> None:
# from shapely.geometry import Polygon

# if output_file is None:
# output_file = self.mask_file.with_suffix(".png")

# level = self.wsi.level_count - 1
# thumbnail = np.array(
# self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB")
# )

# fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# ax.imshow(thumbnail)
# tissue: np.ndarray
# hole: np.ndarray
# for i, tissue in enumerate(self.contours_tissue or [], 1):
# # resize to thumbnail size
# tissue = np.array(
# tissue.squeeze() / self.wsi.level_downsamples[level], dtype="int32"
# )
# poly = Polygon(tissue)
# ax.plot(*tissue.T)
# ax.text(
# *poly.centroid.coords[0],
# str(i),
# color="black",
# ha="center",
# va="center",
# fontsize=10,
# )
# for i, hole in enumerate(self.holes_tissue or [], 1):
# # resize to thumbnail size
# hole = np.array(
# hole.squeeze() / self.wsi.level_downsamples[level], dtype="int32"
# )
# poly = Polygon(hole)
# ax.plot(*hole.T, color="black", linestyle="-", linewidth=0.2)
# ax.axis("off")
# fig.savefig(output_file, bbox_inches="tight", dpi=200, pad_inches=0.0)
# plt.close(fig)
# return fig

def plot_segmentation(self, output_file: tp.Optional[Path] = None) -> None:
from shapely.geometry import Polygon

if output_file is None:
output_file = self.mask_file.with_suffix(".png")
output_file = self.path.with_suffix(".segmentation.png")

level = self.wsi.level_count - 1
thumbnail = np.array(
self.wsi.read_region((0, 0), level, self.level_dim[level]).convert("RGB")
)

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(thumbnail)
tissue: np.ndarray
hole: np.ndarray
for i, tissue in enumerate(self.contours_tissue or [], 1):
# resize to thumbnail size
tissue = np.array(
tissue.squeeze() / self.wsi.level_downsamples[level], dtype="int32"
)
poly = Polygon(tissue)
ax.plot(*tissue.T)
ax.text(
*poly.centroid.coords[0],
str(i),
color="black",
ha="center",
va="center",
fontsize=10,
)
for i, hole in enumerate(self.holes_tissue or [], 1):
# resize to thumbnail size
hole = np.array(
hole.squeeze() / self.wsi.level_downsamples[level], dtype="int32"
)
poly = Polygon(hole)
ax.plot(*hole.T, color="black", linestyle="-", linewidth=0.2)
ax.axis("off")
fig.savefig(output_file, bbox_inches="tight", dpi=200, pad_inches=0.0)
plt.close(fig)
return fig
self.visWSI(vis_level=level).save(output_file)

def tile(
self,
Expand Down
8 changes: 4 additions & 4 deletions wsi_core/dataset_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def summary(self):
for name, value in dset.attrs.items():
print(name, value)

print("\nfeature extraction settings")
print("target patch size: ", self.target_patch_size)
print("pretrained: ", self.pretrained)
print("transformations: ", self.roi_transforms)
# print("\nfeature extraction settings")
# print("target patch size: ", self.target_patch_size)
# print("pretrained: ", self.pretrained)
# print("transformations: ", self.roi_transforms)

def __getitem__(self, idx):
with h5py.File(self.file_path, "r") as hdf5_file:
Expand Down

0 comments on commit e9248ef

Please sign in to comment.