From 9d52edfc7089ccf7f26f32d2ec63d1ab8d4338ca Mon Sep 17 00:00:00 2001 From: oneformer3d Date: Thu, 21 Mar 2024 18:24:45 +0000 Subject: [PATCH] initial commit --- .gitignore | 7 + Dockerfile | 89 + LICENSE | 159 ++ README.md | 124 ++ ...eformer3d_1xb2_scannet-and-structured3d.py | 292 ++++ configs/oneformer3d_1xb2_s3dis-area-5.py | 229 +++ configs/oneformer3d_1xb4_scannet.py | 234 +++ configs/oneformer3d_1xb4_scannet200.py | 302 ++++ data/scannet/README.md | 67 + data/scannet/batch_load_scannet_data.py | 187 ++ data/scannet/load_scannet_data.py | 205 +++ data/scannet/meta_data/scannet_means.npz | Bin 0 -> 676 bytes data/scannet/meta_data/scannet_train.txt | 1513 +++++++++++++++++ .../meta_data/scannetv2-labels.combined.tsv | 608 +++++++ data/scannet/meta_data/scannetv2_test.txt | 100 ++ data/scannet/meta_data/scannetv2_train.txt | 1201 +++++++++++++ data/scannet/meta_data/scannetv2_val.txt | 312 ++++ data/scannet/scannet_utils.py | 87 + data/structured3d/README.md | 69 + data/structured3d/data_prepare.py | 73 + data/structured3d/structured3d_data_utils.py | 265 +++ data/structured3d/unzip.py | 57 + data/structured3d/utils.py | 248 +++ oneformer3d/__init__.py | 23 + oneformer3d/data_preprocessor.py | 78 + oneformer3d/evaluate_semantic_instance.py | 368 ++++ oneformer3d/formatting.py | 142 ++ oneformer3d/instance_criterion.py | 724 ++++++++ oneformer3d/instance_seg_eval.py | 131 ++ oneformer3d/instance_seg_metric.py | 106 ++ oneformer3d/loading.py | 106 ++ oneformer3d/mask_matrix_nms.py | 122 ++ oneformer3d/mink_unet.py | 597 +++++++ oneformer3d/oneformer3d.py | 1346 +++++++++++++++ oneformer3d/query_decoder.py | 718 ++++++++ oneformer3d/s3dis_dataset.py | 19 + oneformer3d/scannet_dataset.py | 102 ++ oneformer3d/semantic_criterion.py | 116 ++ oneformer3d/spconv_unet.py | 236 +++ oneformer3d/structured3d_dataset.py | 88 + oneformer3d/structures.py | 25 + oneformer3d/transforms_3d.py | 408 +++++ oneformer3d/unified_criterion.py | 161 ++ oneformer3d/unified_metric.py | 255 +++ tools/create_data.py | 57 + tools/fix_spconv_checkpoint.py | 18 + tools/indoor_converter.py | 67 + tools/scannet_data_utils.py | 281 +++ tools/test.py | 149 ++ tools/train.py | 135 ++ tools/update_infos_to_v2.py | 417 +++++ 51 files changed, 13423 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 LICENSE create mode 100644 README.md create mode 100644 configs/instance-only-oneformer3d_1xb2_scannet-and-structured3d.py create mode 100644 configs/oneformer3d_1xb2_s3dis-area-5.py create mode 100644 configs/oneformer3d_1xb4_scannet.py create mode 100644 configs/oneformer3d_1xb4_scannet200.py create mode 100644 data/scannet/README.md create mode 100644 data/scannet/batch_load_scannet_data.py create mode 100644 data/scannet/load_scannet_data.py create mode 100644 data/scannet/meta_data/scannet_means.npz create mode 100644 data/scannet/meta_data/scannet_train.txt create mode 100644 data/scannet/meta_data/scannetv2-labels.combined.tsv create mode 100644 data/scannet/meta_data/scannetv2_test.txt create mode 100644 data/scannet/meta_data/scannetv2_train.txt create mode 100644 data/scannet/meta_data/scannetv2_val.txt create mode 100644 data/scannet/scannet_utils.py create mode 100644 data/structured3d/README.md create mode 100644 data/structured3d/data_prepare.py create mode 100644 data/structured3d/structured3d_data_utils.py create mode 100644 data/structured3d/unzip.py create mode 100644 data/structured3d/utils.py create mode 100644 oneformer3d/__init__.py create mode 100644 oneformer3d/data_preprocessor.py create mode 100644 oneformer3d/evaluate_semantic_instance.py create mode 100644 oneformer3d/formatting.py create mode 100644 oneformer3d/instance_criterion.py create mode 100644 oneformer3d/instance_seg_eval.py create mode 100644 oneformer3d/instance_seg_metric.py create mode 100644 oneformer3d/loading.py create mode 100644 oneformer3d/mask_matrix_nms.py create mode 100644 oneformer3d/mink_unet.py create mode 100644 oneformer3d/oneformer3d.py create mode 100644 oneformer3d/query_decoder.py create mode 100644 oneformer3d/s3dis_dataset.py create mode 100644 oneformer3d/scannet_dataset.py create mode 100644 oneformer3d/semantic_criterion.py create mode 100644 oneformer3d/spconv_unet.py create mode 100644 oneformer3d/structured3d_dataset.py create mode 100644 oneformer3d/structures.py create mode 100644 oneformer3d/transforms_3d.py create mode 100644 oneformer3d/unified_criterion.py create mode 100644 oneformer3d/unified_metric.py create mode 100644 tools/create_data.py create mode 100644 tools/fix_spconv_checkpoint.py create mode 100644 tools/indoor_converter.py create mode 100644 tools/scannet_data_utils.py create mode 100644 tools/test.py create mode 100644 tools/train.py create mode 100644 tools/update_infos_to_v2.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3e006a2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +data +work_dirs +.vscode +__pycache__/ +*.py[cod] +*$py.class +*.ipynb \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..5659b03 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,89 @@ +FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel + +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \ + && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub \ + && apt-get update \ + && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 + +# Install OpenMMLab projects +RUN pip install --no-deps \ + mmengine==0.7.3 \ + mmdet==3.0.0 \ + mmsegmentation==1.0.0 \ + git+https://github.com/open-mmlab/mmdetection3d.git@22aaa47fdb53ce1870ff92cb7e3f96ae38d17f61 +RUN pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html --no-deps + +# Install MinkowskiEngine +# Feel free to skip nvidia-cuda-dev if minkowski installation is fine +RUN apt-get update \ + && apt-get -y install libopenblas-dev nvidia-cuda-dev +RUN TORCH_CUDA_ARCH_LIST="6.1 7.0 8.6" \ + pip install git+https://github.com/NVIDIA/MinkowskiEngine.git@02fc608bea4c0549b0a7b00ca1bf15dee4a0b228 -v --no-deps \ + --install-option="--blas=openblas" \ + --install-option="--force_cuda" + +# Install torch-scatter +RUN pip install torch-scatter==2.1.2 -f https://data.pyg.org/whl/torch-1.13.0+cu116.html --no-deps + +# Install ScanNet superpoint segmentator +RUN git clone https://github.com/Karbo123/segmentator.git \ + && cd segmentator/csrc \ + && git reset --hard 76efe46d03dd27afa78df972b17d07f2c6cfb696 \ + && mkdir build \ + && cd build \ + && cmake .. \ + -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` \ + -DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \ + -DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR'))") \ + -DCMAKE_INSTALL_PREFIX=`python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())'` \ + && make \ + && make install \ + && cd ../../.. + +# Install remaining python packages +RUN pip install --no-deps \ + spconv-cu116==2.3.6 \ + addict==2.4.0 \ + yapf==0.33.0 \ + termcolor==2.3.0 \ + packaging==23.1 \ + numpy==1.24.1 \ + rich==13.3.5 \ + opencv-python==4.7.0.72 \ + pycocotools==2.0.6 \ + Shapely==1.8.5 \ + scipy==1.10.1 \ + terminaltables==3.1.10 \ + numba==0.57.0 \ + llvmlite==0.40.0 \ + pccm==0.4.7 \ + ccimport==0.4.2 \ + pybind11==2.10.4 \ + ninja==1.11.1 \ + lark==1.1.5 \ + cumm-cu116==0.4.9 \ + pyquaternion==0.9.9 \ + lyft-dataset-sdk==0.0.8 \ + pandas==2.0.1 \ + python-dateutil==2.8.2 \ + matplotlib==3.5.2 \ + pyparsing==3.0.9 \ + cycler==0.11.0 \ + kiwisolver==1.4.4 \ + scikit-learn==1.2.2 \ + joblib==1.2.0 \ + threadpoolctl==3.1.0 \ + cachetools==5.3.0 \ + nuscenes-devkit==1.1.10 \ + trimesh==3.21.6 \ + open3d==0.17.0 \ + plotly==5.18.0 \ + dash==2.14.2 \ + plyfile==1.0.2 \ + flask==3.0.0 \ + werkzeug==3.0.1 \ + click==8.1.7 \ + blinker==1.7.0 \ + itsdangerous==2.1.2 \ + importlib_metadata==2.1.2 \ + zipp==3.17.0 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e4cf43e --- /dev/null +++ b/LICENSE @@ -0,0 +1,159 @@ +# Attribution-NonCommercial 4.0 International + +> *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.* +> +> ### Using Creative Commons Public Licenses +> +> Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. +> +> * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). +> +> * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). + +## Creative Commons Attribution-NonCommercial 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. + +### Section 1 – Definitions. + +a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. + +b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. + +c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. + +d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. + +e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. + +f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. + +g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. + +h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. + +i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. + +j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. + +k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. + +l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. + +### Section 2 – Scope. + +a. ___License grant.___ + + 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: + + A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and + + B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. + + 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. + + 3. __Term.__ The term of this Public License is specified in Section 6(a). + + 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. + + 5. __Downstream recipients.__ + + A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. + + B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. + + 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). + +b. ___Other rights.___ + + 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this Public License. + + 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. + +### Section 3 – License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the following conditions. + +a. ___Attribution.___ + + 1. If You Share the Licensed Material (including in modified form), You must: + + A. retain the following if it is supplied by the Licensor with the Licensed Material: + + i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of warranties; + + v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; + + B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and + + C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. + + 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. + +### Section 4 – Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: + +a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; + +b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and + +c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. + +### Section 5 – Disclaimer of Warranties and Limitation of Liability. + +a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ + +b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ + +c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. + +### Section 6 – Term and Termination. + +a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. + +b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. + +c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. + +d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. + +### Section 7 – Other Terms and Conditions. + +a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. + +b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. + +### Section 8 – Interpretation. + +a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. + +b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. + +c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. + +d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. + +> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. +> +> Creative Commons may be contacted at creativecommons.org \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..65846aa --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +## OneFormer3D: One Transformer for Unified Point Cloud Segmentation + +**News**: + * :fire: February, 2024. Oneformer3D is now accepted at CVPR 2024. + * :fire: November, 2023. OneFormer3D achieves state-of-the-art in + * 3D instance segmentation on ScanNet ([hidden test](https://kaldir.vc.in.tum.de/scannet_benchmark/semantic_instance_3d)) + [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/oneformer3d-one-transformer-for-unified-point/3d-instance-segmentation-on-scannetv2)](https://paperswithcode.com/sota/3d-instance-segmentation-on-scannetv2?p=oneformer3d-one-transformer-for-unified-point) +
+ leaderboard screenshot + ScanNet leaderboard +
+ * 3D instance segmentation on S3DIS (6-Fold) + [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/oneformer3d-one-transformer-for-unified-point/3d-instance-segmentation-on-s3dis)](https://paperswithcode.com/sota/3d-instance-segmentation-on-s3dis?p=oneformer3d-one-transformer-for-unified-point) + * 3D panoptic segmentation on ScanNet + [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/oneformer3d-one-transformer-for-unified-point/panoptic-segmentation-on-scannet)](https://paperswithcode.com/sota/panoptic-segmentation-on-scannet?p=oneformer3d-one-transformer-for-unified-point) + * 3D object detection on ScanNet (w/o TTA) + [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/oneformer3d-one-transformer-for-unified-point/3d-object-detection-on-scannetv2)](https://paperswithcode.com/sota/3d-object-detection-on-scannetv2?p=oneformer3d-one-transformer-for-unified-point) + * 3D semantic segmentation on ScanNet (val, w/o extra training data) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/oneformer3d-one-transformer-for-unified-point/semantic-segmentation-on-scannet)](https://paperswithcode.com/sota/semantic-segmentation-on-scannet?p=oneformer3d-one-transformer-for-unified-point) + +This repository contains an implementation of OneFormer3D, a 3D (instance, semantic, and panoptic) segmentation method introduced in our paper: + +> **OneFormer3D: One Transformer for Unified Point Cloud Segmentation**
+> [Maksim Kolodiazhnyi](https://github.com/col14m), +> [Anna Vorontsova](https://github.com/highrut), +> [Anton Konushin](https://scholar.google.com/citations?user=ZT_k-wMAAAAJ), +> [Danila Rukhovich](https://github.com/filaPro) +>
+> Samsung Research
+> https://arxiv.org/abs/2311.14405 + +### Installation + +For convenience, we provide a [Dockerfile](Dockerfile). +This implementation is based on [mmdetection3d](https://github.com/open-mmlab/mmdetection3d) framework `v1.1.0`. If installing without docker please follow their [getting_started.md](https://github.com/open-mmlab/mmdetection3d/blob/22aaa47fdb53ce1870ff92cb7e3f96ae38d17f61/docs/en/get_started.md). + + +### Getting Started + +Please see [test_train.md](https://github.com/open-mmlab/mmdetection3d/blob/22aaa47fdb53ce1870ff92cb7e3f96ae38d17f61/docs/en/user_guides/train_test.md) for basic usage examples. +For ScanNet and ScanNet200 datasets preprocessing please follow our [instruction](data/scannet). It differs from original mmdetection3d only by adding superpoint clustering. For S3DIS preprocessing we follow original [instruction](https://github.com/open-mmlab/mmdetection3d/tree/22aaa47fdb53ce1870ff92cb7e3f96ae38d17f61/data/s3dis) from mmdetection3d. We also [support](data/structured3d) Structured3D dataset for pre-training. + +Important notes: + * The metrics from our paper can be achieved in several ways, we just choose the most stable one for each dataset in this repository. + * If you are interested in only one of three segmentation tasks, it is possible to achieve slightly better metrics, than declared in our paper. Specifically, increasing `model.criterion.sem_criterion.loss_weight` in config file leads to better semantic metrics, and decreasing improve instance metrics. + * All models can be trained with a single GPU with 32 Gb memory (or even 24 Gb for ScanNet dataset). If you face issues with RAM during instance segmentation evaluation at validation or test stages feel free to decrease `model.test_cfg.topk_insts` in config file. + * Due to the bug in SpConv we [reshape](tools/fix_spconv_checkpoint.py) backbone weights between train and test stages. + +#### ScanNet + +For ScanNet we present the model with [SpConv](https://github.com/traveller59/spconv) backbone, superpoint pooling, selecting all queries, and predicting semantics directly from instance queries. Backbone is initialized from [SSTNet](https://github.com/Gorilla-Lab-SCUT/SSTNet) checkpoint. It should be [downloaded](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/sstnet_scannet.pth) and put to `work_dirs/tmp` before training. + +```shell +# train (with validation) +python tools/train.py configs/oneformer3d_1xb4_scannet.py +# test +python tools/fix_spconv_checkpoint.py \ + --in-path work_dirs/oneformer3d_1xb4_scannet/epoch_512.pth \ + --out-path work_dirs/oneformer3d_1xb4_scannet/epoch_512.pth +python tools/test.py configs/oneformer3d_1xb4_scannet.py \ + work_dirs/oneformer3d_1xb4_scannet/epoch_512.pth + +``` + +#### ScanNet200 + +For ScanNet200 we present the model with [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) backbone, superpoint pooling, selecting all queries, and predicting semantics directly from instance queries. Backbone is initialized from [Mask3D](https://github.com/JonasSchult/Mask3D) checkpoint. It should be [downloaded](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/mask3d_scannet200.pth) and put to `work_dirs/tmp` before training. + +```shell +# train (with validation) +python tools/train.py configs/oneformer3d_1xb4_scannet200.py +# test +python tools/test.py configs/oneformer3d_1xb4_scannet200.py \ + work_dirs/oneformer3d_1xb4_scannet/epoch_512.pth +``` + +#### S3DIS + +For S3DIS we present the model with [SpConv](https://github.com/traveller59/spconv) backbone, w/o superpoint pooling, w/o query selection, and with separate semantic queries. Backbone is pretrained on Structured3D and ScanNet. It can be [downloaded](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/instance-only-oneformer3d_1xb2_scannet-and-structured3d.pth) and put to `work_dirs/tmp` before training or trained with our code. We train the model on Areas 1, 2, 3, 4, 6 and test on Area 5. To change this split feel free to modify `train_area` and `test_area` parameters in config. + +```shell +# pre-train +python tools/train.py configs/instance-only-oneformer3d_1xb2_scannet-and-structured3d.py +python tools/fix_spconv_checkpoint.py \ + --in-path work_dirs/instance-only-oneformer3d_1xb2_scannet-and-structured3d/iter_600000.pth \ + --out-path work_dirs/tmp/instance-only-oneformer3d_1xb2_scannet-and-structured3d.pth +# train (with validation) +python tools/train.py configs/oneformer3d_1xb2_s3dis-area-5.py +# test +python tools/fix_spconv_checkpoint.py \ + --in-path work_dirs/oneformer3d_1xb2_s3dis-area-5/epoch_512.pth \ + --out-path work_dirs/oneformer3d_1xb2_s3dis-area-5/epoch_512.pth +python tools/test.py configs/oneformer3d_1xb2_s3dis-area-5.py \ + work_dirs/oneformer3d_1xb2_s3dis-area-5/epoch_512.pth +``` + +### Models + +Metric values in the table are given for the provided checkpoints and may vary a little from the ones in our paper. Due to randomness it may be needed to run training with the same config for several times to achieve the best metrics. + +| Dataset | mAP25 | mAP50 | mAP | mIoU | PQ | Download | +|:-------:|:----------------:|:----------------:|:---:|:----:|:--:|:--------:| +| ScanNet | 86.7 | 78.8 | 59.3 | 76.4 | 70.7 | [model](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb4_scannet.pth) | [log](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb4_scannet.log) | [config](configs/oneformer3d_1xb4_scannet.py) | +| ScanNet200 | 44.6 | 40.9 | 30.2 | 29.4 | 29.7 | [model](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb4_scannet200.pth) | [log](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb4_scannet200.log) | [config](configs/oneformer3d_1xb4_scannet200.py) | +| S3DIS | 80.6 | 72.7 | 58.0 | 71.9 | 64.6 | [model](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb2_s3dis-area-5.pth) | [log](https://github.com/oneformer3d/oneformer3d/releases/download/v1.0/oneformer3d_1xb2_s3dis-area-5.log) | [config](configs/oneformer3d_1xb2_s3dis-area-5.py) | + +### Example Predictions + +

+ ScanNet predictions +

+ +### Citation + +If you find this work useful for your research, please cite our paper: + +``` +@misc{kolodiazhnyi2023oneformer3d, + url = {https://arxiv.org/abs/2311.14405}, + author = {Kolodiazhnyi, Maxim and Vorontsova, Anna and Konushin, Anton and Rukhovich, Danila}, + title = {OneFormer3D: One Transformer for Unified Point Cloud Segmentation}, + publisher = {arXiv}, + year = {2023} +} +``` diff --git a/configs/instance-only-oneformer3d_1xb2_scannet-and-structured3d.py b/configs/instance-only-oneformer3d_1xb2_scannet-and-structured3d.py new file mode 100644 index 0000000..b8d22d3 --- /dev/null +++ b/configs/instance-only-oneformer3d_1xb2_scannet-and-structured3d.py @@ -0,0 +1,292 @@ +_base_ = ['mmdet3d::_base_/default_runtime.py'] + +custom_imports = dict(imports=['oneformer3d']) + +# model settings +num_classes_structured3d = 28 +num_classes_scannet = 18 +voxel_size = 0.05 +blocks = 5 +num_channels = 64 + +model = dict( + type='InstanceOnlyOneFormer3D', + data_preprocessor=dict(type='Det3DDataPreprocessor'), + in_channels=6, + num_channels=num_channels, + num_classes_1dataset=num_classes_structured3d, + num_classes_2dataset=num_classes_scannet, + prefix_1dataset='structured3d', + prefix_2dataset ='scannet', + voxel_size=voxel_size, + min_spatial_shape=128, + backbone=dict( + type='SpConvUNet', + num_planes=[num_channels * (i + 1) for i in range(blocks)], + return_blocks=True), + decoder=dict( + type='OneDataQueryDecoder', + num_layers=3, + num_queries_1dataset=400, + num_queries_2dataset=400, + num_classes_1dataset=num_classes_structured3d, + num_classes_2dataset=num_classes_scannet, + prefix_1dataset='structured3d', + prefix_2dataset ='scannet', + in_channels=num_channels, + d_model=256, + num_heads=8, + hidden_dim=1024, + dropout=0.0, + activation_fn='gelu', + iter_pred=True, + attn_mask=True, + fix_attention=True), + criterion=dict( + type='OneDataCriterion', + matcher=dict( + type='HungarianMatcher', + costs=[ + dict(type='QueryClassificationCost', weight=0.5), + dict(type='MaskBCECost', weight=1.0), + dict(type='MaskDiceCost', weight=1.0)]), + loss_weight=[0.5, 1.0, 1.0, 0.5], + non_object_weight=0.05, + num_classes_1dataset=num_classes_structured3d, + num_classes_2dataset=num_classes_scannet, + fix_dice_loss_weight=True, + iter_matcher=True), + train_cfg=dict(), + test_cfg=dict( + topk_insts=400, + score_thr=0.0, + npoint_thr=100, + obj_normalization=True, + obj_normalization_thr=0.01, + sp_score_thr=0.15, + nms=True, + matrix_nms_kernel='linear')) + +# structured3d dataset settings +data_prefix = dict( + pts='points', + pts_instance_mask='instance_mask', + pts_semantic_mask='semantic_mask') +dataset_type_structured3d = 'Structured3DSegDataset' +data_root_structured3d = 'data/structured3d/bins' + +class_names_structured3d = ( + 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', + 'window', 'picture', 'counter', 'desk', 'shelves', 'curtain', 'dresser', + 'pillow', 'mirror', 'ceiling', 'fridge', 'television', 'night stand', + 'toilet', 'sink', 'lamp', 'bathtub', 'structure', 'furniture', 'prop') +metainfo_structured3d = dict( + classes=class_names_structured3d, + ignore_index=num_classes_structured3d) + +train_pipeline_structured3d = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True), + dict( + type='PointSample_', + num_points=200000), + dict(type='PointSegClassMapping'), + dict(type='PointInstClassMapping_', + num_classes=num_classes_structured3d, + structured3d=True), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.14, 0.14], + scale_ratio_range=[0.9, 1.1], + translation_std=[0.1, 0.1, 0.1], + shift_height=False), + dict(type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5]), + dict(type='SkipEmptyScene'), + dict( + type='ElasticTransfrom', + gran=[6, 20], + mag=[40, 160], + voxel_size=voxel_size, + p=-1), + dict( + type='Pack3DDetInputs_', + keys=[ + 'points', 'pts_semantic_mask', 'pts_instance_mask', + 'elastic_coords', 'gt_labels_3d' + ]) +] + +# scannet dataset settings +dataset_type_scannet = 'ScanNetDataset' +data_root_scannet = 'data/scannet/' + +class_names_scannet = ( + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', + 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain', + 'toilet', 'sink', 'bathtub', 'garbagebin') +metainfo_scannet = dict( + classes=class_names_scannet, + ignore_index=num_classes_scannet) + +train_pipeline_scannet = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=True, + with_mask_3d=True, + with_seg_3d=True), + dict(type='PointSegClassMapping'), + dict(type='PointInstClassMapping_', + num_classes=num_classes_scannet), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-3.14, 3.14], + scale_ratio_range=[0.8, 1.2], + translation_std=[0.1, 0.1, 0.1], + shift_height=False), + dict(type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5]), + dict( + type='ElasticTransfrom', + gran=[6, 20], + mag=[40, 160], + voxel_size=voxel_size), + dict( + type='Pack3DDetInputs_', + keys=[ + 'points', 'gt_labels_3d', 'pts_semantic_mask', + 'pts_instance_mask', 'elastic_coords' + ]) +] +test_pipeline_scannet = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='NormalizePointsColor', + color_mean=[127.5, 127.5, 127.5]), + ]), + dict(type='Pack3DDetInputs_', keys=['points']) +] + +train_dataloader = dict( + batch_size=2, + num_workers=6, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type='ConcatDataset_', + datasets=[ + dict( + type=dataset_type_structured3d, + data_root=data_root_structured3d, + ann_file='structured3d_infos_train.pkl', + metainfo=metainfo_structured3d, + data_prefix=data_prefix, + pipeline=train_pipeline_structured3d, + ignore_index=num_classes_structured3d, + scene_idxs=None, + test_mode=False), + dict( + type='RepeatDataset', + times=10, + dataset=dict( + type=dataset_type_scannet, + data_root=data_root_scannet, + ann_file='scannet_oneformer3d_infos_train.pkl', + data_prefix=data_prefix, + metainfo=metainfo_scannet, + pipeline=train_pipeline_scannet, + test_mode=False)), + dict( + type='RepeatDataset', + times=10, + dataset=dict( + type=dataset_type_scannet, + data_root=data_root_scannet, + ann_file='scannet_oneformer3d_infos_val.pkl', + data_prefix=data_prefix, + metainfo=metainfo_scannet, + pipeline=train_pipeline_scannet, + test_mode=False))])) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type_scannet, + data_root=data_root_scannet, + ann_file='scannet_oneformer3d_infos_val.pkl', + metainfo=metainfo_scannet, + data_prefix=data_prefix, + pipeline=test_pipeline_scannet, + test_mode=True)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='InstanceSegMetric_') +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.05), + clip_grad=dict(max_norm=10, norm_type=2)) +param_scheduler = dict(type='PolyLR', begin=0, end=600000, + power=0.9, by_epoch=False) +log_processor = dict(by_epoch=False) + +custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] +default_hooks = dict(checkpoint=dict(by_epoch=False, interval=25000)) + +train_cfg = dict( + type='IterBasedTrainLoop', # Use iter-based training loop + max_iters=600000, # Maximum iterations + val_interval=25000) # Validation interval +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/configs/oneformer3d_1xb2_s3dis-area-5.py b/configs/oneformer3d_1xb2_s3dis-area-5.py new file mode 100644 index 0000000..7cef879 --- /dev/null +++ b/configs/oneformer3d_1xb2_s3dis-area-5.py @@ -0,0 +1,229 @@ +_base_ = [ + 'mmdet3d::_base_/default_runtime.py', +] +custom_imports = dict(imports=['oneformer3d']) + +# model settings +num_channels = 64 +num_instance_classes = 13 +num_semantic_classes = 13 + +model = dict( + type='S3DISOneFormer3D', + data_preprocessor=dict(type='Det3DDataPreprocessor'), + in_channels=6, + num_channels=num_channels, + voxel_size=0.05, + num_classes=num_instance_classes, + min_spatial_shape=128, + backbone=dict( + type='SpConvUNet', + num_planes=[num_channels * (i + 1) for i in range(5)], + return_blocks=True), + decoder=dict( + type='QueryDecoder', + num_layers=3, + num_classes=num_instance_classes, + num_instance_queries=400, + num_semantic_queries=num_semantic_classes, + num_instance_classes=num_instance_classes, + in_channels=num_channels, + d_model=256, + num_heads=8, + hidden_dim=1024, + dropout=0.0, + activation_fn='gelu', + iter_pred=True, + attn_mask=True, + fix_attention=True, + objectness_flag=True), + criterion=dict( + type='S3DISUnifiedCriterion', + num_semantic_classes=num_semantic_classes, + sem_criterion=dict( + type='S3DISSemanticCriterion', + loss_weight=5.0), + inst_criterion=dict( + type='InstanceCriterion', + matcher=dict( + type='HungarianMatcher', + costs=[ + dict(type='QueryClassificationCost', weight=0.5), + dict(type='MaskBCECost', weight=1.0), + dict(type='MaskDiceCost', weight=1.0)]), + loss_weight=[0.5, 1.0, 1.0, 0.5], + num_classes=num_instance_classes, + non_object_weight=0.05, + fix_dice_loss_weight=True, + iter_matcher=True, + fix_mean_loss=True)), + train_cfg=dict(), + test_cfg=dict( + topk_insts=450, + inst_score_thr=0.0, + pan_score_thr=0.4, + npoint_thr=300, + obj_normalization=True, + obj_normalization_thr=0.01, + sp_score_thr=0.15, + nms=True, + matrix_nms_kernel='linear', + num_sem_cls=num_semantic_classes, + stuff_cls=[0, 1, 2, 3, 4, 5, 6, 12], + thing_cls=[7, 8, 9, 10, 11])) + +# dataset settings +dataset_type = 'S3DISSegDataset_' +data_root = 'data/s3dis/' +data_prefix = dict( + pts='points', + pts_instance_mask='instance_mask', + pts_semantic_mask='semantic_mask') + +train_area = [1, 2, 3, 4, 6] +test_area = 5 + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_label_3d=False, + with_bbox_3d=False, + with_mask_3d=True, + with_seg_3d=True), + dict( + type='PointSample_', + num_points=180000), + dict(type='PointInstClassMapping_', + num_classes=num_instance_classes), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[0.0, 0.0], + scale_ratio_range=[0.9, 1.1], + translation_std=[.1, .1, .1], + shift_height=False), + dict( + type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5]), + dict( + type='Pack3DDetInputs_', + keys=[ + 'points', 'gt_labels_3d', + 'pts_semantic_mask', 'pts_instance_mask' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5])]), + dict(type='Pack3DDetInputs_', keys=['points']) +] + +# run settings +train_dataloader = dict( + batch_size=2, + num_workers=3, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='ConcatDataset', + datasets=([ + dict( + type=dataset_type, + data_root=data_root, + ann_file=f's3dis_infos_Area_{i}.pkl', + pipeline=train_pipeline, + filter_empty_gt=True, + data_prefix=data_prefix, + box_type_3d='Depth', + backend_args=None) for i in train_area]))) + +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=f's3dis_infos_Area_{test_area}.pkl', + pipeline=test_pipeline, + test_mode=True, + data_prefix=data_prefix, + box_type_3d='Depth', + backend_args=None)) +test_dataloader = val_dataloader + +class_names = [ + 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', + 'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter', 'unlabeled'] +label2cat = {i: name for i, name in enumerate(class_names)} +metric_meta = dict( + label2cat=label2cat, + ignore_index=[num_semantic_classes], + classes=class_names, + dataset_name='S3DIS') +sem_mapping = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + +val_evaluator = dict( + type='UnifiedSegMetric', + stuff_class_inds=[0, 1, 2, 3, 4, 5, 6, 12], + thing_class_inds=[7, 8, 9, 10, 11], + min_num_points=1, + id_offset=2**16, + sem_mapping=sem_mapping, + inst_mapping=sem_mapping, + submission_prefix_semantic=None, + submission_prefix_instance=None, + metric_meta=metric_meta) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.05), + clip_grad=dict(max_norm=10, norm_type=2)) +param_scheduler = dict(type='PolyLR', begin=0, end=512, power=0.9) + +custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] +default_hooks = dict( + checkpoint=dict( + interval=16, + max_keep_ckpts=1, + save_best=['all_ap_50%', 'miou'], + rule='greater')) + +load_from = 'work_dirs/tmp/instance-only-oneformer3d_1xb2_scannet-and-structured3d.pth' + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=512, val_interval=16) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/configs/oneformer3d_1xb4_scannet.py b/configs/oneformer3d_1xb4_scannet.py new file mode 100644 index 0000000..e695b19 --- /dev/null +++ b/configs/oneformer3d_1xb4_scannet.py @@ -0,0 +1,234 @@ +_base_ = [ + 'mmdet3d::_base_/default_runtime.py', + 'mmdet3d::_base_/datasets/scannet-seg.py' +] +custom_imports = dict(imports=['oneformer3d']) + +# model settings +num_channels = 32 +num_instance_classes = 18 +num_semantic_classes = 20 + +model = dict( + type='ScanNetOneFormer3D', + data_preprocessor=dict(type='Det3DDataPreprocessor_'), + in_channels=6, + num_channels=num_channels, + voxel_size=0.02, + num_classes=num_instance_classes, + min_spatial_shape=128, + query_thr=0.5, + backbone=dict( + type='SpConvUNet', + num_planes=[num_channels * (i + 1) for i in range(5)], + return_blocks=True), + decoder=dict( + type='ScanNetQueryDecoder', + num_layers=6, + num_instance_queries=0, + num_semantic_queries=0, + num_instance_classes=num_instance_classes, + num_semantic_classes=num_semantic_classes, + num_semantic_linears=1, + in_channels=32, + d_model=256, + num_heads=8, + hidden_dim=1024, + dropout=0.0, + activation_fn='gelu', + iter_pred=True, + attn_mask=True, + fix_attention=True, + objectness_flag=False), + criterion=dict( + type='ScanNetUnifiedCriterion', + num_semantic_classes=num_semantic_classes, + sem_criterion=dict( + type='ScanNetSemanticCriterion', + ignore_index=num_semantic_classes, + loss_weight=0.2), + inst_criterion=dict( + type='InstanceCriterion', + matcher=dict( + type='SparseMatcher', + costs=[ + dict(type='QueryClassificationCost', weight=0.5), + dict(type='MaskBCECost', weight=1.0), + dict(type='MaskDiceCost', weight=1.0)], + topk=1), + loss_weight=[0.5, 1.0, 1.0, 0.5], + num_classes=num_instance_classes, + non_object_weight=0.1, + fix_dice_loss_weight=True, + iter_matcher=True, + fix_mean_loss=True)), + train_cfg=dict(), + test_cfg=dict( + topk_insts=600, + inst_score_thr=0.0, + pan_score_thr=0.5, + npoint_thr=100, + obj_normalization=True, + sp_score_thr=0.4, + nms=True, + matrix_nms_kernel='linear', + stuff_classes=[0, 1])) + +# dataset settings +dataset_type = 'ScanNetSegDataset_' +data_prefix = dict( + pts='points', + pts_instance_mask='instance_mask', + pts_semantic_mask='semantic_mask', + sp_pts_mask='super_points') + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D_', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True, + with_sp_mask_3d=True), + dict(type='PointSegClassMapping'), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-3.14, 3.14], + scale_ratio_range=[0.8, 1.2], + translation_std=[0.1, 0.1, 0.1], + shift_height=False), + dict( + type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5]), + dict( + type='AddSuperPointAnnotations', + num_classes=num_semantic_classes, + stuff_classes=[0, 1], + merge_non_stuff_cls=False), + dict( + type='ElasticTransfrom', + gran=[6, 20], + mag=[40, 160], + voxel_size=0.02, + p=0.5), + dict( + type='Pack3DDetInputs_', + keys=[ + 'points', 'gt_labels_3d', 'pts_semantic_mask', 'pts_instance_mask', + 'sp_pts_mask', 'gt_sp_masks', 'elastic_coords' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D_', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True, + with_sp_mask_3d=True), + dict(type='PointSegClassMapping'), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='NormalizePointsColor_', + color_mean=[127.5, 127.5, 127.5]), + dict( + type='AddSuperPointAnnotations', + num_classes=num_semantic_classes, + stuff_classes=[0, 1], + merge_non_stuff_cls=False), + ]), + dict(type='Pack3DDetInputs_', keys=['points', 'sp_pts_mask']) +] + +# run settings +train_dataloader = dict( + batch_size=4, + num_workers=6, + dataset=dict( + type=dataset_type, + ann_file='scannet_oneformer3d_infos_train.pkl', + data_prefix=data_prefix, + pipeline=train_pipeline, + ignore_index=num_semantic_classes, + scene_idxs=None, + test_mode=False)) +val_dataloader = dict( + dataset=dict( + type=dataset_type, + ann_file='scannet_oneformer3d_infos_val.pkl', + data_prefix=data_prefix, + pipeline=test_pipeline, + ignore_index=num_semantic_classes, + test_mode=True)) +test_dataloader = val_dataloader + +class_names = [ + 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', + 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', + 'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink', + 'bathtub', 'otherfurniture'] +class_names += ['unlabeled'] +label2cat = {i: name for i, name in enumerate(class_names)} +metric_meta = dict( + label2cat=label2cat, + ignore_index=[num_semantic_classes], + classes=class_names, + dataset_name='ScanNet') + +sem_mapping = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] +inst_mapping = sem_mapping[2:] +val_evaluator = dict( + type='UnifiedSegMetric', + stuff_class_inds=[0, 1], + thing_class_inds=list(range(2, num_semantic_classes)), + min_num_points=1, + id_offset=2**16, + sem_mapping=sem_mapping, + inst_mapping=inst_mapping, + metric_meta=metric_meta) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.05), + clip_grad=dict(max_norm=10, norm_type=2)) + +param_scheduler = dict(type='PolyLR', begin=0, end=512, power=0.9) + +custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] +default_hooks = dict( + checkpoint=dict(interval=1, max_keep_ckpts=16)) + +load_from = 'work_dirs/tmp/sstnet_scannet.pth' + +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=512, + dynamic_intervals=[(1, 16), (512 - 16, 1)]) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/configs/oneformer3d_1xb4_scannet200.py b/configs/oneformer3d_1xb4_scannet200.py new file mode 100644 index 0000000..36c3fb1 --- /dev/null +++ b/configs/oneformer3d_1xb4_scannet200.py @@ -0,0 +1,302 @@ +_base_ = [ + 'mmdet3d::_base_/default_runtime.py', + 'mmdet3d::_base_/datasets/scannet-seg.py' +] +custom_imports = dict(imports=['oneformer3d']) + +# model settings +num_instance_classes = 198 +num_semantic_classes = 200 + +model = dict( + type='ScanNet200OneFormer3D', + data_preprocessor=dict(type='Det3DDataPreprocessor_'), + voxel_size=0.02, + num_classes=num_instance_classes, + query_thr=0.5, + backbone=dict( + type='Res16UNet34C', + in_channels=3, + out_channels=96, + config=dict( + dilations=[1, 1, 1, 1], + conv1_kernel_size=5, + bn_momentum=0.02)), + decoder=dict( + type='ScanNetQueryDecoder', + num_layers=6, + num_instance_queries=0, + num_semantic_queries=0, + num_instance_classes=num_instance_classes, + num_semantic_classes=num_semantic_classes, + num_semantic_linears=1, + in_channels=96, + d_model=256, + num_heads=8, + hidden_dim=1024, + dropout=0.0, + activation_fn='gelu', + iter_pred=True, + attn_mask=True, + fix_attention=True, + objectness_flag=False), + criterion=dict( + type='ScanNetUnifiedCriterion', + num_semantic_classes=num_semantic_classes, + sem_criterion=dict( + type='ScanNetSemanticCriterion', + ignore_index=num_semantic_classes, + loss_weight=0.5), + inst_criterion=dict( + type='InstanceCriterion', + matcher=dict( + type='SparseMatcher', + costs=[ + dict(type='QueryClassificationCost', weight=0.5), + dict(type='MaskBCECost', weight=1.0), + dict(type='MaskDiceCost', weight=1.0)], + topk=1), + loss_weight=[0.5, 1.0, 1.0, 0.5], + num_classes=num_instance_classes, + non_object_weight=0.1, + fix_dice_loss_weight=True, + iter_matcher=True, + fix_mean_loss=True)), + train_cfg=dict(), + test_cfg=dict( + topk_insts=600, + inst_score_thr=0.0, + pan_score_thr=0.5, + npoint_thr=100, + obj_normalization=True, + sp_score_thr=0.4, + nms=True, + matrix_nms_kernel='linear', + stuff_classes=[0, 1])) + +# dataset settings +dataset_type = 'ScanNet200SegDataset_' +data_root = 'data/scannet200/' +data_prefix = dict( + pts='points', + pts_instance_mask='instance_mask', + pts_semantic_mask='semantic_mask', + sp_pts_mask='super_points') + +# floor and chair are changed +class_names = [ + 'wall', 'floor', 'chair', 'table', 'door', 'couch', 'cabinet', 'shelf', + 'desk', 'office chair', 'bed', 'pillow', 'sink', 'picture', 'window', + 'toilet', 'bookshelf', 'monitor', 'curtain', 'book', 'armchair', + 'coffee table', 'box', 'refrigerator', 'lamp', 'kitchen cabinet', 'towel', + 'clothes', 'tv', 'nightstand', 'counter', 'dresser', 'stool', 'cushion', + 'plant', 'ceiling', 'bathtub', 'end table', 'dining table', 'keyboard', + 'bag', 'backpack', 'toilet paper', 'printer', 'tv stand', 'whiteboard', + 'blanket', 'shower curtain', 'trash can', 'closet', 'stairs', 'microwave', + 'stove', 'shoe', 'computer tower', 'bottle', 'bin', 'ottoman', 'bench', + 'board', 'washing machine', 'mirror', 'copier', 'basket', 'sofa chair', + 'file cabinet', 'fan', 'laptop', 'shower', 'paper', 'person', + 'paper towel dispenser', 'oven', 'blinds', 'rack', 'plate', 'blackboard', + 'piano', 'suitcase', 'rail', 'radiator', 'recycling bin', 'container', + 'wardrobe', 'soap dispenser', 'telephone', 'bucket', 'clock', 'stand', + 'light', 'laundry basket', 'pipe', 'clothes dryer', 'guitar', + 'toilet paper holder', 'seat', 'speaker', 'column', 'bicycle', 'ladder', + 'bathroom stall', 'shower wall', 'cup', 'jacket', 'storage bin', + 'coffee maker', 'dishwasher', 'paper towel roll', 'machine', 'mat', + 'windowsill', 'bar', 'toaster', 'bulletin board', 'ironing board', + 'fireplace', 'soap dish', 'kitchen counter', 'doorframe', + 'toilet paper dispenser', 'mini fridge', 'fire extinguisher', 'ball', + 'hat', 'shower curtain rod', 'water cooler', 'paper cutter', 'tray', + 'shower door', 'pillar', 'ledge', 'toaster oven', 'mouse', + 'toilet seat cover dispenser', 'furniture', 'cart', 'storage container', + 'scale', 'tissue box', 'light switch', 'crate', 'power outlet', + 'decoration', 'sign', 'projector', 'closet door', 'vacuum cleaner', + 'candle', 'plunger', 'stuffed animal', 'headphones', 'dish rack', + 'broom', 'guitar case', 'range hood', 'dustpan', 'hair dryer', + 'water bottle', 'handicap bar', 'purse', 'vent', 'shower floor', + 'water pitcher', 'mailbox', 'bowl', 'paper bag', 'alarm clock', + 'music stand', 'projector screen', 'divider', 'laundry detergent', + 'bathroom counter', 'object', 'bathroom vanity', 'closet wall', + 'laundry hamper', 'bathroom stall door', 'ceiling light', 'trash bin', + 'dumbbell', 'stair rail', 'tube', 'bathroom cabinet', 'cd case', + 'closet rod', 'coffee kettle', 'structure', 'shower head', + 'keyboard piano', 'case of water bottles', 'coat rack', + 'storage organizer', 'folded chair', 'fire alarm', 'power strip', + 'calendar', 'poster', 'potted plant', 'luggage', 'mattress' +] + +color_mean = ( + 0.47793125906962 * 255, + 0.4303257521323044 * 255, + 0.3749598901421883 * 255) +color_std = ( + 0.2834475483823543 * 255, + 0.27566157565723015 * 255, + 0.27018971370874995 * 255) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D_', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True, + with_sp_mask_3d=True), + dict(type='SwapChairAndFloor'), + dict(type='PointSegClassMapping'), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-3.14, 3.14], + scale_ratio_range=[0.8, 1.2], + translation_std=[0.1, 0.1, 0.1], + shift_height=False), + dict( + type='NormalizePointsColor_', + color_mean=color_mean, + color_std=color_std), + dict( + type='AddSuperPointAnnotations', + num_classes=num_semantic_classes, + stuff_classes=[0, 1], + merge_non_stuff_cls=False), + dict( + type='ElasticTransfrom', + gran=[6, 20], + mag=[40, 160], + voxel_size=0.02, + p=0.5), + dict( + type='Pack3DDetInputs_', + keys=[ + 'points', 'gt_labels_3d', 'pts_semantic_mask', 'pts_instance_mask', + 'sp_pts_mask', 'gt_sp_masks', 'elastic_coords' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D_', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=True, + with_seg_3d=True, + with_sp_mask_3d=True), + dict(type='SwapChairAndFloor'), + dict(type='PointSegClassMapping'), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='NormalizePointsColor_', + color_mean=color_mean, + color_std=color_std), + dict( + type='AddSuperPointAnnotations', + num_classes=num_semantic_classes, + stuff_classes=[0, 1], + merge_non_stuff_cls=False), + ]), + dict(type='Pack3DDetInputs_', keys=['points', 'sp_pts_mask']) +] + +# run settings +train_dataloader = dict( + batch_size=4, + num_workers=6, + dataset=dict( + type=dataset_type, + ann_file='scannet200_oneformer3d_infos_train.pkl', + data_root=data_root, + data_prefix=data_prefix, + metainfo=dict(classes=class_names), + pipeline=train_pipeline, + ignore_index=num_semantic_classes, + scene_idxs=None, + test_mode=False)) +val_dataloader = dict( + dataset=dict( + type=dataset_type, + ann_file='scannet200_oneformer3d_infos_val.pkl', + data_root=data_root, + data_prefix=data_prefix, + metainfo=dict(classes=class_names), + pipeline=test_pipeline, + ignore_index=num_semantic_classes, + test_mode=True)) +test_dataloader = val_dataloader + +label2cat = {i: name for i, name in enumerate(class_names + ['unlabeled'])} +metric_meta = dict( + label2cat=label2cat, + ignore_index=[num_semantic_classes], + classes=class_names + ['unlabeled'], + dataset_name='ScanNet200') + +sem_mapping = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, + 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 44, 45, 46, + 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, + 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 110, 112, + 115, 116, 118, 120, 121, 122, 125, 128, 130, 131, 132, 134, 136, 138, 139, + 140, 141, 145, 148, 154, 155, 156, 157, 159, 161, 163, 165, 166, 168, 169, + 170, 177, 180, 185, 188, 191, 193, 195, 202, 208, 213, 214, 221, 229, 230, + 232, 233, 242, 250, 261, 264, 276, 283, 286, 300, 304, 312, 323, 325, 331, + 342, 356, 370, 392, 395, 399, 408, 417, 488, 540, 562, 570, 572, 581, 609, + 748, 776, 1156, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, + 1173, 1174, 1175, 1176, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, + 1186, 1187, 1188, 1189, 1190, 1191 +] +inst_mapping = sem_mapping[2:] + +val_evaluator = dict( + type='UnifiedSegMetric', + stuff_class_inds=[0, 1], + thing_class_inds=list(range(2, num_semantic_classes)), + min_num_points=1, + id_offset=2**16, + sem_mapping=sem_mapping, + inst_mapping=inst_mapping, + metric_meta=metric_meta) +test_evaluator = val_evaluator + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.05), + clip_grad=dict(max_norm=10, norm_type=2)) +param_scheduler = dict(type='PolyLR', begin=0, end=512, power=0.9) + +custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] +default_hooks = dict( + checkpoint=dict( + interval=1, + max_keep_ckpts=1, + save_best=['all_ap_50%', 'miou'], + rule='greater')) + +load_from = 'work_dirs/tmp/mask3d_scannet200.pth' + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=512, val_interval=16) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/data/scannet/README.md b/data/scannet/README.md new file mode 100644 index 0000000..35f3267 --- /dev/null +++ b/data/scannet/README.md @@ -0,0 +1,67 @@ +### Prepare ScanNet Data for Indoor Detection or Segmentation Task + +We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/). + +1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory. If you are performing segmentation tasks and want to upload the results to its official [benchmark](http://kaldir.vc.in.tum.de/scannet_benchmark/), please also link or move the 'scans_test' folder to this directory. + +2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--scannet200` flag if you want to get markup for the ScanNet200 dataset. + +3. Enter the project root directory, generate training data by running + +```bash +python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet +``` +        or for ScanNet200: + +```bash +mkdir data/scannet200 +python tools/create_data.py scannet200 --root-path ./data/scannet --out-dir ./data/scannet200 --extra-tag scannet200 +``` + +The overall process for ScanNet could be achieved through the following script + +```bash +python batch_load_scannet_data.py +cd ../.. +python tools/create_data.py scannet --root-path ./data/scannet --out-dir ./data/scannet --extra-tag scannet +``` + +Or for ScanNet200: + +```bash +python batch_load_scannet_data.py --scannet200 +cd ../.. +mkdir data/scannet200 +python tools/create_data.py scannet200 --root-path ./data/scannet --out-dir ./data/scannet200 --extra-tag scannet200 +``` + +The directory structure after pre-processing should be as below + +``` +scannet +├── meta_data +├── batch_load_scannet_data.py +├── load_scannet_data.py +├── scannet_utils.py +├── README.md +├── scans +├── scans_test +├── scannet_instance_data +├── points +│ ├── xxxxx.bin +├── instance_mask +│ ├── xxxxx.bin +├── semantic_mask +│ ├── xxxxx.bin +├── super_points +│ ├── xxxxx.bin +├── seg_info +│ ├── train_label_weight.npy +│ ├── train_resampled_scene_idxs.npy +│ ├── val_label_weight.npy +│ ├── val_resampled_scene_idxs.npy +├── scannet_oneformer3d_infos_train.pkl +├── scannet_oneformer3d_infos_val.pkl +├── scannet_oneformer3d_infos_test.pkl + +``` diff --git a/data/scannet/batch_load_scannet_data.py b/data/scannet/batch_load_scannet_data.py new file mode 100644 index 0000000..7a8f514 --- /dev/null +++ b/data/scannet/batch_load_scannet_data.py @@ -0,0 +1,187 @@ +# Modified from +# https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Batch mode in loading Scannet scenes with vertices and ground truth labels +for semantic and instance segmentations. + +Usage example: python ./batch_load_scannet_data.py +""" +import argparse +import datetime +import os +from os import path as osp + +import torch +import segmentator +import open3d as o3d +import numpy as np +from load_scannet_data import export + +DONOTCARE_CLASS_IDS = np.array([]) + +SCANNET_OBJ_CLASS_IDS = np.array( + [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) + +SCANNET200_OBJ_CLASS_IDS = np.array([2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 110, 112, 115, 116, 118, 120, 121, 122, 125, 128, 130, 131, 132, 134, 136, 138, 139, 140, 141, 145, 148, 154, + 155, 156, 157, 159, 161, 163, 165, 166, 168, 169, 170, 177, 180, 185, 188, 191, 193, 195, 202, 208, 213, 214, 221, 229, 230, 232, 233, 242, 250, 261, 264, 276, 283, 286, 300, 304, 312, 323, 325, 331, 342, 356, 370, 392, 395, 399, 408, 417, + 488, 540, 562, 570, 572, 581, 609, 748, 776, 1156, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191]) + + + +def export_one_scan(scan_name, + output_filename_prefix, + max_num_point, + label_map_file, + scannet_dir, + test_mode=False, + scannet200=False): + mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply') + agg_file = osp.join(scannet_dir, scan_name, + scan_name + '.aggregation.json') + seg_file = osp.join(scannet_dir, scan_name, + scan_name + '_vh_clean_2.0.010000.segs.json') + # includes axisAlignment info for the train set scans. + meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt') + mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \ + aligned_bboxes, instance2semantic, axis_align_matrix = export( + mesh_file, agg_file, seg_file, meta_file, label_map_file, None, + test_mode, scannet200) + + if not test_mode: + mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS)) + mesh_vertices = mesh_vertices[mask, :] + semantic_labels = semantic_labels[mask] + instance_labels = instance_labels[mask] + + num_instances = len(np.unique(instance_labels)) + print(f'Num of instances: {num_instances}') + if scannet200: + OBJ_CLASS_IDS = SCANNET200_OBJ_CLASS_IDS + else: + OBJ_CLASS_IDS = SCANNET_OBJ_CLASS_IDS + + bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS) + unaligned_bboxes = unaligned_bboxes[bbox_mask, :] + bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS) + aligned_bboxes = aligned_bboxes[bbox_mask, :] + assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0] + print(f'Num of care instances: {unaligned_bboxes.shape[0]}') + + if max_num_point is not None: + max_num_point = int(max_num_point) + N = mesh_vertices.shape[0] + if N > max_num_point: + choices = np.random.choice(N, max_num_point, replace=False) + mesh_vertices = mesh_vertices[choices, :] + if not test_mode: + semantic_labels = semantic_labels[choices] + instance_labels = instance_labels[choices] + + mesh = o3d.io.read_triangle_mesh(mesh_file) + vertices = torch.from_numpy(np.array(mesh.vertices).astype(np.float32)) + faces = torch.from_numpy(np.array(mesh.triangles).astype(np.int64)) + superpoints = segmentator.segment_mesh(vertices, faces).numpy() + + np.save(f'{output_filename_prefix}_sp_label.npy', superpoints) + np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices) + + if not test_mode: + assert superpoints.shape == semantic_labels.shape + np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels) + np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels) + np.save(f'{output_filename_prefix}_unaligned_bbox.npy', + unaligned_bboxes) + np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes) + np.save(f'{output_filename_prefix}_axis_align_matrix.npy', + axis_align_matrix) + + +def batch_export(max_num_point, + output_folder, + scan_names_file, + label_map_file, + scannet_dir, + test_mode=False, + scannet200=False): + if test_mode and not os.path.exists(scannet_dir): + # test data preparation is optional + return + if not os.path.exists(output_folder): + print(f'Creating new data folder: {output_folder}') + os.mkdir(output_folder) + + scan_names = [line.rstrip() for line in open(scan_names_file)] + for scan_name in scan_names: + print('-' * 20 + 'begin') + print(datetime.datetime.now()) + print(scan_name) + output_filename_prefix = osp.join(output_folder, scan_name) + if osp.isfile(f'{output_filename_prefix}_vert.npy'): + print('File already exists. skipping.') + print('-' * 20 + 'done') + continue + try: + export_one_scan(scan_name, output_filename_prefix, max_num_point, + label_map_file, scannet_dir, test_mode, scannet200) + except Exception: + print(f'Failed export scan: {scan_name}') + print('-' * 20 + 'done') + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--max_num_point', + default=None, + help='The maximum number of the points.') + parser.add_argument( + '--output_folder', + default='./scannet_instance_data', + help='output folder of the result.') + parser.add_argument( + '--train_scannet_dir', default='scans', help='scannet data directory.') + parser.add_argument( + '--test_scannet_dir', + default='scans_test', + help='scannet data directory.') + parser.add_argument( + '--label_map_file', + default='meta_data/scannetv2-labels.combined.tsv', + help='The path of label map file.') + parser.add_argument( + '--train_scan_names_file', + default='meta_data/scannet_train.txt', + help='The path of the file that stores the scan names.') + parser.add_argument( + '--test_scan_names_file', + default='meta_data/scannetv2_test.txt', + help='The path of the file that stores the scan names.') + parser.add_argument( + '--scannet200', + action='store_true', + help='Use it for scannet200 mapping') + args = parser.parse_args() + batch_export( + args.max_num_point, + args.output_folder, + args.train_scan_names_file, + args.label_map_file, + args.train_scannet_dir, + test_mode=False, + scannet200=args.scannet200) + batch_export( + args.max_num_point, + args.output_folder, + args.test_scan_names_file, + args.label_map_file, + args.test_scannet_dir, + test_mode=True, + scannet200=args.scannet200) + + +if __name__ == '__main__': + main() diff --git a/data/scannet/load_scannet_data.py b/data/scannet/load_scannet_data.py new file mode 100644 index 0000000..7cbe499 --- /dev/null +++ b/data/scannet/load_scannet_data.py @@ -0,0 +1,205 @@ +# Modified from +# https://github.com/facebookresearch/votenet/blob/master/scannet/load_scannet_data.py +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Load Scannet scenes with vertices and ground truth labels for semantic and +instance segmentations.""" +import argparse +import inspect +import json +import os + +import numpy as np +import scannet_utils + +currentdir = os.path.dirname( + os.path.abspath(inspect.getfile(inspect.currentframe()))) + + +def read_aggregation(filename): + assert os.path.isfile(filename) + object_id_to_segs = {} + label_to_segs = {} + with open(filename) as f: + data = json.load(f) + num_objects = len(data['segGroups']) + for i in range(num_objects): + object_id = data['segGroups'][i][ + 'objectId'] + 1 # instance ids should be 1-indexed + label = data['segGroups'][i]['label'] + segs = data['segGroups'][i]['segments'] + object_id_to_segs[object_id] = segs + if label in label_to_segs: + label_to_segs[label].extend(segs) + else: + label_to_segs[label] = segs + return object_id_to_segs, label_to_segs + + +def read_segmentation(filename): + assert os.path.isfile(filename) + seg_to_verts = {} + with open(filename) as f: + data = json.load(f) + num_verts = len(data['segIndices']) + for i in range(num_verts): + seg_id = data['segIndices'][i] + if seg_id in seg_to_verts: + seg_to_verts[seg_id].append(i) + else: + seg_to_verts[seg_id] = [i] + return seg_to_verts, num_verts + + +def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id, + instance_ids): + num_instances = len(np.unique(list(object_id_to_segs.keys()))) + instance_bboxes = np.zeros((num_instances, 7)) + for obj_id in object_id_to_segs: + label_id = object_id_to_label_id[obj_id] + obj_pc = mesh_vertices[instance_ids == obj_id, 0:3] + if len(obj_pc) == 0: + continue + xyz_min = np.min(obj_pc, axis=0) + xyz_max = np.max(obj_pc, axis=0) + bbox = np.concatenate([(xyz_min + xyz_max) / 2.0, xyz_max - xyz_min, + np.array([label_id])]) + # NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES + instance_bboxes[obj_id - 1, :] = bbox + return instance_bboxes + + +def export(mesh_file, + agg_file, + seg_file, + meta_file, + label_map_file, + output_file=None, + test_mode=False, + scannet200=False): + """Export original files to vert, ins_label, sem_label and bbox file. + + Args: + mesh_file (str): Path of the mesh_file. + agg_file (str): Path of the agg_file. + seg_file (str): Path of the seg_file. + meta_file (str): Path of the meta_file. + label_map_file (str): Path of the label_map_file. + output_file (str): Path of the output folder. + Default: None. + test_mode (bool): Whether is generating test data without labels. + Default: False. + + It returns a tuple, which contains the the following things: + np.ndarray: Vertices of points data. + np.ndarray: Indexes of label. + np.ndarray: Indexes of instance. + np.ndarray: Instance bboxes. + dict: Map from object_id to label_id. + """ + if scannet200: + label_map = scannet_utils.read_label_mapping( + label_map_file, label_from='raw_category', label_to='id') + else: + label_map = scannet_utils.read_label_mapping( + label_map_file, label_from='raw_category', label_to='nyu40id') + + mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file) + + # Load scene axis alignment matrix + lines = open(meta_file).readlines() + # test set data doesn't have align_matrix + axis_align_matrix = np.eye(4) + for line in lines: + if 'axisAlignment' in line: + axis_align_matrix = [ + float(x) + for x in line.rstrip().strip('axisAlignment = ').split(' ') + ] + break + axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) + + # perform global alignment of mesh vertices + pts = np.ones((mesh_vertices.shape[0], 4)) + pts[:, 0:3] = mesh_vertices[:, 0:3] + pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4 + aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]], + axis=1) + + # Load semantic and instance labels + if not test_mode: + object_id_to_segs, label_to_segs = read_aggregation(agg_file) + seg_to_verts, num_verts = read_segmentation(seg_file) + label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) + object_id_to_label_id = {} + for label, segs in label_to_segs.items(): + label_id = label_map[label] + for seg in segs: + verts = seg_to_verts[seg] + label_ids[verts] = label_id + instance_ids = np.zeros( + shape=(num_verts), dtype=np.uint32) # 0: unannotated + for object_id, segs in object_id_to_segs.items(): + for seg in segs: + verts = seg_to_verts[seg] + instance_ids[verts] = object_id + if object_id not in object_id_to_label_id: + object_id_to_label_id[object_id] = label_ids[verts][0] + unaligned_bboxes = extract_bbox(mesh_vertices, object_id_to_segs, + object_id_to_label_id, instance_ids) + aligned_bboxes = extract_bbox(aligned_mesh_vertices, object_id_to_segs, + object_id_to_label_id, instance_ids) + else: + label_ids = None + instance_ids = None + unaligned_bboxes = None + aligned_bboxes = None + object_id_to_label_id = None + + if output_file is not None: + np.save(output_file + '_vert.npy', mesh_vertices) + if not test_mode: + np.save(output_file + '_sem_label.npy', label_ids) + np.save(output_file + '_ins_label.npy', instance_ids) + np.save(output_file + '_unaligned_bbox.npy', unaligned_bboxes) + np.save(output_file + '_aligned_bbox.npy', aligned_bboxes) + np.save(output_file + '_axis_align_matrix.npy', axis_align_matrix) + + return mesh_vertices, label_ids, instance_ids, unaligned_bboxes, \ + aligned_bboxes, object_id_to_label_id, axis_align_matrix + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--scan_path', + required=True, + help='path to scannet scene (e.g., data/ScanNet/v2/scene0000_00') + parser.add_argument('--output_file', required=True, help='output file') + parser.add_argument( + '--label_map_file', + required=True, + help='path to scannetv2-labels.combined.tsv') + parser.add_argument( + '--scannet200', + action='store_true', + help='Use it for scannet200 mapping') + + opt = parser.parse_args() + + scan_name = os.path.split(opt.scan_path)[-1] + mesh_file = os.path.join(opt.scan_path, scan_name + '_vh_clean_2.ply') + agg_file = os.path.join(opt.scan_path, scan_name + '.aggregation.json') + seg_file = os.path.join(opt.scan_path, + scan_name + '_vh_clean_2.0.010000.segs.json') + meta_file = os.path.join( + opt.scan_path, scan_name + + '.txt') # includes axisAlignment info for the train set scans. + export(mesh_file, agg_file, seg_file, meta_file, opt.label_map_file, + opt.output_file, scannet200=opt.scannet200) + + +if __name__ == '__main__': + main() diff --git a/data/scannet/meta_data/scannet_means.npz b/data/scannet/meta_data/scannet_means.npz new file mode 100644 index 0000000000000000000000000000000000000000..e57647c9a3553ca4653a9d1e53ed4a2a58def822 GIT binary patch literal 676 zcmWIWW@Zs#fB;2?qMfn(4VV}hK$sIKm{?R4Z=jb~P&wHz)HfiKk)e#CT0JGTIJrpO zNVaPx7@2E}8$<-lIU?aSr=G`p*2`D_wHFi=4Ad6?Z2wg7 zT-4wBkL}%0{@(Oq&0~8dx7n)%|2(tb8@ImvNANfMQmeV|P8j~M=W9QlxM}k%d#067 zUa&{JxBoLix428>qy623b*tSSpV$lOrYR*cKeT7o{Lg!y@3DP|USJPD!yo(h`{xDw z&%CprZTw4G*5JAQt*;YQe#U*W=YO@prfcCV`xdJe0b(aV*c%%g_q#Iuwy*#AW9BTL z)ApCn)_;q8earr|Y)f&z;Y0i57fRxnOMSNohP6OR)Ia;sqByq**NgVcY$r}^ulQi! zH{+X)O!+sExn{*IZ|vWC+dbI9_RT)O$!tdV_b2wkTMovrc=pTwslw+=kseR&J0DEc zt3CM4-oiOsUFyUa`~8c<*;F6ivQO-)nYPC1hW(a21J<3af9!2LeYE%ddt~pgbh)GX z^&@*-e= + min_region_size and gt['med_dist'] <= distance_thresh + and gt['dist_conf'] >= distance_conf + ] + if gt_instances: + has_gt = True + if pred_instances: + has_pred = True + + cur_true = np.ones(len(gt_instances)) + cur_score = np.ones(len(gt_instances)) * (-float('inf')) + cur_match = np.zeros(len(gt_instances), dtype=bool) + # collect matches + for (gti, gt) in enumerate(gt_instances): + found_match = False + for pred in gt['matched_pred']: + # greedy assignments + if pred_visited[pred['filename']]: + continue + overlap = float(pred['intersection']) / ( + gt['vert_count'] + pred['vert_count'] - + pred['intersection']) + if overlap > overlap_th: + confidence = pred['confidence'] + # if already have a prediction for this gt, + # the prediction with the lower score is automatically a false positive # noqa + if cur_match[gti]: + max_score = max(cur_score[gti], confidence) + min_score = min(cur_score[gti], confidence) + cur_score[gti] = max_score + # append false positive + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, min_score) + cur_match = np.append(cur_match, True) + # otherwise set score + else: + found_match = True + cur_match[gti] = True + cur_score[gti] = confidence + pred_visited[pred['filename']] = True + if not found_match: + hard_false_negatives += 1 + # remove non-matched ground truth instances + cur_true = cur_true[cur_match] + cur_score = cur_score[cur_match] + + # collect non-matched predictions as false positive + for pred in pred_instances: + found_gt = False + for gt in pred['matched_gt']: + overlap = float(gt['intersection']) / ( + gt['vert_count'] + pred['vert_count'] - + gt['intersection']) + if overlap > overlap_th: + found_gt = True + break + if not found_gt: + num_ignore = pred['void_intersection'] + for gt in pred['matched_gt']: + # group? + if gt['instance_id'] < 1000: + num_ignore += gt['intersection'] + # small ground truth instances + if gt['vert_count'] < min_region_size or gt[ + 'med_dist'] > distance_thresh or gt[ + 'dist_conf'] < distance_conf: + num_ignore += gt['intersection'] + proportion_ignore = float( + num_ignore) / pred['vert_count'] + # if not ignored append false positive + if proportion_ignore <= overlap_th: + cur_true = np.append(cur_true, 0) + confidence = pred['confidence'] + cur_score = np.append(cur_score, confidence) + + # append to overall results + y_true = np.append(y_true, cur_true) + y_score = np.append(y_score, cur_score) + + # compute average precision + if has_gt and has_pred: + # compute precision recall curve first + + # sorting and cumsum + score_arg_sort = np.argsort(y_score) + y_score_sorted = y_score[score_arg_sort] + y_true_sorted = y_true[score_arg_sort] + y_true_sorted_cumsum = np.cumsum(y_true_sorted) + + # unique thresholds + (thresholds, unique_indices) = np.unique( + y_score_sorted, return_index=True) + num_prec_recall = len(unique_indices) + 1 + + # prepare precision recall + num_examples = len(y_score_sorted) + # follow https://github.com/ScanNet/ScanNet/pull/26 ? # noqa + num_true_examples = y_true_sorted_cumsum[-1] if len( + y_true_sorted_cumsum) > 0 else 0 + precision = np.zeros(num_prec_recall) + recall = np.zeros(num_prec_recall) + + # deal with the first point + y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0) + # deal with remaining + for idx_res, idx_scores in enumerate(unique_indices): + cumsum = y_true_sorted_cumsum[idx_scores - 1] + tp = num_true_examples - cumsum + fp = num_examples - idx_scores - tp + fn = cumsum + hard_false_negatives + p = float(tp) / (tp + fp) + r = float(tp) / (tp + fn) + precision[idx_res] = p + recall[idx_res] = r + + # first point in curve is artificial + precision[-1] = 1. + recall[-1] = 0. + + #compute optimal precision and recall, based on f1_score + f1_score = 2 * precision * recall / (precision + recall + 0.0001) + f1_argmax = f1_score.argmax() + best_pr = precision[f1_argmax] + best_rc = recall[f1_argmax] + + # compute average of precision-recall curve + recall_for_conv = np.copy(recall) + recall_for_conv = np.append(recall_for_conv[0], + recall_for_conv) + recall_for_conv = np.append(recall_for_conv, 0.) + + stepWidths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], + 'valid') + # integrate is now simply a dot product + ap_current = np.dot(precision, stepWidths) + + elif has_gt: + ap_current = 0.0 + best_pr = 0 + best_rc = 0 + else: + ap_current = float('nan') + best_pr = float('nan') + best_rc = float('nan') + ap[di, li, oi] = ap_current + pr_rc[0, li, oi] = best_pr + pr_rc[1, li, oi] = best_rc + + return ap, pr_rc + + +def compute_averages(aps, pr_rc, options, class_labels): + """Averages AP scores for all categories. + + Args: + aps (np.array): AP scores for all thresholds and categories. + options (dict): ScanNet evaluator options. See get_options. + class_labels (tuple[str]): Class names. + + Returns: + dict: Overall and per-category AP scores. + """ + d_inf = 0 + o50 = np.where(np.isclose(options['overlaps'], 0.5)) + o25 = np.where(np.isclose(options['overlaps'], 0.25)) + o_all_but25 = np.where( + np.logical_not(np.isclose(options['overlaps'], 0.25))) + avg_dict = {} + avg_dict['all_ap'] = np.nanmean(aps[d_inf, :, o_all_but25]) + avg_dict['all_ap_50%'] = np.nanmean(aps[d_inf, :, o50]) + avg_dict['all_ap_25%'] = np.nanmean(aps[d_inf, :, o25]) + avg_dict['all_prec_50%'] = np.nanmean(pr_rc[0, :, o50]) + avg_dict['all_rec_50%'] = np.nanmean(pr_rc[1, :, o50]) + avg_dict['classes'] = {} + for (li, label_name) in enumerate(class_labels): + avg_dict['classes'][label_name] = {} + avg_dict['classes'][label_name]['ap'] = np.average(aps[d_inf, li, + o_all_but25]) + avg_dict['classes'][label_name]['ap50%'] = np.average(aps[d_inf, li, + o50]) + avg_dict['classes'][label_name]['ap25%'] = np.average(aps[d_inf, li, + o25]) + avg_dict['classes'][label_name]['prec50%'] = np.average(pr_rc[0, li, + o50]) + avg_dict['classes'][label_name]['rec50%'] = np.average(pr_rc[1, li, + o50]) + return avg_dict + + +def assign_instances_for_scan(pred_info, gt_ids, options, valid_class_ids, + class_labels, id_to_label): + """Assign gt and predicted instances for a single scene. + + Args: + pred_info (dict): Predicted masks, labels and scores. + gt_ids (np.array): Ground truth instance masks. + options (dict): ScanNet evaluator options. See get_options. + valid_class_ids (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Class names. + id_to_label (dict[int, str]): Mapping of valid class id to class label. + + Returns: + dict: Per class assigned gt to predicted instances. + dict: Per class assigned predicted to gt instances. + """ + # get gt instances + gt_instances = util_3d.get_instances(gt_ids, valid_class_ids, class_labels, + id_to_label) + # associate + gt2pred = deepcopy(gt_instances) + for label in gt2pred: + for gt in gt2pred[label]: + gt['matched_pred'] = [] + pred2gt = {} + for label in class_labels: + pred2gt[label] = [] + num_pred_instances = 0 + # mask of void labels in the ground truth + bool_void = np.logical_not(np.in1d(gt_ids // 1000, valid_class_ids)) + # go through all prediction masks + for pred_mask_file in pred_info: + label_id = int(pred_info[pred_mask_file]['label_id']) + conf = pred_info[pred_mask_file]['conf'] + if not label_id in id_to_label: # noqa E713 + continue + label_name = id_to_label[label_id] + # read the mask + pred_mask = pred_info[pred_mask_file]['mask'] + if len(pred_mask) != len(gt_ids): + raise ValueError('len(pred_mask) != len(gt_ids)') + # convert to binary + pred_mask = np.not_equal(pred_mask, 0) + num = np.count_nonzero(pred_mask) + if num < options['min_region_sizes'][0]: + continue # skip if empty + + pred_instance = {} + pred_instance['filename'] = pred_mask_file + pred_instance['pred_id'] = num_pred_instances + pred_instance['label_id'] = label_id + pred_instance['vert_count'] = num + pred_instance['confidence'] = conf + pred_instance['void_intersection'] = np.count_nonzero( + np.logical_and(bool_void, pred_mask)) + + # matched gt instances + matched_gt = [] + # go through all gt instances with matching label + for (gt_num, gt_inst) in enumerate(gt2pred[label_name]): + intersection = np.count_nonzero( + np.logical_and(gt_ids == gt_inst['instance_id'], pred_mask)) + if intersection > 0: + gt_copy = gt_inst.copy() + pred_copy = pred_instance.copy() + gt_copy['intersection'] = intersection + pred_copy['intersection'] = intersection + matched_gt.append(gt_copy) + gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy) + pred_instance['matched_gt'] = matched_gt + num_pred_instances += 1 + pred2gt[label_name].append(pred_instance) + + return gt2pred, pred2gt + + +def scannet_eval(preds, gts, options, valid_class_ids, class_labels, + id_to_label): + """Evaluate instance segmentation in ScanNet protocol. + + Args: + preds (list[dict]): Per scene predictions of mask, label and + confidence. + gts (list[np.array]): Per scene ground truth instance masks. + options (dict): ScanNet evaluator options. See get_options. + valid_class_ids (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Class names. + id_to_label (dict[int, str]): Mapping of valid class id to class label. + + Returns: + dict: Overall and per-category AP scores. + """ + options = get_options(options) + matches = {} + for i, (pred, gt) in enumerate(zip(preds, gts)): + matches_key = i + # assign gt to predictions + gt2pred, pred2gt = assign_instances_for_scan(pred, gt, options, + valid_class_ids, + class_labels, id_to_label) + matches[matches_key] = {} + matches[matches_key]['gt'] = gt2pred + matches[matches_key]['pred'] = pred2gt + + ap_scores, pr_rc = evaluate_matches(matches, class_labels, options) + avgs = compute_averages(ap_scores, pr_rc, options, class_labels) + return avgs + + +def get_options(options=None): + """Set ScanNet evaluator options. + + Args: + options (dict, optional): Not default options. Default: None. + + Returns: + dict: Updated options with all 4 keys. + """ + assert options is None or isinstance(options, dict) + _options = dict( + overlaps=np.append(np.arange(0.5, 0.95, 0.05), 0.25), + min_region_sizes=np.array([100]), + distance_threshes=np.array([float('inf')]), + distance_confs=np.array([-float('inf')])) + if options is not None: + _options.update(options) + return _options diff --git a/oneformer3d/formatting.py b/oneformer3d/formatting.py new file mode 100644 index 0000000..cdf00f2 --- /dev/null +++ b/oneformer3d/formatting.py @@ -0,0 +1,142 @@ +# Adapted from mmdet3d/datasets/transforms/formating.py +import numpy as np +from .structures import InstanceData_ +from mmdet3d.datasets.transforms import Pack3DDetInputs +from mmdet3d.datasets.transforms.formating import to_tensor +from mmdet3d.registry import TRANSFORMS +from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample, PointData +from mmdet3d.structures.points import BasePoints + + +@TRANSFORMS.register_module() +class Pack3DDetInputs_(Pack3DDetInputs): + """Just add elastic_coords, sp_pts_mask, and gt_sp_masks. + """ + INPUTS_KEYS = ['points', 'img', 'elastic_coords'] + SEG_KEYS = [ + 'gt_seg_map', + 'pts_instance_mask', + 'pts_semantic_mask', + 'gt_semantic_seg', + 'sp_pts_mask', + ] + INSTANCEDATA_3D_KEYS = [ + 'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d', + 'gt_sp_masks' + ] + + def pack_single_results(self, results: dict) -> dict: + """Method to pack the single input data. when the value in this dict is + a list, it usually is in Augmentations Testing. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: A dict contains + + - 'inputs' (dict): The forward data of models. It usually contains + following keys: + + - points + - img + + - 'data_samples' (:obj:`Det3DDataSample`): The annotation info + of the sample. + """ + # Format 3D data + if 'points' in results: + if isinstance(results['points'], BasePoints): + results['points'] = results['points'].tensor + + if 'img' in results: + if isinstance(results['img'], list): + # process multiple imgs in single frame + imgs = np.stack(results['img'], axis=0) + if imgs.flags.c_contiguous: + imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous() + else: + imgs = to_tensor( + np.ascontiguousarray(imgs.transpose(0, 3, 1, 2))) + results['img'] = imgs + else: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + # To improve the computational speed by by 3-5 times, apply: + # `torch.permute()` rather than `np.transpose()`. + # Refer to https://github.com/open-mmlab/mmdetection/pull/9533 + # for more details + if img.flags.c_contiguous: + img = to_tensor(img).permute(2, 0, 1).contiguous() + else: + img = to_tensor( + np.ascontiguousarray(img.transpose(2, 0, 1))) + results['img'] = img + + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', + 'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask', + 'pts_semantic_mask', 'sp_pts_mask', 'gt_sp_masks', + 'elastic_coords', 'centers_2d', 'depths', 'gt_labels_3d' + ]: + if key not in results: + continue + if isinstance(results[key], list): + results[key] = [to_tensor(res) for res in results[key]] + else: + results[key] = to_tensor(results[key]) + if 'gt_bboxes_3d' in results: + if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): + results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d']) + + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = to_tensor( + results['gt_semantic_seg'][None]) + if 'gt_seg_map' in results: + results['gt_seg_map'] = results['gt_seg_map'][None, ...] + + data_sample = Det3DDataSample() + gt_instances_3d = InstanceData_() + gt_instances = InstanceData_() + gt_pts_seg = PointData() + + img_metas = {} + for key in self.meta_keys: + if key in results: + img_metas[key] = results[key] + data_sample.set_metainfo(img_metas) + + inputs = {} + for key in self.keys: + if key in results: + if key in self.INPUTS_KEYS: + inputs[key] = results[key] + elif key in self.INSTANCEDATA_3D_KEYS: + gt_instances_3d[self._remove_prefix(key)] = results[key] + elif key in self.INSTANCEDATA_2D_KEYS: + if key == 'gt_bboxes_labels': + gt_instances['labels'] = results[key] + else: + gt_instances[self._remove_prefix(key)] = results[key] + elif key in self.SEG_KEYS: + gt_pts_seg[self._remove_prefix(key)] = results[key] + else: + raise NotImplementedError(f'Please modified ' + f'`Pack3DDetInputs` ' + f'to put {key} to ' + f'corresponding field') + + data_sample.gt_instances_3d = gt_instances_3d + data_sample.gt_instances = gt_instances + data_sample.gt_pts_seg = gt_pts_seg + if 'eval_ann_info' in results: + data_sample.eval_ann_info = results['eval_ann_info'] + else: + data_sample.eval_ann_info = None + + packed_results = dict() + packed_results['data_samples'] = data_sample + packed_results['inputs'] = inputs + + return packed_results diff --git a/oneformer3d/instance_criterion.py b/oneformer3d/instance_criterion.py new file mode 100644 index 0000000..61e1797 --- /dev/null +++ b/oneformer3d/instance_criterion.py @@ -0,0 +1,724 @@ +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from .structures import InstanceData_ +from mmdet3d.registry import MODELS, TASK_UTILS + + +def batch_sigmoid_bce_loss(inputs, targets): + """Sigmoid BCE loss. + + Args: + inputs: of shape (n_queries, n_points). + targets: of shape (n_gts, n_points). + + Returns: + Tensor: Loss of shape (n_queries, n_gts). + """ + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction='none') + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction='none') + + pos_loss = torch.einsum('nc,mc->nm', pos, targets) + neg_loss = torch.einsum('nc,mc->nm', neg, (1 - targets)) + return (pos_loss + neg_loss) / inputs.shape[1] + + +def batch_dice_loss(inputs, targets): + """Dice loss. + + Args: + inputs: of shape (n_queries, n_points). + targets: of shape (n_gts, n_points). + + Returns: + Tensor: Loss of shape (n_queries, n_gts). + """ + inputs = inputs.sigmoid() + numerator = 2 * torch.einsum('nc,mc->nm', inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def get_iou(inputs, targets): + """IoU for to equal shape masks. + + Args: + inputs (Tensor): of shape (n_gts, n_points). + targets (Tensor): of shape (n_gts, n_points). + + Returns: + Tensor: IoU of shape (n_gts,). + """ + inputs = inputs.sigmoid() + binarized_inputs = (inputs >= 0.5).float() + targets = (targets > 0.5).float() + intersection = (binarized_inputs * targets).sum(-1) + union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection + score = intersection / (union + 1e-6) + return score + + +def dice_loss(inputs, targets): + """Compute the DICE loss, similar to generalized IOU for masks. + + Args: + inputs (Tensor): A float tensor of arbitrary shape. + The predictions for each example. + targets (Tensor): A float tensor with the same shape as inputs. + Stores the binary classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + Tensor: loss value. + """ + inputs = inputs.sigmoid() + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.mean() + + +@MODELS.register_module() +class InstanceCriterion: + """Instance criterion. + + Args: + matcher (Callable): Class for matching queries with gt. + loss_weight (List[float]): 4 weights for query classification, + mask bce, mask dice, and score losses. + non_object_weight (float): no_object weight for query classification. + num_classes (int): number of classes. + fix_dice_loss_weight (bool): Whether to fix dice loss for + batch_size != 4. + iter_matcher (bool): Whether to use separate matcher for + each decoder layer. + fix_mean_loss (bool): Whether to use .mean() instead of .sum() + for mask losses. + + """ + + def __init__(self, matcher, loss_weight, non_object_weight, num_classes, + fix_dice_loss_weight, iter_matcher, fix_mean_loss=False): + self.matcher = TASK_UTILS.build(matcher) + class_weight = [1] * num_classes + [non_object_weight] + self.class_weight = class_weight + self.loss_weight = loss_weight + self.num_classes = num_classes + self.fix_dice_loss_weight = fix_dice_loss_weight + self.iter_matcher = iter_matcher + self.fix_mean_loss = fix_mean_loss + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def get_layer_loss(self, aux_outputs, insts, indices=None): + """Per layer auxiliary loss. + + Args: + aux_outputs (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_queries, n_classes + 1) + List `scores` of len batch_size each of shape (n_queries, 1) + List `masks` of len batch_size each of shape + (n_queries, n_points) + insts (List): + Ground truth of len batch_size, each InstanceData_ with + `sp_masks` of shape (n_gts_i, n_points_i) + `labels_3d` of shape (n_gts_i,) + `query_masks` of shape (n_gts_i, n_queries_i). + + Returns: + Tensor: loss value. + """ + cls_preds = aux_outputs['cls_preds'] + pred_scores = aux_outputs['scores'] + pred_masks = aux_outputs['masks'] + + if indices is None: + indices = [] + for i in range(len(insts)): + pred_instances = InstanceData_( + scores=cls_preds[i], + masks=pred_masks[i]) + gt_instances = InstanceData_( + labels=insts[i].labels_3d, + masks=insts[i].sp_masks) + if insts[i].get('query_masks') is not None: + gt_instances.query_masks = insts[i].query_masks + indices.append(self.matcher(pred_instances, gt_instances)) + + cls_losses = [] + for cls_pred, inst, (idx_q, idx_gt) in zip(cls_preds, insts, indices): + n_classes = cls_pred.shape[1] - 1 + cls_target = cls_pred.new_full( + (len(cls_pred),), n_classes, dtype=torch.long) + cls_target[idx_q] = inst.labels_3d[idx_gt] + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, cls_pred.new_tensor(self.class_weight))) + cls_loss = torch.mean(torch.stack(cls_losses)) + + # 3 other losses + score_losses, mask_bce_losses, mask_dice_losses = [], [], [] + for mask, score, inst, (idx_q, idx_gt) in zip(pred_masks, pred_scores, + insts, indices): + if len(inst) == 0: + continue + + pred_mask = mask[idx_q] + tgt_mask = inst.sp_masks[idx_gt] + mask_bce_losses.append(F.binary_cross_entropy_with_logits( + pred_mask, tgt_mask.float())) + mask_dice_losses.append(dice_loss(pred_mask, tgt_mask.float())) + + # check if skip objectness loss + if score is None: + continue + + pred_score = score[idx_q] + with torch.no_grad(): + tgt_score = get_iou(pred_mask, tgt_mask).unsqueeze(1) + + filter_id, _ = torch.where(tgt_score > 0.5) + if filter_id.numel(): + tgt_score = tgt_score[filter_id] + pred_score = pred_score[filter_id] + score_losses.append(F.mse_loss(pred_score, tgt_score)) + # todo: actually .mean() should be better + if len(score_losses): + score_loss = torch.stack(score_losses).sum() / len(pred_masks) + else: + score_loss = 0 + + if len(mask_bce_losses): + mask_bce_loss = torch.stack(mask_bce_losses).sum() / len(pred_masks) + mask_dice_loss = torch.stack(mask_dice_losses).sum() / len(pred_masks) + + if self.fix_dice_loss_weight: + mask_dice_loss = mask_dice_loss / len(pred_masks) * 4 + + if self.fix_mean_loss: + mask_bce_loss = mask_bce_loss * len(pred_masks) \ + / len(mask_bce_losses) + mask_dice_loss = mask_dice_loss * len(pred_masks) \ + / len(mask_dice_losses) + else: + mask_bce_loss = 0 + mask_dice_loss = 0 + + loss = ( + self.loss_weight[0] * cls_loss + + self.loss_weight[1] * mask_bce_loss + + self.loss_weight[2] * mask_dice_loss + + self.loss_weight[3] * score_loss) + + return loss + + # todo: refactor pred to InstanceData_ + def __call__(self, pred, insts): + """Loss main function. + + Args: + pred (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_queries, n_classes + 1) + List `scores` of len batch_size each of shape (n_queries, 1) + List `masks` of len batch_size each of shape + (n_queries, n_points) + Dict `aux_preds` with list of cls_preds, scores, and masks. + insts (List): + Ground truth of len batch_size, each InstanceData_ with + `sp_masks` of shape (n_gts_i, n_points_i) + `labels_3d` of shape (n_gts_i,) + `query_masks` of shape (n_gts_i, n_queries_i). + + Returns: + Dict: with instance loss value. + """ + cls_preds = pred['cls_preds'] + pred_scores = pred['scores'] + pred_masks = pred['masks'] + + # match + indices = [] + for i in range(len(insts)): + pred_instances = InstanceData_( + scores=cls_preds[i], + masks=pred_masks[i]) + gt_instances = InstanceData_( + labels=insts[i].labels_3d, + masks=insts[i].sp_masks) + if insts[i].get('query_masks') is not None: + gt_instances.query_masks = insts[i].query_masks + indices.append(self.matcher(pred_instances, gt_instances)) + + # class loss + cls_losses = [] + for cls_pred, inst, (idx_q, idx_gt) in zip(cls_preds, insts, indices): + n_classes = cls_pred.shape[1] - 1 + cls_target = cls_pred.new_full( + (len(cls_pred),), n_classes, dtype=torch.long) + cls_target[idx_q] = inst.labels_3d[idx_gt] + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, cls_pred.new_tensor(self.class_weight))) + cls_loss = torch.mean(torch.stack(cls_losses)) + + # 3 other losses + score_losses, mask_bce_losses, mask_dice_losses = [], [], [] + for mask, score, inst, (idx_q, idx_gt) in zip(pred_masks, pred_scores, + insts, indices): + if len(inst) == 0: + continue + pred_mask = mask[idx_q] + tgt_mask = inst.sp_masks[idx_gt] + mask_bce_losses.append(F.binary_cross_entropy_with_logits( + pred_mask, tgt_mask.float())) + mask_dice_losses.append(dice_loss(pred_mask, tgt_mask.float())) + + # check if skip objectness loss + if score is None: + continue + + pred_score = score[idx_q] + with torch.no_grad(): + tgt_score = get_iou(pred_mask, tgt_mask).unsqueeze(1) + + filter_id, _ = torch.where(tgt_score > 0.5) + if filter_id.numel(): + tgt_score = tgt_score[filter_id] + pred_score = pred_score[filter_id] + score_losses.append(F.mse_loss(pred_score, tgt_score)) + # todo: actually .mean() should be better + if len(score_losses): + score_loss = torch.stack(score_losses).sum() / len(pred_masks) + else: + score_loss = 0 + + if len(mask_bce_losses): + mask_bce_loss = torch.stack(mask_bce_losses).sum() / len(pred_masks) + mask_dice_loss = torch.stack(mask_dice_losses).sum() + + if self.fix_dice_loss_weight: + mask_dice_loss = mask_dice_loss / len(pred_masks) * 4 + + if self.fix_mean_loss: + mask_bce_loss = mask_bce_loss * len(pred_masks) \ + / len(mask_bce_losses) + mask_dice_loss = mask_dice_loss * len(pred_masks) \ + / len(mask_dice_losses) + else: + mask_bce_loss = 0 + mask_dice_loss = 0 + + loss = ( + self.loss_weight[0] * cls_loss + + self.loss_weight[1] * mask_bce_loss + + self.loss_weight[2] * mask_dice_loss + + self.loss_weight[3] * score_loss) + + if 'aux_outputs' in pred: + if self.iter_matcher: + indices = None + for i, aux_outputs in enumerate(pred['aux_outputs']): + loss += self.get_layer_loss(aux_outputs, insts, indices) + + return {'inst_loss': loss} + + +@TASK_UTILS.register_module() +class QueryClassificationCost: + """Classification cost for queries. + + Args: + weigth (float): Weight of the cost. + """ + def __init__(self, weight): + self.weight = weight + + def __call__(self, pred_instances, gt_instances, **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData_`): Predicted instances which + must contain `scores` of shape (n_queries, n_classes + 1), + gt_instances (:obj:`InstanceData_`): Ground truth which must contain + `labels` of shape (n_gts,). + + Returns: + Tensor: Cost of shape (n_queries, n_gts). + """ + scores = pred_instances.scores.softmax(-1) + cost = -scores[:, gt_instances.labels] + return cost * self.weight + + +@TASK_UTILS.register_module() +class MaskBCECost: + """Sigmoid BCE cost for masks. + + Args: + weigth (float): Weight of the cost. + """ + def __init__(self, weight): + self.weight = weight + + def __call__(self, pred_instances, gt_instances, **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData_`): Predicted instances which + mast contain `masks` of shape (n_queries, n_points). + gt_instances (:obj:`InstanceData_`): Ground truth which must contain + `labels` of shape (n_gts,), `masks` of shape (n_gts, n_points). + + Returns: + Tensor: Cost of shape (n_queries, n_gts). + """ + cost = batch_sigmoid_bce_loss( + pred_instances.masks, gt_instances.masks.float()) + return cost * self.weight + + +@TASK_UTILS.register_module() +class MaskDiceCost: + """Dice cost for masks. + + Args: + weigth (float): Weight of the cost. + """ + def __init__(self, weight): + self.weight = weight + + def __call__(self, pred_instances, gt_instances, **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData_`): Predicted instances which + mast contain `masks` of shape (n_queries, n_points). + gt_instances (:obj:`InstanceData_`): Ground truth which must contain + `masks` of shape (n_gts, n_points). + + Returns: + Tensor: Cost of shape (n_queries, n_gts). + """ + cost = batch_dice_loss( + pred_instances.masks, gt_instances.masks.float()) + return cost * self.weight + + +@TASK_UTILS.register_module() +class HungarianMatcher: + """Hungarian matcher. + + Args: + costs (List[ConfigDict]): Cost functions. + """ + def __init__(self, costs): + self.costs = [] + for cost in costs: + self.costs.append(TASK_UTILS.build(cost)) + + @torch.no_grad() + def __call__(self, pred_instances, gt_instances, **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData_`): Predicted instances which + can contain `masks` of shape (n_queries, n_points), `scores` + of shape (n_queries, n_classes + 1), + gt_instances (:obj:`InstanceData_`): Ground truth which can contain + `labels` of shape (n_gts,), `masks` of shape (n_gts, n_points). + + Returns: + Tuple: + - Tensor: Query ids of shape (n_matched,), + - Tensor: Object ids of shape (n_matched,). + """ + labels = gt_instances.labels + n_gts = len(labels) + if n_gts == 0: + return labels.new_empty((0,)), labels.new_empty((0,)) + + cost_values = [] + for cost in self.costs: + cost_values.append(cost(pred_instances, gt_instances)) + cost_value = torch.stack(cost_values).sum(dim=0) + query_ids, object_ids = linear_sum_assignment(cost_value.cpu().numpy()) + return labels.new_tensor(query_ids), labels.new_tensor(object_ids) + + +@TASK_UTILS.register_module() +class SparseMatcher: + """Match only queries to their including objects. + + Args: + costs (List[Callable]): Cost functions. + topk (int): Limit topk matches per query. + """ + + def __init__(self, costs, topk): + self.topk = topk + self.costs = [] + self.inf = 1e8 + for cost in costs: + self.costs.append(TASK_UTILS.build(cost)) + + @torch.no_grad() + def __call__(self, pred_instances, gt_instances, **kwargs): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData_`): Predicted instances which + can contain `masks` of shape (n_queries, n_points), `scores` + of shape (n_queries, n_classes + 1), + gt_instances (:obj:`InstanceData_`): Ground truth which can contain + `labels` of shape (n_gts,), `masks` of shape (n_gts, n_points), + `query_masks` of shape (n_gts, n_queries). + + Returns: + Tuple: + Tensor: Query ids of shape (n_matched,), + Tensor: Object ids of shape (n_matched,). + """ + labels = gt_instances.labels + n_gts = len(labels) + if n_gts == 0: + return labels.new_empty((0,)), labels.new_empty((0,)) + + cost_values = [] + for cost in self.costs: + cost_values.append(cost(pred_instances, gt_instances)) + # of shape (n_queries, n_gts) + cost_value = torch.stack(cost_values).sum(dim=0) + cost_value = torch.where( + gt_instances.query_masks.T, cost_value, self.inf) + + values = torch.topk( + cost_value, self.topk + 1, dim=0, sorted=True, + largest=False).values[-1:, :] + ids = torch.argwhere(cost_value < values) + return ids[:, 0], ids[:, 1] + + +@MODELS.register_module() +class OneDataCriterion: + """Loss module for SPFormer. + + Args: + matcher (Callable): Class for matching queries with gt. + loss_weight (List[float]): 4 weights for query classification, + mask bce, mask dice, and score losses. + non_object_weight (float): no_object weight for query classification. + num_classes_1dataset (int): Number of classes in the first dataset. + num_classes_2dataset (int): Number of classes in the second dataset. + fix_dice_loss_weight (bool): Whether to fix dice loss for + batch_size != 4. + iter_matcher (bool): Whether to use separate matcher for + each decoder layer. + """ + + def __init__(self, matcher, loss_weight, non_object_weight, + num_classes_1dataset, num_classes_2dataset, + fix_dice_loss_weight, iter_matcher): + self.matcher = TASK_UTILS.build(matcher) + self.num_classes_1dataset = num_classes_1dataset + self.num_classes_2dataset = num_classes_2dataset + self.class_weight_1dataset = [1] * num_classes_1dataset + [non_object_weight] + self.class_weight_2dataset = [1] * num_classes_2dataset + [non_object_weight] + self.loss_weight = loss_weight + self.fix_dice_loss_weight = fix_dice_loss_weight + self.iter_matcher = iter_matcher + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def get_layer_loss(self, aux_outputs, insts, indices=None): + cls_preds = aux_outputs['cls_preds'] + pred_scores = aux_outputs['scores'] + pred_masks = aux_outputs['masks'] + + if indices is None: + indices = [] + for i in range(len(insts)): + pred_instances = InstanceData_( + scores=cls_preds[i], + masks=pred_masks[i]) + gt_instances = InstanceData_( + labels=insts[i].labels_3d, + masks=insts[i].sp_masks) + if insts[i].get('query_masks') is not None: + gt_instances.query_masks = insts[i].query_masks + indices.append(self.matcher(pred_instances, gt_instances)) + + cls_losses = [] + for cls_pred, inst, (idx_q, idx_gt) in zip(cls_preds, insts, indices): + n_classes = cls_pred.shape[1] - 1 + cls_target = cls_pred.new_full( + (len(cls_pred),), n_classes, dtype=torch.long) + cls_target[idx_q] = inst.labels_3d[idx_gt] + if cls_pred.shape[1] == self.num_classes_1dataset + 1: + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, + cls_pred.new_tensor(self.class_weight_1dataset))) + elif cls_pred.shape[1] == self.num_classes_2dataset + 1: + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, + cls_pred.new_tensor(self.class_weight_2dataset))) + else: + raise RuntimeError( + f'Invalid classes number {cls_pred.shape[1]}.') + + cls_loss = torch.mean(torch.stack(cls_losses)) + + # 3 other losses + score_losses, mask_bce_losses, mask_dice_losses = [], [], [] + for mask, score, inst, (idx_q, idx_gt) in zip( + pred_masks, pred_scores, insts, indices): + if len(inst) == 0: + continue + + pred_mask = mask[idx_q] + tgt_mask = inst.sp_masks[idx_gt] + mask_bce_losses.append(F.binary_cross_entropy_with_logits( + pred_mask, tgt_mask.float())) + mask_dice_losses.append(dice_loss(pred_mask, tgt_mask.float())) + + # check if skip objectness loss + if score is None: + continue + + pred_score = score[idx_q] + with torch.no_grad(): + tgt_score = get_iou(pred_mask, tgt_mask).unsqueeze(1) + + filter_id, _ = torch.where(tgt_score > 0.5) + if filter_id.numel(): + tgt_score = tgt_score[filter_id] + pred_score = pred_score[filter_id] + score_losses.append(F.mse_loss(pred_score, tgt_score)) + # todo: actually .mean() should be better + if len(score_losses): + score_loss = torch.stack(score_losses).sum() / len(pred_masks) + else: + score_loss = 0 + mask_bce_loss = torch.stack(mask_bce_losses).sum() / len(pred_masks) + mask_dice_loss = torch.stack(mask_dice_losses).sum() / len(pred_masks) + + loss = ( + self.loss_weight[0] * cls_loss + + self.loss_weight[1] * mask_bce_loss + + self.loss_weight[2] * mask_dice_loss + + self.loss_weight[3] * score_loss) + + return loss + + # todo: refactor pred to InstanceData + def __call__(self, pred, insts): + """Loss main function. + + Args: + pred (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_gts, n_classes + 1); + List `scores` of len batch_size each of shape (n_gts, 1); + List `masks` of len batch_size each of shape (n_gts, n_points). + Dict `aux_preds` with list of cls_preds, scores, and masks. + """ + cls_preds = pred['cls_preds'] + pred_scores = pred['scores'] + pred_masks = pred['masks'] + + # match + indices = [] + for i in range(len(insts)): + pred_instances = InstanceData_( + scores=cls_preds[i], + masks=pred_masks[i]) + gt_instances = InstanceData_( + labels=insts[i].labels_3d, + masks=insts[i].sp_masks) + if insts[i].get('query_masks') is not None: + gt_instances.query_masks = insts[i].query_masks + indices.append(self.matcher(pred_instances, gt_instances)) + + # class loss + cls_losses = [] + for cls_pred, inst, (idx_q, idx_gt) in zip(cls_preds, insts, indices): + n_classes = cls_pred.shape[1] - 1 + cls_target = cls_pred.new_full( + (len(cls_pred),), n_classes, dtype=torch.long) + cls_target[idx_q] = inst.labels_3d[idx_gt] + if cls_pred.shape[1] == self.num_classes_1dataset + 1: + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, + cls_pred.new_tensor(self.class_weight_1dataset))) + elif cls_pred.shape[1] == self.num_classes_2dataset + 1: + cls_losses.append(F.cross_entropy( + cls_pred, cls_target, + cls_pred.new_tensor(self.class_weight_2dataset))) + else: + raise RuntimeError( + f'Invalid classes number {cls_pred.shape[1]}.') + + cls_loss = torch.mean(torch.stack(cls_losses)) + + # 3 other losses + score_losses, mask_bce_losses, mask_dice_losses = [], [], [] + for mask, score, inst, (idx_q, idx_gt) in zip(pred_masks, pred_scores, + insts, indices): + if len(inst) == 0: + continue + pred_mask = mask[idx_q] + tgt_mask = inst.sp_masks[idx_gt] + mask_bce_losses.append(F.binary_cross_entropy_with_logits( + pred_mask, tgt_mask.float())) + mask_dice_losses.append(dice_loss(pred_mask, tgt_mask.float())) + + # check if skip objectness loss + if score is None: + continue + + pred_score = score[idx_q] + with torch.no_grad(): + tgt_score = get_iou(pred_mask, tgt_mask).unsqueeze(1) + + filter_id, _ = torch.where(tgt_score > 0.5) + if filter_id.numel(): + tgt_score = tgt_score[filter_id] + pred_score = pred_score[filter_id] + score_losses.append(F.mse_loss(pred_score, tgt_score)) + # todo: actually .mean() should be better + if len(score_losses): + score_loss = torch.stack(score_losses).sum() / len(pred_masks) + else: + score_loss = 0 + mask_bce_loss = torch.stack(mask_bce_losses).sum() / len(pred_masks) + mask_dice_loss = torch.stack(mask_dice_losses).sum() + + if self.fix_dice_loss_weight: + mask_dice_loss = mask_dice_loss / len(pred_masks) * 4 + + loss = ( + self.loss_weight[0] * cls_loss + + self.loss_weight[1] * mask_bce_loss + + self.loss_weight[2] * mask_dice_loss + + self.loss_weight[3] * score_loss) + + if 'aux_outputs' in pred: + if self.iter_matcher: + indices = None + for i, aux_outputs in enumerate(pred['aux_outputs']): + loss += self.get_layer_loss(aux_outputs, insts, indices) + + return {'inst_loss': loss} diff --git a/oneformer3d/instance_seg_eval.py b/oneformer3d/instance_seg_eval.py new file mode 100644 index 0000000..748a4e0 --- /dev/null +++ b/oneformer3d/instance_seg_eval.py @@ -0,0 +1,131 @@ +# Copied from mmdet3d/evaluation/functional/instance_seg_eval.py +# We fix instance seg metric to accept boolean instance seg mask of +# shape (n_points, n_instances) instead of integer mask of shape +# (n_points, ). +import numpy as np +from mmengine.logging import print_log +from terminaltables import AsciiTable + +from .evaluate_semantic_instance import scannet_eval + + +# 1) We fix this line: info[file_name]['mask'] = mask[i]. +# 2) mask.max() + 1 in for is always equal to 2. +# We have changed it to mask.shape[0] for iterating over all masks. +def aggregate_predictions(masks, labels, scores, valid_class_ids): + """Maps predictions to ScanNet evaluator format. + + Args: + masks (list[torch.Tensor]): Per scene predicted instance masks. + labels (list[torch.Tensor]): Per scene predicted instance labels. + scores (list[torch.Tensor]): Per scene predicted instance scores. + valid_class_ids (tuple[int]): Ids of valid categories. + + Returns: + list[dict]: Per scene aggregated predictions. + """ + infos = [] + for id, (mask, label, score) in enumerate(zip(masks, labels, scores)): + mask = mask.numpy() + label = label.numpy() + score = score.numpy() + info = dict() + for i in range(mask.shape[0]): + # match pred_instance['filename'] from assign_instances_for_scan + file_name = f'{id}_{i}' + info[file_name] = dict() + info[file_name]['mask'] = mask[i] + info[file_name]['label_id'] = valid_class_ids[label[i]] + info[file_name]['conf'] = score[i] + infos.append(info) + return infos + + +# For some reason the inputs are not torch.Tensor but np.ndarray. +# We just remove torch -> numpy conversion here. +def rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids): + """Maps gt instance and semantic masks to instance masks for ScanNet + evaluator. + + Args: + gt_semantic_masks (list[np.ndarray]): Per scene gt semantic masks. + gt_instance_masks (list[np.ndarray]): Per scene gt instance masks. + valid_class_ids (tuple[int]): Ids of valid categories. + + Returns: + list[np.array]: Per scene instance masks. + """ + renamed_instance_masks = [] + for semantic_mask, instance_mask in zip(gt_semantic_masks, + gt_instance_masks): + unique = np.unique(instance_mask) + assert len(unique) < 1000 + for i in unique: + semantic_instance = semantic_mask[instance_mask == i] + semantic_unique = np.unique(semantic_instance) + assert len(semantic_unique) == 1 + if semantic_unique[0] in valid_class_ids: + instance_mask[instance_mask == + i] = 1000 * semantic_unique[0] + i + renamed_instance_masks.append(instance_mask) + return renamed_instance_masks + +def instance_seg_eval(gt_semantic_masks, + gt_instance_masks, + pred_instance_masks, + pred_instance_labels, + pred_instance_scores, + valid_class_ids, + class_labels, + options=None, + logger=None): + """Instance Segmentation Evaluation. + + Evaluate the result of the instance segmentation. + + Args: + gt_semantic_masks (list[torch.Tensor]): Ground truth semantic masks. + gt_instance_masks (list[torch.Tensor]): Ground truth instance masks. + pred_instance_masks (list[torch.Tensor]): Predicted instance masks. + pred_instance_labels (list[torch.Tensor]): Predicted instance labels. + pred_instance_scores (list[torch.Tensor]): Predicted instance labels. + valid_class_ids (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Names of valid categories. + options (dict, optional): Additional options. Keys may contain: + `overlaps`, `min_region_sizes`, `distance_threshes`, + `distance_confs`. Default: None. + logger (logging.Logger | str, optional): The way to print the mAP + summary. See `mmdet.utils.print_log()` for details. Default: None. + + Returns: + dict[str, float]: Dict of results. + """ + assert len(valid_class_ids) == len(class_labels) + id_to_label = { + valid_class_ids[i]: class_labels[i] + for i in range(len(valid_class_ids)) + } + preds = aggregate_predictions( + masks=pred_instance_masks, + labels=pred_instance_labels, + scores=pred_instance_scores, + valid_class_ids=valid_class_ids) + gts = rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids) + metrics = scannet_eval( + preds=preds, + gts=gts, + options=options, + valid_class_ids=valid_class_ids, + class_labels=class_labels, + id_to_label=id_to_label) + header = ['classes', 'AP_0.25', 'AP_0.50', 'AP', 'Prec_0.50', 'Rec_0.50'] + rows = [] + for label, data in metrics['classes'].items(): + aps = [data['ap25%'], data['ap50%'], data['ap'], data['prec50%'], data['rec50%']] + rows.append([label] + [f'{ap:.4f}' for ap in aps]) + aps = metrics['all_ap_25%'], metrics['all_ap_50%'], metrics['all_ap'], metrics['all_prec_50%'], metrics['all_rec_50%'] + footer = ['Overall'] + [f'{ap:.4f}' for ap in aps] + table = AsciiTable([header] + rows + [footer]) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) + return metrics diff --git a/oneformer3d/instance_seg_metric.py b/oneformer3d/instance_seg_metric.py new file mode 100644 index 0000000..eab50b5 --- /dev/null +++ b/oneformer3d/instance_seg_metric.py @@ -0,0 +1,106 @@ +# Copied from mmdet3d/evaluation/metrics/instance_seg_metric.py +from mmengine.logging import MMLogger + +from mmdet3d.evaluation import InstanceSegMetric +from mmdet3d.registry import METRICS +from .instance_seg_eval import instance_seg_eval + + +@METRICS.register_module() +class SPInstanceSegMetric(InstanceSegMetric): + """The only difference with InstanceSegMetric is that following ScanNet + evaluator we accept instance prediction as a boolean tensor of shape + (n_points, n_instances) instead of integer tensor of shape (n_points, ). + + For this purpose we only replace instance_seg_eval call. + """ + + def compute_metrics(self, results): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + self.classes = self.dataset_meta['classes'] + self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] + + gt_semantic_masks = [] + gt_instance_masks = [] + pred_instance_masks = [] + pred_instance_labels = [] + pred_instance_scores = [] + + for eval_ann, single_pred_results in results: + gt_semantic_masks.append(eval_ann['pts_semantic_mask']) + gt_instance_masks.append(eval_ann['pts_instance_mask']) + pred_instance_masks.append( + single_pred_results['pts_instance_mask']) + pred_instance_labels.append(single_pred_results['instance_labels']) + pred_instance_scores.append(single_pred_results['instance_scores']) + + ret_dict = instance_seg_eval( + gt_semantic_masks, + gt_instance_masks, + pred_instance_masks, + pred_instance_labels, + pred_instance_scores, + valid_class_ids=self.valid_class_ids, + class_labels=self.classes, + logger=logger) + + return ret_dict + + +@METRICS.register_module() +class SPS3DISInstanceSegMetric(InstanceSegMetric): + """The only difference with SPInstanceSegMetric is that we shift + predicted and gt class labels with +1, as ScanNet evaluator ignores + gt label of 0. + """ + + def compute_metrics(self, results): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + self.classes = self.dataset_meta['classes'] + self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] + + gt_semantic_masks = [] + gt_instance_masks = [] + pred_instance_masks = [] + pred_instance_labels = [] + pred_instance_scores = [] + + for eval_ann, single_pred_results in results: + gt_semantic_masks.append(eval_ann['pts_semantic_mask'] + 1) + gt_instance_masks.append(eval_ann['pts_instance_mask']) + pred_instance_masks.append( + single_pred_results['pts_instance_mask']) + pred_instance_labels.append(single_pred_results['instance_labels']) + pred_instance_scores.append(single_pred_results['instance_scores']) + + ret_dict = instance_seg_eval( + gt_semantic_masks, + gt_instance_masks, + pred_instance_masks, + pred_instance_labels, + pred_instance_scores, + valid_class_ids=[class_id + 1 for class_id in self.valid_class_ids], + class_labels=self.classes, + logger=logger) + + return ret_dict diff --git a/oneformer3d/loading.py b/oneformer3d/loading.py new file mode 100644 index 0000000..0ea0c12 --- /dev/null +++ b/oneformer3d/loading.py @@ -0,0 +1,106 @@ +# Adapted from mmdet3d/datasets/transforms/loading.py +import mmengine +import numpy as np + +from mmdet3d.datasets.transforms import LoadAnnotations3D +from mmdet3d.datasets.transforms.loading import get +from mmdet3d.datasets.transforms.loading import NormalizePointsColor +from mmdet3d.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class LoadAnnotations3D_(LoadAnnotations3D): + """Just add super point mask loading. + + Args: + with_sp_mask_3d (bool): Whether to load super point maks. + """ + + def __init__(self, with_sp_mask_3d, **kwargs): + self.with_sp_mask_3d = with_sp_mask_3d + super().__init__(**kwargs) + + def _load_sp_pts_3d(self, results): + """Private function to load 3D superpoints mask annotations. + + Args: + results (dict): Result dict from :obj:`mmdet3d.CustomDataset`. + + Returns: + dict: The dict containing loaded 3D mask annotations. + """ + sp_pts_mask_path = results['super_pts_path'] + + try: + mask_bytes = get( + sp_pts_mask_path, backend_args=self.backend_args) + # add .copy() to fix read-only bug + sp_pts_mask = np.frombuffer( + mask_bytes, dtype=np.int64).copy() + except ConnectionError: + mmengine.check_file_exist(sp_pts_mask_path) + sp_pts_mask = np.fromfile( + sp_pts_mask_path, dtype=np.int64) + + results['sp_pts_mask'] = sp_pts_mask + + # 'eval_ann_info' will be passed to evaluator + if 'eval_ann_info' in results: + results['eval_ann_info']['sp_pts_mask'] = sp_pts_mask + results['eval_ann_info']['lidar_idx'] = \ + sp_pts_mask_path.split("/")[-1][:-4] + return results + + def transform(self, results: dict) -> dict: + """Function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmdet3d.CustomDataset`. + + Returns: + dict: The dict containing loaded 3D bounding box, label, mask and + semantic segmentation annotations. + """ + results = super().transform(results) + if self.with_sp_mask_3d: + results = self._load_sp_pts_3d(results) + return results + + +@TRANSFORMS.register_module() +class NormalizePointsColor_(NormalizePointsColor): + """Just add color_std parameter. + + Args: + color_mean (list[float]): Mean color of the point cloud. + color_std (list[float]): Std color of the point cloud. + Default value is from SPFormer preprocessing. + """ + + def __init__(self, color_mean, color_std=127.5): + self.color_mean = color_mean + self.color_std = color_std + + def transform(self, input_dict): + """Call function to normalize color of points. + + Args: + results (dict): Result dict containing point clouds data. + + Returns: + dict: The result dict containing the normalized points. + Updated key and value are described below. + - points (:obj:`BasePoints`): Points after color normalization. + """ + points = input_dict['points'] + assert points.attribute_dims is not None and \ + 'color' in points.attribute_dims.keys(), \ + 'Expect points have color attribute' + if self.color_mean is not None: + points.color = points.color - \ + points.color.new_tensor(self.color_mean) + if self.color_std is not None: + points.color = points.color / \ + points.color.new_tensor(self.color_std) + input_dict['points'] = points + return input_dict diff --git a/oneformer3d/mask_matrix_nms.py b/oneformer3d/mask_matrix_nms.py new file mode 100644 index 0000000..59f45b8 --- /dev/null +++ b/oneformer3d/mask_matrix_nms.py @@ -0,0 +1,122 @@ +# This is a copy from mmdet/models/layers/matrix_nms.py. +# We just change the input shape of `masks` tensor. +import torch + + +def mask_matrix_nms(masks, + labels, + scores, + filter_thr=-1, + nms_pre=-1, + max_num=-1, + kernel='gaussian', + sigma=2.0, + mask_area=None): + """Matrix NMS for multi-class masks. + + Args: + masks (Tensor): Has shape (num_instances, m) + labels (Tensor): Labels of corresponding masks, + has shape (num_instances,). + scores (Tensor): Mask scores of corresponding masks, + has shape (num_instances). + filter_thr (float): Score threshold to filter the masks + after matrix nms. Default: -1, which means do not + use filter_thr. + nms_pre (int): The max number of instances to do the matrix nms. + Default: -1, which means do not use nms_pre. + max_num (int, optional): If there are more than max_num masks after + matrix, only top max_num will be kept. Default: -1, which means + do not use max_num. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + mask_area (Tensor): The sum of seg_masks. + + Returns: + tuple(Tensor): Processed mask results. + + - scores (Tensor): Updated scores, has shape (n,). + - labels (Tensor): Remained labels, has shape (n,). + - masks (Tensor): Remained masks, has shape (n, m). + - keep_inds (Tensor): The indices number of + the remaining mask in the input mask, has shape (n,). + """ + assert len(labels) == len(masks) == len(scores) + if len(labels) == 0: + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-1:]), labels.new_zeros(0) + if mask_area is None: + mask_area = masks.sum(1).float() + else: + assert len(masks) == len(mask_area) + + # sort and keep top nms_pre + scores, sort_inds = torch.sort(scores, descending=True) + + keep_inds = sort_inds + if nms_pre > 0 and len(sort_inds) > nms_pre: + sort_inds = sort_inds[:nms_pre] + keep_inds = keep_inds[:nms_pre] + scores = scores[:nms_pre] + masks = masks[sort_inds] + mask_area = mask_area[sort_inds] + labels = labels[sort_inds] + + num_masks = len(labels) + flatten_masks = masks.reshape(num_masks, -1).float() + # inter. + inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) + expanded_mask_area = mask_area.expand(num_masks, num_masks) + # Upper triangle iou matrix. + iou_matrix = (inter_matrix / + (expanded_mask_area + expanded_mask_area.transpose(1, 0) - + inter_matrix)).triu(diagonal=1) + # label_specific matrix. + expanded_labels = labels.expand(num_masks, num_masks) + # Upper triangle label matrix. + label_matrix = (expanded_labels == expanded_labels.transpose( + 1, 0)).triu(diagonal=1) + + # IoU compensation + compensate_iou, _ = (iou_matrix * label_matrix).max(0) + compensate_iou = compensate_iou.expand(num_masks, + num_masks).transpose(1, 0) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # Calculate the decay_coefficient + if kernel == 'gaussian': + decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) + compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) + decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) + elif kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient, _ = decay_matrix.min(0) + else: + raise NotImplementedError( + f'{kernel} kernel is not supported in matrix nms!') + # update the score. + scores = scores * decay_coefficient + + if filter_thr > 0: + keep = scores >= filter_thr + keep_inds = keep_inds[keep] + if not keep.any(): + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-1:]), labels.new_zeros(0) + masks = masks[keep] + scores = scores[keep] + labels = labels[keep] + + # sort and keep top max_num + scores, sort_inds = torch.sort(scores, descending=True) + keep_inds = keep_inds[sort_inds] + if max_num > 0 and len(sort_inds) > max_num: + sort_inds = sort_inds[:max_num] + keep_inds = keep_inds[:max_num] + scores = scores[:max_num] + masks = masks[sort_inds] + labels = labels[sort_inds] + + return scores, labels, masks, keep_inds diff --git a/oneformer3d/mink_unet.py b/oneformer3d/mink_unet.py new file mode 100644 index 0000000..b1de6bf --- /dev/null +++ b/oneformer3d/mink_unet.py @@ -0,0 +1,597 @@ +# Adapted from JonasSchult/Mask3D. +from enum import Enum +from collections.abc import Sequence +import torch.nn as nn +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiOps as me +from MinkowskiEngine import MinkowskiReLU + +from mmengine.model import BaseModule +from mmdet3d.registry import MODELS + + +class NormType(Enum): + BATCH_NORM = 0 + INSTANCE_NORM = 1 + INSTANCE_BATCH_NORM = 2 + + +def get_norm(norm_type, n_channels, D, bn_momentum=0.1): + if norm_type == NormType.BATCH_NORM: + return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) + elif norm_type == NormType.INSTANCE_NORM: + return ME.MinkowskiInstanceNorm(n_channels) + elif norm_type == NormType.INSTANCE_BATCH_NORM: + return nn.Sequential( + ME.MinkowskiInstanceNorm(n_channels), + ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum)) + else: + raise ValueError(f"Norm type: {norm_type} not supported") + + +class ConvType(Enum): + """ + Define the kernel region type + """ + + HYPERCUBE = 0, "HYPERCUBE" + SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE" + SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE" + HYPERCROSS = 3, "HYPERCROSS" + SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS" + SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS" + SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = ( + 6, + "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS") + + def __new__(cls, value, name): + member = object.__new__(cls) + member._value_ = value + member.fullname = name + return member + + def __int__(self): + return self.value + + +# Convert the ConvType var to a RegionType var +conv_to_region_type = { + # kernel_size = [k, k, k, 1] + ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CUBE +} + +# int_to_region_type = {m.value: m for m in ME.RegionType} +int_to_region_type = {m: ME.RegionType(m) for m in range(3)} + + +def convert_region_type(region_type): + """Convert the integer region_type to the corresponding + RegionType enum object. + """ + return int_to_region_type[region_type] + + +def convert_conv_type(conv_type, kernel_size, D): + assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" + region_type = conv_to_region_type[conv_type] + axis_types = None + if conv_type == ConvType.SPATIAL_HYPERCUBE: + # No temporal convolution + if isinstance(kernel_size, Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [ + kernel_size, + ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.HYPERCUBE: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIAL_HYPERCROSS: + if isinstance(kernel_size, Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [ + kernel_size, + ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.HYPERCROSS: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: + # Define the CUBIC conv kernel for spatial dims + # and CROSS conv for temp dim + axis_types = [ + ME.RegionType.HYPER_CUBE, + ] * 3 + if D == 4: + axis_types.append(ME.RegionType.HYPER_CROSS) + return region_type, axis_types, kernel_size + + +def conv(in_planes, + out_planes, + kernel_size, + stride=1, + dilation=1, + bias=False, + conv_type=ConvType.HYPERCUBE, + D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type( + conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + stride, + dilation, + region_type=region_type, + axis_types=None, # axis_types JONAS + dimension=D) + + return ME.MinkowskiConvolution( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D) + + +def conv_tr(in_planes, + out_planes, + kernel_size, + upsample_stride=1, + dilation=1, + bias=False, + conv_type=ConvType.HYPERCUBE, + D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type( + conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, + upsample_stride, + dilation, + region_type=region_type, + axis_types=axis_types, + dimension=D) + + return ME.MinkowskiConvolutionTranspose( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=upsample_stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D) + + +class BasicBlockBase(nn.Module): + expansion = 1 + NORM_TYPE = NormType.BATCH_NORM + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + conv_type=ConvType.HYPERCUBE, + bn_momentum=0.1, + D=3): + super().__init__() + + self.conv1 = conv( + inplanes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + conv_type=conv_type, + D=D) + self.norm1 = get_norm( + self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + self.conv2 = conv( + planes, + planes, + kernel_size=3, + stride=1, + dilation=dilation, + bias=False, + conv_type=conv_type, + D=D) + self.norm2 = get_norm( + self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) + self.relu = MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicBlock(BasicBlockBase): + NORM_TYPE = NormType.BATCH_NORM + + +class Res16UNetBase(BaseModule): + """Base class for Minkowski U-Net. + + Args: + in_channels (int): Number of input channels. + out_channles (int): Number of output channels. + config (dict): Extra parameters including + `dilations`, `conv1_kernel_size`, `bn_momentum`. + D (int): Conv dimension. + """ + BLOCK = None + PLANES = (32, 64, 128, 256, 256, 256, 256, 256) + DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + INIT_DIM = 32 + OUT_PIXEL_DIST = 1 + NORM_TYPE = NormType.BATCH_NORM + NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE + CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS + + def __init__(self, + in_channels, + out_channels, + config, + D=3, + **kwargs): + self.D = D + super().__init__() + self.network_initialization(in_channels, out_channels, config, D) + self.weight_initialization() + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_type=NormType.BATCH_NORM, + bn_momentum=0.1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + D=self.D), + get_norm( + norm_type, + planes * block.expansion, + D=self.D, + bn_momentum=bn_momentum)) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride=stride, + dilation=dilation, + downsample=downsample, + conv_type=self.CONV_TYPE, + D=self.D)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + stride=1, + dilation=dilation, + conv_type=self.CONV_TYPE, + D=self.D)) + + return nn.Sequential(*layers) + + def network_initialization(self, in_channels, out_channels, config, D): + # Setup net_metadata + dilations = self.DILATIONS + bn_momentum = config.bn_momentum + + def space_n_time_m(n, m): + return n if D == 3 else [n, n, n, m] + + if D == 4: + self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) + + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv0p1s1 = conv( + in_channels, + self.inplanes, + kernel_size=space_n_time_m(config.conv1_kernel_size, 1), + stride=1, + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + + self.bn0 = get_norm( + self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + + self.conv1p1s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bn1 = get_norm( + self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block1 = self._make_layer( + self.BLOCK, + self.PLANES[0], + self.LAYERS[0], + dilation=dilations[0], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + + self.conv2p2s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bn2 = get_norm( + self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block2 = self._make_layer( + self.BLOCK, + self.PLANES[1], + self.LAYERS[1], + dilation=dilations[1], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + + self.conv3p4s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bn3 = get_norm( + self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block3 = self._make_layer( + self.BLOCK, + self.PLANES[2], + self.LAYERS[2], + dilation=dilations[2], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + + self.conv4p8s2 = conv( + self.inplanes, + self.inplanes, + kernel_size=space_n_time_m(2, 1), + stride=space_n_time_m(2, 1), + dilation=1, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bn4 = get_norm( + self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) + self.block4 = self._make_layer( + self.BLOCK, + self.PLANES[3], + self.LAYERS[3], + dilation=dilations[3], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + self.convtr4p16s2 = conv_tr( + self.inplanes, + self.PLANES[4], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bntr4 = get_norm( + self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer( + self.BLOCK, + self.PLANES[4], + self.LAYERS[4], + dilation=dilations[4], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + self.convtr5p8s2 = conv_tr( + self.inplanes, + self.PLANES[5], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bntr5 = get_norm( + self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer( + self.BLOCK, + self.PLANES[5], + self.LAYERS[5], + dilation=dilations[5], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + self.convtr6p4s2 = conv_tr( + self.inplanes, + self.PLANES[6], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bntr6 = get_norm( + self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) + + self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion + self.block7 = self._make_layer( + self.BLOCK, + self.PLANES[6], + self.LAYERS[6], + dilation=dilations[6], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + self.convtr7p2s2 = conv_tr( + self.inplanes, + self.PLANES[7], + kernel_size=space_n_time_m(2, 1), + upsample_stride=space_n_time_m(2, 1), + dilation=1, + bias=False, + conv_type=self.NON_BLOCK_CONV_TYPE, + D=D) + self.bntr7 = get_norm( + self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum) + + self.inplanes = self.PLANES[7] + self.INIT_DIM + self.block8 = self._make_layer( + self.BLOCK, + self.PLANES[7], + self.LAYERS[7], + dilation=dilations[7], + norm_type=self.NORM_TYPE, + bn_momentum=bn_momentum) + + self.final = conv( + self.PLANES[7], + out_channels, + kernel_size=1, + stride=1, + bias=True, + D=D) + self.relu = MinkowskiReLU(inplace=True) + + def forward(self, x): + feature_maps = [] + + out = self.conv0p1s1(x) + out = self.bn0(out) + out_p1 = self.relu(out) + + out = self.conv1p1s2(out_p1) + out = self.bn1(out) + out = self.relu(out) + out_b1p2 = self.block1(out) + + out = self.conv2p2s2(out_b1p2) + out = self.bn2(out) + out = self.relu(out) + out_b2p4 = self.block2(out) + + out = self.conv3p4s2(out_b2p4) + out = self.bn3(out) + out = self.relu(out) + out_b3p8 = self.block3(out) + + # pixel_dist=16 + out = self.conv4p8s2(out_b3p8) + out = self.bn4(out) + out = self.relu(out) + out = self.block4(out) + + feature_maps.append(out) + + # pixel_dist=8 + out = self.convtr4p16s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = me.cat(out, out_b3p8) + out = self.block5(out) + + feature_maps.append(out) + + # pixel_dist=4 + out = self.convtr5p8s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = me.cat(out, out_b2p4) + out = self.block6(out) + + feature_maps.append(out) + + # pixel_dist=2 + out = self.convtr6p4s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = me.cat(out, out_b1p2) + out = self.block7(out) + + feature_maps.append(out) + + # pixel_dist=1 + out = self.convtr7p2s2(out) + out = self.bntr7(out) + out = self.relu(out) + + out = me.cat(out, out_p1) + out = self.block8(out) + + feature_maps.append(out) + + return out + + +class Res16UNet34(Res16UNetBase): + BLOCK = BasicBlock + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +@MODELS.register_module() +class Res16UNet34C(Res16UNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) diff --git a/oneformer3d/oneformer3d.py b/oneformer3d/oneformer3d.py new file mode 100644 index 0000000..fdfa33e --- /dev/null +++ b/oneformer3d/oneformer3d.py @@ -0,0 +1,1346 @@ +import torch +import torch.nn.functional as F +import spconv.pytorch as spconv +from torch_scatter import scatter_mean +import MinkowskiEngine as ME + +from mmdet3d.registry import MODELS +from mmdet3d.structures import PointData +from mmdet3d.models import Base3DDetector +from .mask_matrix_nms import mask_matrix_nms + + +class ScanNetOneFormer3DMixin: + """Class contains common methods for ScanNet and ScanNet200.""" + + def predict_by_feat(self, out, superpoints): + """Predict instance, semantic, and panoptic masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `cls_preds` of shape (n_queries, n_instance_classes + 1), + `sem_preds` of shape (n_queries, n_semantic_classes + 1), + `masks` of shape (n_queries, n_points), + `scores` of shape (n_queris, 1) or None. + superpoints (Tensor): of shape (n_raw_points,). + + Returns: + List[PointData]: of len 1 with `pts_semantic_mask`, + `pts_instance_mask`, `instance_labels`, `instance_scores`. + """ + inst_res = self.predict_by_feat_instance( + out, superpoints, self.test_cfg.inst_score_thr) + sem_res = self.predict_by_feat_semantic(out, superpoints) + pan_res = self.predict_by_feat_panoptic(out, superpoints) + + pts_semantic_mask = [sem_res.cpu().numpy(), pan_res[0].cpu().numpy()] + pts_instance_mask = [inst_res[0].cpu().bool().numpy(), + pan_res[1].cpu().numpy()] + + return [ + PointData( + pts_semantic_mask=pts_semantic_mask, + pts_instance_mask=pts_instance_mask, + instance_labels=inst_res[1].cpu().numpy(), + instance_scores=inst_res[2].cpu().numpy())] + + def predict_by_feat_instance(self, out, superpoints, score_threshold): + """Predict instance masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `cls_preds` of shape (n_queries, n_instance_classes + 1), + `masks` of shape (n_queries, n_points), + `scores` of shape (n_queris, 1) or None. + superpoints (Tensor): of shape (n_raw_points,). + score_threshold (float): minimal score for predicted object. + + Returns: + Tuple: + Tensor: mask_preds of shape (n_preds, n_raw_points), + Tensor: labels of shape (n_preds,), + Tensor: scors of shape (n_preds,). + """ + cls_preds = out['cls_preds'][0] + pred_masks = out['masks'][0] + + scores = F.softmax(cls_preds, dim=-1)[:, :-1] + if out['scores'][0] is not None: + scores *= out['scores'][0] + labels = torch.arange( + self.num_classes, + device=scores.device).unsqueeze(0).repeat( + len(cls_preds), 1).flatten(0, 1) + scores, topk_idx = scores.flatten(0, 1).topk( + self.test_cfg.topk_insts, sorted=False) + labels = labels[topk_idx] + + topk_idx = torch.div(topk_idx, self.num_classes, rounding_mode='floor') + mask_pred = pred_masks + mask_pred = mask_pred[topk_idx] + mask_pred_sigmoid = mask_pred.sigmoid() + + if self.test_cfg.get('obj_normalization', None): + mask_scores = (mask_pred_sigmoid * (mask_pred > 0)).sum(1) / \ + ((mask_pred > 0).sum(1) + 1e-6) + scores = scores * mask_scores + + if self.test_cfg.get('nms', None): + kernel = self.test_cfg.matrix_nms_kernel + scores, labels, mask_pred_sigmoid, _ = mask_matrix_nms( + mask_pred_sigmoid, labels, scores, kernel=kernel) + + mask_pred_sigmoid = mask_pred_sigmoid[:, superpoints] + mask_pred = mask_pred_sigmoid > self.test_cfg.sp_score_thr + + # score_thr + score_mask = scores > score_threshold + scores = scores[score_mask] + labels = labels[score_mask] + mask_pred = mask_pred[score_mask] + + # npoint_thr + mask_pointnum = mask_pred.sum(1) + npoint_mask = mask_pointnum > self.test_cfg.npoint_thr + scores = scores[npoint_mask] + labels = labels[npoint_mask] + mask_pred = mask_pred[npoint_mask] + + return mask_pred, labels, scores + + def predict_by_feat_semantic(self, out, superpoints, classes=None): + """Predict semantic masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `sem_preds` of shape (n_queries, n_semantic_classes + 1). + superpoints (Tensor): of shape (n_raw_points,). + classes (List[int] or None): semantic (stuff) class ids. + + Returns: + Tensor: semantic preds of shape + (n_raw_points, n_semantic_classe + 1), + """ + if classes is None: + classes = list(range(out['sem_preds'][0].shape[1] - 1)) + return out['sem_preds'][0][:, classes].argmax(dim=1)[superpoints] + + def predict_by_feat_panoptic(self, out, superpoints): + """Predict panoptic masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `cls_preds` of shape (n_queries, n_instance_classes + 1), + `sem_preds` of shape (n_queries, n_semantic_classes + 1), + `masks` of shape (n_queries, n_points), + `scores` of shape (n_queris, 1) or None. + superpoints (Tensor): of shape (n_raw_points,). + + Returns: + Tuple: + Tensor: semantic mask of shape (n_raw_points,), + Tensor: instance mask of shape (n_raw_points,). + """ + sem_map = self.predict_by_feat_semantic( + out, superpoints, self.test_cfg.stuff_classes) + mask_pred, labels, scores = self.predict_by_feat_instance( + out, superpoints, self.test_cfg.pan_score_thr) + if mask_pred.shape[0] == 0: + return sem_map, sem_map + + scores, idxs = scores.sort() + labels = labels[idxs] + mask_pred = mask_pred[idxs] + + n_stuff_classes = len(self.test_cfg.stuff_classes) + inst_idxs = torch.arange( + n_stuff_classes, + mask_pred.shape[0] + n_stuff_classes, + device=mask_pred.device).view(-1, 1) + insts = inst_idxs * mask_pred + things_inst_mask, idxs = insts.max(axis=0) + things_sem_mask = labels[idxs] + n_stuff_classes + + inst_idxs, num_pts = things_inst_mask.unique(return_counts=True) + for inst, pts in zip(inst_idxs, num_pts): + if pts <= self.test_cfg.npoint_thr and inst != 0: + things_inst_mask[things_inst_mask == inst] = 0 + + things_sem_mask[things_inst_mask == 0] = 0 + + sem_map[things_inst_mask != 0] = 0 + inst_map = sem_map.clone() + inst_map += things_inst_mask + sem_map += things_sem_mask + return sem_map, inst_map + + def _select_queries(self, x, gt_instances): + """Select queries for train pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, n_channels). + gt_instances (List[InstanceData_]): of len batch_size. + Ground truth which can contain `labels` of shape (n_gts_i,), + `sp_masks` of shape (n_gts_i, n_points_i). + + Returns: + Tuple: + List[Tensor]: Queries of len batch_size, each queries of shape + (n_queries_i, n_channels). + List[InstanceData_]: of len batch_size, each updated + with `query_masks` of shape (n_gts_i, n_queries_i). + """ + queries = [] + for i in range(len(x)): + if self.query_thr < 1: + n = (1 - self.query_thr) * torch.rand(1) + self.query_thr + n = (n * len(x[i])).int() + ids = torch.randperm(len(x[i]))[:n].to(x[i].device) + queries.append(x[i][ids]) + gt_instances[i].query_masks = gt_instances[i].sp_masks[:, ids] + else: + queries.append(x[i]) + gt_instances[i].query_masks = gt_instances[i].sp_masks + return queries, gt_instances + + +@MODELS.register_module() +class ScanNetOneFormer3D(ScanNetOneFormer3DMixin, Base3DDetector): + r"""OneFormer3D for ScanNet dataset. + + Args: + in_channels (int): Number of input channels. + num_channels (int): NUmber of output channels. + voxel_size (float): Voxel size. + num_classes (int): Number of classes. + min_spatial_shape (int): Minimal shape for spconv tensor. + query_thr (float): We select >= query_thr * n_queries queries + for training and all n_queries for testing. + backbone (ConfigDict): Config dict of the backbone. + decoder (ConfigDict): Config dict of the decoder. + criterion (ConfigDict): Config dict of the criterion. + train_cfg (dict, optional): Config dict of training hyper-parameters. + Defaults to None. + test_cfg (dict, optional): Config dict of test hyper-parameters. + Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__(self, + in_channels, + num_channels, + voxel_size, + num_classes, + min_spatial_shape, + query_thr, + backbone=None, + decoder=None, + criterion=None, + train_cfg=None, + test_cfg=None, + data_preprocessor=None, + init_cfg=None): + super(Base3DDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.unet = MODELS.build(backbone) + self.decoder = MODELS.build(decoder) + self.criterion = MODELS.build(criterion) + self.voxel_size = voxel_size + self.num_classes = num_classes + self.min_spatial_shape = min_spatial_shape + self.query_thr = query_thr + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(in_channels, num_channels) + + def _init_layers(self, in_channels, num_channels): + self.input_conv = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, + num_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key='subm1')) + self.output_layer = spconv.SparseSequential( + torch.nn.BatchNorm1d(num_channels, eps=1e-4, momentum=0.1), + torch.nn.ReLU(inplace=True)) + + def extract_feat(self, x, superpoints, inverse_mapping, batch_offsets): + """Extract features from sparse tensor. + + Args: + x (SparseTensor): Input sparse tensor of shape + (n_points, in_channels). + superpoints (Tensor): of shape (n_points,). + inverse_mapping (Tesnor): of shape (n_points,). + batch_offsets (List[int]): of len batch_size + 1. + + Returns: + List[Tensor]: of len batch_size, + each of shape (n_points_i, n_channels). + """ + x = self.input_conv(x) + x, _ = self.unet(x) + x = self.output_layer(x) + x = scatter_mean(x.features[inverse_mapping], superpoints, dim=0) + out = [] + for i in range(len(batch_offsets) - 1): + out.append(x[batch_offsets[i]: batch_offsets[i + 1]]) + return out + + def collate(self, points, elastic_points=None): + """Collate batch of points to sparse tensor. + + Args: + points (List[Tensor]): Batch of points. + quantization_mode (SparseTensorQuantizationMode): Minkowski + quantization mode. We use random sample for training + and unweighted average for inference. + + Returns: + TensorField: Containing features and coordinates of a + sparse tensor. + """ + if elastic_points is None: + coordinates, features = ME.utils.batch_sparse_collate( + [((p[:, :3] - p[:, :3].min(0)[0]) / self.voxel_size, + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for p in points]) + else: + coordinates, features = ME.utils.batch_sparse_collate( + [((el_p - el_p.min(0)[0]), + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for el_p, p in zip(elastic_points, points)]) + + spatial_shape = torch.clip( + coordinates.max(0)[0][1:] + 1, self.min_spatial_shape) + field = ME.TensorField(features=features, coordinates=coordinates) + tensor = field.sparse() + coordinates = tensor.coordinates + features = tensor.features + inverse_mapping = field.inverse_mapping(tensor.coordinate_map_key) + + return coordinates, features, inverse_mapping, spatial_shape + + def _forward(*args, **kwargs): + """Implement abstract method of Base3DDetector.""" + pass + + def loss(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs dict and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instances_3d` and `gt_sem_seg_3d`. + Returns: + dict: A dictionary of loss components. + """ + batch_offsets = [0] + superpoint_bias = 0 + sp_gt_instances = [] + sp_pts_masks = [] + for i in range(len(batch_data_samples)): + gt_pts_seg = batch_data_samples[i].gt_pts_seg + + gt_pts_seg.sp_pts_mask += superpoint_bias + superpoint_bias = gt_pts_seg.sp_pts_mask.max().item() + 1 + batch_offsets.append(superpoint_bias) + + sp_gt_instances.append(batch_data_samples[i].gt_instances_3d) + sp_pts_masks.append(gt_pts_seg.sp_pts_mask) + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points'], + batch_inputs_dict.get('elastic_coords', None)) + + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + sp_pts_masks = torch.hstack(sp_pts_masks) + x = self.extract_feat( + x, sp_pts_masks, inverse_mapping, batch_offsets) + queries, sp_gt_instances = self._select_queries(x, sp_gt_instances) + x = self.decoder(x, queries) + loss = self.criterion(x, sp_gt_instances) + return loss + + def predict(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instance_3d` and `gt_sem_seg_3d`. + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input samples. Each Det3DDataSample contains 'pred_pts_seg'. + And the `pred_pts_seg` contains following keys. + - instance_scores (Tensor): Classification scores, has a shape + (num_instance, ) + - instance_labels (Tensor): Labels of instances, has a shape + (num_instances, ) + - pts_instance_mask (Tensor): Instance mask, has a shape + (num_points, num_instances) of type bool. + """ + batch_offsets = [0] + superpoint_bias = 0 + sp_pts_masks = [] + for i in range(len(batch_data_samples)): + gt_pts_seg = batch_data_samples[i].gt_pts_seg + gt_pts_seg.sp_pts_mask += superpoint_bias + superpoint_bias = gt_pts_seg.sp_pts_mask.max().item() + 1 + batch_offsets.append(superpoint_bias) + sp_pts_masks.append(gt_pts_seg.sp_pts_mask) + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points']) + + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + sp_pts_masks = torch.hstack(sp_pts_masks) + x = self.extract_feat( + x, sp_pts_masks, inverse_mapping, batch_offsets) + x = self.decoder(x, x) + + results_list = self.predict_by_feat(x, sp_pts_masks) + for i, data_sample in enumerate(batch_data_samples): + data_sample.pred_pts_seg = results_list[i] + return batch_data_samples + + +@MODELS.register_module() +class ScanNet200OneFormer3D(ScanNetOneFormer3DMixin, Base3DDetector): + """OneFormer3D for ScanNet200 dataset. + + Args: + voxel_size (float): Voxel size. + num_classes (int): Number of classes. + query_thr (float): Min percent of queries. + backbone (ConfigDict): Config dict of the backbone. + neck (ConfigDict, optional): Config dict of the neck. + decoder (ConfigDict): Config dict of the decoder. + criterion (ConfigDict): Config dict of the criterion. + matcher (ConfigDict): To match superpoints to objects. + train_cfg (dict, optional): Config dict of training hyper-parameters. + Defaults to None. + test_cfg (dict, optional): Config dict of test hyper-parameters. + Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__(self, + voxel_size, + num_classes, + query_thr, + backbone=None, + neck=None, + decoder=None, + criterion=None, + train_cfg=None, + test_cfg=None, + data_preprocessor=None, + init_cfg=None): + super(Base3DDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.decoder = MODELS.build(decoder) + self.criterion = MODELS.build(criterion) + self.voxel_size = voxel_size + self.num_classes = num_classes + self.query_thr = query_thr + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_feat(self, batch_inputs_dict, batch_data_samples): + """Extract features from sparse tensor. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_pts_seg.sp_pts_mask`. + + Returns: + Tuple: + List[Tensor]: of len batch_size, + each of shape (n_points_i, n_channels). + List[Tensor]: of len batch_size, + each of shape (n_points_i, n_classes + 1). + """ + # construct tensor field + coordinates, features = [], [] + for i in range(len(batch_inputs_dict['points'])): + if 'elastic_coords' in batch_inputs_dict: + coordinates.append( + batch_inputs_dict['elastic_coords'][i] * self.voxel_size) + else: + coordinates.append(batch_inputs_dict['points'][i][:, :3]) + features.append(batch_inputs_dict['points'][i][:, 3:]) + + coordinates, features = ME.utils.batch_sparse_collate( + [(c / self.voxel_size, f) for c, f in zip(coordinates, features)], + device=coordinates[0].device) + field = ME.TensorField(coordinates=coordinates, features=features) + + # forward of backbone and neck + x = self.backbone(field.sparse()) + if self.with_neck: + x = self.neck(x) + x = x.slice(field).features + + # apply scatter_mean + sp_pts_masks, n_super_points = [], [] + for data_sample in batch_data_samples: + sp_pts_mask = data_sample.gt_pts_seg.sp_pts_mask + sp_pts_masks.append(sp_pts_mask + sum(n_super_points)) + n_super_points.append(sp_pts_mask.max() + 1) + x = scatter_mean(x, torch.cat(sp_pts_masks), dim=0) # todo: do we need dim? + + # apply cls_layer + features = [] + for i in range(len(n_super_points)): + begin = sum(n_super_points[:i]) + end = sum(n_super_points[:i + 1]) + features.append(x[begin: end]) + return features + + def _forward(*args, **kwargs): + """Implement abstract method of Base3DDetector.""" + pass + + def loss(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs dict and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instances_3d` and `gt_sem_seg_3d`. + Returns: + dict: A dictionary of loss components. + """ + x = self.extract_feat(batch_inputs_dict, batch_data_samples) + gt_instances = [s.gt_instances_3d for s in batch_data_samples] + queries, gt_instances = self._select_queries(x, gt_instances) + x = self.decoder(x, queries) + return self.criterion(x, gt_instances) + + def predict(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_pts_seg.sp_pts_mask`. + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input samples. Each Det3DDataSample contains 'pred_pts_seg'. + And the `pred_pts_seg` contains following keys. + - instance_scores (Tensor): Classification scores, has a shape + (num_instance, ) + - instance_labels (Tensor): Labels of instances, has a shape + (num_instances, ) + - pts_instance_mask (Tensor): Instance mask, has a shape + (num_points, num_instances) of type bool. + """ + assert len(batch_data_samples) == 1 + x = self.extract_feat(batch_inputs_dict, batch_data_samples) + x = self.decoder(x, x) + pred_pts_seg = self.predict_by_feat( + x, batch_data_samples[0].gt_pts_seg.sp_pts_mask) + batch_data_samples[0].pred_pts_seg = pred_pts_seg[0] + return batch_data_samples + + +@MODELS.register_module() +class S3DISOneFormer3D(Base3DDetector): + r"""OneFormer3D for S3DIS dataset. + + Args: + in_channels (int): Number of input channels. + num_channels (int): NUmber of output channels. + voxel_size (float): Voxel size. + num_classes (int): Number of classes. + min_spatial_shape (int): Minimal shape for spconv tensor. + backbone (ConfigDict): Config dict of the backbone. + decoder (ConfigDict): Config dict of the decoder. + criterion (ConfigDict): Config dict of the criterion. + train_cfg (dict, optional): Config dict of training hyper-parameters. + Defaults to None. + test_cfg (dict, optional): Config dict of test hyper-parameters. + Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__(self, + in_channels, + num_channels, + voxel_size, + num_classes, + min_spatial_shape, + backbone=None, + decoder=None, + criterion=None, + train_cfg=None, + test_cfg=None, + data_preprocessor=None, + init_cfg=None): + super(Base3DDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.unet = MODELS.build(backbone) + self.decoder = MODELS.build(decoder) + self.criterion = MODELS.build(criterion) + self.voxel_size = voxel_size + self.num_classes = num_classes + self.min_spatial_shape = min_spatial_shape + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(in_channels, num_channels) + + def _init_layers(self, in_channels, num_channels): + self.input_conv = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, + num_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key='subm1')) + self.output_layer = spconv.SparseSequential( + torch.nn.BatchNorm1d(num_channels, eps=1e-4, momentum=0.1), + torch.nn.ReLU(inplace=True)) + + def extract_feat(self, x): + """Extract features from sparse tensor. + + Args: + x (SparseTensor): Input sparse tensor of shape + (n_points, in_channels). + + Returns: + List[Tensor]: of len batch_size, + each of shape (n_points_i, n_channels). + """ + x = self.input_conv(x) + x, _ = self.unet(x) + x = self.output_layer(x) + out = [] + for i in x.indices[:, 0].unique(): + out.append(x.features[x.indices[:, 0] == i]) + return out + + def collate(self, points, elastic_points=None): + """Collate batch of points to sparse tensor. + + Args: + points (List[Tensor]): Batch of points. + quantization_mode (SparseTensorQuantizationMode): Minkowski + quantization mode. We use random sample for training + and unweighted average for inference. + + Returns: + TensorField: Containing features and coordinates of a + sparse tensor. + """ + if elastic_points is None: + coordinates, features = ME.utils.batch_sparse_collate( + [((p[:, :3] - p[:, :3].min(0)[0]) / self.voxel_size, + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for p in points]) + else: + coordinates, features = ME.utils.batch_sparse_collate( + [((el_p - el_p.min(0)[0]), + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for el_p, p in zip(elastic_points, points)]) + + spatial_shape = torch.clip( + coordinates.max(0)[0][1:] + 1, self.min_spatial_shape) + field = ME.TensorField(features=features, coordinates=coordinates) + tensor = field.sparse() + coordinates = tensor.coordinates + features = tensor.features + inverse_mapping = field.inverse_mapping(tensor.coordinate_map_key) + + return coordinates, features, inverse_mapping, spatial_shape + + def _forward(*args, **kwargs): + """Implement abstract method of Base3DDetector.""" + pass + + def loss(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs dict and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instances_3d` and `gt_sem_seg_3d`. + Returns: + dict: A dictionary of loss components. + """ + batch_offsets = [0] + superpoint_bias = 0 + sp_gt_instances = [] + sp_pts_masks = [] + for i in range(len(batch_data_samples)): + gt_pts_seg = batch_data_samples[i].gt_pts_seg + + gt_pts_seg.sp_pts_mask += superpoint_bias + superpoint_bias = gt_pts_seg.sp_pts_mask.max().item() + 1 + batch_offsets.append(superpoint_bias) + + sp_gt_instances.append(batch_data_samples[i].gt_instances_3d) + sp_pts_masks.append(gt_pts_seg.sp_pts_mask) + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points'], + batch_inputs_dict.get('elastic_coords', None)) + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + + sp_pts_masks = torch.hstack(sp_pts_masks) + + x = self.extract_feat( + x, sp_pts_masks, inverse_mapping, batch_offsets) + queries, sp_gt_instances = self._select_queries(x, sp_gt_instances) + x = self.decoder(x, queries) + + loss = self.criterion(x, sp_gt_instances) + return loss + + def loss(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs dict and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instances_3d` and `gt_sem_seg_3d`. + Returns: + dict: A dictionary of loss components. + """ + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points'], + batch_inputs_dict.get('elastic_coords', None)) + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + + x = self.extract_feat(x) + + x = self.decoder(x) + + sp_gt_instances = [] + for i in range(len(batch_data_samples)): + voxel_superpoints = inverse_mapping[coordinates[:, 0][ \ + inverse_mapping] == i] + voxel_superpoints = torch.unique(voxel_superpoints, + return_inverse=True)[1] + inst_mask = batch_data_samples[i].gt_pts_seg.pts_instance_mask + sem_mask = batch_data_samples[i].gt_pts_seg.pts_semantic_mask + assert voxel_superpoints.shape == inst_mask.shape + + batch_data_samples[i].gt_instances_3d.sp_sem_masks = \ + self.get_gt_semantic_masks(sem_mask, + voxel_superpoints, + self.num_classes) + batch_data_samples[i].gt_instances_3d.sp_inst_masks = \ + self.get_gt_inst_masks(inst_mask, + voxel_superpoints) + sp_gt_instances.append(batch_data_samples[i].gt_instances_3d) + + loss = self.criterion(x, sp_gt_instances) + return loss + + def predict(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Predict results from a batch of inputs and data samples with post- + processing. + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instance_3d` and `gt_sem_seg_3d`. + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input samples. Each Det3DDataSample contains 'pred_pts_seg'. + And the `pred_pts_seg` contains following keys. + - instance_scores (Tensor): Classification scores, has a shape + (num_instance, ) + - instance_labels (Tensor): Labels of instances, has a shape + (num_instances, ) + - pts_instance_mask (Tensor): Instance mask, has a shape + (num_points, num_instances) of type bool. + """ + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points']) + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + + x = self.extract_feat(x) + + x = self.decoder(x) + + results_list = self.predict_by_feat(x, inverse_mapping) + + for i, data_sample in enumerate(batch_data_samples): + data_sample.pred_pts_seg = results_list[i] + return batch_data_samples + + def predict_by_feat(self, out, superpoints): + """Predict instance, semantic, and panoptic masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `cls_preds` of shape (n_queries, n_instance_classes + 1), + `masks` of shape (n_queries, n_points), + `scores` of shape (n_queris, 1) or None. + superpoints (Tensor): of shape (n_raw_points,). + + Returns: + List[PointData]: of len 1 with `pts_semantic_mask`, + `pts_instance_mask`, `instance_labels`, `instance_scores`. + """ + pred_labels = out['cls_preds'][0] + pred_masks = out['masks'][0] + pred_scores = out['scores'][0] + + inst_res = self.pred_inst(pred_masks[:-self.test_cfg.num_sem_cls, :], + pred_scores[:-self.test_cfg.num_sem_cls, :], + pred_labels[:-self.test_cfg.num_sem_cls, :], + superpoints, self.test_cfg.inst_score_thr) + sem_res = self.pred_sem(pred_masks[-self.test_cfg.num_sem_cls:, :], + superpoints) + pan_res = self.pred_pan(pred_masks, pred_scores, pred_labels, + superpoints) + + pts_semantic_mask = [sem_res.cpu().numpy(), pan_res[0].cpu().numpy()] + pts_instance_mask = [inst_res[0].cpu().bool().numpy(), + pan_res[1].cpu().numpy()] + + return [ + PointData( + pts_semantic_mask=pts_semantic_mask, + pts_instance_mask=pts_instance_mask, + instance_labels=inst_res[1].cpu().numpy(), + instance_scores=inst_res[2].cpu().numpy())] + + def pred_inst(self, pred_masks, pred_scores, pred_labels, + superpoints, score_threshold): + """Predict instance masks for a single scene. + + Args: + pred_masks (Tensor): of shape (n_queries, n_points). + pred_scores (Tensor): of shape (n_queris, 1). + pred_labels (Tensor): of shape (n_queries, n_instance_classes + 1). + superpoints (Tensor): of shape (n_raw_points,). + score_threshold (float): minimal score for predicted object. + + Returns: + Tuple: + Tensor: mask_preds of shape (n_preds, n_raw_points), + Tensor: labels of shape (n_preds,), + Tensor: scors of shape (n_preds,). + """ + scores = F.softmax(pred_labels, dim=-1)[:, :-1] + scores *= pred_scores + + labels = torch.arange( + self.num_classes, + device=scores.device).unsqueeze(0).repeat( + self.decoder.num_queries - self.test_cfg.num_sem_cls, + 1).flatten(0, 1) + + scores, topk_idx = scores.flatten(0, 1).topk( + self.test_cfg.topk_insts, sorted=False) + labels = labels[topk_idx] + + topk_idx = torch.div(topk_idx, self.num_classes, rounding_mode='floor') + mask_pred = pred_masks + mask_pred = mask_pred[topk_idx] + mask_pred_sigmoid = mask_pred.sigmoid() + if self.test_cfg.get('obj_normalization', None): + mask_pred_thr = mask_pred_sigmoid > \ + self.test_cfg.obj_normalization_thr + mask_scores = (mask_pred_sigmoid * mask_pred_thr).sum(1) / \ + (mask_pred_thr.sum(1) + 1e-6) + scores = scores * mask_scores + + if self.test_cfg.get('nms', None): + kernel = self.test_cfg.matrix_nms_kernel + scores, labels, mask_pred_sigmoid, _ = mask_matrix_nms( + mask_pred_sigmoid, labels, scores, kernel=kernel) + + mask_pred = mask_pred_sigmoid > self.test_cfg.sp_score_thr + mask_pred = mask_pred[:, superpoints] + # score_thr + score_mask = scores > score_threshold + scores = scores[score_mask] + labels = labels[score_mask] + mask_pred = mask_pred[score_mask] + + # npoint_thr + mask_pointnum = mask_pred.sum(1) + npoint_mask = mask_pointnum > self.test_cfg.npoint_thr + scores = scores[npoint_mask] + labels = labels[npoint_mask] + mask_pred = mask_pred[npoint_mask] + + return mask_pred, labels, scores + + def pred_sem(self, pred_masks, superpoints): + """Predict semantic masks for a single scene. + + Args: + pred_masks (Tensor): of shape (n_points, n_semantic_classes). + superpoints (Tensor): of shape (n_raw_points,). + + Returns: + Tensor: semantic preds of shape + (n_raw_points, 1). + """ + mask_pred = pred_masks.sigmoid() + mask_pred = mask_pred[:, superpoints] + seg_map = mask_pred.argmax(0) + return seg_map + + def pred_pan(self, pred_masks, pred_scores, pred_labels, + superpoints): + """Predict panoptic masks for a single scene. + + Args: + pred_masks (Tensor): of shape (n_queries, n_points). + pred_scores (Tensor): of shape (n_queris, 1). + pred_labels (Tensor): of shape (n_queries, n_instance_classes + 1). + superpoints (Tensor): of shape (n_raw_points,). + + Returns: + Tuple: + Tensor: semantic mask of shape (n_raw_points,), + Tensor: instance mask of shape (n_raw_points,). + """ + stuff_cls = pred_masks.new_tensor(self.test_cfg.stuff_cls).long() + sem_map = self.pred_sem( + pred_masks[-self.test_cfg.num_sem_cls + stuff_cls, :], superpoints) + sem_map_src_mapping = stuff_cls[sem_map] + + n_cls = self.test_cfg.num_sem_cls + thr = self.test_cfg.pan_score_thr + mask_pred, labels, scores = self.pred_inst( + pred_masks[:-n_cls, :], pred_scores[:-n_cls, :], + pred_labels[:-n_cls, :], superpoints, thr) + + thing_idxs = torch.zeros_like(labels) + for thing_cls in self.test_cfg.thing_cls: + thing_idxs = thing_idxs.logical_or(labels == thing_cls) + + mask_pred = mask_pred[thing_idxs] + scores = scores[thing_idxs] + labels = labels[thing_idxs] + + if mask_pred.shape[0] == 0: + return sem_map_src_mapping, sem_map + + scores, idxs = scores.sort() + labels = labels[idxs] + mask_pred = mask_pred[idxs] + + inst_idxs = torch.arange( + 0, mask_pred.shape[0], device=mask_pred.device).view(-1, 1) + insts = inst_idxs * mask_pred + things_inst_mask, idxs = insts.max(axis=0) + things_sem_mask = labels[idxs] + + inst_idxs, num_pts = things_inst_mask.unique(return_counts=True) + for inst, pts in zip(inst_idxs, num_pts): + if pts <= self.test_cfg.npoint_thr and inst != 0: + things_inst_mask[things_inst_mask == inst] = 0 + + things_inst_mask = torch.unique( + things_inst_mask, return_inverse=True)[1] + things_inst_mask[things_inst_mask != 0] += len(stuff_cls) - 1 + things_sem_mask[things_inst_mask == 0] = 0 + + sem_map_src_mapping[things_inst_mask != 0] = 0 + sem_map[things_inst_mask != 0] = 0 + sem_map += things_inst_mask + sem_map_src_mapping += things_sem_mask + return sem_map_src_mapping, sem_map + + @staticmethod + def get_gt_semantic_masks(mask_src, sp_pts_mask, num_classes): + """Create ground truth semantic masks. + + Args: + mask_src (Tensor): of shape (n_raw_points, 1). + sp_pts_mask (Tensor): of shape (n_raw_points, 1). + num_classes (Int): number of classes. + + Returns: + sp_masks (Tensor): semantic mask of shape (n_points, num_classes). + """ + + mask = torch.nn.functional.one_hot( + mask_src, num_classes=num_classes + 1) + + mask = mask.T + sp_masks = scatter_mean(mask.float(), sp_pts_mask, dim=-1) + sp_masks = sp_masks > 0.5 + sp_masks[-1, sp_masks.sum(axis=0) == 0] = True + assert sp_masks.sum(axis=0).max().item() == 1 + + return sp_masks + + @staticmethod + def get_gt_inst_masks(mask_src, sp_pts_mask): + """Create ground truth instance masks. + + Args: + mask_src (Tensor): of shape (n_raw_points, 1). + sp_pts_mask (Tensor): of shape (n_raw_points, 1). + + Returns: + sp_masks (Tensor): semantic mask of shape (n_points, num_inst_obj). + """ + mask = mask_src.clone() + if torch.sum(mask == -1) != 0: + mask[mask == -1] = torch.max(mask) + 1 + mask = torch.nn.functional.one_hot(mask)[:, :-1] + else: + mask = torch.nn.functional.one_hot(mask) + + mask = mask.T + sp_masks = scatter_mean(mask, sp_pts_mask, dim=-1) + sp_masks = sp_masks > 0.5 + + return sp_masks + + +@MODELS.register_module() +class InstanceOnlyOneFormer3D(Base3DDetector): + r"""InstanceOnlyOneFormer3D for training on different datasets jointly. + + Args: + in_channels (int): Number of input channels. + num_channels (int): Number of output channels. + voxel_size (float): Voxel size. + num_classes_1dataset (int): Number of classes in the first dataset. + num_classes_2dataset (int): Number of classes in the second dataset. + prefix_1dataset (string): Prefix for the first dataset. + prefix_2dataset (string): Prefix for the second dataset. + min_spatial_shape (int): Minimal shape for spconv tensor. + backbone (ConfigDict): Config dict of the backbone. + decoder (ConfigDict): Config dict of the decoder. + criterion (ConfigDict): Config dict of the criterion. + train_cfg (dict, optional): Config dict of training hyper-parameters. + Defaults to None. + test_cfg (dict, optional): Config dict of test hyper-parameters. + Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`BaseDataPreprocessor`. it usually includes, + ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. + init_cfg (dict or ConfigDict, optional): the config to control the + initialization. Defaults to None. + """ + + def __init__(self, + in_channels, + num_channels, + voxel_size, + num_classes_1dataset, + num_classes_2dataset, + prefix_1dataset, + prefix_2dataset, + min_spatial_shape, + backbone=None, + decoder=None, + criterion=None, + train_cfg=None, + test_cfg=None, + data_preprocessor=None, + init_cfg=None): + super(InstanceOnlyOneFormer3D, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.num_classes_1dataset = num_classes_1dataset + self.num_classes_2dataset = num_classes_2dataset + + self.prefix_1dataset = prefix_1dataset + self.prefix_2dataset = prefix_2dataset + + self.unet = MODELS.build(backbone) + self.decoder = MODELS.build(decoder) + self.criterion = MODELS.build(criterion) + self.voxel_size = voxel_size + self.min_spatial_shape = min_spatial_shape + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(in_channels, num_channels) + + def _init_layers(self, in_channels, num_channels): + self.input_conv = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, + num_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key='subm1')) + self.output_layer = spconv.SparseSequential( + torch.nn.BatchNorm1d(num_channels, eps=1e-4, momentum=0.1), + torch.nn.ReLU(inplace=True)) + + def extract_feat(self, x): + """Extract features from sparse tensor. + + Args: + x (SparseTensor): Input sparse tensor of shape + (n_points, in_channels). + + Returns: + List[Tensor]: of len batch_size, + each of shape (n_points_i, n_channels). + """ + x = self.input_conv(x) + x, _ = self.unet(x) + x = self.output_layer(x) + out = [] + for i in x.indices[:, 0].unique(): + out.append(x.features[x.indices[:, 0] == i]) + return out + + def collate(self, points, elastic_points=None): + """Collate batch of points to sparse tensor. + + Args: + points (List[Tensor]): Batch of points. + quantization_mode (SparseTensorQuantizationMode): Minkowski + quantization mode. We use random sample for training + and unweighted average for inference. + + Returns: + TensorField: Containing features and coordinates of a + sparse tensor. + """ + if elastic_points is None: + coordinates, features = ME.utils.batch_sparse_collate( + [((p[:, :3] - p[:, :3].min(0)[0]) / self.voxel_size, + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for p in points]) + else: + coordinates, features = ME.utils.batch_sparse_collate( + [((el_p - el_p.min(0)[0]), + torch.hstack((p[:, 3:], p[:, :3] - p[:, :3].mean(0)))) + for el_p, p in zip(elastic_points, points)]) + + spatial_shape = torch.clip( + coordinates.max(0)[0][1:] + 1, self.min_spatial_shape) + field = ME.TensorField(features=features, coordinates=coordinates) + tensor = field.sparse() + coordinates = tensor.coordinates + features = tensor.features + inverse_mapping = field.inverse_mapping(tensor.coordinate_map_key) + + return coordinates, features, inverse_mapping, spatial_shape + + def _forward(*args, **kwargs): + """Implement abstract method of Base3DDetector.""" + pass + + def loss(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs dict and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instances_3d` and `gt_sem_seg_3d`. + Returns: + dict: A dictionary of loss components. + """ + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points'], + batch_inputs_dict.get('elastic_coords', None)) + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + + x = self.extract_feat(x) + + scene_names = [] + for i in range(len(batch_data_samples)): + scene_names.append(batch_data_samples[i].lidar_path) + x = self.decoder(x, scene_names) + + sp_gt_instances = [] + for i in range(len(batch_data_samples)): + voxel_superpoints = inverse_mapping[ + coordinates[:, 0][inverse_mapping] == i] + voxel_superpoints = torch.unique( + voxel_superpoints, return_inverse=True)[1] + inst_mask = batch_data_samples[i].gt_pts_seg.pts_instance_mask + assert voxel_superpoints.shape == inst_mask.shape + + batch_data_samples[i].gt_instances_3d.sp_masks = \ + S3DISOneFormer3D.get_gt_inst_masks(inst_mask, voxel_superpoints) + sp_gt_instances.append(batch_data_samples[i].gt_instances_3d) + + loss = self.criterion(x, sp_gt_instances) + return loss + + def predict(self, batch_inputs_dict, batch_data_samples, **kwargs): + """Predict results from a batch of inputs and data samples with post- + processing. + Args: + batch_inputs_dict (dict): The model input dict which include + `points` key. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It includes information such as + `gt_instance_3d` and `gt_sem_seg_3d`. + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input samples. Each Det3DDataSample contains 'pred_pts_seg'. + And the `pred_pts_seg` contains following keys. + - instance_scores (Tensor): Classification scores, has a shape + (num_instance, ) + - instance_labels (Tensor): Labels of instances, has a shape + (num_instances, ) + - pts_instance_mask (Tensor): Instance mask, has a shape + (num_points, num_instances) of type bool. + """ + + coordinates, features, inverse_mapping, spatial_shape = self.collate( + batch_inputs_dict['points']) + x = spconv.SparseConvTensor( + features, coordinates, spatial_shape, len(batch_data_samples)) + + x = self.extract_feat(x) + + scene_names = [] + for i in range(len(batch_data_samples)): + scene_names.append(batch_data_samples[i].lidar_path) + x = self.decoder(x, scene_names) + + results_list = self.predict_by_feat(x, inverse_mapping, scene_names) + + for i, data_sample in enumerate(batch_data_samples): + data_sample.pred_pts_seg = results_list[i] + return batch_data_samples + + def predict_by_feat(self, out, superpoints, scene_names): + """Predict instance masks for a single scene. + + Args: + out (Dict): Decoder output, each value is List of len 1. Keys: + `cls_preds` of shape (n_queries, n_instance_classes + 1), + `masks` of shape (n_queries, n_points), + `scores` of shape (n_queris, 1) or None. + superpoints (Tensor): of shape (n_raw_points,). + scene_names (List[string]): of len 1, which contain scene name. + + Returns: + List[PointData]: of len 1 with `pts_instance_mask`, + `instance_labels`, `instance_scores`. + """ + pred_labels = out['cls_preds'] + pred_masks = out['masks'] + pred_scores = out['scores'] + scene_name = scene_names[0] + + scores = F.softmax(pred_labels[0], dim=-1)[:, :-1] + scores *= pred_scores[0] + + if self.prefix_1dataset in scene_name: + labels = torch.arange( + self.num_classes_1dataset, + device=scores.device).unsqueeze(0).repeat( + self.decoder.num_queries_1dataset, + 1).flatten(0, 1) + elif self.prefix_2dataset in scene_name: + labels = torch.arange( + self.num_classes_2dataset, + device=scores.device).unsqueeze(0).repeat( + self.decoder.num_queries_2dataset, + 1).flatten(0, 1) + else: + raise RuntimeError(f'Invalid scene name "{scene_name}".') + + scores, topk_idx = scores.flatten(0, 1).topk( + self.test_cfg.topk_insts, sorted=False) + labels = labels[topk_idx] + + if self.prefix_1dataset in scene_name: + topk_idx = torch.div(topk_idx, self.num_classes_1dataset, + rounding_mode='floor') + elif self.prefix_2dataset in scene_name: + topk_idx = torch.div(topk_idx, self.num_classes_2dataset, + rounding_mode='floor') + else: + raise RuntimeError(f'Invalid scene name "{scene_name}".') + + mask_pred = pred_masks[0] + mask_pred = mask_pred[topk_idx] + mask_pred_sigmoid = mask_pred.sigmoid() + if self.test_cfg.get('obj_normalization', None): + mask_pred_thr = mask_pred_sigmoid > \ + self.test_cfg.obj_normalization_thr + mask_scores = (mask_pred_sigmoid * mask_pred_thr).sum(1) / \ + (mask_pred_thr.sum(1) + 1e-6) + scores = scores * mask_scores + + if self.test_cfg.get('nms', None): + kernel = self.test_cfg.matrix_nms_kernel + scores, labels, mask_pred_sigmoid, _ = mask_matrix_nms( + mask_pred_sigmoid, labels, scores, kernel=kernel) + + mask_pred = mask_pred_sigmoid > self.test_cfg.sp_score_thr + mask_pred = mask_pred[:, superpoints] + # score_thr + score_mask = scores > self.test_cfg.score_thr + scores = scores[score_mask] + labels = labels[score_mask] + mask_pred = mask_pred[score_mask] + + # npoint_thr + mask_pointnum = mask_pred.sum(1) + npoint_mask = mask_pointnum > self.test_cfg.npoint_thr + scores = scores[npoint_mask] + labels = labels[npoint_mask] + mask_pred = mask_pred[npoint_mask] + + return [ + PointData( + pts_instance_mask=mask_pred, + instance_labels=labels, + instance_scores=scores) + ] diff --git a/oneformer3d/query_decoder.py b/oneformer3d/query_decoder.py new file mode 100644 index 0000000..b0cb9ad --- /dev/null +++ b/oneformer3d/query_decoder.py @@ -0,0 +1,718 @@ +import torch +import torch.nn as nn + +from mmengine.model import BaseModule +from mmdet3d.registry import MODELS + + +class CrossAttentionLayer(BaseModule): + """Cross attention layer. + + Args: + d_model (int): Model dimension. + num_heads (int): Number of heads. + dropout (float): Dropout rate. + """ + + def __init__(self, d_model, num_heads, dropout, fix=False): + super().__init__() + self.fix = fix + self.attn = nn.MultiheadAttention( + d_model, num_heads, dropout=dropout, batch_first=True) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + # todo: why BaseModule doesn't call it without us? + self.init_weights() + + def init_weights(self): + """Init weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, sources, queries, attn_masks=None): + """Forward pass. + + Args: + sources (List[Tensor]): of len batch_size, + each of shape (n_points_i, d_model). + queries (List[Tensor]): of len batch_size, + each of shape(n_queries_i, d_model). + attn_masks (List[Tensor] or None): of len batch_size, + each of shape (n_queries, n_points). + + Return: + List[Tensor]: Queries of len batch_size, + each of shape(n_queries_i, d_model). + """ + outputs = [] + for i in range(len(sources)): + k = v = sources[i] + attn_mask = attn_masks[i] if attn_masks is not None else None + output, _ = self.attn(queries[i], k, v, attn_mask=attn_mask) + if self.fix: + output = self.dropout(output) + output = output + queries[i] + if self.fix: + output = self.norm(output) + outputs.append(output) + return outputs + + +class SelfAttentionLayer(BaseModule): + """Self attention layer. + + Args: + d_model (int): Model dimension. + num_heads (int): Number of heads. + dropout (float): Dropout rate. + """ + + def __init__(self, d_model, num_heads, dropout): + super().__init__() + self.attn = nn.MultiheadAttention( + d_model, num_heads, dropout=dropout, batch_first=True) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """Forward pass. + + Args: + x (List[Tensor]): Queries of len batch_size, + each of shape(n_queries_i, d_model). + + Returns: + List[Tensor]: Queries of len batch_size, + each of shape(n_queries_i, d_model). + """ + out = [] + for y in x: + z, _ = self.attn(y, y, y) + z = self.dropout(z) + y + z = self.norm(z) + out.append(z) + return out + + +class FFN(BaseModule): + """Feed forward network. + + Args: + d_model (int): Model dimension. + hidden_dim (int): Hidden dimension. + dropout (float): Dropout rate. + activation_fn (str): 'relu' or 'gelu'. + """ + + def __init__(self, d_model, hidden_dim, dropout, activation_fn): + super().__init__() + self.net = nn.Sequential( + nn.Linear(d_model, hidden_dim), + nn.ReLU() if activation_fn == 'relu' else nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, d_model), + nn.Dropout(dropout)) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x): + """Forward pass. + + Args: + x (List[Tensor]): Queries of len batch_size, + each of shape(n_queries_i, d_model). + + Returns: + List[Tensor]: Queries of len batch_size, + each of shape(n_queries_i, d_model). + """ + out = [] + for y in x: + z = self.net(y) + z = z + y + z = self.norm(z) + out.append(z) + return out + +@MODELS.register_module() +class QueryDecoder(BaseModule): + """Query decoder. + + Args: + num_layers (int): Number of transformer layers. + num_instance_queries (int): Number of instance queries. + num_semantic_queries (int): Number of semantic queries. + num_classes (int): Number of classes. + in_channels (int): Number of input channels. + d_model (int): Number of channels for model layers. + num_heads (int): Number of head in attention layer. + hidden_dim (int): Dimension of attention layer. + dropout (float): Dropout rate for transformer layer. + activation_fn (str): 'relu' of 'gelu'. + iter_pred (bool): Whether to predict iteratively. + attn_mask (bool): Whether to use mask attention. + pos_enc_flag (bool): Whether to use positional enconding. + """ + + def __init__(self, num_layers, num_instance_queries, num_semantic_queries, + num_classes, in_channels, d_model, num_heads, hidden_dim, + dropout, activation_fn, iter_pred, attn_mask, fix_attention, + objectness_flag, **kwargs): + super().__init__() + self.objectness_flag = objectness_flag + self.input_proj = nn.Sequential( + nn.Linear(in_channels, d_model), nn.LayerNorm(d_model), nn.ReLU()) + self.num_queries = num_instance_queries + num_semantic_queries + if num_instance_queries + num_semantic_queries > 0: + self.query = nn.Embedding(num_instance_queries + num_semantic_queries, d_model) + if num_instance_queries == 0: + self.query_proj = nn.Sequential( + nn.Linear(in_channels, d_model), nn.ReLU(), + nn.Linear(d_model, d_model)) + self.cross_attn_layers = nn.ModuleList([]) + self.self_attn_layers = nn.ModuleList([]) + self.ffn_layers = nn.ModuleList([]) + for i in range(num_layers): + self.cross_attn_layers.append( + CrossAttentionLayer( + d_model, num_heads, dropout, fix_attention)) + self.self_attn_layers.append( + SelfAttentionLayer(d_model, num_heads, dropout)) + self.ffn_layers.append( + FFN(d_model, hidden_dim, dropout, activation_fn)) + self.out_norm = nn.LayerNorm(d_model) + self.out_cls = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), + nn.Linear(d_model, num_classes + 1)) + if objectness_flag: + self.out_score = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 1)) + self.x_mask = nn.Sequential( + nn.Linear(in_channels, d_model), nn.ReLU(), + nn.Linear(d_model, d_model)) + self.iter_pred = iter_pred + self.attn_mask = attn_mask + + def _get_queries(self, queries=None, batch_size=None): + """Get query tensor. + + Args: + queries (List[Tensor], optional): of len batch_size, + each of shape (n_queries_i, in_channels). + batch_size (int, optional): batch size. + + Returns: + List[Tensor]: of len batch_size, each of shape + (n_queries_i, d_model). + """ + if batch_size is None: + batch_size = len(queries) + + result_queries = [] + for i in range(batch_size): + result_query = [] + if hasattr(self, 'query'): + result_query.append(self.query.weight) + if queries is not None: + result_query.append(self.query_proj(queries[i])) + result_queries.append(torch.cat(result_query)) + return result_queries + + def _forward_head(self, queries, mask_feats): + """Prediction head forward. + + Args: + queries (List[Tensor] | Tensor): List of len batch_size, + each of shape (n_queries_i, d_model). Or tensor of + shape (batch_size, n_queries, d_model). + mask_feats (List[Tensor]): of len batch_size, + each of shape (n_points_i, d_model). + + Returns: + Tuple: + List[Tensor]: Classification predictions of len batch_size, + each of shape (n_queries_i, n_classes + 1). + List[Tensor]: Confidence scores of len batch_size, + each of shape (n_queries_i, 1). + List[Tensor]: Predicted masks of len batch_size, + each of shape (n_queries_i, n_points_i). + List[Tensor] or None: Attention masks of len batch_size, + each of shape (n_queries_i, n_points_i). + """ + cls_preds, pred_scores, pred_masks, attn_masks = [], [], [], [] + for i in range(len(queries)): + norm_query = self.out_norm(queries[i]) + cls_preds.append(self.out_cls(norm_query)) + pred_score = self.out_score(norm_query) if self.objectness_flag \ + else None + pred_scores.append(pred_score) + pred_mask = torch.einsum('nd,md->nm', norm_query, mask_feats[i]) + if self.attn_mask: + attn_mask = (pred_mask.sigmoid() < 0.5).bool() + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = attn_mask.detach() + attn_masks.append(attn_mask) + pred_masks.append(pred_mask) + attn_masks = attn_masks if self.attn_mask else None + return cls_preds, pred_scores, pred_masks, attn_masks + + def forward_simple(self, x, queries): + """Simple forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + queries (List[Tensor], optional): of len batch_size, each of shape + (n_points_i, in_channles). + + Returns: + Dict: with labels, masks, and scores. + """ + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(queries, len(x)) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + cls_preds, pred_scores, pred_masks, _ = self._forward_head( + queries, mask_feats) + return dict( + cls_preds=cls_preds, + masks=pred_masks, + scores=pred_scores) + + def forward_iter_pred(self, x, queries): + """Iterative forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + queries (List[Tensor], optional): of len batch_size, each of shape + (n_points_i, in_channles). + + Returns: + Dict: with labels, masks, scores, and aux_outputs. + """ + cls_preds, pred_scores, pred_masks = [], [], [] + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(queries, len(x)) + cls_pred, pred_score, pred_mask, attn_mask = self._forward_head( + queries, mask_feats) + cls_preds.append(cls_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries, attn_mask) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + cls_pred, pred_score, pred_mask, attn_mask = self._forward_head( + queries, mask_feats) + cls_preds.append(cls_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + + aux_outputs = [ + {'cls_preds': cls_pred, 'masks': masks, 'scores': scores} + for cls_pred, scores, masks in zip( + cls_preds[:-1], pred_scores[:-1], pred_masks[:-1])] + return dict( + cls_preds=cls_preds[-1], + masks=pred_masks[-1], + scores=pred_scores[-1], + aux_outputs=aux_outputs) + + def forward(self, x, queries=None): + """Forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + queries (List[Tensor], optional): of len batch_size, each of shape + (n_points_i, in_channles). + + Returns: + Dict: with labels, masks, scores, and possibly aux_outputs. + """ + if self.iter_pred: + return self.forward_iter_pred(x, queries) + else: + return self.forward_simple(x, queries) + + +@MODELS.register_module() +class ScanNetQueryDecoder(QueryDecoder): + """We simply add semantic prediction for each instance query. + """ + def __init__(self, num_instance_classes, num_semantic_classes, + d_model, num_semantic_linears, **kwargs): + super().__init__( + num_classes=num_instance_classes, d_model=d_model, **kwargs) + assert num_semantic_linears in [1, 2] + if num_semantic_linears == 2: + self.out_sem = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), + nn.Linear(d_model, num_semantic_classes + 1)) + else: + self.out_sem = nn.Linear(d_model, num_semantic_classes + 1) + + def _forward_head(self, queries, mask_feats, last_flag): + """Prediction head forward. + + Args: + queries (List[Tensor] | Tensor): List of len batch_size, + each of shape (n_queries_i, d_model). Or tensor of + shape (batch_size, n_queries, d_model). + mask_feats (List[Tensor]): of len batch_size, + each of shape (n_points_i, d_model). + + Returns: + Tuple: + List[Tensor]: Classification predictions of len batch_size, + each of shape (n_queries_i, n_instance_classes + 1). + List[Tensor] or None: Semantic predictions of len batch_size, + each of shape (n_queries_i, n_semantic_classes + 1). + List[Tensor]: Confidence scores of len batch_size, + each of shape (n_queries_i, 1). + List[Tensor]: Predicted masks of len batch_size, + each of shape (n_queries_i, n_points_i). + List[Tensor] or None: Attention masks of len batch_size, + each of shape (n_queries_i, n_points_i). + """ + cls_preds, sem_preds, pred_scores, pred_masks, attn_masks = \ + [], [], [], [], [] + for i in range(len(queries)): + norm_query = self.out_norm(queries[i]) + cls_preds.append(self.out_cls(norm_query)) + if last_flag: + sem_preds.append(self.out_sem(norm_query)) + pred_score = self.out_score(norm_query) if self.objectness_flag \ + else None + pred_scores.append(pred_score) + pred_mask = torch.einsum('nd,md->nm', norm_query, mask_feats[i]) + if self.attn_mask: + attn_mask = (pred_mask.sigmoid() < 0.5).bool() + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = attn_mask.detach() + attn_masks.append(attn_mask) + pred_masks.append(pred_mask) + attn_masks = attn_masks if self.attn_mask else None + sem_preds = sem_preds if last_flag else None + return cls_preds, sem_preds, pred_scores, pred_masks, attn_masks + + def forward_simple(self, x, queries): + """Simple forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + queries (List[Tensor], optional): of len batch_size, each of shape + (n_points_i, in_channles). + + Returns: + Dict: with instance scores, semantic scores, masks, and scores. + """ + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(queries, len(x)) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + cls_preds, sem_preds, pred_scores, pred_masks, _ = self._forward_head( + queries, mask_feats, last_flag=True) + return dict( + cls_preds=cls_preds, + sem_preds=sem_preds, + masks=pred_masks, + scores=pred_scores) + + def forward_iter_pred(self, x, queries): + """Iterative forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + queries (List[Tensor], optional): of len batch_size, each of shape + (n_points_i, in_channles). + + Returns: + Dict: with instance scores, semantic scores, masks, scores, + and aux_outputs. + """ + cls_preds, sem_preds, pred_scores, pred_masks = [], [], [], [] + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(queries, len(x)) + cls_pred, sem_pred, pred_score, pred_mask, attn_mask = \ + self._forward_head(queries, mask_feats, last_flag=False) + cls_preds.append(cls_pred) + sem_preds.append(sem_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries, attn_mask) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + last_flag = i == len(self.cross_attn_layers) - 1 + cls_pred, sem_pred, pred_score, pred_mask, attn_mask = \ + self._forward_head(queries, mask_feats, last_flag) + cls_preds.append(cls_pred) + sem_preds.append(sem_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + + aux_outputs = [ + dict( + cls_preds=cls_pred, + sem_preds=sem_pred, + masks=masks, + scores=scores) + for cls_pred, sem_pred, scores, masks in zip( + cls_preds[:-1], sem_preds[:-1], + pred_scores[:-1], pred_masks[:-1])] + return dict( + cls_preds=cls_preds[-1], + sem_preds=sem_preds[-1], + masks=pred_masks[-1], + scores=pred_scores[-1], + aux_outputs=aux_outputs) + + +@MODELS.register_module() +class OneDataQueryDecoder(BaseModule): + """Query decoder. The same as above, but for 2 datasets. + + Args: + num_layers (int): Number of transformer layers. + num_queries_1dataset (int): Number of queries for the first dataset. + num_queries_2dataset (int): Number of queries for the second dataset. + num_classes_1dataset (int): Number of classes in the first dataset. + num_classes_2dataset (int): Number of classes in the second dataset. + prefix_1dataset (string): Prefix for the first dataset. + prefix_2dataset (string): Prefix for the second dataset. + in_channels (int): Number of input channels. + d_model (int): Number of channels for model layers. + num_heads (int): Number of head in attention layer. + hidden_dim (int): Dimension of attention layer. + dropout (float): Dropout rate for transformer layer. + activation_fn (str): 'relu' of 'gelu'. + iter_pred (bool): Whether to predict iteratively. + attn_mask (bool): Whether to use mask attention. + pos_enc_flag (bool): Whether to use positional enconding. + """ + + def __init__(self, + num_layers, + num_queries_1dataset, + num_queries_2dataset, + num_classes_1dataset, + num_classes_2dataset, + prefix_1dataset, + prefix_2dataset, + in_channels, + d_model, + num_heads, + hidden_dim, + dropout, + activation_fn, + iter_pred, + attn_mask, + fix_attention, + **kwargs): + super().__init__() + self.input_proj = nn.Sequential( + nn.Linear(in_channels, d_model), nn.LayerNorm(d_model), nn.ReLU()) + + self.num_queries_1dataset = num_queries_1dataset + self.num_queries_2dataset = num_queries_2dataset + + self.queries_1dataset = nn.Embedding(num_queries_1dataset, d_model) + self.queries_2dataset = nn.Embedding(num_queries_2dataset, d_model) + + self.prefix_1dataset = prefix_1dataset + self.prefix_2dataset = prefix_2dataset + + self.cross_attn_layers = nn.ModuleList([]) + self.self_attn_layers = nn.ModuleList([]) + self.ffn_layers = nn.ModuleList([]) + for i in range(num_layers): + self.cross_attn_layers.append( + CrossAttentionLayer( + d_model, num_heads, dropout, fix_attention)) + self.self_attn_layers.append( + SelfAttentionLayer(d_model, num_heads, dropout)) + self.ffn_layers.append( + FFN(d_model, hidden_dim, dropout, activation_fn)) + self.out_norm = nn.LayerNorm(d_model) + self.out_cls_1dataset = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), + nn.Linear(d_model, num_classes_1dataset + 1)) + self.out_cls_2dataset = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), + nn.Linear(d_model, num_classes_2dataset + 1)) + self.out_score = nn.Sequential( + nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 1)) + self.x_mask = nn.Sequential( + nn.Linear(in_channels, d_model), nn.ReLU(), + nn.Linear(d_model, d_model)) + self.iter_pred = iter_pred + self.attn_mask = attn_mask + self.num_classes_1dataset = num_classes_1dataset + self.num_classes_2dataset = num_classes_2dataset + + def _get_queries(self, batch_size, scene_names): + """Get query tensor. + + Args: + batch_size (int, optional): batch size. + scene_names (List[string]): list of len batch size, which + contains scene names. + Returns: + List[Tensor]: of len batch_size, each of shape + (n_queries_i, d_model). + """ + + result_queries = [] + for i in range(batch_size): + if self.prefix_1dataset in scene_names[i]: + result_queries.append(self.queries_1dataset.weight) + elif self.prefix_2dataset in scene_names[i]: + result_queries.append(self.queries_2dataset.weight) + else: + raise RuntimeError(f'Invalid scene name "{scene_names[i]}".') + + return result_queries + + def _forward_head(self, queries, mask_feats, scene_names): + """Prediction head forward. + + Args: + queries (List[Tensor] | Tensor): List of len batch_size, + each of shape (n_queries_i, d_model). Or tensor of + shape (batch_size, n_queries, d_model). + mask_feats (List[Tensor]): of len batch_size, + each of shape (n_points_i, d_model). + scene_names (List[string]): list of len batch size, which + contains scene names. + + Returns: + Tuple: + List[Tensor]: Classification predictions of len batch_size, + each of shape (n_queries_i, n_classes + 1). + List[Tensor]: Confidence scores of len batch_size, + each of shape (n_queries_i, 1). + List[Tensor]: Predicted masks of len batch_size, + each of shape (n_queries_i, n_points_i). + List[Tensor]: Attention masks of len batch_size, + each of shape (n_queries_i, n_points_i). + """ + cls_preds, pred_scores, pred_masks, attn_masks = [], [], [], [] + for i in range(len(queries)): + norm_query = self.out_norm(queries[i]) + + if self.prefix_1dataset in scene_names[i]: + cls_preds.append(self.out_cls_1dataset(norm_query)) + elif self.prefix_2dataset in scene_names[i]: + cls_preds.append(self.out_cls_2dataset(norm_query)) + else: + raise RuntimeError(f'Invalid scene name "{scene_names[i]}".') + + + pred_scores.append(self.out_score(norm_query)) + pred_mask = torch.einsum('nd,md->nm', norm_query, mask_feats[i]) + if self.attn_mask: + attn_mask = (pred_mask.sigmoid() < 0.5).bool() + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = attn_mask.detach() + attn_masks.append(attn_mask) + pred_masks.append(pred_mask) + return cls_preds, pred_scores, pred_masks, attn_masks + + def forward_simple(self, x, scene_names): + """Simple forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + scene_names (List[string]): list of len batch size, which + contains scene names. + + Returns: + Dict: with labels, masks, and scores. + """ + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(len(x), scene_names) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + cls_preds, pred_scores, pred_masks, _ = self._forward_head( + queries, mask_feats, scene_names) + return dict( + cls_preds=cls_preds, + masks=pred_masks, + scores=pred_scores) + + def forward_iter_pred(self, x, scene_names): + """Iterative forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + scene_names (List[string]): list of len batch size, which + contains scene names. + + Returns: + Dict: with labels, masks, scores, and aux_outputs. + """ + cls_preds, pred_scores, pred_masks = [], [], [] + inst_feats = [self.input_proj(y) for y in x] + mask_feats = [self.x_mask(y) for y in x] + queries = self._get_queries(len(x), scene_names) + cls_pred, pred_score, pred_mask, attn_mask = self._forward_head( + queries, mask_feats, scene_names) + cls_preds.append(cls_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + for i in range(len(self.cross_attn_layers)): + queries = self.cross_attn_layers[i](inst_feats, queries, attn_mask) + queries = self.self_attn_layers[i](queries) + queries = self.ffn_layers[i](queries) + cls_pred, pred_score, pred_mask, attn_mask = self._forward_head( + queries, mask_feats, scene_names) + cls_preds.append(cls_pred) + pred_scores.append(pred_score) + pred_masks.append(pred_mask) + + aux_outputs = [ + {'cls_preds': cls_pred, 'masks': masks, 'scores': scores} + for cls_pred, scores, masks in zip( + cls_preds[:-1], pred_scores[:-1], pred_masks[:-1])] + return dict( + cls_preds=cls_preds[-1], + masks=pred_masks[-1], + scores=pred_scores[-1], + aux_outputs=aux_outputs) + + def forward(self, x, scene_names): + """Forward pass. + + Args: + x (List[Tensor]): of len batch_size, each of shape + (n_points_i, in_channels). + scene_names (List[string]): list of len batch size, which + contains scene names. + + Returns: + Dict: with labels, masks, scores, and possibly aux_outputs. + """ + if self.iter_pred: + return self.forward_iter_pred(x, scene_names) + else: + return self.forward_simple(x, scene_names) diff --git a/oneformer3d/s3dis_dataset.py b/oneformer3d/s3dis_dataset.py new file mode 100644 index 0000000..7cc1ea4 --- /dev/null +++ b/oneformer3d/s3dis_dataset.py @@ -0,0 +1,19 @@ +from mmdet3d.registry import DATASETS +from mmdet3d.datasets.s3dis_dataset import S3DISDataset + + +@DATASETS.register_module() +class S3DISSegDataset_(S3DISDataset): + METAINFO = { + 'classes': + ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', + 'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'), + 'palette': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0], + [255, 0, 255], [100, 100, 255], [200, 200, 100], + [170, 120, 200], [255, 0, 0], [200, 100, 100], + [10, 200, 100], [200, 200, 200], [50, 50, 50]], + 'seg_valid_class_ids': + tuple(range(13)), + 'seg_all_class_ids': + tuple(range(14)) # possibly with 'stair' class + } diff --git a/oneformer3d/scannet_dataset.py b/oneformer3d/scannet_dataset.py new file mode 100644 index 0000000..535062c --- /dev/null +++ b/oneformer3d/scannet_dataset.py @@ -0,0 +1,102 @@ +from os import path as osp +import numpy as np +import random + +from mmdet3d.datasets.scannet_dataset import ScanNetSegDataset +from mmdet3d.registry import DATASETS + + +@DATASETS.register_module() +class ScanNetSegDataset_(ScanNetSegDataset): + """We just add super_pts_path.""" + + def get_scene_idxs(self, *args, **kwargs): + """Compute scene_idxs for data sampling.""" + return np.arange(len(self)).astype(np.int32) + + def parse_data_info(self, info: dict) -> dict: + """Process the raw data info. + + Args: + info (dict): Raw info dict. + + Returns: + dict: Has `ann_info` in training stage. And + all path has been converted to absolute path. + """ + info['super_pts_path'] = osp.join( + self.data_prefix.get('sp_pts_mask', ''), info['super_pts_path']) + + info = super().parse_data_info(info) + + return info + + +@DATASETS.register_module() +class ScanNet200SegDataset_(ScanNetSegDataset_): + # IMPORTANT: the floor and chair categories are swapped. + METAINFO = { + 'classes': ('wall', 'floor', 'chair', 'table', 'door', 'couch', 'cabinet', + 'shelf', 'desk', 'office chair', 'bed', 'pillow', 'sink', + 'picture', 'window', 'toilet', 'bookshelf', 'monitor', + 'curtain', 'book', 'armchair', 'coffee table', 'box', + 'refrigerator', 'lamp', 'kitchen cabinet', 'towel', 'clothes', + 'tv', 'nightstand', 'counter', 'dresser', 'stool', 'cushion', + 'plant', 'ceiling', 'bathtub', 'end table', 'dining table', + 'keyboard', 'bag', 'backpack', 'toilet paper', 'printer', + 'tv stand', 'whiteboard', 'blanket', 'shower curtain', + 'trash can', 'closet', 'stairs', 'microwave', 'stove', 'shoe', + 'computer tower', 'bottle', 'bin', 'ottoman', 'bench', 'board', + 'washing machine', 'mirror', 'copier', 'basket', 'sofa chair', + 'file cabinet', 'fan', 'laptop', 'shower', 'paper', 'person', + 'paper towel dispenser', 'oven', 'blinds', 'rack', 'plate', + 'blackboard', 'piano', 'suitcase', 'rail', 'radiator', + 'recycling bin', 'container', 'wardrobe', 'soap dispenser', + 'telephone', 'bucket', 'clock', 'stand', 'light', + 'laundry basket', 'pipe', 'clothes dryer', 'guitar', + 'toilet paper holder', 'seat', 'speaker', 'column', 'bicycle', + 'ladder', 'bathroom stall', 'shower wall', 'cup', 'jacket', + 'storage bin', 'coffee maker', 'dishwasher', + 'paper towel roll', 'machine', 'mat', 'windowsill', 'bar', + 'toaster', 'bulletin board', 'ironing board', 'fireplace', + 'soap dish', 'kitchen counter', 'doorframe', + 'toilet paper dispenser', 'mini fridge', 'fire extinguisher', + 'ball', 'hat', 'shower curtain rod', 'water cooler', + 'paper cutter', 'tray', 'shower door', 'pillar', 'ledge', + 'toaster oven', 'mouse', 'toilet seat cover dispenser', + 'furniture', 'cart', 'storage container', 'scale', + 'tissue box', 'light switch', 'crate', 'power outlet', + 'decoration', 'sign', 'projector', 'closet door', + 'vacuum cleaner', 'candle', 'plunger', 'stuffed animal', + 'headphones', 'dish rack', 'broom', 'guitar case', + 'range hood', 'dustpan', 'hair dryer', 'water bottle', + 'handicap bar', 'purse', 'vent', 'shower floor', + 'water pitcher', 'mailbox', 'bowl', 'paper bag', 'alarm clock', + 'music stand', 'projector screen', 'divider', + 'laundry detergent', 'bathroom counter', 'object', + 'bathroom vanity', 'closet wall', 'laundry hamper', + 'bathroom stall door', 'ceiling light', 'trash bin', + 'dumbbell', 'stair rail', 'tube', 'bathroom cabinet', + 'cd case', 'closet rod', 'coffee kettle', 'structure', + 'shower head', 'keyboard piano', 'case of water bottles', + 'coat rack', 'storage organizer', 'folded chair', 'fire alarm', + 'power strip', 'calendar', 'poster', 'potted plant', 'luggage', + 'mattress'), + # the valid ids of segmentation annotations + 'seg_valid_class_ids': ( + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, + 23, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, 41, 42, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 82, 84, 86, + 87, 88, 89, 90, 93, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 110, 112, 115, 116, 118, 120, 121, 122, 125, 128, 130, 131, + 132, 134, 136, 138, 139, 140, 141, 145, 148, 154,155, 156, 157, 159, + 161, 163, 165, 166, 168, 169, 170, 177, 180, 185, 188, 191, 193, 195, + 202, 208, 213, 214, 221, 229, 230, 232, 233, 242, 250, 261, 264, 276, + 283, 286, 300, 304, 312, 323, 325, 331, 342, 356, 370, 392, 395, 399, + 408, 417, 488, 540, 562, 570, 572, 581, 609, 748, 776, 1156, 1163, + 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, + 1176, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, + 1189, 1190, 1191), + 'seg_all_class_ids': tuple(range(1, 1358)), + 'palette': [random.sample(range(0, 255), 3) for i in range(200)]} diff --git a/oneformer3d/semantic_criterion.py b/oneformer3d/semantic_criterion.py new file mode 100644 index 0000000..09f55d2 --- /dev/null +++ b/oneformer3d/semantic_criterion.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class ScanNetSemanticCriterion: + """Semantic criterion for ScanNet. + + Args: + ignore_index (int): Ignore index. + loss_weight (float): Loss weight. + """ + + def __init__(self, ignore_index, loss_weight): + self.ignore_index = ignore_index + self.loss_weight = loss_weight + + def __call__(self, pred, insts): + """Calculate loss. + + Args: + pred (dict): Predictions with List `sem_preds` + of len batch_size, each of shape + (n_queries_i, n_classes + 1). + insts (list): Ground truth of len batch_size, + each InstanceData_ with `sp_masks` of shape + (n_classes + 1, n_queries_i). + + Returns: + Dict: with semantic loss value. + """ + losses = [] + for pred_mask, gt_mask in zip(pred['sem_preds'], insts): + if self.ignore_index >= 0: + pred_mask = pred_mask[:, :-1] + losses.append(F.cross_entropy( + pred_mask, + gt_mask.sp_masks.float().argmax(0), + ignore_index=self.ignore_index)) + loss = self.loss_weight * torch.mean(torch.stack(losses)) + return dict(seg_loss=loss) + + +@MODELS.register_module() +class S3DISSemanticCriterion: + """Semantic criterion for S3DIS. + + Args: + loss_weight (float): loss weight. + seg_loss (ConfigDict): loss config. + """ + + def __init__(self, + loss_weight, + seg_loss=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True)): + self.seg_loss = MODELS.build(seg_loss) + self.loss_weight = loss_weight + + def get_layer_loss(self, layer, aux_outputs, insts): + """Calculate loss at intermediate level. + + Args: + layer (int): transformer layer number + aux_outputs (dict): Predictions with List `masks` + of len batch_size, each of shape + (n_points_i, n_classes + 1). + insts (list): Ground truth of len batch_size, + each InstanceData_ with `sp_masks` of shape + (n_classes + 1, n_points_i). + + Returns: + Dict: with semantic loss value. + """ + pred_masks = aux_outputs['masks'] + seg_losses = [] + for pred_mask, gt_mask in zip(pred_masks, insts): + seg_loss = self.seg_loss( + pred_mask.T, gt_mask.sp_masks.float().argmax(0)) + seg_losses.append(seg_loss) + + seg_loss = self.loss_weight * torch.mean(torch.stack(seg_losses)) + return {f'layer_{layer}_seg_loss': seg_loss} + + def __call__(self, pred, insts): + """Calculate loss. + + Args: + pred (dict): Predictions with List `masks` + of len batch_size, each of shape + (n_points_i, n_classes + 1). + insts (list): Ground truth of len batch_size, + each InstanceData_ with `sp_masks` of shape + (n_classes + 1, n_points_i). + + Returns: + Dict: with semantic loss value. + """ + pred_masks = pred['masks'] + seg_losses = [] + for pred_mask, gt_mask in zip(pred_masks, insts): + seg_loss = self.seg_loss( + pred_mask.T, gt_mask.sp_masks.float().argmax(0)) + seg_losses.append(seg_loss) + + seg_loss = self.loss_weight * torch.mean(torch.stack(seg_losses)) + loss = {'last_layer_seg_loss': seg_loss} + + if 'aux_outputs' in pred: + for i, aux_outputs in enumerate(pred['aux_outputs']): + loss_i = self.get_layer_loss(i, aux_outputs, insts) + loss.update(loss_i) + + return loss diff --git a/oneformer3d/spconv_unet.py b/oneformer3d/spconv_unet.py new file mode 100644 index 0000000..f81fcbf --- /dev/null +++ b/oneformer3d/spconv_unet.py @@ -0,0 +1,236 @@ +# Adapted from sunjiahao1999/SPFormer. +import functools +from collections import OrderedDict + +import spconv.pytorch as spconv +import torch +from spconv.pytorch.modules import SparseModule +from torch import nn + +from mmdet3d.registry import MODELS + + +class ResidualBlock(SparseModule): + """Resudual block for SpConv U-Net. + + Args: + in_channels (int): Number of input channels. + out_channels (int: Number of output channels. + norm_fn (Callable): Normalization function constructor. + indice_key (str): SpConv key for conv layer. + normalize_before (bool): Wheter to call norm before conv. + """ + + def __init__(self, + in_channels, + out_channels, + norm_fn=functools.partial( + nn.BatchNorm1d, eps=1e-4, momentum=0.1), + indice_key=None, + normalize_before=True): + super().__init__() + + if in_channels == out_channels: + self.i_branch = spconv.SparseSequential(nn.Identity()) + else: + self.i_branch = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, out_channels, kernel_size=1, bias=False)) + + if normalize_before: + self.conv_branch = spconv.SparseSequential( + norm_fn(in_channels), nn.ReLU(), + spconv.SubMConv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key), norm_fn(out_channels), nn.ReLU(), + spconv.SubMConv3d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key)) + else: + self.conv_branch = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key), norm_fn(out_channels), nn.ReLU(), + spconv.SubMConv3d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key), norm_fn(out_channels), nn.ReLU()) + + def forward(self, input): + """Forward pass. + + Args: + input (SparseConvTensor): Input tensor. + + Returns: + SparseConvTensor: Output tensor. + """ + identity = spconv.SparseConvTensor(input.features, input.indices, + input.spatial_shape, + input.batch_size) + + output = self.conv_branch(input) + output = output.replace_feature(output.features + + self.i_branch(identity).features) + + return output + + +@MODELS.register_module() +class SpConvUNet(nn.Module): + """SpConv U-Net model. + + Args: + num_planes (List[int]): Number of channels in each level. + norm_fn (Callable): Normalization function constructor. + block_reps (int): Times to repeat each block. + block (Callable): Block base class. + indice_key_id (int): Id of current level. + normalize_before (bool): Wheter to call norm before conv. + return_blocks (bool): Whether to return previous blocks. + """ + + def __init__(self, + num_planes, + norm_fn=functools.partial( + nn.BatchNorm1d, eps=1e-4, momentum=0.1), + block_reps=2, + block=ResidualBlock, + indice_key_id=1, + normalize_before=True, + return_blocks=False): + super().__init__() + self.return_blocks = return_blocks + self.num_planes = num_planes + + # process block and norm_fn caller + if isinstance(block, str): + area = ['residual', 'vgg', 'asym'] + assert block in area, f'block must be in {area}, but got {block}' + if block == 'residual': + block = ResidualBlock + + blocks = { + f'block{i}': block( + num_planes[0], + num_planes[0], + norm_fn, + normalize_before=normalize_before, + indice_key=f'subm{indice_key_id}') + for i in range(block_reps) + } + blocks = OrderedDict(blocks) + self.blocks = spconv.SparseSequential(blocks) + + if len(num_planes) > 1: + if normalize_before: + self.conv = spconv.SparseSequential( + norm_fn(num_planes[0]), nn.ReLU(), + spconv.SparseConv3d( + num_planes[0], + num_planes[1], + kernel_size=2, + stride=2, + bias=False, + indice_key=f'spconv{indice_key_id}')) + else: + self.conv = spconv.SparseSequential( + spconv.SparseConv3d( + num_planes[0], + num_planes[1], + kernel_size=2, + stride=2, + bias=False, + indice_key=f'spconv{indice_key_id}'), + norm_fn(num_planes[1]), nn.ReLU()) + + self.u = SpConvUNet( + num_planes[1:], + norm_fn, + block_reps, + block, + indice_key_id=indice_key_id + 1, + normalize_before=normalize_before, + return_blocks=return_blocks) + + if normalize_before: + self.deconv = spconv.SparseSequential( + norm_fn(num_planes[1]), nn.ReLU(), + spconv.SparseInverseConv3d( + num_planes[1], + num_planes[0], + kernel_size=2, + bias=False, + indice_key=f'spconv{indice_key_id}')) + else: + self.deconv = spconv.SparseSequential( + spconv.SparseInverseConv3d( + num_planes[1], + num_planes[0], + kernel_size=2, + bias=False, + indice_key=f'spconv{indice_key_id}'), + norm_fn(num_planes[0]), nn.ReLU()) + + blocks_tail = {} + for i in range(block_reps): + blocks_tail[f'block{i}'] = block( + num_planes[0] * (2 - i), + num_planes[0], + norm_fn, + indice_key=f'subm{indice_key_id}', + normalize_before=normalize_before) + blocks_tail = OrderedDict(blocks_tail) + self.blocks_tail = spconv.SparseSequential(blocks_tail) + + def forward(self, input, previous_outputs=None): + """Forward pass. + + Args: + input (SparseConvTensor): Input tensor. + previous_outputs (List[SparseConvTensor]): Previous imput tensors. + + Returns: + SparseConvTensor: Output tensor. + """ + output = self.blocks(input) + identity = spconv.SparseConvTensor(output.features, output.indices, + output.spatial_shape, + output.batch_size) + + if len(self.num_planes) > 1: + output_decoder = self.conv(output) + if self.return_blocks: + output_decoder, previous_outputs = self.u( + output_decoder, previous_outputs) + else: + output_decoder = self.u(output_decoder) + output_decoder = self.deconv(output_decoder) + + output = output.replace_feature( + torch.cat((identity.features, output_decoder.features), dim=1)) + output = self.blocks_tail(output) + + if self.return_blocks: + # NOTE: to avoid the residual bug + if previous_outputs is None: + previous_outputs = [] + previous_outputs.append(output) + return output, previous_outputs + else: + return output diff --git a/oneformer3d/structured3d_dataset.py b/oneformer3d/structured3d_dataset.py new file mode 100644 index 0000000..c19cbb1 --- /dev/null +++ b/oneformer3d/structured3d_dataset.py @@ -0,0 +1,88 @@ +import numpy as np + +from mmengine.dataset.dataset_wrapper import ConcatDataset +from mmengine.dataset.base_dataset import BaseDataset +from mmdet3d.datasets.seg3d_dataset import Seg3DDataset +from mmdet3d.registry import DATASETS + + +@DATASETS.register_module() +class Structured3DSegDataset(Seg3DDataset): + METAINFO = { + 'classes': + ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', + 'window', 'picture', 'counter', 'desk', 'shelves', 'curtain', + 'dresser', 'pillow', 'mirror', 'ceiling', 'fridge', 'television', + 'night stand', 'toilet', 'sink', 'lamp', 'bathtub', 'structure', + 'furniture', 'prop'), + 'palette': [[135, 141, 249], [91, 186, 154], [134, 196, 138], + [205, 82, 150], [245, 38, 29], [238, 130,249], [189, 22, 4], + [128, 94, 103], [121, 74, 63], [98, 252, 9], [227, 8, 226], + [224, 58, 233], [244, 26, 146], [50, 62, 237], + [141, 30, 106], [60, 187, 63], [206, 106, 254], + [164, 85, 194], [187, 218, 244], [244, 140, 56], + [118, 8, 242], [88, 60, 134], [230, 110, 157], + [174, 48, 170], [3, 119, 80], [69, 148, 166], + [171, 16, 47], [81, 66, 251]], + 'seg_valid_class_ids': + (1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, + 32, 33, 34, 35, 36, 38, 39, 40), + 'seg_all_class_ids': + tuple(range(41)), + } + + def get_scene_idxs(self, scene_idxs): + """Compute scene_idxs for data sampling. + + We sample more times for scenes with more points. + """ + return np.arange(len(self)).astype(np.int32) + + +@DATASETS.register_module() +class ConcatDataset_(ConcatDataset): + """A wrapper of concatenated dataset. + + Args: + datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets + which will be concatenated. + lazy_init (bool, optional): Whether to load annotation during + instantiation. Defaults to False. + ignore_keys (List[str] or str): Ignore the keys that can be + unequal in `dataset.metainfo`. Defaults to None. + `New in version 0.3.0.` + """ + + def __init__(self, + datasets, + lazy_init=False, + ignore_keys=None): + self.datasets = [] + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + self.datasets.append(DATASETS.build(dataset)) + elif isinstance(dataset, BaseDataset): + self.datasets.append(dataset) + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') + if ignore_keys is None: + self.ignore_keys = [] + elif isinstance(ignore_keys, str): + self.ignore_keys = [ignore_keys] + elif isinstance(ignore_keys, list): + self.ignore_keys = ignore_keys + else: + raise TypeError('ignore_keys should be a list or str, ' + f'but got {type(ignore_keys)}') + + meta_keys: set = set() + for dataset in self.datasets: + meta_keys |= dataset.metainfo.keys() + # Only use metainfo of first dataset. + self._metainfo = self.datasets[0].metainfo + + self._fully_initialized = False + if not lazy_init: + self.full_init() diff --git a/oneformer3d/structures.py b/oneformer3d/structures.py new file mode 100644 index 0000000..2037bb5 --- /dev/null +++ b/oneformer3d/structures.py @@ -0,0 +1,25 @@ +from collections.abc import Sized +from mmengine.structures import InstanceData + + +class InstanceData_(InstanceData): + """We only remove a single assert from __setattr__.""" + + def __setattr__(self, name: str, value: Sized): + """setattr is only used to set data. + + The value must have the attribute of `__len__` and have the same length + of `InstanceData`. + """ + if name in ('_metainfo_fields', '_data_fields'): + if not hasattr(self, name): + super(InstanceData, self).__setattr__(name, value) + else: + raise AttributeError(f'{name} has been used as a ' + 'private attribute, which is immutable.') + + else: + assert isinstance(value, + Sized), 'value must contain `__len__` attribute' + + super(InstanceData, self).__setattr__(name, value) diff --git a/oneformer3d/transforms_3d.py b/oneformer3d/transforms_3d.py new file mode 100644 index 0000000..242d306 --- /dev/null +++ b/oneformer3d/transforms_3d.py @@ -0,0 +1,408 @@ +import numpy as np +import scipy +import torch +from torch_scatter import scatter_mean +from mmcv.transforms import BaseTransform +from mmdet3d.datasets.transforms import PointSample + +from mmdet3d.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class ElasticTransfrom(BaseTransform): + """Apply elastic augmentation to a 3D scene. Required Keys: + + Args: + gran (List[float]): Size of the noise grid (in same scale[m/cm] + as the voxel grid). + mag (List[float]): Noise multiplier. + voxel_size (float): Voxel size. + p (float): probability of applying this transform. + """ + + def __init__(self, gran, mag, voxel_size, p=1.0): + self.gran = gran + self.mag = mag + self.voxel_size = voxel_size + self.p = p + + def transform(self, input_dict): + """Private function-wrapper for elastic transform. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after elastic, 'points' is updated + in the result dict. + """ + coords = input_dict['points'].tensor[:, :3].numpy() / self.voxel_size + if np.random.rand() < self.p: + coords = self.elastic(coords, self.gran[0], self.mag[0]) + coords = self.elastic(coords, self.gran[1], self.mag[1]) + input_dict['elastic_coords'] = coords + return input_dict + + def elastic(self, x, gran, mag): + """Private function for elastic transform to a points. + + Args: + x (ndarray): Point cloud. + gran (List[float]): Size of the noise grid (in same scale[m/cm] + as the voxel grid). + mag: (List[float]): Noise multiplier. + + Returns: + dict: Results after elastic, 'points' is updated + in the result dict. + """ + blur0 = np.ones((3, 1, 1)).astype('float32') / 3 + blur1 = np.ones((1, 3, 1)).astype('float32') / 3 + blur2 = np.ones((1, 1, 3)).astype('float32') / 3 + + noise_dim = np.abs(x).max(0).astype(np.int32) // gran + 3 + noise = [ + np.random.randn(noise_dim[0], noise_dim[1], + noise_dim[2]).astype('float32') for _ in range(3) + ] + + for blur in [blur0, blur1, blur2, blur0, blur1, blur2]: + noise = [ + scipy.ndimage.filters.convolve( + n, blur, mode='constant', cval=0) for n in noise + ] + + ax = [ + np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in noise_dim + ] + interp = [ + scipy.interpolate.RegularGridInterpolator( + ax, n, bounds_error=0, fill_value=0) for n in noise + ] + + return x + np.hstack([i(x)[:, None] for i in interp]) * mag + + +@TRANSFORMS.register_module() +class AddSuperPointAnnotations(BaseTransform): + """Prepare ground truth markup for training. + + Required Keys: + - pts_semantic_mask (np.float32) + + Added Keys: + - gt_sp_masks (np.int64) + + Args: + num_classes (int): Number of classes. + """ + + def __init__(self, + num_classes, + stuff_classes, + merge_non_stuff_cls=True): + self.num_classes = num_classes + self.stuff_classes = stuff_classes + self.merge_non_stuff_cls = merge_non_stuff_cls + + def transform(self, input_dict): + """Private function for preparation ground truth + markup for training. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: results, 'gt_sp_masks' is added. + """ + # create class mapping + # because pts_instance_mask contains instances from non-instaces classes + pts_instance_mask = torch.tensor(input_dict['pts_instance_mask']) + pts_semantic_mask = torch.tensor(input_dict['pts_semantic_mask']) + + pts_instance_mask[pts_semantic_mask == self.num_classes] = -1 + for stuff_cls in self.stuff_classes: + pts_instance_mask[pts_semantic_mask == stuff_cls] = -1 + + idxs = torch.unique(pts_instance_mask) + assert idxs[0] == -1 + + mapping = torch.zeros(torch.max(idxs) + 2, dtype=torch.long) + new_idxs = torch.arange(len(idxs), device=idxs.device) + mapping[idxs] = new_idxs - 1 + pts_instance_mask = mapping[pts_instance_mask] + input_dict['pts_instance_mask'] = pts_instance_mask.numpy() + + + # create gt instance markup + insts_mask = pts_instance_mask.clone() + + if torch.sum(insts_mask == -1) != 0: + insts_mask[insts_mask == -1] = torch.max(insts_mask) + 1 + insts_mask = torch.nn.functional.one_hot(insts_mask)[:, :-1] + else: + insts_mask = torch.nn.functional.one_hot(insts_mask) + + if insts_mask.shape[1] != 0: + insts_mask = insts_mask.T + sp_pts_mask = torch.tensor(input_dict['sp_pts_mask']) + sp_masks_inst = scatter_mean( + insts_mask.float(), sp_pts_mask, dim=-1) + sp_masks_inst = sp_masks_inst > 0.5 + else: + sp_masks_inst = insts_mask.new_zeros( + (0, input_dict['sp_pts_mask'].max() + 1), dtype=torch.bool) + + num_stuff_cls = len(self.stuff_classes) + insts = new_idxs[1:] - 1 + if self.merge_non_stuff_cls: + gt_labels = insts.new_zeros(len(insts) + num_stuff_cls + 1) + else: + gt_labels = insts.new_zeros(len(insts) + self.num_classes + 1) + + for inst in insts: + index = pts_semantic_mask[pts_instance_mask == inst][0] + gt_labels[inst] = index - num_stuff_cls + + input_dict['gt_labels_3d'] = gt_labels.numpy() + + # create gt semantic markup + sem_mask = torch.tensor(input_dict['pts_semantic_mask']) + sem_mask = torch.nn.functional.one_hot(sem_mask, + num_classes=self.num_classes + 1) + + sem_mask = sem_mask.T + sp_pts_mask = torch.tensor(input_dict['sp_pts_mask']) + sp_masks_seg = scatter_mean(sem_mask.float(), sp_pts_mask, dim=-1) + sp_masks_seg = sp_masks_seg > 0.5 + + sp_masks_seg[-1, sp_masks_seg.sum(axis=0) == 0] = True + + assert sp_masks_seg.sum(axis=0).max().item() + + if self.merge_non_stuff_cls: + sp_masks_seg = torch.vstack(( + sp_masks_seg[:num_stuff_cls, :], + sp_masks_seg[num_stuff_cls:, :].sum(axis=0).unsqueeze(0))) + + sp_masks_all = torch.vstack((sp_masks_inst, sp_masks_seg)) + + input_dict['gt_sp_masks'] = sp_masks_all.numpy() + + # create eval markup + if 'eval_ann_info' in input_dict.keys(): + pts_instance_mask[pts_instance_mask != -1] += num_stuff_cls + for idx, stuff_cls in enumerate(self.stuff_classes): + pts_instance_mask[pts_semantic_mask == stuff_cls] = idx + + input_dict['eval_ann_info']['pts_instance_mask'] = \ + pts_instance_mask.numpy() + + return input_dict + + +@TRANSFORMS.register_module() +class SwapChairAndFloor(BaseTransform): + """Swap two categories for ScanNet200 dataset. It is convenient for + panoptic evaluation. After this swap first two categories are + `stuff` and other 198 are `thing`. + """ + def transform(self, input_dict): + """Private function-wrapper for swap transform. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after swap, 'pts_semantic_mask' is updated + in the result dict. + """ + mask = input_dict['pts_semantic_mask'].copy() + mask[input_dict['pts_semantic_mask'] == 2] = 3 + mask[input_dict['pts_semantic_mask'] == 3] = 2 + input_dict['pts_semantic_mask'] = mask + if 'eval_ann_info' in input_dict: + input_dict['eval_ann_info']['pts_semantic_mask'] = mask + return input_dict + + +@TRANSFORMS.register_module() +class PointInstClassMapping_(BaseTransform): + """Delete instances from non-instaces classes. + + Required Keys: + - pts_instance_mask (np.float32) + - pts_semantic_mask (np.float32) + + Modified Keys: + - pts_instance_mask (np.float32) + - pts_semantic_mask (np.float32) + + Added Keys: + - gt_labels_3d (int) + + Args: + num_classes (int): Number of classes. + """ + + def __init__(self, num_classes, structured3d=False): + self.num_classes = num_classes + self.structured3d = structured3d + + def transform(self, input_dict): + """Private function for deleting + instances from non-instaces classes. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: results, 'pts_instance_mask', 'pts_semantic_mask', + are updated in the result dict. 'gt_labels_3d' is added. + """ + + # because pts_instance_mask contains instances from non-instaces + # classes + pts_instance_mask = np.array(input_dict['pts_instance_mask']) + pts_semantic_mask = input_dict['pts_semantic_mask'] + + if self.structured3d: + # wall as one instance + pts_instance_mask[pts_semantic_mask == 0] = \ + pts_instance_mask.max() + 1 + # floor as one instance + pts_instance_mask[pts_semantic_mask == 1] = \ + pts_instance_mask.max() + 1 + + pts_instance_mask[pts_semantic_mask == self.num_classes] = -1 + pts_semantic_mask[pts_semantic_mask == self.num_classes] = -1 + + idxs = np.unique(pts_instance_mask) + mapping = np.zeros(np.max(idxs) + 2, dtype=int) + new_idxs = np.arange(len(idxs)) + if idxs[0] == -1: + mapping[idxs] = new_idxs - 1 + new_idxs = new_idxs[:-1] + else: + mapping[idxs] = new_idxs + pts_instance_mask = mapping[pts_instance_mask] + + input_dict['pts_instance_mask'] = pts_instance_mask + input_dict['pts_semantic_mask'] = pts_semantic_mask + + gt_labels = np.zeros(len(new_idxs), dtype=int) + for inst in new_idxs: + gt_labels[inst] = pts_semantic_mask[pts_instance_mask == inst][0] + + input_dict['gt_labels_3d'] = gt_labels + + return input_dict + + +@TRANSFORMS.register_module() +class PointSample_(PointSample): + + def _points_random_sampling(self, points, num_samples): + """Points random sampling. Sample points to a certain number. + + Args: + points (:obj:`BasePoints`): 3D Points. + num_samples (int): Number of samples to be sampled. + + Returns: + tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: + - points (:obj:`BasePoints`): 3D Points. + - choices (np.ndarray, optional): The generated random samples. + """ + + point_range = range(len(points)) + choices = np.random.choice(point_range, + min(num_samples, len(points))) + + return points[choices], choices + + def transform(self, input_dict): + """Transform function to sample points to in indoor scenes. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after sampling, 'points', 'pts_instance_mask', + 'pts_semantic_mask', sp_pts_mask' keys are updated in the + result dict. + """ + points = input_dict['points'] + points, choices = self._points_random_sampling( + points, self.num_points) + input_dict['points'] = points + pts_instance_mask = input_dict.get('pts_instance_mask', None) + pts_semantic_mask = input_dict.get('pts_semantic_mask', None) + sp_pts_mask = input_dict.get('sp_pts_mask', None) + + if pts_instance_mask is not None: + pts_instance_mask = pts_instance_mask[choices] + + idxs = np.unique(pts_instance_mask) + mapping = np.zeros(np.max(idxs) + 2, dtype=int) + new_idxs = np.arange(len(idxs)) + if idxs[0] == -1: + mapping[idxs] = new_idxs - 1 + else: + mapping[idxs] = new_idxs + pts_instance_mask = mapping[pts_instance_mask] + + input_dict['pts_instance_mask'] = pts_instance_mask + + if pts_semantic_mask is not None: + pts_semantic_mask = pts_semantic_mask[choices] + input_dict['pts_semantic_mask'] = pts_semantic_mask + + if sp_pts_mask is not None: + sp_pts_mask = sp_pts_mask[choices] + sp_pts_mask = np.unique( + sp_pts_mask, return_inverse=True)[1] + input_dict['sp_pts_mask'] = sp_pts_mask + return input_dict + +@TRANSFORMS.register_module() +class SkipEmptyScene(BaseTransform): + """Skip empty scene during training. + + Required Keys: + - pts_instance_mask (np.float32) + - pts_semantic_mask (np.float32) + - points (:obj:`BasePoints`) + - gt_labels_3d (int) + + Modified Keys: + - pts_instance_mask (np.float32) + - pts_semantic_mask (np.float32) + - points (:obj:`BasePoints`) + - gt_labels_3d (int) + + """ + + def transform(self, input_dict): + """Private function for skipping empty scene during training. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: results, 'pts_instance_mask', 'pts_semantic_mask', + 'points', 'gt_labels_3d' are updated in the result dict. + """ + + if len(input_dict['gt_labels_3d']) != 0: + self.inst = input_dict['pts_instance_mask'] + self.sem = input_dict['pts_semantic_mask'] + self.gt_labels = input_dict['gt_labels_3d'] + self.points = input_dict['points'] + else: + input_dict['pts_instance_mask'] = self.inst + input_dict['pts_semantic_mask'] = self.sem + input_dict['gt_labels_3d'] = self.gt_labels + input_dict['points'] = self.points + + return input_dict diff --git a/oneformer3d/unified_criterion.py b/oneformer3d/unified_criterion.py new file mode 100644 index 0000000..da4a718 --- /dev/null +++ b/oneformer3d/unified_criterion.py @@ -0,0 +1,161 @@ +from mmdet3d.registry import MODELS +from .structures import InstanceData_ + + +@MODELS.register_module() +class ScanNetUnifiedCriterion: + """Simply call semantic and instance criterions. + + Args: + num_semantic_classes (int): Number of semantic classes. + sem_criterion (ConfigDict): Class for semantic loss calculation. + inst_criterion (ConfigDict): Class for instance loss calculation. + """ + + def __init__(self, num_semantic_classes, sem_criterion, inst_criterion): + self.num_semantic_classes = num_semantic_classes + self.sem_criterion = MODELS.build(sem_criterion) + self.inst_criterion = MODELS.build(inst_criterion) + + def __call__(self, pred, insts): + """Calculate loss. + + Args: + pred (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_queries, n_classes + 1) + List `scores` of len batch_size each of shape (n_queries, 1) + List `masks` of len batch_size each of shape + (n_queries, n_points) + Dict `aux_preds` with list of cls_preds, scores, and masks + List `sem_preds` of len batch_size each of shape + (n_queries, n_classes + 1). + insts (list): Ground truth of len batch_size, + each InstanceData_ with + `sp_masks` of shape (n_gts_i + n_classes + 1, n_points_i) + `labels_3d` of shape (n_gts_i + n_classes + 1,) + `query_masks` of shape + (n_gts_i + n_classes + 1, n_queries_i). + + Returns: + Dict: with semantic and instance loss values. + """ + sem_gts = [] + inst_gts = [] + n = self.num_semantic_classes + + for i in range(len(pred['masks'])): + sem_gt = InstanceData_() + if insts[i].get('query_masks') is not None: + sem_gt.sp_masks = insts[i].query_masks[-n - 1:, :] + else: + sem_gt.sp_masks = insts[i].sp_masks[-n - 1:, :] + sem_gts.append(sem_gt) + + inst_gt = InstanceData_() + inst_gt.sp_masks = insts[i].sp_masks[:-n - 1, :] + inst_gt.labels_3d = insts[i].labels_3d[:-n - 1] + if insts[i].get('query_masks') is not None: + inst_gt.query_masks = insts[i].query_masks[:-n - 1, :] + inst_gts.append(inst_gt) + + loss = self.inst_criterion(pred, inst_gts) + loss.update(self.sem_criterion(pred, sem_gts)) + return loss + +@MODELS.register_module() +class S3DISUnifiedCriterion: + """Simply call semantic and instance criterions. + + Args: + num_semantic_classes (int): Number of semantic classes. + sem_criterion (ConfigDict): Class for semantic loss calculation. + inst_criterion (ConfigDict): Class for instance loss calculation. + """ + + def __init__(self, num_semantic_classes, sem_criterion, inst_criterion): + self.num_semantic_classes = num_semantic_classes + self.sem_criterion = MODELS.build(sem_criterion) + self.inst_criterion = MODELS.build(inst_criterion) + + def __call__(self, pred, insts): + """Calculate loss. + + Args: + pred (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_queries, n_classes + 1) + List `scores` of len batch_size each of shape (n_queries, 1) + List `masks` of len batch_size each of shape + (n_queries, n_points) + Dict `aux_preds` with list of cls_preds, scores, and masks + insts (list): Ground truth of len batch_size, + each InstanceData_ with + `sp_inst_masks` of shape + (n_gts_i, n_points_i) + `sp_sem_masks` of shape + (n_classes + 1, n_points_i) + `labels_3d` of shape (n_gts_i + n_classes + 1,). + + Returns: + Dict: with semantic and instance loss values. + """ + pred_masks = pred['masks'] + pred_cls = pred['cls_preds'] + pred_scores = pred['scores'] + + sem_preds = [] + sem_gts = [] + inst_gts = [] + n = self.num_semantic_classes + for i in range(len(pred_masks)): + sem_preds.append(pred_masks[i][-n:, :]) + pred_masks[i] = pred_masks[i][:-n, :] + pred_cls[i] = pred_cls[i][:-n, :] + pred_scores[i] = pred_scores[i][:-n, :] + + sem_gt = InstanceData_() + inst_gt = InstanceData_() + sem_gt.sp_masks = insts[i].sp_sem_masks + sem_gts.append(sem_gt) + inst_gt.sp_masks = insts[i].sp_inst_masks + inst_gt.labels_3d = insts[i].labels_3d + inst_gts.append(inst_gt) + + if 'aux_outputs' in pred: + sem_aux_outputs = [] + for aux_outputs in pred['aux_outputs']: + sem_aux_outputs.append(self.prepare_aux_outputs(aux_outputs)) + + loss = self.inst_criterion(pred, inst_gts) + loss.update(self.sem_criterion( + {'masks': sem_preds, 'aux_outputs': sem_aux_outputs}, sem_gts)) + return loss + + def prepare_aux_outputs(self, aux_outputs): + """Prepare aux outputs for intermediate layers. + + Args: + aux_outputs (Dict): + List `cls_preds` of shape len batch_size, each of shape + (n_queries, n_classes + 1) + List `scores` of len batch_size each of shape (n_queries, 1) + List `masks` of len batch_size each of shape + (n_queries, n_points). + + Returns: + Dict: with semantic predictions. + """ + pred_masks = aux_outputs['masks'] + pred_cls = aux_outputs['cls_preds'] + pred_scores = aux_outputs['scores'] + + sem_preds = [] + n = self.num_semantic_classes + for i in range(len(pred_masks)): + sem_preds.append(pred_masks[i][-n:, :]) + pred_masks[i] = pred_masks[i][:-n, :] + pred_cls[i] = pred_cls[i][:-n, :] + pred_scores[i] = pred_scores[i][:-n, :] + + return {'masks': sem_preds} diff --git a/oneformer3d/unified_metric.py b/oneformer3d/unified_metric.py new file mode 100644 index 0000000..d6e526c --- /dev/null +++ b/oneformer3d/unified_metric.py @@ -0,0 +1,255 @@ +import torch +import numpy as np + +from mmengine.logging import MMLogger + +from mmdet3d.evaluation import InstanceSegMetric +from mmdet3d.evaluation.metrics import SegMetric +from mmdet3d.registry import METRICS +from mmdet3d.evaluation import panoptic_seg_eval, seg_eval +from .instance_seg_eval import instance_seg_eval + + +@METRICS.register_module() +class UnifiedSegMetric(SegMetric): + """Metric for instance, semantic, and panoptic evaluation. + The order of classes must be [stuff classes, thing classes, unlabeled]. + + Args: + thing_class_inds (List[int]): Ids of thing classes. + stuff_class_inds (List[int]): Ids of stuff classes. + min_num_points (int): Minimal size of mask for panoptic segmentation. + id_offset (int): Offset for instance classes. + sem_mapping (List[int]): Semantic class to gt id. + inst_mapping (List[int]): Instance class to gt id. + metric_meta (Dict): Analogue of dataset meta of SegMetric. Keys: + `label2cat` (Dict[int, str]): class names, + `ignore_index` (List[int]): ids of semantic categories to ignore, + `classes` (List[str]): class names. + logger_keys (List[Tuple]): Keys for logger to save; of len 3: + semantic, instance, and panoptic. + """ + + def __init__(self, + thing_class_inds, + stuff_class_inds, + min_num_points, + id_offset, + sem_mapping, + inst_mapping, + metric_meta, + logger_keys=[('miou',), + ('all_ap', 'all_ap_50%', 'all_ap_25%'), + ('pq',)], + **kwargs): + self.thing_class_inds = thing_class_inds + self.stuff_class_inds = stuff_class_inds + self.min_num_points = min_num_points + self.id_offset = id_offset + self.metric_meta = metric_meta + self.logger_keys = logger_keys + self.sem_mapping = np.array(sem_mapping) + self.inst_mapping = np.array(inst_mapping) + super().__init__(**kwargs) + + def compute_metrics(self, results): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] + label2cat = self.metric_meta['label2cat'] + ignore_index = self.metric_meta['ignore_index'] + classes = self.metric_meta['classes'] + thing_classes = [classes[i] for i in self.thing_class_inds] + stuff_classes = [classes[i] for i in self.stuff_class_inds] + num_stuff_cls = len(stuff_classes) + + gt_semantic_masks_inst_task = [] + gt_instance_masks_inst_task = [] + pred_instance_masks_inst_task = [] + pred_instance_labels = [] + pred_instance_scores = [] + + gt_semantic_masks_sem_task = [] + pred_semantic_masks_sem_task = [] + + gt_masks_pan = [] + pred_masks_pan = [] + + for eval_ann, single_pred_results in results: + + if self.metric_meta['dataset_name'] == 'S3DIS': + pan_gt = {} + pan_gt['pts_semantic_mask'] = eval_ann['pts_semantic_mask'] + pan_gt['pts_instance_mask'] = \ + eval_ann['pts_instance_mask'].copy() + + for stuff_cls in self.stuff_class_inds: + pan_gt['pts_instance_mask'][\ + pan_gt['pts_semantic_mask'] == stuff_cls] = \ + np.max(pan_gt['pts_instance_mask']) + 1 + + pan_gt['pts_instance_mask'] = np.unique( + pan_gt['pts_instance_mask'], + return_inverse=True)[1] + gt_masks_pan.append(pan_gt) + else: + gt_masks_pan.append(eval_ann) + + pred_masks_pan.append({ + 'pts_instance_mask': \ + single_pred_results['pts_instance_mask'][1], + 'pts_semantic_mask': \ + single_pred_results['pts_semantic_mask'][1] + }) + + gt_semantic_masks_sem_task.append(eval_ann['pts_semantic_mask']) + pred_semantic_masks_sem_task.append( + single_pred_results['pts_semantic_mask'][0]) + + if self.metric_meta['dataset_name'] == 'S3DIS': + gt_semantic_masks_inst_task.append(eval_ann['pts_semantic_mask']) + gt_instance_masks_inst_task.append(eval_ann['pts_instance_mask']) + else: + sem_mask, inst_mask = self.map_inst_markup( + eval_ann['pts_semantic_mask'].copy(), + eval_ann['pts_instance_mask'].copy(), + self.valid_class_ids[num_stuff_cls:], + num_stuff_cls) + gt_semantic_masks_inst_task.append(sem_mask) + gt_instance_masks_inst_task.append(inst_mask) + + pred_instance_masks_inst_task.append( + torch.tensor(single_pred_results['pts_instance_mask'][0])) + pred_instance_labels.append( + torch.tensor(single_pred_results['instance_labels'])) + pred_instance_scores.append( + torch.tensor(single_pred_results['instance_scores'])) + + ret_pan = panoptic_seg_eval( + gt_masks_pan, pred_masks_pan, classes, thing_classes, + stuff_classes, self.min_num_points, self.id_offset, + label2cat, ignore_index, logger) + + ret_sem = seg_eval( + gt_semantic_masks_sem_task, + pred_semantic_masks_sem_task, + label2cat, + ignore_index[0], + logger=logger) + + if self.metric_meta['dataset_name'] == 'S3DIS': + # :-1 for unlabeled + ret_inst = instance_seg_eval( + gt_semantic_masks_inst_task, + gt_instance_masks_inst_task, + pred_instance_masks_inst_task, + pred_instance_labels, + pred_instance_scores, + valid_class_ids=self.valid_class_ids, + class_labels=classes[:-1], + logger=logger) + else: + # :-1 for unlabeled + ret_inst = instance_seg_eval( + gt_semantic_masks_inst_task, + gt_instance_masks_inst_task, + pred_instance_masks_inst_task, + pred_instance_labels, + pred_instance_scores, + valid_class_ids=self.valid_class_ids[num_stuff_cls:], + class_labels=classes[num_stuff_cls:-1], + logger=logger) + + metrics = dict() + for ret, keys in zip((ret_sem, ret_inst, ret_pan), self.logger_keys): + for key in keys: + metrics[key] = ret[key] + return metrics + + def map_inst_markup(self, + pts_semantic_mask, + pts_instance_mask, + valid_class_ids, + num_stuff_cls): + """Map gt instance and semantic classes back from panoptic annotations. + + Args: + pts_semantic_mask (np.array): of shape (n_raw_points,) + pts_instance_mask (np.array): of shape (n_raw_points.) + valid_class_ids (Tuple): of len n_instance_classes + num_stuff_cls (int): number of stuff classes + + Returns: + Tuple: + np.array: pts_semantic_mask of shape (n_raw_points,) + np.array: pts_instance_mask of shape (n_raw_points,) + """ + pts_instance_mask -= num_stuff_cls + pts_instance_mask[pts_instance_mask < 0] = -1 + pts_semantic_mask -= num_stuff_cls + pts_semantic_mask[pts_instance_mask == -1] = -1 + + mapping = np.array(list(valid_class_ids) + [-1]) + pts_semantic_mask = mapping[pts_semantic_mask] + + return pts_semantic_mask, pts_instance_mask + + +@METRICS.register_module() +class InstanceSegMetric_(InstanceSegMetric): + """The only difference with InstanceSegMetric is that following ScanNet + evaluator we accept instance prediction as a boolean tensor of shape + (n_points, n_instances) instead of integer tensor of shape (n_points, ). + + For this purpose we only replace instance_seg_eval call. + """ + + def compute_metrics(self, results): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + self.classes = self.dataset_meta['classes'] + self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] + + gt_semantic_masks = [] + gt_instance_masks = [] + pred_instance_masks = [] + pred_instance_labels = [] + pred_instance_scores = [] + + for eval_ann, single_pred_results in results: + gt_semantic_masks.append(eval_ann['pts_semantic_mask']) + gt_instance_masks.append(eval_ann['pts_instance_mask']) + pred_instance_masks.append( + single_pred_results['pts_instance_mask']) + pred_instance_labels.append(single_pred_results['instance_labels']) + pred_instance_scores.append(single_pred_results['instance_scores']) + + ret_dict = instance_seg_eval( + gt_semantic_masks, + gt_instance_masks, + pred_instance_masks, + pred_instance_labels, + pred_instance_scores, + valid_class_ids=self.valid_class_ids, + class_labels=self.classes, + logger=logger) + + return ret_dict diff --git a/tools/create_data.py b/tools/create_data.py new file mode 100644 index 0000000..88658d6 --- /dev/null +++ b/tools/create_data.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from os import path as osp + +from indoor_converter import create_indoor_info_file +from update_infos_to_v2 import update_pkl_infos + + +def scannet_data_prep(root_path, info_prefix, out_dir, workers): + """Prepare the info file for scannet dataset. + + Args: + root_path (str): Path of dataset root. + info_prefix (str): The prefix of info filenames. + out_dir (str): Output directory of the generated info file. + workers (int): Number of threads to be used. + """ + create_indoor_info_file( + root_path, info_prefix, out_dir, workers=workers) + info_train_path = osp.join(out_dir, f'{info_prefix}_oneformer3d_infos_train.pkl') + info_val_path = osp.join(out_dir, f'{info_prefix}_oneformer3d_infos_val.pkl') + info_test_path = osp.join(out_dir, f'{info_prefix}_oneformer3d_infos_test.pkl') + update_pkl_infos(info_prefix, out_dir=out_dir, pkl_path=info_train_path) + update_pkl_infos(info_prefix, out_dir=out_dir, pkl_path=info_val_path) + update_pkl_infos(info_prefix, out_dir=out_dir, pkl_path=info_test_path) + + +parser = argparse.ArgumentParser(description='Data converter arg parser') +parser.add_argument('dataset', metavar='kitti', help='name of the dataset') +parser.add_argument( + '--root-path', + type=str, + default='./data/kitti', + help='specify the root path of dataset') +parser.add_argument( + '--out-dir', + type=str, + default='./data/kitti', + required=False, + help='name of info pkl') +parser.add_argument('--extra-tag', type=str, default='kitti') +parser.add_argument( + '--workers', type=int, default=4, help='number of threads to be used') +args = parser.parse_args() + +if __name__ == '__main__': + from mmdet3d.utils import register_all_modules + register_all_modules() + + if args.dataset in ('scannet', 'scannet200'): + scannet_data_prep( + root_path=args.root_path, + info_prefix=args.extra_tag, + out_dir=args.out_dir, + workers=args.workers) + else: + raise NotImplementedError(f'Don\'t support {args.dataset} dataset.') diff --git a/tools/fix_spconv_checkpoint.py b/tools/fix_spconv_checkpoint.py new file mode 100644 index 0000000..b838aaa --- /dev/null +++ b/tools/fix_spconv_checkpoint.py @@ -0,0 +1,18 @@ +import argparse +import torch + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--in-path', type=str, required=True) + parser.add_argument('--out-path', type=str, required=True) + args = parser.parse_args() + + checkpoint = torch.load(args.in_path) + key = 'state_dict' # 'model' for SSTNet + for layer in checkpoint[key]: + if (layer.startswith('unet') or layer.startswith('input_conv')) \ + and layer.endswith('weight') \ + and len(checkpoint[key][layer].shape) == 5: + checkpoint[key][layer] = checkpoint[key][layer].permute(1, 2, 3, 4, 0) + torch.save(checkpoint, args.out_path) diff --git a/tools/indoor_converter.py b/tools/indoor_converter.py new file mode 100644 index 0000000..1fc6f74 --- /dev/null +++ b/tools/indoor_converter.py @@ -0,0 +1,67 @@ +# Modified from mmdetection3d/tools/dataset_converters/indoor_converter.py +# We just support ScanNet 200. +import os + +import mmengine + +from scannet_data_utils import ScanNetData + + +def create_indoor_info_file(data_path, + pkl_prefix='sunrgbd', + save_path=None, + use_v1=False, + workers=4): + """Create indoor information file. + + Get information of the raw data and save it to the pkl file. + + Args: + data_path (str): Path of the data. + pkl_prefix (str, optional): Prefix of the pkl to be saved. + Default: 'sunrgbd'. + save_path (str, optional): Path of the pkl to be saved. Default: None. + use_v1 (bool, optional): Whether to use v1. Default: False. + workers (int, optional): Number of threads to be used. Default: 4. + """ + assert os.path.exists(data_path) + assert pkl_prefix in ['scannet', 'scannet200'], \ + f'unsupported indoor dataset {pkl_prefix}' + save_path = data_path if save_path is None else save_path + assert os.path.exists(save_path) + + # generate infos for both detection and segmentation task + train_filename = os.path.join( + save_path, f'{pkl_prefix}_oneformer3d_infos_train.pkl') + val_filename = os.path.join( + save_path, f'{pkl_prefix}_oneformer3d_infos_val.pkl') + test_filename = os.path.join( + save_path, f'{pkl_prefix}_oneformer3d_infos_test.pkl') + if pkl_prefix == 'scannet': + # ScanNet has a train-val-test split + train_dataset = ScanNetData(root_path=data_path, split='train') + val_dataset = ScanNetData(root_path=data_path, split='val') + test_dataset = ScanNetData(root_path=data_path, split='test') + else: # ScanNet200 + # ScanNet has a train-val-test split + train_dataset = ScanNetData(root_path=data_path, split='train', + scannet200=True, save_path=save_path) + val_dataset = ScanNetData(root_path=data_path, split='val', + scannet200=True, save_path=save_path) + test_dataset = ScanNetData(root_path=data_path, split='test', + scannet200=True, save_path=save_path) + + infos_train = train_dataset.get_infos( + num_workers=workers, has_label=True) + mmengine.dump(infos_train, train_filename, 'pkl') + print(f'{pkl_prefix} info train file is saved to {train_filename}') + + infos_val = val_dataset.get_infos( + num_workers=workers, has_label=True) + mmengine.dump(infos_val, val_filename, 'pkl') + print(f'{pkl_prefix} info val file is saved to {val_filename}') + + infos_test = test_dataset.get_infos( + num_workers=workers, has_label=False) + mmengine.dump(infos_test, test_filename, 'pkl') + print(f'{pkl_prefix} info test file is saved to {test_filename}') diff --git a/tools/scannet_data_utils.py b/tools/scannet_data_utils.py new file mode 100644 index 0000000..942b527 --- /dev/null +++ b/tools/scannet_data_utils.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from concurrent import futures as futures +from os import path as osp + +import mmengine +import numpy as np + + +class ScanNetData(object): + """ScanNet data. + Generate scannet infos for scannet_converter. + + Args: + root_path (str): Root path of the raw data. + split (str, optional): Set split type of the data. Default: 'train'. + scannet200 (bool): True for ScanNet200, else for ScanNet. + save_path (str, optional): Output directory. + """ + + def __init__(self, root_path, split='train', scannet200=False, save_path=None): + self.root_dir = root_path + self.save_path = root_path if save_path is None else save_path + self.split = split + self.split_dir = osp.join(root_path) + self.scannet200 = scannet200 + if self.scannet200: + self.classes = [ + 'chair', 'table', 'door', 'couch', 'cabinet', 'shelf', 'desk', + 'office chair', 'bed', 'pillow', 'sink', 'picture', 'window', + 'toilet', 'bookshelf', 'monitor', 'curtain', 'book', + 'armchair', 'coffee table', 'box', 'refrigerator', 'lamp', + 'kitchen cabinet', 'towel', 'clothes', 'tv', 'nightstand', + 'counter', 'dresser', 'stool', 'cushion', 'plant', 'ceiling', + 'bathtub', 'end table', 'dining table', 'keyboard', 'bag', + 'backpack', 'toilet paper', 'printer', 'tv stand', + 'whiteboard', 'blanket', 'shower curtain', 'trash can', + 'closet', 'stairs', 'microwave', 'stove', 'shoe', + 'computer tower', 'bottle', 'bin', 'ottoman', 'bench', 'board', + 'washing machine', 'mirror', 'copier', 'basket', 'sofa chair', + 'file cabinet', 'fan', 'laptop', 'shower', 'paper', 'person', + 'paper towel dispenser', 'oven', 'blinds', 'rack', 'plate', + 'blackboard', 'piano', 'suitcase', 'rail', 'radiator', + 'recycling bin', 'container', 'wardrobe', 'soap dispenser', + 'telephone', 'bucket', 'clock', 'stand', 'light', + 'laundry basket', 'pipe', 'clothes dryer', 'guitar', + 'toilet paper holder', 'seat', 'speaker', 'column', 'bicycle', + 'ladder', 'bathroom stall', 'shower wall', 'cup', 'jacket', + 'storage bin', 'coffee maker', 'dishwasher', + 'paper towel roll', 'machine', 'mat', 'windowsill', 'bar', + 'toaster', 'bulletin board', 'ironing board', 'fireplace', + 'soap dish', 'kitchen counter', 'doorframe', + 'toilet paper dispenser', 'mini fridge', 'fire extinguisher', + 'ball', 'hat', 'shower curtain rod', 'water cooler', + 'paper cutter', 'tray', 'shower door', 'pillar', 'ledge', + 'toaster oven', 'mouse', 'toilet seat cover dispenser', + 'furniture', 'cart', 'storage container', 'scale', + 'tissue box', 'light switch', 'crate', 'power outlet', + 'decoration', 'sign', 'projector', 'closet door', + 'vacuum cleaner', 'candle', 'plunger', 'stuffed animal', + 'headphones', 'dish rack', 'broom', 'guitar case', + 'range hood', 'dustpan', 'hair dryer', 'water bottle', + 'handicap bar', 'purse', 'vent', 'shower floor', + 'water pitcher', 'mailbox', 'bowl', 'paper bag', 'alarm clock', + 'music stand', 'projector screen', 'divider', + 'laundry detergent', 'bathroom counter', 'object', + 'bathroom vanity', 'closet wall', 'laundry hamper', + 'bathroom stall door', 'ceiling light', 'trash bin', + 'dumbbell', 'stair rail', 'tube', 'bathroom cabinet', + 'cd case', 'closet rod', 'coffee kettle', 'structure', + 'shower head', 'keyboard piano', 'case of water bottles', + 'coat rack', 'storage organizer', 'folded chair', 'fire alarm', + 'power strip', 'calendar', 'poster', 'potted plant', 'luggage', + 'mattress' + ] + self.cat_ids = np.array([ + 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, + 22, 23, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 38, 39, 40, + 41, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, + 59, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 82, 84, 86, 87, 88, 89, 90, 93, 95, 96, 97, 98, + 99, 100, 101, 102, 103, 104, 105, 106, 107, 110, 112, 115, 116, + 118, 120, 121, 122, 125, 128, 130, 131, 132, 134, 136, 138, + 139, 140, 141, 145, 148, 154, 155, 156, 157, 159, 161, 163, + 165, 166, 168, 169, 170, 177, 180, 185, 188, 191, 193, 195, + 202, 208, 213, 214, 221, 229, 230, 232, 233, 242, 250, 261, + 264, 276, 283, 286, 300, 304, 312, 323, 325, 331, 342, 356, + 370, 392, 395, 399, 408, 417, 488, 540, 562, 570, 572, 581, + 609, 748, 776, 1156, 1163, 1164, 1165, 1166, 1167, 1168, 1169, + 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1178, 1179, 1180, + 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, + 1191 + ]) + else: + self.classes = [ + 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin' + ] + self.cat_ids = np.array([ + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39 + ]) + + self.cat2label = {cat: self.classes.index(cat) for cat in self.classes} + self.label2cat = {self.cat2label[t]: t for t in self.cat2label} + self.cat_ids2class = { + nyu40id: i + for i, nyu40id in enumerate(list(self.cat_ids)) + } + assert split in ['train', 'val', 'test'] + split_file = osp.join(self.root_dir, 'meta_data', + f'scannetv2_{split}.txt') + mmengine.check_file_exist(split_file) + self.sample_id_list = mmengine.list_from_file(split_file) + self.test_mode = (split == 'test') + + def __len__(self): + return len(self.sample_id_list) + + def get_aligned_box_label(self, idx): + box_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_aligned_bbox.npy') + mmengine.check_file_exist(box_file) + return np.load(box_file) + + def get_unaligned_box_label(self, idx): + box_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_unaligned_bbox.npy') + mmengine.check_file_exist(box_file) + return np.load(box_file) + + def get_axis_align_matrix(self, idx): + matrix_file = osp.join(self.root_dir, 'scannet_instance_data', + f'{idx}_axis_align_matrix.npy') + mmengine.check_file_exist(matrix_file) + return np.load(matrix_file) + + def get_images(self, idx): + paths = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.jpg'): + paths.append(osp.join('posed_images', idx, file)) + return paths + + def get_extrinsics(self, idx): + extrinsics = [] + path = osp.join(self.root_dir, 'posed_images', idx) + for file in sorted(os.listdir(path)): + if file.endswith('.txt') and not file == 'intrinsic.txt': + extrinsics.append(np.loadtxt(osp.join(path, file))) + return extrinsics + + def get_intrinsics(self, idx): + matrix_file = osp.join(self.root_dir, 'posed_images', idx, + 'intrinsic.txt') + mmengine.check_file_exist(matrix_file) + return np.loadtxt(matrix_file) + + def get_infos(self, num_workers=4, has_label=True, sample_id_list=None): + """Get data infos. + + This method gets information from the raw data. + + Args: + num_workers (int, optional): Number of threads to be used. + Default: 4. + has_label (bool, optional): Whether the data has label. + Default: True. + sample_id_list (list[int], optional): Index list of the sample. + Default: None. + + Returns: + infos (list[dict]): Information of the raw data. + """ + + def process_single_scene(sample_idx): + print(f'{self.split} sample_idx: {sample_idx}') + info = dict() + pc_info = {'num_features': 6, 'lidar_idx': sample_idx} + info['point_cloud'] = pc_info + pts_filename = osp.join(self.root_dir, 'scannet_instance_data', + f'{sample_idx}_vert.npy') + points = np.load(pts_filename) + mmengine.mkdir_or_exist(osp.join(self.save_path, 'points')) + points.tofile( + osp.join(self.save_path, 'points', f'{sample_idx}.bin')) + info['pts_path'] = osp.join('points', f'{sample_idx}.bin') + + sp_filename = osp.join(self.root_dir, 'scannet_instance_data', + f'{sample_idx}_sp_label.npy') + super_points = np.load(sp_filename) + mmengine.mkdir_or_exist(osp.join(self.save_path, 'super_points')) + super_points.tofile( + osp.join(self.save_path, 'super_points', f'{sample_idx}.bin')) + info['super_pts_path'] = osp.join('super_points', f'{sample_idx}.bin') + + # update with RGB image paths if exist + if os.path.exists(osp.join(self.root_dir, 'posed_images')): + info['intrinsics'] = self.get_intrinsics(sample_idx) + all_extrinsics = self.get_extrinsics(sample_idx) + all_img_paths = self.get_images(sample_idx) + # some poses in ScanNet are invalid + extrinsics, img_paths = [], [] + for extrinsic, img_path in zip(all_extrinsics, all_img_paths): + if np.all(np.isfinite(extrinsic)): + img_paths.append(img_path) + extrinsics.append(extrinsic) + info['extrinsics'] = extrinsics + info['img_paths'] = img_paths + + if not self.test_mode: + pts_instance_mask_path = osp.join( + self.root_dir, 'scannet_instance_data', + f'{sample_idx}_ins_label.npy') + pts_semantic_mask_path = osp.join( + self.root_dir, 'scannet_instance_data', + f'{sample_idx}_sem_label.npy') + + pts_instance_mask = np.load(pts_instance_mask_path).astype( + np.int64) + pts_semantic_mask = np.load(pts_semantic_mask_path).astype( + np.int64) + + mmengine.mkdir_or_exist( + osp.join(self.save_path, 'instance_mask')) + mmengine.mkdir_or_exist( + osp.join(self.save_path, 'semantic_mask')) + + pts_instance_mask.tofile( + osp.join(self.save_path, 'instance_mask', + f'{sample_idx}.bin')) + pts_semantic_mask.tofile( + osp.join(self.save_path, 'semantic_mask', + f'{sample_idx}.bin')) + + info['pts_instance_mask_path'] = osp.join( + 'instance_mask', f'{sample_idx}.bin') + info['pts_semantic_mask_path'] = osp.join( + 'semantic_mask', f'{sample_idx}.bin') + + if has_label: + annotations = {} + # box is of shape [k, 6 + class] + aligned_box_label = self.get_aligned_box_label(sample_idx) + unaligned_box_label = self.get_unaligned_box_label(sample_idx) + annotations['gt_num'] = aligned_box_label.shape[0] + if annotations['gt_num'] != 0: + aligned_box = aligned_box_label[:, :-1] # k, 6 + unaligned_box = unaligned_box_label[:, :-1] + classes = aligned_box_label[:, -1] # k + annotations['name'] = np.array([ + self.label2cat[self.cat_ids2class[classes[i]]] + for i in range(annotations['gt_num']) + ]) + # default names are given to aligned bbox for compatibility + # we also save unaligned bbox info with marked names + annotations['location'] = aligned_box[:, :3] + annotations['dimensions'] = aligned_box[:, 3:6] + annotations['gt_boxes_upright_depth'] = aligned_box + annotations['unaligned_location'] = unaligned_box[:, :3] + annotations['unaligned_dimensions'] = unaligned_box[:, 3:6] + annotations[ + 'unaligned_gt_boxes_upright_depth'] = unaligned_box + annotations['index'] = np.arange( + annotations['gt_num'], dtype=np.int32) + annotations['class'] = np.array([ + self.cat_ids2class[classes[i]] + for i in range(annotations['gt_num']) + ]) + axis_align_matrix = self.get_axis_align_matrix(sample_idx) + annotations['axis_align_matrix'] = axis_align_matrix # 4x4 + info['annos'] = annotations + return info + + sample_id_list = sample_id_list if sample_id_list is not None \ + else self.sample_id_list + with futures.ThreadPoolExecutor(num_workers) as executor: + infos = executor.map(process_single_scene, sample_id_list) + return list(infos) diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000..cee82d8 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,149 @@ +# This is an exact copy of tools/test.py from open-mmlab/mmdetection3d. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, ConfigDict, DictAction +from mmengine.registry import RUNNERS +from mmengine.runner import Runner + +from mmdet3d.utils import replace_ceph_backend + + +# TODO: support fuse_conv_bn and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet3D test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--ceph', action='store_true', help='Use ceph as data storage backend') + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--score-thr', type=float, default=0.1, help='bbox score threshold') + parser.add_argument( + '--task', + type=str, + choices=[ + 'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg', + 'multi-modality_det' + ], + help='Determine the visualization method depending on the task.') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/test.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualization_hook['test_out_dir'] = args.show_dir + all_task_choices = [ + 'mono_det', 'multi-view_det', 'lidar_det', 'lidar_seg', + 'multi-modality_det' + ] + assert args.task in all_task_choices, 'You must set '\ + f"'--task' in {all_task_choices} in the command " \ + 'if you want to use visualization hook' + visualization_hook['vis_task'] = args.task + visualization_hook['score_thr'] = args.score_thr + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + + # TODO: We will unify the ceph support approach with other OpenMMLab repos + if args.ceph: + cfg = replace_ceph_backend(cfg) + + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + + if args.tta: + # Currently, we only support tta for 3D segmentation + # TODO: Support tta for 3D detection + assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' + assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` in config.' + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000..dd904ed --- /dev/null +++ b/tools/train.py @@ -0,0 +1,135 @@ +# This is an exact copy of tools/train.py from open-mmlab/mmdetection3d. +import argparse +import logging +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.logging import print_log +from mmengine.registry import RUNNERS +from mmengine.runner import Runner + +from mmdet3d.utils import replace_ceph_backend + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a 3D detector') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='enable automatically scaling LR.') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpoint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--ceph', action='store_true', help='Use ceph as data storage backend') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + + # TODO: We will unify the ceph support approach with other OpenMMLab repos + if args.ceph: + cfg = replace_ceph_backend(cfg) + + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + + # enable automatically scaling LR + if args.auto_scale_lr: + if 'auto_scale_lr' in cfg and \ + 'enable' in cfg.auto_scale_lr and \ + 'base_batch_size' in cfg.auto_scale_lr: + cfg.auto_scale_lr.enable = True + else: + raise RuntimeError('Can not find "auto_scale_lr" or ' + '"auto_scale_lr.enable" or ' + '"auto_scale_lr.base_batch_size" in your' + ' configuration file.') + + # resume is determined in this priority: resume from > auto_resume + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() diff --git a/tools/update_infos_to_v2.py b/tools/update_infos_to_v2.py new file mode 100644 index 0000000..00722d7 --- /dev/null +++ b/tools/update_infos_to_v2.py @@ -0,0 +1,417 @@ +# Modified from mmdetection3d/tools/dataset_converters /update_infos_to_v2.py +"""Convert the annotation pkl to the standard format in OpenMMLab V2.0. + +Example: + python tools/dataset_converters/update_infos_to_v2.py + --dataset kitti + --pkl-path ./data/kitti/kitti_infos_train.pkl + --out-dir ./kitti_v2/ +""" + +import argparse +import time +from os import path as osp +from pathlib import Path + +import mmengine + +def get_empty_instance(): + """Empty annotation for single instance.""" + instance = dict( + # (list[float], required): list of 4 numbers representing + # the bounding box of the instance, in (x1, y1, x2, y2) order. + bbox=None, + # (int, required): an integer in the range + # [0, num_categories-1] representing the category label. + bbox_label=None, + # (list[float], optional): list of 7 (or 9) numbers representing + # the 3D bounding box of the instance, + # in [x, y, z, w, h, l, yaw] + # (or [x, y, z, w, h, l, yaw, vx, vy]) order. + bbox_3d=None, + # (bool, optional): Whether to use the + # 3D bounding box during training. + bbox_3d_isvalid=None, + # (int, optional): 3D category label + # (typically the same as label). + bbox_label_3d=None, + # (float, optional): Projected center depth of the + # 3D bounding box compared to the image plane. + depth=None, + # (list[float], optional): Projected + # 2D center of the 3D bounding box. + center_2d=None, + # (int, optional): Attribute labels + # (fine-grained labels such as stopping, moving, ignore, crowd). + attr_label=None, + # (int, optional): The number of LiDAR + # points in the 3D bounding box. + num_lidar_pts=None, + # (int, optional): The number of Radar + # points in the 3D bounding box. + num_radar_pts=None, + # (int, optional): Difficulty level of + # detecting the 3D bounding box. + difficulty=None, + unaligned_bbox_3d=None) + return instance + +def get_empty_lidar_points(): + lidar_points = dict( + # (int, optional) : Number of features for each point. + num_pts_feats=None, + # (str, optional): Path of LiDAR data file. + lidar_path=None, + # (list[list[float]], optional): Transformation matrix + # from lidar to ego-vehicle + # with shape [4, 4]. + # (Referenced camera coordinate system is ego in KITTI.) + lidar2ego=None, + ) + return lidar_points + + +def get_empty_radar_points(): + radar_points = dict( + # (int, optional) : Number of features for each point. + num_pts_feats=None, + # (str, optional): Path of RADAR data file. + radar_path=None, + # Transformation matrix from lidar to + # ego-vehicle with shape [4, 4]. + # (Referenced camera coordinate system is ego in KITTI.) + radar2ego=None, + ) + return radar_points + +def get_empty_img_info(): + img_info = dict( + # (str, required): the path to the image file. + img_path=None, + # (int) The height of the image. + height=None, + # (int) The width of the image. + width=None, + # (str, optional): Path of the depth map file + depth_map=None, + # (list[list[float]], optional) : Transformation + # matrix from camera to image with + # shape [3, 3], [3, 4] or [4, 4]. + cam2img=None, + # (list[list[float]]): Transformation matrix from lidar + # or depth to image with shape [4, 4]. + lidar2img=None, + # (list[list[float]], optional) : Transformation + # matrix from camera to ego-vehicle + # with shape [4, 4]. + cam2ego=None) + return img_info + +def get_single_image_sweep(camera_types): + single_image_sweep = dict( + # (float, optional) : Timestamp of the current frame. + timestamp=None, + # (list[list[float]], optional) : Transformation matrix + # from ego-vehicle to the global + ego2global=None) + # (dict): Information of images captured by multiple cameras + images = dict() + for cam_type in camera_types: + images[cam_type] = get_empty_img_info() + single_image_sweep['images'] = images + return single_image_sweep + +def get_empty_standard_data_info( + camera_types=['CAM0', 'CAM1', 'CAM2', 'CAM3', 'CAM4']): + + data_info = dict( + # (str): Sample id of the frame. + sample_idx=None, + # (str, optional): '000010' + token=None, + **get_single_image_sweep(camera_types), + # (dict, optional): dict contains information + # of LiDAR point cloud frame. + lidar_points=get_empty_lidar_points(), + # (dict, optional) Each dict contains + # information of Radar point cloud frame. + radar_points=get_empty_radar_points(), + # (list[dict], optional): Image sweeps data. + image_sweeps=[], + lidar_sweeps=[], + instances=[], + # (list[dict], optional): Required by object + # detection, instance to be ignored during training. + instances_ignore=[], + # (str, optional): Path of semantic labels for each point. + pts_semantic_mask_path=None, + # (str, optional): Path of instance labels for each point. + pts_instance_mask_path=None) + return data_info + + +def clear_instance_unused_keys(instance): + keys = list(instance.keys()) + for k in keys: + if instance[k] is None: + del instance[k] + return instance + + +def clear_data_info_unused_keys(data_info): + keys = list(data_info.keys()) + empty_flag = True + for key in keys: + # we allow no annotations in datainfo + if key in ['instances', 'cam_sync_instances', 'cam_instances']: + empty_flag = False + continue + if isinstance(data_info[key], list): + if len(data_info[key]) == 0: + del data_info[key] + else: + empty_flag = False + elif data_info[key] is None: + del data_info[key] + elif isinstance(data_info[key], dict): + _, sub_empty_flag = clear_data_info_unused_keys(data_info[key]) + if sub_empty_flag is False: + empty_flag = False + else: + # sub field is empty + del data_info[key] + else: + empty_flag = False + + return data_info, empty_flag + +def update_scannet_infos(pkl_path, out_dir): + print(f'{pkl_path} will be modified.') + if out_dir in pkl_path: + print(f'Warning, you may overwriting ' + f'the original data {pkl_path}.') + time.sleep(5) + METAINFO = { + 'classes': + ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', + 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin') + } + print(f'Reading from input file: {pkl_path}.') + data_list = mmengine.load(pkl_path) + print('Start updating:') + converted_list = [] + for ori_info_dict in mmengine.track_iter_progress(data_list): + temp_data_info = get_empty_standard_data_info() + temp_data_info['lidar_points']['num_pts_feats'] = ori_info_dict[ + 'point_cloud']['num_features'] + temp_data_info['lidar_points']['lidar_path'] = Path( + ori_info_dict['pts_path']).name + if 'pts_semantic_mask_path' in ori_info_dict: + temp_data_info['pts_semantic_mask_path'] = Path( + ori_info_dict['pts_semantic_mask_path']).name + if 'pts_instance_mask_path' in ori_info_dict: + temp_data_info['pts_instance_mask_path'] = Path( + ori_info_dict['pts_instance_mask_path']).name + if 'super_pts_path' in ori_info_dict: + temp_data_info['super_pts_path'] = Path( + ori_info_dict['super_pts_path']).name + + # TODO support camera + # np.linalg.inv(info['axis_align_matrix'] @ extrinsic): depth2cam + anns = ori_info_dict.get('annos', None) + ignore_class_name = set() + if anns is not None: + temp_data_info['axis_align_matrix'] = anns[ + 'axis_align_matrix'].tolist() + if anns['gt_num'] == 0: + instance_list = [] + else: + num_instances = len(anns['name']) + instance_list = [] + for instance_id in range(num_instances): + empty_instance = get_empty_instance() + empty_instance['bbox_3d'] = anns['gt_boxes_upright_depth'][ + instance_id].tolist() + + if anns['name'][instance_id] in METAINFO['classes']: + empty_instance['bbox_label_3d'] = METAINFO[ + 'classes'].index(anns['name'][instance_id]) + else: + ignore_class_name.add(anns['name'][instance_id]) + empty_instance['bbox_label_3d'] = -1 + + empty_instance = clear_instance_unused_keys(empty_instance) + instance_list.append(empty_instance) + temp_data_info['instances'] = instance_list + temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) + converted_list.append(temp_data_info) + pkl_name = Path(pkl_path).name + out_path = osp.join(out_dir, pkl_name) + print(f'Writing to output file: {out_path}.') + print(f'ignore classes: {ignore_class_name}') + + # dataset metainfo + metainfo = dict() + metainfo['categories'] = {k: i for i, k in enumerate(METAINFO['classes'])} + if ignore_class_name: + for ignore_class in ignore_class_name: + metainfo['categories'][ignore_class] = -1 + metainfo['dataset'] = 'scannet' + metainfo['info_version'] = '1.1' + + converted_data_info = dict(metainfo=metainfo, data_list=converted_list) + + mmengine.dump(converted_data_info, out_path, 'pkl') + +def update_scannet200_infos(pkl_path, out_dir): + print(f'{pkl_path} will be modified.') + if out_dir in pkl_path: + print(f'Warning, you may overwriting ' + f'the original data {pkl_path}.') + time.sleep(5) + METAINFO = { + 'classes': + ('chair', 'table', 'door', 'couch', 'cabinet', 'shelf', 'desk', + 'office chair', 'bed', 'pillow', 'sink', 'picture', 'window', + 'toilet', 'bookshelf', 'monitor', 'curtain', 'book', 'armchair', + 'coffee table', 'box', 'refrigerator', 'lamp', 'kitchen cabinet', + 'towel', 'clothes', 'tv', 'nightstand', 'counter', 'dresser', 'stool', + 'cushion', 'plant', 'ceiling', 'bathtub', 'end table', 'dining table', + 'keyboard', 'bag', 'backpack', 'toilet paper', 'printer', 'tv stand', + 'whiteboard', 'blanket', 'shower curtain', 'trash can', 'closet', + 'stairs', 'microwave', 'stove', 'shoe', 'computer tower', 'bottle', + 'bin', 'ottoman', 'bench', 'board', 'washing machine', 'mirror', + 'copier', 'basket', 'sofa chair', 'file cabinet', 'fan', 'laptop', + 'shower', 'paper', 'person', 'paper towel dispenser', 'oven', + 'blinds', 'rack', 'plate', 'blackboard', 'piano', 'suitcase', 'rail', + 'radiator', 'recycling bin', 'container', 'wardrobe', + 'soap dispenser', 'telephone', 'bucket', 'clock', 'stand', 'light', + 'laundry basket', 'pipe', 'clothes dryer', 'guitar', + 'toilet paper holder', 'seat', 'speaker', 'column', 'bicycle', + 'ladder', 'bathroom stall', 'shower wall', 'cup', 'jacket', + 'storage bin', 'coffee maker', 'dishwasher', 'paper towel roll', + 'machine', 'mat', 'windowsill', 'bar', 'toaster', 'bulletin board', + 'ironing board', 'fireplace', 'soap dish', 'kitchen counter', + 'doorframe', 'toilet paper dispenser', 'mini fridge', + 'fire extinguisher', 'ball', 'hat', 'shower curtain rod', + 'water cooler', 'paper cutter', 'tray', 'shower door', 'pillar', + 'ledge', 'toaster oven', 'mouse', 'toilet seat cover dispenser', + 'furniture', 'cart', 'storage container', 'scale', 'tissue box', + 'light switch', 'crate', 'power outlet', 'decoration', 'sign', + 'projector', 'closet door', 'vacuum cleaner', 'candle', 'plunger', + 'stuffed animal', 'headphones', 'dish rack', 'broom', 'guitar case', + 'range hood', 'dustpan', 'hair dryer', 'water bottle', 'handicap bar', + 'purse', 'vent', 'shower floor', 'water pitcher', 'mailbox', 'bowl', + 'paper bag', 'alarm clock', 'music stand', 'projector screen', + 'divider', 'laundry detergent', 'bathroom counter', 'object', + 'bathroom vanity', 'closet wall', 'laundry hamper', + 'bathroom stall door', 'ceiling light', 'trash bin', 'dumbbell', + 'stair rail', 'tube', 'bathroom cabinet', 'cd case', 'closet rod', + 'coffee kettle', 'structure', 'shower head', 'keyboard piano', + 'case of water bottles', 'coat rack', 'storage organizer', + 'folded chair', 'fire alarm', 'power strip', 'calendar', 'poster', + 'potted plant', 'luggage', 'mattress') + } + print(f'Reading from input file: {pkl_path}.') + data_list = mmengine.load(pkl_path) + print('Start updating:') + converted_list = [] + for ori_info_dict in mmengine.track_iter_progress(data_list): + temp_data_info = get_empty_standard_data_info() + temp_data_info['lidar_points']['num_pts_feats'] = ori_info_dict[ + 'point_cloud']['num_features'] + temp_data_info['lidar_points']['lidar_path'] = Path( + ori_info_dict['pts_path']).name + if 'pts_semantic_mask_path' in ori_info_dict: + temp_data_info['pts_semantic_mask_path'] = Path( + ori_info_dict['pts_semantic_mask_path']).name + if 'pts_instance_mask_path' in ori_info_dict: + temp_data_info['pts_instance_mask_path'] = Path( + ori_info_dict['pts_instance_mask_path']).name + if 'super_pts_path' in ori_info_dict: + temp_data_info['super_pts_path'] = Path( + ori_info_dict['super_pts_path']).name + + # TODO support camera + # np.linalg.inv(info['axis_align_matrix'] @ extrinsic): depth2cam + anns = ori_info_dict.get('annos', None) + ignore_class_name = set() + if anns is not None: + temp_data_info['axis_align_matrix'] = anns[ + 'axis_align_matrix'].tolist() + if anns['gt_num'] == 0: + instance_list = [] + else: + num_instances = len(anns['name']) + instance_list = [] + for instance_id in range(num_instances): + empty_instance = get_empty_instance() + empty_instance['bbox_3d'] = anns['gt_boxes_upright_depth'][ + instance_id].tolist() + + if anns['name'][instance_id] in METAINFO['classes']: + empty_instance['bbox_label_3d'] = METAINFO[ + 'classes'].index(anns['name'][instance_id]) + else: + ignore_class_name.add(anns['name'][instance_id]) + empty_instance['bbox_label_3d'] = -1 + + empty_instance = clear_instance_unused_keys(empty_instance) + instance_list.append(empty_instance) + temp_data_info['instances'] = instance_list + temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) + converted_list.append(temp_data_info) + pkl_name = Path(pkl_path).name + out_path = osp.join(out_dir, pkl_name) + print(f'Writing to output file: {out_path}.') + print(f'ignore classes: {ignore_class_name}') + + # dataset metainfo + metainfo = dict() + metainfo['categories'] = {k: i for i, k in enumerate(METAINFO['classes'])} + if ignore_class_name: + for ignore_class in ignore_class_name: + metainfo['categories'][ignore_class] = -1 + metainfo['dataset'] = 'scannet200' + metainfo['info_version'] = '1.1' + + converted_data_info = dict(metainfo=metainfo, data_list=converted_list) + + mmengine.dump(converted_data_info, out_path, 'pkl') + +def parse_args(): + parser = argparse.ArgumentParser(description='Arg parser for data coords ' + 'update due to coords sys refactor.') + parser.add_argument( + '--dataset', type=str, default='kitti', help='name of dataset') + parser.add_argument( + '--pkl-path', + type=str, + default='./data/kitti/kitti_infos_train.pkl ', + help='specify the root dir of dataset') + parser.add_argument( + '--out-dir', + type=str, + default='converted_annotations', + required=False, + help='output direction of info pkl') + args = parser.parse_args() + return args + + +def update_pkl_infos(dataset, out_dir, pkl_path): + if dataset.lower() == 'scannet': + update_scannet_infos(pkl_path=pkl_path, out_dir=out_dir) + elif dataset.lower() == 'scannet200': + update_scannet200_infos(pkl_path=pkl_path, out_dir=out_dir) + else: + raise NotImplementedError(f'Do not support convert {dataset} to v2.') + + +if __name__ == '__main__': + args = parse_args() + if args.out_dir is None: + args.out_dir = args.root_dir + update_pkl_infos( + dataset=args.dataset, out_dir=args.out_dir, pkl_path=args.pkl_path)