diff --git a/src/compute/layers/processing/SplitDataLayer.py b/src/compute/layers/processing/SplitDataLayer.py index d91959eb..f4e51ca4 100644 --- a/src/compute/layers/processing/SplitDataLayer.py +++ b/src/compute/layers/processing/SplitDataLayer.py @@ -52,61 +52,45 @@ def modifies_data(self): return False def process(self, data_el: Tuple[ImageDescriptor, Annotation]): - def _split_by_percent() -> List[Tuple[ImageDescriptor, Annotation]]: + def replace_ds_name(item_desc, new_ds_name): new_item_desc = deepcopy(item_desc) + new_item_desc.res_ds_name = new_ds_name + return new_item_desc + + def _split_by_percent() -> List[Tuple[ImageDescriptor, Annotation]]: total_items_cnt = self.net.total_elements_cnt split_ratio = self.settings["split_ratio"] split_num = total_items_cnt * split_ratio / 100 split_index = int(item_idx / split_num) + (item_idx % split_num > 0) dataset = f"split_{split_index}" - new_item_desc.res_ds_name = dataset - return [(new_item_desc, ann)] + return [(replace_ds_name(item_desc, dataset), ann)] def _split_by_num() -> List[Tuple[ImageDescriptor, Annotation]]: - new_item_desc = deepcopy(item_desc) split_num = self.settings["split_num"] split_index = int(item_idx / split_num) + (item_idx % split_num > 0) dataset = f"split_{split_index}" - new_item_desc.res_ds_name = dataset - return [(new_item_desc, ann)] + return [(replace_ds_name(item_desc, dataset), ann)] def _split_by_class() -> List[Tuple[ImageDescriptor, Annotation]]: image_labels = ann.labels if len(image_labels) > 0: classes = list({label.obj_class.name for label in image_labels}) - items = [] - for class_name in classes: - curr_item_desc = deepcopy(item_desc) - curr_item_desc.res_ds_name = class_name - items.append((curr_item_desc, ann)) - return items + return [(replace_ds_name(item_desc, class_name), ann) for class_name in classes] else: - new_item_desc = deepcopy(item_desc) - new_item_desc.res_ds_name = "unlabeled" - return [(new_item_desc, ann)] + return [(replace_ds_name(item_desc, "unlabeled"), ann)] def _split_by_tags() -> List[Tuple[ImageDescriptor, Annotation]]: image_tags = list(set(ann.img_tags.keys())) label_tags = list(set([tag for label in ann.labels for tag in label.tags.keys()])) if len(image_tags) == 0 and len(label_tags) == 0: - new_item_desc = deepcopy(item_desc) - new_item_desc.res_ds_name = "no tags" - return [(new_item_desc, ann)] + return [(replace_ds_name(item_desc, "no tags"), ann)] tag_names = set() items = [] - for img_tag_name in image_tags: - if img_tag_name not in tag_names: - tag_names.add(img_tag_name) - new_img_item_desc = deepcopy(item_desc) - new_img_item_desc.res_ds_name = img_tag_name - items.append((new_img_item_desc, ann)) - for tag_name in label_tags: - if tag_name not in tag_names: - tag_names.add(tag_name) - new_label_item_desc = deepcopy(item_desc) - new_label_item_desc.res_ds_name = tag_name - items.append((new_label_item_desc, ann)) + for tag in image_tags + label_tags: + if tag not in tag_names: + tag_names.add(tag) + items.append((replace_ds_name(item_desc, tag), ann)) return items if self.net.preview_mode: diff --git a/src/compute/layers/save/CreateNewProjectLayer.py b/src/compute/layers/save/CreateNewProjectLayer.py index f0ee3855..b14eb7d9 100644 --- a/src/compute/layers/save/CreateNewProjectLayer.py +++ b/src/compute/layers/save/CreateNewProjectLayer.py @@ -131,18 +131,19 @@ def process_batch(self, data_els: List[Tuple[ImageDescriptor, Annotation]]): ds_item_map[dataset_name].append((item_desc, ann)) for ds_name in ds_item_map: - out_item_names = [ - self.get_free_name( - item_desc.get_item_name(), dataset_name, self.out_project_name - ) - + get_file_ext(item_desc.info.item_info.name) - for item_desc, _ in ds_item_map[ds_name] - ] if self.sly_project_info is not None: # @TODO: not safe, fix later orig_ds_info = ds_item_map[ds_name][0][0].info.ds_info ds_parents = self.get_ds_parents(orig_ds_info) dataset_info = self.get_or_create_dataset(ds_name, ds_parents) + + out_item_names = [ + self.get_free_name( + item_desc.get_item_name(), ds_name, self.out_project_name + ) + + get_file_ext(item_desc.info.item_info.name) + for item_desc, _ in ds_item_map[ds_name] + ] if self.net.modality == "images": if self.net.may_require_items(): item_infos = g.api.image.upload_nps( @@ -151,14 +152,17 @@ def process_batch(self, data_els: List[Tuple[ImageDescriptor, Annotation]]): [item_desc.read_image() for item_desc, _ in ds_item_map[ds_name]], ) else: - item_infos = g.api.image.upload_ids( - dataset_info.id, - out_item_names, - [ - item_desc.info.item_info.id - for item_desc, _ in ds_item_map[ds_name] - ], - ) + try: + item_infos = g.api.image.upload_ids( + dataset_info.id, + out_item_names, + [ + item_desc.info.item_info.id + for item_desc, _ in ds_item_map[ds_name] + ], + ) + except: + pass g.api.annotation.upload_anns( [item_info.id for item_info in item_infos], [ann for _, ann in ds_item_map[ds_name]],