Skip to content

Commit

Permalink
nonunique error fix + code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
vorozhkog committed Sep 17, 2024
1 parent f50e17e commit 89240f2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 45 deletions.
44 changes: 14 additions & 30 deletions src/compute/layers/processing/SplitDataLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,61 +52,45 @@ def modifies_data(self):
return False

def process(self, data_el: Tuple[ImageDescriptor, Annotation]):
def _split_by_percent() -> List[Tuple[ImageDescriptor, Annotation]]:
def replace_ds_name(item_desc, new_ds_name):
new_item_desc = deepcopy(item_desc)
new_item_desc.res_ds_name = new_ds_name
return new_item_desc

def _split_by_percent() -> List[Tuple[ImageDescriptor, Annotation]]:
total_items_cnt = self.net.total_elements_cnt
split_ratio = self.settings["split_ratio"]
split_num = total_items_cnt * split_ratio / 100
split_index = int(item_idx / split_num) + (item_idx % split_num > 0)
dataset = f"split_{split_index}"
new_item_desc.res_ds_name = dataset
return [(new_item_desc, ann)]
return [(replace_ds_name(item_desc, dataset), ann)]

def _split_by_num() -> List[Tuple[ImageDescriptor, Annotation]]:
new_item_desc = deepcopy(item_desc)
split_num = self.settings["split_num"]
split_index = int(item_idx / split_num) + (item_idx % split_num > 0)
dataset = f"split_{split_index}"
new_item_desc.res_ds_name = dataset
return [(new_item_desc, ann)]
return [(replace_ds_name(item_desc, dataset), ann)]

def _split_by_class() -> List[Tuple[ImageDescriptor, Annotation]]:
image_labels = ann.labels
if len(image_labels) > 0:
classes = list({label.obj_class.name for label in image_labels})
items = []
for class_name in classes:
curr_item_desc = deepcopy(item_desc)
curr_item_desc.res_ds_name = class_name
items.append((curr_item_desc, ann))
return items
return [(replace_ds_name(item_desc, class_name), ann) for class_name in classes]
else:
new_item_desc = deepcopy(item_desc)
new_item_desc.res_ds_name = "unlabeled"
return [(new_item_desc, ann)]
return [(replace_ds_name(item_desc, "unlabeled"), ann)]

def _split_by_tags() -> List[Tuple[ImageDescriptor, Annotation]]:
image_tags = list(set(ann.img_tags.keys()))
label_tags = list(set([tag for label in ann.labels for tag in label.tags.keys()]))
if len(image_tags) == 0 and len(label_tags) == 0:
new_item_desc = deepcopy(item_desc)
new_item_desc.res_ds_name = "no tags"
return [(new_item_desc, ann)]
return [(replace_ds_name(item_desc, "no tags"), ann)]

tag_names = set()
items = []
for img_tag_name in image_tags:
if img_tag_name not in tag_names:
tag_names.add(img_tag_name)
new_img_item_desc = deepcopy(item_desc)
new_img_item_desc.res_ds_name = img_tag_name
items.append((new_img_item_desc, ann))
for tag_name in label_tags:
if tag_name not in tag_names:
tag_names.add(tag_name)
new_label_item_desc = deepcopy(item_desc)
new_label_item_desc.res_ds_name = tag_name
items.append((new_label_item_desc, ann))
for tag in image_tags + label_tags:
if tag not in tag_names:
tag_names.add(tag)
items.append((replace_ds_name(item_desc, tag), ann))
return items

if self.net.preview_mode:
Expand Down
34 changes: 19 additions & 15 deletions src/compute/layers/save/CreateNewProjectLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,19 @@ def process_batch(self, data_els: List[Tuple[ImageDescriptor, Annotation]]):
ds_item_map[dataset_name].append((item_desc, ann))

for ds_name in ds_item_map:
out_item_names = [
self.get_free_name(
item_desc.get_item_name(), dataset_name, self.out_project_name
)
+ get_file_ext(item_desc.info.item_info.name)
for item_desc, _ in ds_item_map[ds_name]
]
if self.sly_project_info is not None:
# @TODO: not safe, fix later
orig_ds_info = ds_item_map[ds_name][0][0].info.ds_info
ds_parents = self.get_ds_parents(orig_ds_info)
dataset_info = self.get_or_create_dataset(ds_name, ds_parents)

out_item_names = [
self.get_free_name(
item_desc.get_item_name(), ds_name, self.out_project_name
)
+ get_file_ext(item_desc.info.item_info.name)
for item_desc, _ in ds_item_map[ds_name]
]
if self.net.modality == "images":
if self.net.may_require_items():
item_infos = g.api.image.upload_nps(
Expand All @@ -151,14 +152,17 @@ def process_batch(self, data_els: List[Tuple[ImageDescriptor, Annotation]]):
[item_desc.read_image() for item_desc, _ in ds_item_map[ds_name]],
)
else:
item_infos = g.api.image.upload_ids(
dataset_info.id,
out_item_names,
[
item_desc.info.item_info.id
for item_desc, _ in ds_item_map[ds_name]
],
)
try:
item_infos = g.api.image.upload_ids(
dataset_info.id,
out_item_names,
[
item_desc.info.item_info.id
for item_desc, _ in ds_item_map[ds_name]
],
)
except:
pass
g.api.annotation.upload_anns(
[item_info.id for item_info in item_infos],
[ann for _, ann in ds_item_map[ds_name]],
Expand Down

0 comments on commit 89240f2

Please sign in to comment.