Skip to content

Commit

Permalink
Show class samples
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-lightly committed Nov 8, 2023
1 parent 0498841 commit 039b21c
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 103 deletions.
5 changes: 5 additions & 0 deletions src/lightly_insights/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import logging

logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
120 changes: 89 additions & 31 deletions src/lightly_insights/analyze.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,145 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Counter, Dict, Set, Tuple
from typing import Counter, Dict, List, Set, Tuple

import tqdm
from labelformat.model.object_detection import ObjectDetectionInput
from PIL import Image

logger = logging.getLogger(__name__)


IMAGE_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)


@dataclass(frozen=True)
class ImageAnalysis:
num_images: int
image_sizes: Counter[Tuple[int, int]]
filename_set: Set[str]
image_folder: Path
filename_set: Set[str]
image_sizes: Counter[Tuple[int, int]]
median_size: Tuple[int, int]


@dataclass
class ObjectAnalysis:
class ClassAnalysis:
class_id: int
class_name: str

num_objects: int
objects_per_image: Counter[int]
object_sizes_abs: Counter[Tuple[float, float]]
object_sizes_rel: Counter[Tuple[float, float]]
object_sizes_abs: List[Tuple[float, float]]
object_sizes_rel: List[Tuple[float, float]]

sample_filenames: List[str]

@classmethod
def create_empty(cls) -> "ObjectAnalysis":
def create_empty(cls, id: int, name: str) -> "ClassAnalysis":
return cls(
class_id=id,
class_name=name,
num_objects=0,
objects_per_image=Counter(),
object_sizes_abs=Counter(),
object_sizes_rel=Counter(),
object_sizes_abs=[],
object_sizes_rel=[],
sample_filenames=[],
)


@dataclass(frozen=True)
class ObjectDetectionAnalysis:
num_images: int
num_images_zero_objects: int
filename_set: Set[str]
total: ObjectAnalysis
classes: Dict[str, ObjectAnalysis]
total: ClassAnalysis
classes: Dict[int, ClassAnalysis]


def analyze_images(image_folder: Path) -> ImageAnalysis:
num_images = 0
image_sizes = Counter[Tuple[int, int]]()
filename_set = set()

# Param: Recursive?
# Param: Subsample?
# All image types please!
sorted_paths = sorted(image_folder.glob("*.jpg"))
image_sizes = Counter[Tuple[int, int]]()
image_widths = []
image_heights = []

# Currently we list non-recursively. We could add a flag to allow
# recursive listing in the future.
logger.info(f"Listing images in {image_folder}.")
sorted_paths = sorted(
path
for path in image_folder.glob("*.*")
if path.suffix.lower() in IMAGE_EXTENSIONS
)
logger.info(f"Found {len(sorted_paths)} images.")

for image_path in sorted_paths:
num_images += 1
filename_set.add(image_path.name)
with Image.open(image_path) as image:
image_sizes[image.size] += 1
image_widths.append(image.size[0])
image_heights.append(image.size[1])

median_size = (
sorted(image_widths)[num_images // 2] if num_images > 0 else 0,
sorted(image_heights)[num_images // 2] if num_images > 0 else 0,
)

return ImageAnalysis(
num_images=num_images,
image_sizes=image_sizes,
filename_set=filename_set,
image_folder=image_folder,
filename_set=filename_set,
image_sizes=image_sizes,
median_size=median_size,
)


def analyze_object_detections(
label_input: ObjectDetectionInput,
) -> ObjectDetectionAnalysis:
num_images = 0
num_images_zero_objects = 0
filename_set = set()
total_data = ObjectAnalysis.create_empty()
total_data = ClassAnalysis.create_empty(id=-1, name="[All classes]")
class_data = {
category.name: ObjectAnalysis.create_empty()
category.id: ClassAnalysis.create_empty(id=category.id, name=category.name)
for category in label_input.get_categories()
}

for label in label_input.get_labels():
# Iterate over labels and count objects.
for label in tqdm.tqdm(
label_input.get_labels(),
desc="Reading object detection labels",
unit="labels",
):
num_images += 1
if len(label.objects) == 0:
num_images_zero_objects += 1
filename_set.add(label.image.filename)

total_data.num_objects += len(label.objects)
total_data.objects_per_image[len(label.objects)] += 1

num_objects_per_category = Counter[str]()
num_objects_per_category = Counter[int]()

for obj in label.objects:
class_datum = class_data[obj.category.id]

# Number of objects.
class_data[obj.category.name].num_objects += 1
num_objects_per_category[obj.category.name] += 1
class_datum.num_objects += 1
num_objects_per_category[obj.category.id] += 1

# Object sizes.
obj_size_abs = (
Expand All @@ -96,18 +150,22 @@ def analyze_object_detections(
(obj.box.xmax - obj.box.xmin) / label.image.width,
(obj.box.ymax - obj.box.ymin) / label.image.height,
)
total_data.object_sizes_abs[obj_size_abs] += 1
total_data.object_sizes_rel[obj_size_rel] += 1
class_data[obj.category.name].object_sizes_abs[obj_size_abs] += 1
class_data[obj.category.name].object_sizes_rel[obj_size_rel] += 1
total_data.object_sizes_abs.append(obj_size_abs)
total_data.object_sizes_rel.append(obj_size_rel)
class_datum.object_sizes_abs.append(obj_size_abs)
class_datum.object_sizes_rel.append(obj_size_rel)

if len(class_datum.sample_filenames) < 4:
class_datum.sample_filenames.append(label.image.filename)

for category in label_input.get_categories():
class_data[category.name].objects_per_image[
num_objects_per_category[category.name]
] += num_objects_per_category[category.name]
class_data[category.id].objects_per_image[
num_objects_per_category[category.id]
] += num_objects_per_category[category.id]

return ObjectDetectionAnalysis(
num_images=num_images,
num_images_zero_objects=num_images_zero_objects,
filename_set=filename_set,
total=total_data,
classes=class_data,
Expand Down
Loading

0 comments on commit 039b21c

Please sign in to comment.