diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..719cb0bf4d63f666856bae3bb3551d44c2d5fcec
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+data/ikun/reference_images/wt.jpg filter=lfs diff=lfs merge=lfs -text
+data/motorbike/reference_images/pink-motor.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..be6814a6026ea3c342bd0fe0cc35c01ba63cc466
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,165 @@
+checkpoints/
+experiments/
+
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+.DS_Store
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..74a6f299d51d145387ae5ab220e9730719614111
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 hysts
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Make-A-Protagonist/.gitignore b/Make-A-Protagonist/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b69727e7e4ffae8230f962072f430efebea7a5a6
--- /dev/null
+++ b/Make-A-Protagonist/.gitignore
@@ -0,0 +1,5 @@
+outputs/
+data_images/
+__pycache__
+.idea/
+.DS_Store
diff --git a/Make-A-Protagonist/LICENSE b/Make-A-Protagonist/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..897920c676c0d4a53071d717da9ada09eeee671f
--- /dev/null
+++ b/Make-A-Protagonist/LICENSE
@@ -0,0 +1,191 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2023 Yuyang Zhao
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.he Software.
+
diff --git a/Make-A-Protagonist/README.md b/Make-A-Protagonist/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3b79d8a133d8df68a4d8f26e0cc66debd3e26881
--- /dev/null
+++ b/Make-A-Protagonist/README.md
@@ -0,0 +1,191 @@
+# Make-A-Protagonist
+
+This repository is the official implementation of **Make-A-Protagonist**.
+
+**[Make-A-Protagonist: Generic Video Editing with An Ensemble of Experts](https://arxiv.org/abs/2305.08850)**
+
+[Yuyang Zhao](https://yuyangzhao.com), [Enze Xie](https://xieenze.github.io/), [Lanqing Hong](https://scholar.google.com.sg/citations?user=2p7x6OUAAAAJ&hl=en), [Zhenguo Li](https://scholar.google.com.sg/citations?user=XboZC1AAAAAJ&hl=en), [Gim Hee Lee](https://www.comp.nus.edu.sg/~leegh/)
+
+
+[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://make-a-protagonist.github.io/) [![arXiv](https://img.shields.io/badge/arXiv-2305.08850-b31b1b.svg)](https://arxiv.org/abs/2305.08850)
+
+
+
+
+
+The first framework for generic video editing with both visual and textual clues.
+
+
+
+## Abstract
+> The text-driven image and video diffusion models have achieved unprecedented success in generating realistic and diverse content. Recently, the editing and variation of existing images and videos in diffusion-based generative models have garnered significant attention. However, previous works are limited to editing content with text or providing coarse personalization using a single visual clue, rendering them unsuitable for indescribable content that requires fine-grained and detailed control. In this regard, we propose a generic video editing framework called Make-A-Protagonist, which utilizes textual and visual clues to edit videos with the goal of empowering individuals to become the protagonists. Specifically, we leverage multiple experts to parse source video, target visual and textual clues, and propose a visual-textual-based video generation model that employs mask-guided denoising sampling to generate the desired output. Extensive results demonstrate the versatile and remarkable editing capabilities of Make-A-Protagonist.
+
+## News
+- [16/05/2023] Code released!
+
+### Todo
+- [ ] Release training code for ControlNet UnCLIP Small
+- [ ] Release inference demo
+
+
+## Setup
+
+### Requirements
+- Python 3.9 and Pytorch 1.13.1
+- xformers 0.0.17
+- Other packages in `requirements.txt`
+- Build GroundedSAM expert
+```bash
+cd experts/GroundedSAM
+python -m pip install -e GroundingDINO
+python -m pip install -e segment_anything
+```
+
+### Weights
+
+The following weights from HuggingFace are used in this project. You can download them into `checkpoints` or load them from HuggingFace repo.
+- [Stable Diffusion UnCLIP Small](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip-small)
+- [BLIP-2 Flan T5-xL](https://huggingface.co/Salesforce/blip2-flan-t5-xl)
+- [CLIP ViT-L](https://huggingface.co/openai/clip-vit-large-patch14)
+- [DALL-E 2 Prior](https://huggingface.co/kakaobrain/karlo-v1-alpha)
+
+ControlNet for Stable Diffusion UnCLIP Small should be downloaded manually into `checkpoints`:
+- [ControlNet UnCLIP Small](https://huggingface.co/Make-A-Protagonist/Make-A-Protagonist/tree/main)
+
+The code for training these models will be released soon.
+
+Pre-trained model for other experts should be downloaded manually into `checkpoints`:
+- [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) `wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth`
+- [Segment Anything](https://github.com/facebookresearch/segment-anything) `wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth`
+- [XMem](https://github.com/hkchengrex/XMem) `wget https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth`
+
+
+
+## Usage
+
+### Data Preprocess
+
+#### Source Video Parsing
+
+**Captioning and VQA**:
+```bash
+python experts/blip_inference.py -d data//images
+```
+
+**Protagonist Segmentation**:
+
+- Frame segmentation with GroundedSAM
+```bash
+python experts/grounded_sam_inference.py -d data//images/0000.jpg -t
+```
+
+- Video object segmentation through the video
+```bash
+python experts/xmem_inference.py -d data//images -v --mask_dir .mask
+```
+
+**Control Signals Extraction**:
+```bash
+python experts/controlnet_signal_extraction.py -d data//images -c
+```
+Currently we only support two types of control signals: depth and openposefull.
+
+#### Visual Clue Parsing
+
+**Reference Protagonist Segmentation**:
+```bash
+python experts/grounded_sam_inference.py -d data//reference_images -t --masked_out
+```
+
+### Training
+
+To fine-tune the text-to-image diffusion models with visual and textual clues, run this command:
+
+```bash
+python train.py --config="configs//train.yaml"
+```
+
+Note: At least 24 GB is requires to train the model.
+
+### Inference
+
+Once the training is done, run inference:
+
+```bash
+python eval.py --config="configs//eval.yaml"
+```
+**Applications**: Three applications are supported by Make-A-Protagonist, which can be achieved by modifying the inference configuration file.
+- Protagonist Editing: `source_protagonist: true`
+- Background Editing: `source_background: true`
+- Text-to-Video Editing with Protagonist: `source_protagonist: false & source_background: false`
+
+## Results
+
+
+
+ Input Video |
+ Reference Image |
+ Generated Video |
+
+
+ |
+ |
+ |
+
+
+ "A man walking down the street" |
+ |
+ "A panda walking down the snowy street" |
+
+
+
+ |
+ |
+ |
+
+
+ "A man playing basketball" |
+ |
+ "A man playing basketball on the beach, anime style" |
+
+
+
+ |
+ |
+ |
+
+
+ "A man walking down the street" |
+ |
+ "Elon Musk walking down the street" |
+
+
+
+ |
+ |
+ |
+
+
+ "A Suzuki Jimny driving down a mountain road" |
+ |
+ "A Suzuki Jimny driving down a mountain road in the rain" |
+
+
+
+
+
+
+## Citation
+If you make use of our work, please cite our paper.
+```bibtex
+@article{zhao2023makeaprotagonist,
+ title={Make-A-Protagonist: Generic Video Editing with An Ensemble of Experts},
+ author={Zhao, Yuyang and Xie, Enze and Hong, Lanqing and Li, Zhenguo and Lee, Gim Hee},
+ journal={arXiv preprint arXiv:2305.08850},
+ year={2023}
+}
+```
+
+## Acknowledgements
+
+This code is heavily derived from [diffusers](https://github.com/huggingface/diffusers) and [Tune-A-Video](https://github.com/showlab/Tune-A-Video). If you use this code in your research, please also acknowledge their work.
diff --git a/Make-A-Protagonist/configs/car-turn.yaml b/Make-A-Protagonist/configs/car-turn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6edaedfc31da4d4b759182a105d4723d72004510
--- /dev/null
+++ b/Make-A-Protagonist/configs/car-turn.yaml
@@ -0,0 +1,13 @@
+video_dir: "data/car-turn"
+prompt: "a suzuki jimny driving down a mountain road"
+n_sample_frames: 8
+width: 768
+height: 768
+sample_start_idx: 0
+sample_frame_rate: 1
+condition: [openposefull, depth]
+video_suffix: .jpg
+condition_suffix: .png
+noise_level: 10000
+image_embed_drop: 0.1
+mask_dir: suzuki-jimny.mask
diff --git a/Make-A-Protagonist/configs/car-turn/eval.yaml b/Make-A-Protagonist/configs/car-turn/eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d086a89e9a2b1e10d8a8ee5d27361b2beeb25590
--- /dev/null
+++ b/Make-A-Protagonist/configs/car-turn/eval.yaml
@@ -0,0 +1,63 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-car-turn"
+resume_from_checkpoint: "outputs/car-turn/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/car-turn"
+ prompt: "a suzuki jimny driving down a mountain road"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: suzuki-jimny.mask
+
+validation_data:
+ prompts:
+ - "a suzuki jimny driving down a mountain road in the rain"
+ - "a suzuki jimny driving down a mountain road in the rain"
+
+ ref_images:
+ - "data/car-turn/images/0000.jpg"
+ - "data/car-turn/images/0000.jpg"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: true # using source protagonist and changing the background
+ controlnet_conditioning_scale: 1.0
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/car-turn/train.yaml b/Make-A-Protagonist/configs/car-turn/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bb843828b6c3d50370a927533842a40062a8a1f2
--- /dev/null
+++ b/Make-A-Protagonist/configs/car-turn/train.yaml
@@ -0,0 +1,62 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/car-turn"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/car-turn"
+ prompt: "a suzuki jimny driving down a mountain road"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: suzuki-jimny.mask
+
+validation_data:
+ prompts:
+ - "a suzuki jimny driving down a mountain road in the rain"
+ - "a suzuki jimny driving down a mountain road in the rain"
+
+ ref_images:
+ - "data/car-turn/images/0000.jpg"
+ - "data/car-turn/images/0000.jpg"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: true # using source protagonist and changing the background
+ controlnet_conditioning_scale: 1.0
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 200
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/huaqiang.yaml b/Make-A-Protagonist/configs/huaqiang.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4115ff54d094583f3fd43591b56f1257e3a0075a
--- /dev/null
+++ b/Make-A-Protagonist/configs/huaqiang.yaml
@@ -0,0 +1,13 @@
+video_dir: "data/huaqiang"
+prompt: "a man walking down the street"
+n_sample_frames: 8
+width: 768
+height: 768
+sample_start_idx: 0
+sample_frame_rate: 1
+condition: [openposefull, depth]
+video_suffix: .jpg
+condition_suffix: .png
+noise_level: 10000
+image_embed_drop: 0.1
+mask_dir: man.mask
diff --git a/Make-A-Protagonist/configs/huaqiang/eval.yaml b/Make-A-Protagonist/configs/huaqiang/eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5c395a8c161c6c3df52fad78f5d2f5f6f9700680
--- /dev/null
+++ b/Make-A-Protagonist/configs/huaqiang/eval.yaml
@@ -0,0 +1,61 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-huaqiang"
+resume_from_checkpoint: "outputs/huaqiang/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/huaqiang"
+ prompt: "a man walking down the street"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "elon musk walking down the street"
+
+ ref_images:
+ - "data/huaqiang/masked_musk.png"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: true # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/huaqiang/train.yaml b/Make-A-Protagonist/configs/huaqiang/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7fe20f764a0d62515b43dcf9904740c40836a260
--- /dev/null
+++ b/Make-A-Protagonist/configs/huaqiang/train.yaml
@@ -0,0 +1,60 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/huaqiang"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/huaqiang"
+ prompt: "a man walking down the street"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "elon musk walking down the street"
+
+ ref_images:
+ - "data/huaqiang/masked_musk.png"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: true # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 200
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/ikun.yaml b/Make-A-Protagonist/configs/ikun.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb579532189c2635adaaf57890e6c8a4509235d2
--- /dev/null
+++ b/Make-A-Protagonist/configs/ikun.yaml
@@ -0,0 +1,14 @@
+video_dir: "data/ikun"
+prompt: "A man is playing basketball"
+n_sample_frames: 8
+width: 768
+height: 768
+sample_start_idx: 0
+sample_frame_rate: 1
+condition: [openposefull, depth]
+video_suffix: .jpg
+condition_suffix: .png
+noise_level: 10000
+image_embed_drop: 0.1
+mask_dir: man.mask
+
diff --git a/Make-A-Protagonist/configs/ikun/eval-background.yaml b/Make-A-Protagonist/configs/ikun/eval-background.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6a287b6938f9f4790c549759f25c7b10bb7163d8
--- /dev/null
+++ b/Make-A-Protagonist/configs/ikun/eval-background.yaml
@@ -0,0 +1,66 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-ikun-background"
+resume_from_checkpoint: "outputs/ikun/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/ikun"
+ prompt: "A man is playing basketball"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "A man is dribbling a basketball in the forest"
+ - "A man is dribbling a basketball in the forest"
+
+ ref_images:
+ - "data/ikun/images/0000.jpg"
+ - "data/ikun/images/0000.jpg"
+
+
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: true # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/ikun/eval-both.yaml b/Make-A-Protagonist/configs/ikun/eval-both.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4a79ea839bf2697261e693f90e5cabdd2cacb8b
--- /dev/null
+++ b/Make-A-Protagonist/configs/ikun/eval-both.yaml
@@ -0,0 +1,68 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-ikun-both"
+resume_from_checkpoint: "outputs/ikun/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/ikun"
+ prompt: "A man is playing basketball"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "A man is playing a basketball on the beach, anime style"
+ - "A man is playing a basketball on the beach, anime style"
+
+
+ ref_images:
+ - "data/ikun/masked_zhongli.png"
+ - "data/ikun/masked_zhongli.png"
+
+
+
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/ikun/eval-protagonist.yaml b/Make-A-Protagonist/configs/ikun/eval-protagonist.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..54d99c74c12313a96e6fa2d47813f08750c263f9
--- /dev/null
+++ b/Make-A-Protagonist/configs/ikun/eval-protagonist.yaml
@@ -0,0 +1,62 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-ikun-protagonist"
+resume_from_checkpoint: "outputs/ikun/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/ikun"
+ prompt: "A man is playing basketball"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "A man is playing basketball"
+
+ ref_images:
+ - "data/ikun/masked_wt.png"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: true # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/ikun/train.yaml b/Make-A-Protagonist/configs/ikun/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b4fbfb46916740021ccdb5ee0fc8b3910bd50a0f
--- /dev/null
+++ b/Make-A-Protagonist/configs/ikun/train.yaml
@@ -0,0 +1,64 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/ikun"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+train_data:
+ video_dir: "data/ikun"
+ prompt: "A man is playing basketball"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "A man is playing a basketball on the beach, anime style"
+ - "A man is playing a basketball on the beach, anime style"
+
+ ref_images:
+ - "data/ikun/masked_zhongli.png"
+ - "data/ikun/masked_zhongli.png"
+
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 200
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/yanzi.yaml b/Make-A-Protagonist/configs/yanzi.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..889b422e12889932d09fe60a170c6525e813cf6f
--- /dev/null
+++ b/Make-A-Protagonist/configs/yanzi.yaml
@@ -0,0 +1,13 @@
+video_dir: "data/yanzi"
+prompt: "a man walking down the street at night"
+n_sample_frames: 8
+width: 768
+height: 768
+sample_start_idx: 0
+sample_frame_rate: 1
+condition: [openposefull, depth]
+video_suffix: .jpg
+condition_suffix: .png
+noise_level: 10000
+image_embed_drop: 0.1
+mask_dir: man.mask
diff --git a/Make-A-Protagonist/configs/yanzi/eval.yaml b/Make-A-Protagonist/configs/yanzi/eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a03e8406889db140e8a4c9605c576db9b1eb5526
--- /dev/null
+++ b/Make-A-Protagonist/configs/yanzi/eval.yaml
@@ -0,0 +1,64 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/eval-yanzi"
+resume_from_checkpoint: "outputs/yanzi/checkpoint-200"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+
+train_data:
+ video_dir: "data/yanzi"
+ prompt: "a man walking down the street at night"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "a panda walking down the snowy street"
+ - "a panda walking down the snowy street"
+
+ ref_images:
+ - "data/yanzi/masked_panda.png"
+ - "data/yanzi/masked_panda.png"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 500
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/configs/yanzi/train.yaml b/Make-A-Protagonist/configs/yanzi/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..982fbc3f0dba20d58b05007ac87a1b0ebeeac1a5
--- /dev/null
+++ b/Make-A-Protagonist/configs/yanzi/train.yaml
@@ -0,0 +1,63 @@
+pretrained_model_path: "./checkpoints/stable-diffusion-2-1-unclip-small"
+output_dir: "./outputs/yanzi"
+controlnet_pretrained_model_path: [checkpoints/controlnet-2-1-unclip-small-openposefull, checkpoints/controlnet-2-1-unclip-small-depth]
+use_temporal_conv: True
+
+
+train_data:
+ video_dir: "data/yanzi"
+ prompt: "a man walking down the street at night"
+ n_sample_frames: 8
+ width: 768
+ height: 768
+ sample_start_idx: 0
+ sample_frame_rate: 1
+ condition: [openposefull, depth]
+ video_suffix: .jpg
+ condition_suffix: .png
+ noise_level: 10000
+ image_embed_drop: 0.1
+ mask_dir: man.mask
+
+validation_data:
+ prompts:
+ - "a panda walking down the snowy street"
+ - "a panda walking down the snowy street"
+
+ ref_images:
+ - "data/yanzi/masked_panda.png"
+ - "data/yanzi/masked_panda.png"
+
+ video_length: 8 # 24
+ width: 768
+ height: 768
+ num_inference_steps: 50
+ guidance_scale: 12.5
+ use_inv_latent: True
+ num_inv_steps: 50 #50
+ noise_level: 0
+ interpolate_embed_weight: 1.0 ## 1.0 means all use image embedding
+ use_masks: true
+ start_step: 0 ## start to use mask
+ end_step: 50 ## end to use mask
+ mask_mode: all # mask_mode: emb / latent / all
+ mask_latent_fuse_mode: all # inverse or all
+ source_background: false # using source background and changing the protagonist
+ source_protagonist: false # using source protagonist and changing the background
+ controlnet_conditioning_scale: [.5, .5]
+
+learning_rate: 3e-5
+train_batch_size: 1
+max_train_steps: 200
+checkpointing_steps: 200
+validation_steps: 200
+trainable_modules:
+ - "attn1.to_q"
+ - "attn2.to_q"
+ - "attn_temp"
+
+seed: 33
+mixed_precision: fp16
+use_8bit_adam: False
+gradient_checkpointing: True
+enable_xformers_memory_efficient_attention: True
diff --git a/Make-A-Protagonist/eval.py b/Make-A-Protagonist/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f1310a62b91392ba4aa205b21e916be894d3bdc
--- /dev/null
+++ b/Make-A-Protagonist/eval.py
@@ -0,0 +1,368 @@
+import argparse
+import datetime
+import logging
+import inspect
+import math
+import os
+from typing import Dict, Optional, Tuple
+from omegaconf import OmegaConf
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import numpy as np
+from PIL import Image
+
+import diffusers
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, PNDMScheduler, ControlNetModel, PriorTransformer, UnCLIPScheduler
+from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
+
+from makeaprotagonist.models.unet import UNet3DConditionModel
+from makeaprotagonist.dataset.dataset import MakeAProtagonistDataset
+from makeaprotagonist.pipelines.pipeline_stable_unclip_controlavideo import MakeAProtagonistStableUnCLIPPipeline, MultiControlNetModel
+from makeaprotagonist.util import save_videos_grid, ddim_inversion_unclip, ddim_inversion_prior
+from einops import rearrange
+from makeaprotagonist.args_util import DictAction, config_merge_dict
+import ipdb
+import random
+from glob import glob
+import sys
+
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.15.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def main(
+ pretrained_model_path: str,
+ controlnet_pretrained_model_path: str,
+ output_dir: str,
+ train_data: Dict,
+ validation_data: Dict,
+ validation_steps: int = 100,
+ trainable_modules: Tuple[str] = (
+ "attn1.to_q",
+ "attn2.to_q",
+ "attn_temp",
+ ),
+ trainable_params: Tuple[str] = (),
+ train_batch_size: int = 1,
+ max_train_steps: int = 500,
+ learning_rate: float = 3e-5,
+ scale_lr: bool = False,
+ lr_scheduler: str = "constant",
+ lr_warmup_steps: int = 0,
+ adam_beta1: float = 0.9,
+ adam_beta2: float = 0.999,
+ adam_weight_decay: float = 1e-2,
+ adam_epsilon: float = 1e-08,
+ max_grad_norm: float = 1.0,
+ gradient_accumulation_steps: int = 1,
+ gradient_checkpointing: bool = True,
+ checkpointing_steps: int = 500,
+ resume_from_checkpoint: Optional[str] = None,
+ mixed_precision: Optional[str] = "fp16",
+ use_8bit_adam: bool = False,
+ enable_xformers_memory_efficient_attention: bool = True,
+ seed: Optional[int] = None,
+ adapter_config=None, # the config for adapter
+ use_temporal_conv=False, ## use temporal conv in resblocks
+):
+ *_, config = inspect.getargvalues(inspect.currentframe())
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ mixed_precision=mixed_precision,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if seed is not None:
+ set_seed(seed)
+
+ # Handle the output folder creation
+ if accelerator.is_main_process:
+ # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ # output_dir = os.path.join(output_dir, now)
+ os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
+ os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+ prior_model_id = "kakaobrain/karlo-v1-alpha"
+ data_type = torch.float16
+ prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)
+
+ prior_text_model_id = "openai/clip-vit-large-patch14"
+ prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)
+ prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)
+ prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler")
+ prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+
+
+ # image encoding components
+ feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
+ # image noising components
+ image_normalizer = StableUnCLIPImageNormalizer.from_pretrained(pretrained_model_path, subfolder="image_normalizer")
+ image_noising_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="image_noising_scheduler")
+ # regular denoising components
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_temporal_conv=use_temporal_conv)
+
+
+ # vae
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+ ## controlnet
+ assert not isinstance(controlnet_pretrained_model_path, str)
+ controlnet = MultiControlNetModel( [ControlNetModel.from_pretrained(_control_model_path) for _control_model_path in controlnet_pretrained_model_path] )
+
+ # Freeze vae and text_encoder and adapter
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ ## freeze image embed
+ image_encoder.requires_grad_(False)
+
+ unet.requires_grad_(False)
+ ## freeze controlnet
+ controlnet.requires_grad_(False)
+
+ ## freeze prior
+ prior.requires_grad_(False)
+ prior_text_model.requires_grad_(False)
+
+
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if scale_lr:
+ learning_rate = (
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
+ )
+
+ # Get the training dataset
+ train_dataset = MakeAProtagonistDataset(**train_data)
+
+ # Preprocessing the dataset
+ train_dataset.prompt_ids = tokenizer(
+ train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids[0]
+
+ train_dataset.preprocess_img_embedding(feature_extractor, image_encoder)
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=train_batch_size, num_workers=0,
+ )
+
+ prior_val_scheduler = DDIMScheduler.from_config(prior_scheduler.config) if validation_data.get("prior_val_scheduler", "") == "DDIM" else prior_scheduler
+ # ipdb.set_trace()
+ validation_pipeline = MakeAProtagonistStableUnCLIPPipeline(
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_val_scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ )
+
+
+ validation_pipeline.enable_vae_slicing()
+ ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
+ ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
+
+ ddim_inv_prior_scheduler = None
+ if validation_data.get("use_prior_inv_latent", False):
+ ddim_inv_prior_scheduler = DDIMScheduler.from_config(prior_scheduler.config)
+ ddim_inv_prior_scheduler.set_timesteps(validation_data.prior_num_inv_steps)
+
+ unet, train_dataloader = accelerator.prepare(
+ unet, train_dataloader
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ ## note controlnet use the unet dtype
+ controlnet.to(accelerator.device, dtype=weight_dtype)
+ ## prior
+ prior.to(accelerator.device, dtype=weight_dtype)
+ prior_text_model.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2video-fine-tune")
+
+ global_step = 0
+ # Potentially load in the weights and states from a previous save
+ if resume_from_checkpoint:
+ ## resume_from_checkpoint is the path to the checkpoint-300 dir
+ accelerator.load_state(resume_from_checkpoint)
+ path = os.path.basename(resume_from_checkpoint)
+ global_step = int(path.split("-")[1])
+
+
+ if not "noise_level" in validation_data:
+ validation_data.noise_level = train_data.noise_level
+ if not "noise_level_inv" in validation_data:
+ validation_data.noise_level_inv = validation_data.noise_level
+ # Checks if the accelerator has performed an optimization step behind the scenes
+
+ if accelerator.is_main_process:
+
+ batch = next(iter(train_dataloader))
+
+ # ipdb.set_trace()
+ pixel_values = batch["pixel_values"].to(weight_dtype)
+ video_length = pixel_values.shape[1]
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+ latents = vae.encode(pixel_values).latent_dist.sample()
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+ latents = latents * vae.config.scaling_factor
+
+
+ # ControlNet
+ # ipdb.set_trace()
+ conditions = [_condition.to(weight_dtype) for _, _condition in batch["conditions"].items()] # b f c h w
+ masks = batch["masks"].to(weight_dtype) # b,f,1,h,w
+ # ipdb.set_trace()
+ if not validation_data.get("use_masks", False):
+ masks = torch.ones_like(masks)
+ # conditions = rearrange(conditions, "b f c h w -> (b f) c h w") ## here is rgb
+ ## NOTE in this pretrained model, the config is also rgb
+ ## https://huggingface.co/thibaud/controlnet-sd21-openpose-diffusers/blob/main/config.json
+
+ # ipdb.set_trace()
+ ddim_inv_latent = None
+ if validation_data.use_inv_latent: #
+ emb_dim = train_dataset.img_embeddings[0].size(0)
+ key_frame_embed = torch.zeros((1, emb_dim)).to(device=latents.device, dtype=latents.dtype) ## this is dim 0
+ ddim_inv_latent = ddim_inversion_unclip(
+ validation_pipeline, ddim_inv_scheduler, video_latent=latents,
+ num_inv_steps=validation_data.num_inv_steps, prompt="", image_embed=key_frame_embed, noise_level=validation_data.noise_level, seed=seed)[-1].to(weight_dtype)
+
+ set_noise = validation_data.pop("noise_level")
+ v_noise = set_noise
+
+ if not validation_data.get("interpolate_embed_weight", False):
+ validation_data.interpolate_embed_weight = 0
+
+
+ samples = []
+
+ generator = torch.Generator(device=accelerator.device)
+ generator.manual_seed(seed)
+
+ for idx, prompt in enumerate(validation_data.prompts):
+
+ _ref_image = Image.open(validation_data.ref_images[idx])
+ image_embed = None
+ ## prior latents
+ prior_embeds = None
+ prior_denoised_embeds = None
+ if validation_data.get("source_background", False):
+ ## using source background and changing the protagonist
+ prior_denoised_embeds = train_dataset.img_embeddings[0][None].to(device=latents.device, dtype=latents.dtype) # 1, 768 for UnCLIP-small
+
+ if validation_data.get("source_protagonist", False):
+ # using source protagonist and changing the background
+ sample_indices = batch["sample_indices"][0]
+ image_embed = [train_dataset.img_embeddings[idx] for idx in sample_indices]
+ image_embed = torch.stack(image_embed, dim=0).to(device=latents.device, dtype=latents.dtype) # F, 768 for UnCLIP-small # F,C
+ _ref_image = None
+
+ sample = validation_pipeline(image=_ref_image, prompt=prompt, control_image=conditions, generator=generator, latents=ddim_inv_latent, image_embeds=image_embed, noise_level=v_noise, masks=masks, prior_latents=prior_embeds, prior_denoised_embeds=prior_denoised_embeds, **validation_data).videos
+
+ save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}-seed{seed}/{idx}-{prompt}.gif")
+ samples.append(sample)
+
+ #
+ samples = [sample.float() for sample in samples]
+ samples = torch.concat(samples)
+ save_path = f"{output_dir}/samples/sample-{global_step}-s{validation_data.start_step}-e{validation_data.end_step}-seed{seed}.gif" # noise level and noise level for inv
+ save_videos_grid(samples, save_path, n_rows=len(samples))
+ logger.info(f"Saved samples to {save_path}")
+
+
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction, ##NOTE cannot support multi-level config change
+ help="--options is deprecated in favor of --cfg_options' and it will "
+ 'not be supported in version v0.22.0. Override some settings in the '
+ 'used config, the key-value pair in xxx=yyy format will be merged '
+ 'into config file. If the value to be overwritten is a list, it '
+ 'should be like key="[a,b]" or key=a,b It also allows nested '
+ 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
+ 'marks are necessary and that no white space is allowed.')
+
+ args = parser.parse_args()
+
+ ## read from cmd line
+ # ipdb.set_trace()
+ # Load the YAML configuration file
+ config = OmegaConf.load(args.config)
+ # Merge the command-line arguments with the configuration file
+ if args.options is not None:
+ # config = OmegaConf.merge(config, args.options)
+ config_merge_dict(args.options, config)
+
+ main(**config)
diff --git a/Make-A-Protagonist/experts/BLIP2/blip_video_model.py b/Make-A-Protagonist/experts/BLIP2/blip_video_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8d6f21128809cbb92f3eb11e51175b62c4a40e
--- /dev/null
+++ b/Make-A-Protagonist/experts/BLIP2/blip_video_model.py
@@ -0,0 +1,87 @@
+
+
+from typing import Any, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers.utils import logging
+
+from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGeneration
+import ipdb
+
+
+logger = logging.get_logger(__name__)
+
+class Blip2ForVideoConditionalGeneration(Blip2ForConditionalGeneration):
+
+ @torch.no_grad()
+ def generate(
+ self,
+ pixel_values: torch.FloatTensor,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ **generate_kwargs,
+ ) -> torch.LongTensor:
+ """
+ Overrides `generate` function to be able to use the model as a conditional generator.
+ Args:
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
+ Input images to be processed.
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ The sequence used as a prompt for the generation.
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
+ Mask to avoid performing attention on padding token indices
+ Returns:
+ captions (list): A list of strings of length batch_size * num_captions.
+ """
+ if hasattr(self, "hf_device_map"):
+ # preprocess for `accelerate`
+ self._preprocess_accelerate()
+
+ batch_size = pixel_values.shape[0]
+ image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
+ ## image_embeds B,257, 1408
+ ## NOTE the video should be concatenated here
+ ## NOTE only support one video now
+ image_embeds = image_embeds.reshape(1, -1, image_embeds.size(-1)) # 1, 257*B, C
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) # 1, 257*B
+ # ipdb.set_trace()
+ # self.query_tokens 1,32,768
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_outputs = self.qformer(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_attention_mask,
+ return_dict=True,
+ )
+ query_output = query_outputs.last_hidden_state # 1,32,768
+
+ language_model_inputs = self.language_projection(query_output)
+ language_attention_mask = torch.ones(
+ language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
+ )
+ if input_ids is None:
+ input_ids = (
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
+ .repeat(batch_size, 1)
+ .to(image_embeds.device)
+ )
+ ## NOTE only support one video now
+ input_ids = input_ids[:1] #
+
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+ attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
+
+ # concatenate query embeddings with prompt embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
+
+ outputs = self.language_model.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ **generate_kwargs,
+ )
+
+ return outputs
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/LICENSE b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b1395e94b016dd1b95b4c7e3ed493e1d0b342917
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020 - present, Facebook, Inc
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..358ac997bbef3c43f6c692f1cccbe6abad07c222
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/__init__.py
@@ -0,0 +1 @@
+from .groundingdino import *
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..f490c4bbd598a35de43d36ceafcbd769e7ff21bf
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py
@@ -0,0 +1,43 @@
+batch_size = 1
+modelname = "groundingdino"
+backbone = "swin_B_384_22k"
+position_embedding = "sine"
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+two_stage_type = "standard"
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+transformer_activation = "relu"
+dec_pred_bbox_embed_share = True
+dn_box_noise_scale = 1.0
+dn_label_noise_ratio = 0.5
+dn_label_coef = 1.0
+dn_bbox_coef = 1.0
+embed_init_tgt = True
+dn_labelbook_size = 2000
+max_text_len = 256
+text_encoder_type = "bert-base-uncased"
+use_text_enhancer = True
+use_fusion_layer = True
+use_checkpoint = True
+use_transformer_ckpt = True
+use_text_cross_attention = True
+text_dropout = 0.0
+fusion_dropout = 0.0
+fusion_droppath = 0.1
+sub_sentence_present = True
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
new file mode 100644
index 0000000000000000000000000000000000000000..9158d5f6260ec74bded95377d382387430d7cd70
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
@@ -0,0 +1,43 @@
+batch_size = 1
+modelname = "groundingdino"
+backbone = "swin_T_224_1k"
+position_embedding = "sine"
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+two_stage_type = "standard"
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+transformer_activation = "relu"
+dec_pred_bbox_embed_share = True
+dn_box_noise_scale = 1.0
+dn_label_noise_ratio = 0.5
+dn_label_coef = 1.0
+dn_bbox_coef = 1.0
+embed_init_tgt = True
+dn_labelbook_size = 2000
+max_text_len = 256
+text_encoder_type = "bert-base-uncased"
+use_text_enhancer = True
+use_fusion_layer = True
+use_checkpoint = True
+use_transformer_ckpt = True
+use_text_cross_attention = True
+text_dropout = 0.0
+fusion_dropout = 0.0
+fusion_droppath = 0.1
+sub_sentence_present = True
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/datasets/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/datasets/transforms.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/datasets/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..91cf9269e4b31008a3ddca34a19b038a9b399991
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/datasets/transforms.py
@@ -0,0 +1,311 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Transforms and data augmentation for both image + bbox.
+"""
+import os
+import random
+
+import PIL
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as F
+
+from groundingdino.util.box_ops import box_xyxy_to_cxcywh
+from groundingdino.util.misc import interpolate
+
+
+def crop(image, target, region):
+ cropped_image = F.crop(image, *region)
+
+ target = target.copy()
+ i, j, h, w = region
+
+ # should we do something wrt the original size?
+ target["size"] = torch.tensor([h, w])
+
+ fields = ["labels", "area", "iscrowd", "positive_map"]
+
+ if "boxes" in target:
+ boxes = target["boxes"]
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
+ cropped_boxes = cropped_boxes.clamp(min=0)
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
+ target["area"] = area
+ fields.append("boxes")
+
+ if "masks" in target:
+ # FIXME should we update the area here if there are no boxes?
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
+ fields.append("masks")
+
+ # remove elements for which the boxes or masks that have zero area
+ if "boxes" in target or "masks" in target:
+ # favor boxes selection when defining which elements to keep
+ # this is compatible with previous implementation
+ if "boxes" in target:
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
+ else:
+ keep = target["masks"].flatten(1).any(1)
+
+ for field in fields:
+ if field in target:
+ target[field] = target[field][keep]
+
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
+ # for debug and visualization only.
+ if "strings_positive" in target:
+ target["strings_positive"] = [
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
+ ]
+
+ return cropped_image, target
+
+
+def hflip(image, target):
+ flipped_image = F.hflip(image)
+
+ w, h = image.size
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
+ [w, 0, w, 0]
+ )
+ target["boxes"] = boxes
+
+ if "masks" in target:
+ target["masks"] = target["masks"].flip(-1)
+
+ return flipped_image, target
+
+
+def resize(image, target, size, max_size=None):
+ # size can be min_size (scalar) or (w, h) tuple
+
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(round(max_size * min_original_size / max_original_size))
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (h, w)
+
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+
+ return (oh, ow)
+
+ def get_size(image_size, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size[::-1]
+ else:
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+ size = get_size(image.size, size, max_size)
+ rescaled_image = F.resize(image, size)
+
+ if target is None:
+ return rescaled_image, None
+
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
+ ratio_width, ratio_height = ratios
+
+ target = target.copy()
+ if "boxes" in target:
+ boxes = target["boxes"]
+ scaled_boxes = boxes * torch.as_tensor(
+ [ratio_width, ratio_height, ratio_width, ratio_height]
+ )
+ target["boxes"] = scaled_boxes
+
+ if "area" in target:
+ area = target["area"]
+ scaled_area = area * (ratio_width * ratio_height)
+ target["area"] = scaled_area
+
+ h, w = size
+ target["size"] = torch.tensor([h, w])
+
+ if "masks" in target:
+ target["masks"] = (
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
+ )
+
+ return rescaled_image, target
+
+
+def pad(image, target, padding):
+ # assumes that we only pad on the bottom right corners
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
+ if target is None:
+ return padded_image, None
+ target = target.copy()
+ # should we do something wrt the original size?
+ target["size"] = torch.tensor(padded_image.size[::-1])
+ if "masks" in target:
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
+ return padded_image, target
+
+
+class ResizeDebug(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ return resize(img, target, self.size)
+
+
+class RandomCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ region = T.RandomCrop.get_params(img, self.size)
+ return crop(img, target, region)
+
+
+class RandomSizeCrop(object):
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
+ # respect_boxes: True to keep all boxes
+ # False to tolerence box filter
+ self.min_size = min_size
+ self.max_size = max_size
+ self.respect_boxes = respect_boxes
+
+ def __call__(self, img: PIL.Image.Image, target: dict):
+ init_boxes = len(target["boxes"])
+ max_patience = 10
+ for i in range(max_patience):
+ w = random.randint(self.min_size, min(img.width, self.max_size))
+ h = random.randint(self.min_size, min(img.height, self.max_size))
+ region = T.RandomCrop.get_params(img, [h, w])
+ result_img, result_target = crop(img, target, region)
+ if (
+ not self.respect_boxes
+ or len(result_target["boxes"]) == init_boxes
+ or i == max_patience - 1
+ ):
+ return result_img, result_target
+ return result_img, result_target
+
+
+class CenterCrop(object):
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img, target):
+ image_width, image_height = img.size
+ crop_height, crop_width = self.size
+ crop_top = int(round((image_height - crop_height) / 2.0))
+ crop_left = int(round((image_width - crop_width) / 2.0))
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return hflip(img, target)
+ return img, target
+
+
+class RandomResize(object):
+ def __init__(self, sizes, max_size=None):
+ assert isinstance(sizes, (list, tuple))
+ self.sizes = sizes
+ self.max_size = max_size
+
+ def __call__(self, img, target=None):
+ size = random.choice(self.sizes)
+ return resize(img, target, size, self.max_size)
+
+
+class RandomPad(object):
+ def __init__(self, max_pad):
+ self.max_pad = max_pad
+
+ def __call__(self, img, target):
+ pad_x = random.randint(0, self.max_pad)
+ pad_y = random.randint(0, self.max_pad)
+ return pad(img, target, (pad_x, pad_y))
+
+
+class RandomSelect(object):
+ """
+ Randomly selects between transforms1 and transforms2,
+ with probability p for transforms1 and (1 - p) for transforms2
+ """
+
+ def __init__(self, transforms1, transforms2, p=0.5):
+ self.transforms1 = transforms1
+ self.transforms2 = transforms2
+ self.p = p
+
+ def __call__(self, img, target):
+ if random.random() < self.p:
+ return self.transforms1(img, target)
+ return self.transforms2(img, target)
+
+
+class ToTensor(object):
+ def __call__(self, img, target):
+ return F.to_tensor(img), target
+
+
+class RandomErasing(object):
+ def __init__(self, *args, **kwargs):
+ self.eraser = T.RandomErasing(*args, **kwargs)
+
+ def __call__(self, img, target):
+ return self.eraser(img), target
+
+
+class Normalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, image, target=None):
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ if target is None:
+ return image, None
+ target = target.copy()
+ h, w = image.shape[-2:]
+ if "boxes" in target:
+ boxes = target["boxes"]
+ boxes = box_xyxy_to_cxcywh(boxes)
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
+ target["boxes"] = boxes
+ return image, target
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af819d61d589cfec2e0ca46612a7456f42b831a
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/__init__.py
@@ -0,0 +1,15 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+from .groundingdino import build_groundingdino
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e4b272b479a26c63d120c818c140870cd8c287
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py
@@ -0,0 +1 @@
+from .backbone import build_backbone
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8340c723fad8e07e2fc62daaa3912487498814b
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py
@@ -0,0 +1,221 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+Backbone modules.
+"""
+
+from typing import Dict, List
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+
+from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
+
+from .position_encoding import build_position_encoding
+from .swin_transformer import build_swin_transformer
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ train_backbone: bool,
+ num_channels: int,
+ return_interm_indices: list,
+ ):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if (
+ not train_backbone
+ or "layer2" not in name
+ and "layer3" not in name
+ and "layer4" not in name
+ ):
+ parameter.requires_grad_(False)
+
+ return_layers = {}
+ for idx, layer_index in enumerate(return_interm_indices):
+ return_layers.update(
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
+ )
+
+ # if len:
+ # if use_stage1_feature:
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ # else:
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+ # else:
+ # return_layers = {'layer4': "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ # import ipdb; ipdb.set_trace()
+ return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+
+ def __init__(
+ self,
+ name: str,
+ train_backbone: bool,
+ dilation: bool,
+ return_interm_indices: list,
+ batch_norm=FrozenBatchNorm2d,
+ ):
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(),
+ norm_layer=batch_norm,
+ )
+ else:
+ raise NotImplementedError("Why you can get here with name {}".format(name))
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
+ num_channels_all = [256, 512, 1024, 2048]
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ """
+ Useful args:
+ - backbone: backbone name
+ - lr_backbone:
+ - dilation
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
+ - backbone_freeze_keywords:
+ - use_checkpoint: for swin only for now
+
+ """
+ position_embedding = build_position_encoding(args)
+ train_backbone = True
+ if not train_backbone:
+ raise ValueError("Please set lr_backbone > 0")
+ return_interm_indices = args.return_interm_indices
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
+ args.backbone_freeze_keywords
+ use_checkpoint = getattr(args, "use_checkpoint", False)
+
+ if args.backbone in ["resnet50", "resnet101"]:
+ backbone = Backbone(
+ args.backbone,
+ train_backbone,
+ args.dilation,
+ return_interm_indices,
+ batch_norm=FrozenBatchNorm2d,
+ )
+ bb_num_channels = backbone.num_channels
+ elif args.backbone in [
+ "swin_T_224_1k",
+ "swin_B_224_22k",
+ "swin_B_384_22k",
+ "swin_L_224_22k",
+ "swin_L_384_22k",
+ ]:
+ pretrain_img_size = int(args.backbone.split("_")[-2])
+ backbone = build_swin_transformer(
+ args.backbone,
+ pretrain_img_size=pretrain_img_size,
+ out_indices=tuple(return_interm_indices),
+ dilation=False,
+ use_checkpoint=use_checkpoint,
+ )
+
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
+ else:
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
+
+ assert len(bb_num_channels) == len(
+ return_interm_indices
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
+
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = bb_num_channels
+ assert isinstance(
+ bb_num_channels, List
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
+ # import ipdb; ipdb.set_trace()
+ return model
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..eac7e896bbe85a670824bfe8ef487d0535d5bd99
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py
@@ -0,0 +1,186 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+from groundingdino.util.misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ # if os.environ.get("SHILONG_AMP", None) == '1':
+ # eps = 1e-4
+ # else:
+ # eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingSineHW(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperatureH = temperatureH
+ self.temperatureW = temperatureW
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+ # import ipdb; ipdb.set_trace()
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
+ pos_x = x_embed[:, :, :, None] / dim_tx
+
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
+ pos_y = y_embed[:, :, :, None] / dim_ty
+
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+
+ # import ipdb; ipdb.set_trace()
+
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = (
+ torch.cat(
+ [
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ],
+ dim=-1,
+ )
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(x.shape[0], 1, 1, 1)
+ )
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ("v2", "sine"):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSineHW(
+ N_steps,
+ temperatureH=args.pe_temperatureH,
+ temperatureW=args.pe_temperatureW,
+ normalize=True,
+ )
+ elif args.position_embedding in ("v3", "learned"):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c66194deb5dd370e797e57e2712f44303e568cc
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py
@@ -0,0 +1,802 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# --------------------------------------------------------
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from groundingdino.util.misc import NestedTensor
+
+
+class Mlp(nn.Module):
+ """Multilayer perceptron."""
+
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """Forward function.
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+ )
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0)
+ )
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ dilation=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.dilation = dilation
+
+ # if use_checkpoint:
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ pretrain_img_size[0] // patch_size[0],
+ pretrain_img_size[1] // patch_size[1],
+ ]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ # prepare downsample list
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
+ downsamplelist[-1] = None
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
+ if self.dilation:
+ downsamplelist[-2] = None
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ # dim=int(embed_dim * 2 ** i_layer),
+ dim=num_features[i_layer],
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ downsample=downsamplelist[i_layer],
+ use_checkpoint=use_checkpoint,
+ )
+ self.layers.append(layer)
+
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f"norm{i_layer}"
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ # def init_weights(self, pretrained=None):
+ # """Initialize the weights in backbone.
+ # Args:
+ # pretrained (str, optional): Path to pre-trained weights.
+ # Defaults to None.
+ # """
+
+ # def _init_weights(m):
+ # if isinstance(m, nn.Linear):
+ # trunc_normal_(m.weight, std=.02)
+ # if isinstance(m, nn.Linear) and m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+ # elif isinstance(m, nn.LayerNorm):
+ # nn.init.constant_(m.bias, 0)
+ # nn.init.constant_(m.weight, 1.0)
+
+ # if isinstance(pretrained, str):
+ # self.apply(_init_weights)
+ # logger = get_root_logger()
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
+ # elif pretrained is None:
+ # self.apply(_init_weights)
+ # else:
+ # raise TypeError('pretrained must be a str or None')
+
+ def forward_raw(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+ # import ipdb; ipdb.set_trace()
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+ # in:
+ # torch.Size([2, 3, 1024, 1024])
+ # outs:
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
+ return tuple(outs)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+ # in:
+ # torch.Size([2, 3, 1024, 1024])
+ # out:
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
+
+ # collect for nesttensors
+ outs_dict = {}
+ for idx, out_i in enumerate(outs):
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
+ outs_dict[idx] = NestedTensor(out_i, mask)
+
+ return outs_dict
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+
+def build_swin_transformer(modelname, pretrain_img_size, **kw):
+ assert modelname in [
+ "swin_T_224_1k",
+ "swin_B_224_22k",
+ "swin_B_384_22k",
+ "swin_L_224_22k",
+ "swin_L_384_22k",
+ ]
+
+ model_para_dict = {
+ "swin_T_224_1k": dict(
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
+ ),
+ "swin_B_224_22k": dict(
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
+ ),
+ "swin_B_384_22k": dict(
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
+ ),
+ "swin_L_224_22k": dict(
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
+ ),
+ "swin_L_384_22k": dict(
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
+ ),
+ }
+ kw_cgf = model_para_dict[modelname]
+ kw_cgf.update(kw)
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
+ return model
+
+
+if __name__ == "__main__":
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
+ x = torch.rand(2, 3, 1024, 1024)
+ y = model.forward_raw(x)
+ import ipdb
+
+ ipdb.set_trace()
+ x = torch.rand(2, 3, 384, 384)
+ y = model.forward_raw(x)
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0cf9779b270e1aead32845006f8b881fcba37ad
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py
@@ -0,0 +1,273 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from torch import Tensor, nn
+from torchvision.ops.boxes import nms
+from transformers import BertConfig, BertModel, BertPreTrainedModel
+from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
+
+
+class BertModelWarper(nn.Module):
+ def __init__(self, bert_model):
+ super().__init__()
+ # self.bert = bert_modelc
+
+ self.config = bert_model.config
+ self.embeddings = bert_model.embeddings
+ self.encoder = bert_model.encoder
+ self.pooler = bert_model.pooler
+
+ self.get_extended_attention_mask = bert_model.get_extended_attention_mask
+ self.invert_attention_mask = bert_model.invert_attention_mask
+ self.get_head_mask = bert_model.get_head_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ )
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape, device
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class TextEncoderShell(nn.Module):
+ def __init__(self, text_encoder):
+ super().__init__()
+ self.text_encoder = text_encoder
+ self.config = self.text_encoder.config
+
+ def forward(self, **kw):
+ # feed into text encoder
+ return self.text_encoder(**kw)
+
+
+def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
+ """Generate attention mask between each pair of special tokens
+ Args:
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
+ special_tokens_mask (list): special tokens mask.
+ Returns:
+ torch.Tensor: attention mask between each special tokens.
+ """
+ input_ids = tokenized["input_ids"]
+ bs, num_token = input_ids.shape
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
+ for special_token in special_tokens_list:
+ special_tokens_mask |= input_ids == special_token
+
+ # idxs: each row is a list of indices of special tokens
+ idxs = torch.nonzero(special_tokens_mask)
+
+ # generate attention mask and positional ids
+ attention_mask = (
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
+ )
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
+ previous_col = 0
+ for i in range(idxs.shape[0]):
+ row, col = idxs[i]
+ if (col == 0) or (col == num_token - 1):
+ attention_mask[row, col, col] = True
+ position_ids[row, col] = 0
+ else:
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
+ 0, col - previous_col, device=input_ids.device
+ )
+
+ previous_col = col
+
+ # # padding mask
+ # padding_mask = tokenized['attention_mask']
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
+
+ return attention_mask, position_ids.to(torch.long)
+
+
+def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
+ """Generate attention mask between each pair of special tokens
+ Args:
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
+ special_tokens_mask (list): special tokens mask.
+ Returns:
+ torch.Tensor: attention mask between each special tokens.
+ """
+ input_ids = tokenized["input_ids"]
+ bs, num_token = input_ids.shape
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
+ for special_token in special_tokens_list:
+ special_tokens_mask |= input_ids == special_token
+
+ # idxs: each row is a list of indices of special tokens
+ idxs = torch.nonzero(special_tokens_mask)
+
+ # generate attention mask and positional ids
+ attention_mask = (
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
+ )
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
+ cate_to_token_mask_list = [[] for _ in range(bs)]
+ previous_col = 0
+ for i in range(idxs.shape[0]):
+ row, col = idxs[i]
+ if (col == 0) or (col == num_token - 1):
+ attention_mask[row, col, col] = True
+ position_ids[row, col] = 0
+ else:
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
+ 0, col - previous_col, device=input_ids.device
+ )
+ c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
+ c2t_maski[previous_col + 1 : col] = True
+ cate_to_token_mask_list[row].append(c2t_maski)
+ previous_col = col
+
+ cate_to_token_mask_list = [
+ torch.stack(cate_to_token_mask_listi, dim=0)
+ for cate_to_token_mask_listi in cate_to_token_mask_list
+ ]
+
+ # # padding mask
+ # padding_mask = tokenized['attention_mask']
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
+
+ return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..c7408eba007b424194618baa63726657e36875e3
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h
@@ -0,0 +1,64 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "ms_deform_attn_cuda.h"
+#endif
+
+namespace groundingdino {
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+} // namespace groundingdino
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..551243fdadfd1682b5dc6628623b67a79b3f6c74
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp
@@ -0,0 +1,43 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+namespace groundingdino {
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+} // namespace groundingdino
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..b2b88e8c46f19b6db0933163e57ccdb51180f517
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h
@@ -0,0 +1,35 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+namespace groundingdino {
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+} // namespace groundingdino
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d04fae8a9a45c11e4e74f3035e94762796da4096
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu
@@ -0,0 +1,156 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+namespace groundingdino {
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
+
+} // namespace groundingdino
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..ad1311a78f61303616504eb991aaa9c4a93d9948
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h
@@ -0,0 +1,33 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+namespace groundingdino {
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+} // namespace groundingdino
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6bc2acb7aea0eab2e9e91e769a16861e1652c284
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu
new file mode 100644
index 0000000000000000000000000000000000000000..64569e34ffb250964de27e33e7a53f3822270b9e
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu
@@ -0,0 +1,7 @@
+#include
+
+namespace groundingdino {
+int get_cudart_version() {
+ return CUDART_VERSION;
+}
+} // namespace groundingdino
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c1f2c50c82909bbd5492c163d634af77a3ba1781
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp
@@ -0,0 +1,58 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+#include "MsDeformAttn/ms_deform_attn.h"
+
+namespace groundingdino {
+
+#ifdef WITH_CUDA
+extern int get_cudart_version();
+#endif
+
+std::string get_cuda_version() {
+#ifdef WITH_CUDA
+ std::ostringstream oss;
+
+ // copied from
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
+ auto printCudaStyleVersion = [&](int v) {
+ oss << (v / 1000) << "." << (v / 10 % 100);
+ if (v % 10 != 0) {
+ oss << "." << (v % 10);
+ }
+ };
+ printCudaStyleVersion(get_cudart_version());
+ return oss.str();
+#else
+ return std::string("not available");
+#endif
+}
+
+// similar to
+// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
+std::string get_compiler_version() {
+ std::ostringstream ss;
+#if defined(__GNUC__)
+#ifndef __clang__
+ { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
+#endif
+#endif
+
+#if defined(__clang_major__)
+ {
+ ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
+ << __clang_patchlevel__;
+ }
+#endif
+
+#if defined(_MSC_VER)
+ { ss << "MSVC " << _MSC_FULL_VER; }
+#endif
+ return ss.str();
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
+
+} // namespace groundingdino
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2753b3ddee43c7a9fe28d1824db5d786e7e1ad59
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py
@@ -0,0 +1,297 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import DropPath
+
+
+class FeatureResizer(nn.Module):
+ """
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
+ """
+
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
+ super().__init__()
+ self.do_ln = do_ln
+ # Object feature encoding
+ self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, encoder_features):
+ x = self.fc(encoder_features)
+ if self.do_ln:
+ x = self.layer_norm(x)
+ output = self.dropout(x)
+ return output
+
+
+def l1norm(X, dim, eps=1e-8):
+ """L1-normalize columns of X"""
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
+ X = torch.div(X, norm)
+ return X
+
+
+def l2norm(X, dim, eps=1e-8):
+ """L2-normalize columns of X"""
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
+ X = torch.div(X, norm)
+ return X
+
+
+def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
+ """
+ query: (n_context, queryL, d)
+ context: (n_context, sourceL, d)
+ """
+ batch_size_q, queryL = query.size(0), query.size(1)
+ batch_size, sourceL = context.size(0), context.size(1)
+
+ # Get attention
+ # --> (batch, d, queryL)
+ queryT = torch.transpose(query, 1, 2)
+
+ # (batch, sourceL, d)(batch, d, queryL)
+ # --> (batch, sourceL, queryL)
+ attn = torch.bmm(context, queryT)
+ if raw_feature_norm == "softmax":
+ # --> (batch*sourceL, queryL)
+ attn = attn.view(batch_size * sourceL, queryL)
+ attn = nn.Softmax()(attn)
+ # --> (batch, sourceL, queryL)
+ attn = attn.view(batch_size, sourceL, queryL)
+ elif raw_feature_norm == "l2norm":
+ attn = l2norm(attn, 2)
+ elif raw_feature_norm == "clipped_l2norm":
+ attn = nn.LeakyReLU(0.1)(attn)
+ attn = l2norm(attn, 2)
+ else:
+ raise ValueError("unknown first norm type:", raw_feature_norm)
+ # --> (batch, queryL, sourceL)
+ attn = torch.transpose(attn, 1, 2).contiguous()
+ # --> (batch*queryL, sourceL)
+ attn = attn.view(batch_size * queryL, sourceL)
+ attn = nn.Softmax()(attn * smooth)
+ # --> (batch, queryL, sourceL)
+ attn = attn.view(batch_size, queryL, sourceL)
+ # --> (batch, sourceL, queryL)
+ attnT = torch.transpose(attn, 1, 2).contiguous()
+
+ # --> (batch, d, sourceL)
+ contextT = torch.transpose(context, 1, 2)
+ # (batch x d x sourceL)(batch x sourceL x queryL)
+ # --> (batch, d, queryL)
+ weightedContext = torch.bmm(contextT, attnT)
+ # --> (batch, queryL, d)
+ weightedContext = torch.transpose(weightedContext, 1, 2)
+
+ return weightedContext, attnT
+
+
+class BiMultiHeadAttention(nn.Module):
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
+ super(BiMultiHeadAttention, self).__init__()
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+ self.v_dim = v_dim
+ self.l_dim = l_dim
+
+ assert (
+ self.head_dim * self.num_heads == self.embed_dim
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
+ self.scale = self.head_dim ** (-0.5)
+ self.dropout = dropout
+
+ self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
+ self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
+ self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
+ self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
+
+ self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
+ self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
+
+ self.stable_softmax_2d = True
+ self.clamp_min_for_underflow = True
+ self.clamp_max_for_overflow = True
+
+ self._reset_parameters()
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def _reset_parameters(self):
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ self.v_proj.bias.data.fill_(0)
+ nn.init.xavier_uniform_(self.l_proj.weight)
+ self.l_proj.bias.data.fill_(0)
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
+ self.values_v_proj.bias.data.fill_(0)
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
+ self.values_l_proj.bias.data.fill_(0)
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
+ self.out_v_proj.bias.data.fill_(0)
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
+ self.out_l_proj.bias.data.fill_(0)
+
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
+ """_summary_
+
+ Args:
+ v (_type_): bs, n_img, dim
+ l (_type_): bs, n_text, dim
+ attention_mask_v (_type_, optional): _description_. bs, n_img
+ attention_mask_l (_type_, optional): _description_. bs, n_text
+
+ Returns:
+ _type_: _description_
+ """
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+ bsz, tgt_len, _ = v.size()
+
+ query_states = self.v_proj(v) * self.scale
+ key_states = self._shape(self.l_proj(l), -1, bsz)
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_v_states = value_v_states.view(*proj_shape)
+ value_l_states = value_l_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
+ )
+
+ if self.stable_softmax_2d:
+ attn_weights = attn_weights - attn_weights.max()
+
+ if self.clamp_min_for_underflow:
+ attn_weights = torch.clamp(
+ attn_weights, min=-50000
+ ) # Do not increase -50000, data type half has quite limited range
+ if self.clamp_max_for_overflow:
+ attn_weights = torch.clamp(
+ attn_weights, max=50000
+ ) # Do not increase 50000, data type half has quite limited range
+
+ attn_weights_T = attn_weights.transpose(1, 2)
+ attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
+ if self.clamp_min_for_underflow:
+ attn_weights_l = torch.clamp(
+ attn_weights_l, min=-50000
+ ) # Do not increase -50000, data type half has quite limited range
+ if self.clamp_max_for_overflow:
+ attn_weights_l = torch.clamp(
+ attn_weights_l, max=50000
+ ) # Do not increase 50000, data type half has quite limited range
+
+ # mask vison for language
+ if attention_mask_v is not None:
+ attention_mask_v = (
+ attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
+ )
+ attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
+
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
+
+ # mask language for vision
+ if attention_mask_l is not None:
+ attention_mask_l = (
+ attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
+ )
+ attn_weights.masked_fill_(attention_mask_l, float("-inf"))
+ attn_weights_v = attn_weights.softmax(dim=-1)
+
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
+
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
+
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
+ )
+
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
+ )
+
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output_v = attn_output_v.transpose(1, 2)
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
+ attn_output_l = attn_output_l.transpose(1, 2)
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
+
+ attn_output_v = self.out_v_proj(attn_output_v)
+ attn_output_l = self.out_l_proj(attn_output_l)
+
+ return attn_output_v, attn_output_l
+
+
+# Bi-Direction MHA (text->image, image->text)
+class BiAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ v_dim,
+ l_dim,
+ embed_dim,
+ num_heads,
+ dropout=0.1,
+ drop_path=0.0,
+ init_values=1e-4,
+ cfg=None,
+ ):
+ """
+ Inputs:
+ embed_dim - Dimensionality of input and attention feature vectors
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
+ (usually 2-4x larger than embed_dim)
+ num_heads - Number of heads to use in the Multi-Head Attention block
+ dropout - Amount of dropout to apply in the feed-forward network
+ """
+ super(BiAttentionBlock, self).__init__()
+
+ # pre layer norm
+ self.layer_norm_v = nn.LayerNorm(v_dim)
+ self.layer_norm_l = nn.LayerNorm(l_dim)
+ self.attn = BiMultiHeadAttention(
+ v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
+ )
+
+ # add layer scale for training stability
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
+
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
+ v = self.layer_norm_v(v)
+ l = self.layer_norm_l(l)
+ delta_v, delta_l = self.attn(
+ v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
+ )
+ # v, l = v + delta_v, l + delta_l
+ v = v + self.drop_path(self.gamma_v * delta_v)
+ l = l + self.drop_path(self.gamma_l * delta_l)
+ return v, l
+
+ # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py
new file mode 100644
index 0000000000000000000000000000000000000000..052df6220595a1b39b7e2aea37ca4872d113dfd2
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py
@@ -0,0 +1,395 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR model and criterion classes.
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# ------------------------------------------------------------------------
+import copy
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops.boxes import nms
+from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
+
+from groundingdino.util import box_ops, get_tokenlizer
+from groundingdino.util.misc import (
+ NestedTensor,
+ accuracy,
+ get_world_size,
+ interpolate,
+ inverse_sigmoid,
+ is_dist_avail_and_initialized,
+ nested_tensor_from_tensor_list,
+)
+from groundingdino.util.utils import get_phrases_from_posmap
+from groundingdino.util.visualizer import COCOVisualizer
+from groundingdino.util.vl_utils import create_positive_map_from_span
+
+from ..registry import MODULE_BUILD_FUNCS
+from .backbone import build_backbone
+from .bertwarper import (
+ BertModelWarper,
+ generate_masks_with_special_tokens,
+ generate_masks_with_special_tokens_and_transfer_map,
+)
+from .transformer import build_transformer
+from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
+
+
+class GroundingDINO(nn.Module):
+ """This is the Cross-Attention Detector module that performs object detection"""
+
+ def __init__(
+ self,
+ backbone,
+ transformer,
+ num_queries,
+ aux_loss=False,
+ iter_update=False,
+ query_dim=2,
+ num_feature_levels=1,
+ nheads=8,
+ # two stage
+ two_stage_type="no", # ['no', 'standard']
+ dec_pred_bbox_embed_share=True,
+ two_stage_class_embed_share=True,
+ two_stage_bbox_embed_share=True,
+ num_patterns=0,
+ dn_number=100,
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ dn_labelbook_size=100,
+ text_encoder_type="bert-base-uncased",
+ sub_sentence_present=True,
+ max_text_len=256,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.transformer = transformer
+ self.hidden_dim = hidden_dim = transformer.d_model
+ self.num_feature_levels = num_feature_levels
+ self.nheads = nheads
+ self.max_text_len = 256
+ self.sub_sentence_present = sub_sentence_present
+
+ # setting query dim
+ self.query_dim = query_dim
+ assert query_dim == 4
+
+ # for dn training
+ self.num_patterns = num_patterns
+ self.dn_number = dn_number
+ self.dn_box_noise_scale = dn_box_noise_scale
+ self.dn_label_noise_ratio = dn_label_noise_ratio
+ self.dn_labelbook_size = dn_labelbook_size
+
+ # bert
+ self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
+ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
+ self.bert.pooler.dense.weight.requires_grad_(False)
+ self.bert.pooler.dense.bias.requires_grad_(False)
+ self.bert = BertModelWarper(bert_model=self.bert)
+
+ self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
+ nn.init.constant_(self.feat_map.bias.data, 0)
+ nn.init.xavier_uniform_(self.feat_map.weight.data)
+ # freeze
+
+ # special tokens
+ self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
+
+ # prepare input projection layers
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.num_channels)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ )
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ )
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
+ self.input_proj = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ ]
+ )
+
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.box_pred_damping = box_pred_damping = None
+
+ self.iter_update = iter_update
+ assert iter_update, "Why not iter_update?"
+
+ # prepare pred layers
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
+ # prepare class & box embed
+ _class_embed = ContrastiveEmbed()
+
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
+
+ if dec_pred_bbox_embed_share:
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
+ else:
+ box_embed_layerlist = [
+ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
+ ]
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ self.transformer.decoder.class_embed = self.class_embed
+
+ # two stage
+ self.two_stage_type = two_stage_type
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
+ two_stage_type
+ )
+ if two_stage_type != "no":
+ if two_stage_bbox_embed_share:
+ assert dec_pred_bbox_embed_share
+ self.transformer.enc_out_bbox_embed = _bbox_embed
+ else:
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
+
+ if two_stage_class_embed_share:
+ assert dec_pred_bbox_embed_share
+ self.transformer.enc_out_class_embed = _class_embed
+ else:
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
+
+ self.refpoint_embed = None
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ # init input_proj
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ def init_ref_points(self, use_num_queries):
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
+
+ def forward(self, samples: NestedTensor, targets: List = None, **kw):
+ """The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x num_classes]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ if targets is None:
+ captions = kw["captions"]
+ else:
+ captions = [t["caption"] for t in targets]
+ len(captions)
+
+ # encoder texts
+ tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
+ samples.device
+ )
+ (
+ text_self_attention_masks,
+ position_ids,
+ cate_to_token_mask_list,
+ ) = generate_masks_with_special_tokens_and_transfer_map(
+ tokenized, self.specical_tokens, self.tokenizer
+ )
+
+ if text_self_attention_masks.shape[1] > self.max_text_len:
+ text_self_attention_masks = text_self_attention_masks[
+ :, : self.max_text_len, : self.max_text_len
+ ]
+ position_ids = position_ids[:, : self.max_text_len]
+ tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
+ tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
+ tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
+
+ # extract text embeddings
+ if self.sub_sentence_present:
+ tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
+ tokenized_for_encoder["attention_mask"] = text_self_attention_masks
+ tokenized_for_encoder["position_ids"] = position_ids
+ else:
+ # import ipdb; ipdb.set_trace()
+ tokenized_for_encoder = tokenized
+
+ bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
+
+ encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
+ text_token_mask = tokenized.attention_mask.bool() # bs, 195
+ # text_token_mask: True for nomask, False for mask
+ # text_self_attention_masks: True for nomask, False for mask
+
+ if encoded_text.shape[1] > self.max_text_len:
+ encoded_text = encoded_text[:, : self.max_text_len, :]
+ text_token_mask = text_token_mask[:, : self.max_text_len]
+ position_ids = position_ids[:, : self.max_text_len]
+ text_self_attention_masks = text_self_attention_masks[
+ :, : self.max_text_len, : self.max_text_len
+ ]
+
+ text_dict = {
+ "encoded_text": encoded_text, # bs, 195, d_model
+ "text_token_mask": text_token_mask, # bs, 195
+ "position_ids": position_ids, # bs, 195
+ "text_self_attention_masks": text_self_attention_masks, # bs, 195,195
+ }
+
+ # import ipdb; ipdb.set_trace()
+
+ if isinstance(samples, (list, torch.Tensor)):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, poss = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ poss.append(pos_l)
+
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
+ srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
+ )
+
+ # deformable-detr-like anchor update
+ outputs_coord_list = []
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
+ zip(reference[:-1], self.bbox_embed, hs)
+ ):
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
+ outputs_coord_list.append(layer_outputs_unsig)
+ outputs_coord_list = torch.stack(outputs_coord_list)
+
+ # output
+ outputs_class = torch.stack(
+ [
+ layer_cls_embed(layer_hs, text_dict)
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
+ ]
+ )
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
+
+ # # for intermediate outputs
+ # if self.aux_loss:
+ # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
+
+ # # for encoder output
+ # if hs_enc is not None:
+ # # prepare intermediate outputs
+ # interm_coord = ref_enc[-1]
+ # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
+ # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
+ # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
+
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [
+ {"pred_logits": a, "pred_boxes": b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
+ ]
+
+
+@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
+def build_groundingdino(args):
+
+ backbone = build_backbone(args)
+ transformer = build_transformer(args)
+
+ dn_labelbook_size = args.dn_labelbook_size
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
+ sub_sentence_present = args.sub_sentence_present
+
+ model = GroundingDINO(
+ backbone,
+ transformer,
+ num_queries=args.num_queries,
+ aux_loss=True,
+ iter_update=True,
+ query_dim=4,
+ num_feature_levels=args.num_feature_levels,
+ nheads=args.nheads,
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
+ two_stage_type=args.two_stage_type,
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
+ num_patterns=args.num_patterns,
+ dn_number=0,
+ dn_box_noise_scale=args.dn_box_noise_scale,
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
+ dn_labelbook_size=dn_labelbook_size,
+ text_encoder_type=args.text_encoder_type,
+ sub_sentence_present=sub_sentence_present,
+ max_text_len=args.max_text_len,
+ )
+
+ return model
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..489d501bef364020212306d81e9b85c8daa27491
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py
@@ -0,0 +1,413 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from:
+# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
+# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
+# https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
+# ------------------------------------------------------------------------------------------------
+
+import math
+import warnings
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.init import constant_, xavier_uniform_
+
+try:
+ from groundingdino import _C
+except:
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
+
+
+# helpers
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+
+class MultiScaleDeformableAttnFunction(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step,
+ ):
+ ctx.im2col_step = im2col_step
+ output = _C.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ctx.im2col_step,
+ )
+ ctx.save_for_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ (
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ) = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output,
+ ctx.im2col_step,
+ )
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def multi_scale_deformable_attn_pytorch(
+ value: torch.Tensor,
+ value_spatial_shapes: torch.Tensor,
+ sampling_locations: torch.Tensor,
+ attention_weights: torch.Tensor,
+) -> torch.Tensor:
+
+ bs, _, num_heads, embed_dims = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
+ # bs, H_*W_, num_heads, embed_dims ->
+ # bs, H_*W_, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, H_*W_ ->
+ # bs*num_heads, embed_dims, H_, W_
+ value_l_ = (
+ value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
+ )
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(bs, num_heads * embed_dims, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+class MultiScaleDeformableAttention(nn.Module):
+ """Multi-Scale Deformable Attention Module used in Deformable-DETR
+
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+
+ Args:
+ embed_dim (int): The embedding dimension of Attention. Default: 256.
+ num_heads (int): The number of attention heads. Default: 8.
+ num_levels (int): The number of feature map used in Attention. Default: 4.
+ num_points (int): The number of sampling points for each query
+ in each head. Default: 4.
+ img2col_steps (int): The step used in image_to_column. Defualt: 64.
+ dropout (float): Dropout layer used in output. Default: 0.1.
+ batch_first (bool): if ``True``, then the input and output tensor will be
+ provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 256,
+ num_heads: int = 8,
+ num_levels: int = 4,
+ num_points: int = 4,
+ img2col_step: int = 64,
+ batch_first: bool = False,
+ ):
+ super().__init__()
+ if embed_dim % num_heads != 0:
+ raise ValueError(
+ "embed_dim must be divisible by num_heads, but got {} and {}".format(
+ embed_dim, num_heads
+ )
+ )
+ head_dim = embed_dim // num_heads
+
+ self.batch_first = batch_first
+
+ if not _is_power_of_2(head_dim):
+ warnings.warn(
+ """
+ You'd better set d_model in MSDeformAttn to make sure that
+ each dim of the attention head a power of 2, which is more efficient.
+ """
+ )
+
+ self.im2col_step = img2col_step
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.num_levels = num_levels
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dim, embed_dim)
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
+
+ self.init_weights()
+
+ def _reset_parameters(self):
+ return self.init_weights()
+
+ def init_weights(self):
+ """
+ Default initialization for Parameters of Module.
+ """
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
+ 2.0 * math.pi / self.num_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.num_heads, 1, 1, 2)
+ .repeat(1, self.num_levels, self.num_points, 1)
+ )
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ def freeze_sampling_offsets(self):
+ print("Freeze sampling offsets")
+ self.sampling_offsets.weight.requires_grad = False
+ self.sampling_offsets.bias.requires_grad = False
+
+ def freeze_attention_weights(self):
+ print("Freeze attention weights")
+ self.attention_weights.weight.requires_grad = False
+ self.attention_weights.bias.requires_grad = False
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ value: Optional[torch.Tensor] = None,
+ query_pos: Optional[torch.Tensor] = None,
+ key_padding_mask: Optional[torch.Tensor] = None,
+ reference_points: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.Tensor] = None,
+ level_start_index: Optional[torch.Tensor] = None,
+ **kwargs
+ ) -> torch.Tensor:
+
+ """Forward Function of MultiScaleDeformableAttention
+
+ Args:
+ query (torch.Tensor): Query embeddings with shape
+ `(num_query, bs, embed_dim)`
+ key (torch.Tensor): Key embeddings with shape
+ `(num_key, bs, embed_dim)`
+ value (torch.Tensor): Value embeddings with shape
+ `(num_key, bs, embed_dim)`
+ query_pos (torch.Tensor): The position embedding for `query`. Default: None.
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
+ indicating which elements within `key` to be ignored in attention.
+ reference_points (torch.Tensor): The normalized reference points
+ with shape `(bs, num_query, num_levels, 2)`,
+ all elements is range in [0, 1], top-left (0, 0),
+ bottom-right (1, 1), including padding are.
+ or `(N, Length_{query}, num_levels, 4)`, add additional
+ two dimensions `(h, w)` to form reference boxes.
+ spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
+ With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
+ level_start_index (torch.Tensor): The start index of each level. A tensor with
+ shape `(num_levels, )` which can be represented as
+ `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
+
+ Returns:
+ torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
+ """
+
+ if value is None:
+ value = query
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], float(0))
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
+ )
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
+ )
+ attention_weights = attention_weights.softmax(-1)
+ attention_weights = attention_weights.view(
+ bs,
+ num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points,
+ )
+
+ # bs, num_query, num_heads, num_levels, num_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.num_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
+ reference_points.shape[-1]
+ )
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ halffloat = False
+ if value.dtype == torch.float16:
+ halffloat = True
+ value = value.float()
+ sampling_locations = sampling_locations.float()
+ attention_weights = attention_weights.float()
+
+ output = MultiScaleDeformableAttnFunction.apply(
+ value,
+ spatial_shapes,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+
+ if halffloat:
+ output = output.half()
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights
+ )
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ output = output.permute(1, 0, 2)
+
+ return output
+
+
+def create_dummy_class(klass, dependency, message=""):
+ """
+ When a dependency of a class is not available, create a dummy class which throws ImportError
+ when used.
+
+ Args:
+ klass (str): name of the class.
+ dependency (str): name of the dependency.
+ message: extra message to print
+ Returns:
+ class: a class object
+ """
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
+ if message:
+ err = err + " " + message
+
+ class _DummyMetaClass(type):
+ # throw error on class attribute access
+ def __getattr__(_, __): # noqa: B902
+ raise ImportError(err)
+
+ class _Dummy(object, metaclass=_DummyMetaClass):
+ # throw error on constructor
+ def __init__(self, *args, **kwargs):
+ raise ImportError(err)
+
+ return _Dummy
+
+
+def create_dummy_func(func, dependency, message=""):
+ """
+ When a dependency of a function is not available, create a dummy function which throws
+ ImportError when used.
+
+ Args:
+ func (str): name of the function.
+ dependency (str or list[str]): name(s) of the dependency.
+ message: extra message to print
+ Returns:
+ function: a function object
+ """
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
+ if message:
+ err = err + " " + message
+
+ if isinstance(dependency, (list, tuple)):
+ dependency = ",".join(dependency)
+
+ def _dummy(*args, **kwargs):
+ raise ImportError(err)
+
+ return _dummy
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcb8742dbdde6e80fd38b11d064211f6935aae76
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer.py
@@ -0,0 +1,959 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR Transformer class.
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+from typing import Optional
+
+import torch
+import torch.utils.checkpoint as checkpoint
+from torch import Tensor, nn
+
+from groundingdino.util.misc import inverse_sigmoid
+
+from .fuse_modules import BiAttentionBlock
+from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
+from .transformer_vanilla import TransformerEncoderLayer
+from .utils import (
+ MLP,
+ _get_activation_fn,
+ _get_clones,
+ gen_encoder_output_proposals,
+ gen_sineembed_for_position,
+ get_sine_pos_embed,
+)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ nhead=8,
+ num_queries=300,
+ num_encoder_layers=6,
+ num_unicoder_layers=0,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.0,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ query_dim=4,
+ num_patterns=0,
+ # for deformable encoder
+ num_feature_levels=1,
+ enc_n_points=4,
+ dec_n_points=4,
+ # init query
+ learnable_tgt_init=False,
+ # two stage
+ two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
+ embed_init_tgt=False,
+ # for text
+ use_text_enhancer=False,
+ use_fusion_layer=False,
+ use_checkpoint=False,
+ use_transformer_ckpt=False,
+ use_text_cross_attention=False,
+ text_dropout=0.1,
+ fusion_dropout=0.1,
+ fusion_droppath=0.0,
+ ):
+ super().__init__()
+ self.num_feature_levels = num_feature_levels
+ self.num_encoder_layers = num_encoder_layers
+ self.num_unicoder_layers = num_unicoder_layers
+ self.num_decoder_layers = num_decoder_layers
+ self.num_queries = num_queries
+ assert query_dim == 4
+
+ # choose encoder layer type
+ encoder_layer = DeformableTransformerEncoderLayer(
+ d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
+ )
+
+ if use_text_enhancer:
+ text_enhance_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=nhead // 2,
+ dim_feedforward=dim_feedforward // 2,
+ dropout=text_dropout,
+ )
+ else:
+ text_enhance_layer = None
+
+ if use_fusion_layer:
+ feature_fusion_layer = BiAttentionBlock(
+ v_dim=d_model,
+ l_dim=d_model,
+ embed_dim=dim_feedforward // 2,
+ num_heads=nhead // 2,
+ dropout=fusion_dropout,
+ drop_path=fusion_droppath,
+ )
+ else:
+ feature_fusion_layer = None
+
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ assert encoder_norm is None
+ self.encoder = TransformerEncoder(
+ encoder_layer,
+ num_encoder_layers,
+ d_model=d_model,
+ num_queries=num_queries,
+ text_enhance_layer=text_enhance_layer,
+ feature_fusion_layer=feature_fusion_layer,
+ use_checkpoint=use_checkpoint,
+ use_transformer_ckpt=use_transformer_ckpt,
+ )
+
+ # choose decoder layer type
+ decoder_layer = DeformableTransformerDecoderLayer(
+ d_model,
+ dim_feedforward,
+ dropout,
+ activation,
+ num_feature_levels,
+ nhead,
+ dec_n_points,
+ use_text_cross_attention=use_text_cross_attention,
+ )
+
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ d_model=d_model,
+ query_dim=query_dim,
+ num_feature_levels=num_feature_levels,
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dec_layers = num_decoder_layers
+ self.num_queries = num_queries # useful for single stage model only
+ self.num_patterns = num_patterns
+ if not isinstance(num_patterns, int):
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
+ self.num_patterns = 0
+
+ if num_feature_levels > 1:
+ if self.num_encoder_layers > 0:
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+ else:
+ self.level_embed = None
+
+ self.learnable_tgt_init = learnable_tgt_init
+ assert learnable_tgt_init, "why not learnable_tgt_init"
+ self.embed_init_tgt = embed_init_tgt
+ if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
+ nn.init.normal_(self.tgt_embed.weight.data)
+ else:
+ self.tgt_embed = None
+
+ # for two stage
+ self.two_stage_type = two_stage_type
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
+ two_stage_type
+ )
+ if two_stage_type == "standard":
+ # anchor selection at the output of encoder
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+ self.two_stage_wh_embedding = None
+
+ if two_stage_type == "no":
+ self.init_ref_points(num_queries) # init self.refpoint_embed
+
+ self.enc_out_class_embed = None
+ self.enc_out_bbox_embed = None
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ if self.num_feature_levels > 1 and self.level_embed is not None:
+ nn.init.normal_(self.level_embed)
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def init_ref_points(self, use_num_queries):
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
+
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
+ """
+ Input:
+ - srcs: List of multi features [bs, ci, hi, wi]
+ - masks: List of multi masks [bs, hi, wi]
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
+ - tgt: [bs, num_dn, d_model]. None in infer
+
+ """
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
+ mask = mask.flatten(1) # bs, hw
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
+ if self.num_feature_levels > 1 and self.level_embed is not None:
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ else:
+ lvl_pos_embed = pos_embed
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
+ )
+ level_start_index = torch.cat(
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
+ )
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # two stage
+ enc_topk_proposals = enc_refpoint_embed = None
+
+ #########################################################
+ # Begin Encoder
+ #########################################################
+ memory, memory_text = self.encoder(
+ src_flatten,
+ pos=lvl_pos_embed_flatten,
+ level_start_index=level_start_index,
+ spatial_shapes=spatial_shapes,
+ valid_ratios=valid_ratios,
+ key_padding_mask=mask_flatten,
+ memory_text=text_dict["encoded_text"],
+ text_attention_mask=~text_dict["text_token_mask"],
+ # we ~ the mask . False means use the token; True means pad the token
+ position_ids=text_dict["position_ids"],
+ text_self_attention_masks=text_dict["text_self_attention_masks"],
+ )
+ #########################################################
+ # End Encoder
+ # - memory: bs, \sum{hw}, c
+ # - mask_flatten: bs, \sum{hw}
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
+ #########################################################
+ text_dict["encoded_text"] = memory_text
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
+ # if memory.isnan().any() | memory.isinf().any():
+ # import ipdb; ipdb.set_trace()
+
+ if self.two_stage_type == "standard":
+ output_memory, output_proposals = gen_encoder_output_proposals(
+ memory, mask_flatten, spatial_shapes
+ )
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+
+ if text_dict is not None:
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
+ else:
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
+
+ topk_logits = enc_outputs_class_unselected.max(-1)[0]
+ enc_outputs_coord_unselected = (
+ self.enc_out_bbox_embed(output_memory) + output_proposals
+ ) # (bs, \sum{hw}, 4) unsigmoid
+ topk = self.num_queries
+
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
+
+ # gather boxes
+ refpoint_embed_undetach = torch.gather(
+ enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
+ ) # unsigmoid
+ refpoint_embed_ = refpoint_embed_undetach.detach()
+ init_box_proposal = torch.gather(
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
+ ).sigmoid() # sigmoid
+
+ # gather tgt
+ tgt_undetach = torch.gather(
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
+ )
+ if self.embed_init_tgt:
+ tgt_ = (
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
+ ) # nq, bs, d_model
+ else:
+ tgt_ = tgt_undetach.detach()
+
+ if refpoint_embed is not None:
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
+ tgt = torch.cat([tgt, tgt_], dim=1)
+ else:
+ refpoint_embed, tgt = refpoint_embed_, tgt_
+
+ elif self.two_stage_type == "no":
+ tgt_ = (
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
+ ) # nq, bs, d_model
+ refpoint_embed_ = (
+ self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
+ ) # nq, bs, 4
+
+ if refpoint_embed is not None:
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
+ tgt = torch.cat([tgt, tgt_], dim=1)
+ else:
+ refpoint_embed, tgt = refpoint_embed_, tgt_
+
+ if self.num_patterns > 0:
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
+ self.num_queries, 1
+ ) # 1, n_q*n_pat, d_model
+ tgt = tgt_embed + tgt_pat
+
+ init_box_proposal = refpoint_embed_.sigmoid()
+
+ else:
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
+ #########################################################
+ # End preparing tgt
+ # - tgt: bs, NQ, d_model
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
+ #########################################################
+
+ #########################################################
+ # Begin Decoder
+ #########################################################
+ hs, references = self.decoder(
+ tgt=tgt.transpose(0, 1),
+ memory=memory.transpose(0, 1),
+ memory_key_padding_mask=mask_flatten,
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
+ level_start_index=level_start_index,
+ spatial_shapes=spatial_shapes,
+ valid_ratios=valid_ratios,
+ tgt_mask=attn_mask,
+ memory_text=text_dict["encoded_text"],
+ text_attention_mask=~text_dict["text_token_mask"],
+ # we ~ the mask . False means use the token; True means pad the token
+ )
+ #########################################################
+ # End Decoder
+ # hs: n_dec, bs, nq, d_model
+ # references: n_dec+1, bs, nq, query_dim
+ #########################################################
+
+ #########################################################
+ # Begin postprocess
+ #########################################################
+ if self.two_stage_type == "standard":
+ hs_enc = tgt_undetach.unsqueeze(0)
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
+ else:
+ hs_enc = ref_enc = None
+ #########################################################
+ # End postprocess
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
+ #########################################################
+
+ return hs, references, hs_enc, ref_enc, init_box_proposal
+ # hs: (n_dec, bs, nq, d_model)
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
+ # ref_enc: sigmoid coordinates. \
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ encoder_layer,
+ num_layers,
+ d_model=256,
+ num_queries=300,
+ enc_layer_share=False,
+ text_enhance_layer=None,
+ feature_fusion_layer=None,
+ use_checkpoint=False,
+ use_transformer_ckpt=False,
+ ):
+ """_summary_
+
+ Args:
+ encoder_layer (_type_): _description_
+ num_layers (_type_): _description_
+ norm (_type_, optional): _description_. Defaults to None.
+ d_model (int, optional): _description_. Defaults to 256.
+ num_queries (int, optional): _description_. Defaults to 300.
+ enc_layer_share (bool, optional): _description_. Defaults to False.
+
+ """
+ super().__init__()
+ # prepare layers
+ self.layers = []
+ self.text_layers = []
+ self.fusion_layers = []
+ if num_layers > 0:
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
+
+ if text_enhance_layer is not None:
+ self.text_layers = _get_clones(
+ text_enhance_layer, num_layers, layer_share=enc_layer_share
+ )
+ if feature_fusion_layer is not None:
+ self.fusion_layers = _get_clones(
+ feature_fusion_layer, num_layers, layer_share=enc_layer_share
+ )
+ else:
+ self.layers = []
+ del encoder_layer
+
+ if text_enhance_layer is not None:
+ self.text_layers = []
+ del text_enhance_layer
+ if feature_fusion_layer is not None:
+ self.fusion_layers = []
+ del feature_fusion_layer
+
+ self.query_scale = None
+ self.num_queries = num_queries
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.use_checkpoint = use_checkpoint
+ self.use_transformer_ckpt = use_transformer_ckpt
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(
+ self,
+ # for images
+ src: Tensor,
+ pos: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ key_padding_mask: Tensor,
+ # for texts
+ memory_text: Tensor = None,
+ text_attention_mask: Tensor = None,
+ pos_text: Tensor = None,
+ text_self_attention_masks: Tensor = None,
+ position_ids: Tensor = None,
+ ):
+ """
+ Input:
+ - src: [bs, sum(hi*wi), 256]
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
+ - spatial_shapes: h,w of each level [num_level, 2]
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
+ - valid_ratios: [bs, num_level, 2]
+ - key_padding_mask: [bs, sum(hi*wi)]
+
+ - memory_text: bs, n_text, 256
+ - text_attention_mask: bs, n_text
+ False for no padding; True for padding
+ - pos_text: bs, n_text, 256
+
+ - position_ids: bs, n_text
+ Intermedia:
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
+ Outpus:
+ - output: [bs, sum(hi*wi), 256]
+ """
+
+ output = src
+
+ # preparation and reshape
+ if self.num_layers > 0:
+ reference_points = self.get_reference_points(
+ spatial_shapes, valid_ratios, device=src.device
+ )
+
+ if self.text_layers:
+ # generate pos_text
+ bs, n_text, text_dim = memory_text.shape
+ if pos_text is None and position_ids is None:
+ pos_text = (
+ torch.arange(n_text, device=memory_text.device)
+ .float()
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .repeat(bs, 1, 1)
+ )
+ pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
+ if position_ids is not None:
+ pos_text = get_sine_pos_embed(
+ position_ids[..., None], num_pos_feats=256, exchange_xy=False
+ )
+
+ # main process
+ for layer_id, layer in enumerate(self.layers):
+ # if output.isnan().any() or memory_text.isnan().any():
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+ if self.fusion_layers:
+ if self.use_checkpoint:
+ output, memory_text = checkpoint.checkpoint(
+ self.fusion_layers[layer_id],
+ output,
+ memory_text,
+ key_padding_mask,
+ text_attention_mask,
+ )
+ else:
+ output, memory_text = self.fusion_layers[layer_id](
+ v=output,
+ l=memory_text,
+ attention_mask_v=key_padding_mask,
+ attention_mask_l=text_attention_mask,
+ )
+
+ if self.text_layers:
+ memory_text = self.text_layers[layer_id](
+ src=memory_text.transpose(0, 1),
+ src_mask=~text_self_attention_masks, # note we use ~ for mask here
+ src_key_padding_mask=text_attention_mask,
+ pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
+ ).transpose(0, 1)
+
+ # main process
+ if self.use_transformer_ckpt:
+ output = checkpoint.checkpoint(
+ layer,
+ output,
+ pos,
+ reference_points,
+ spatial_shapes,
+ level_start_index,
+ key_padding_mask,
+ )
+ else:
+ output = layer(
+ src=output,
+ pos=pos,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ key_padding_mask=key_padding_mask,
+ )
+
+ return output, memory_text
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ decoder_layer,
+ num_layers,
+ norm=None,
+ return_intermediate=False,
+ d_model=256,
+ query_dim=4,
+ num_feature_levels=1,
+ ):
+ super().__init__()
+ if num_layers > 0:
+ self.layers = _get_clones(decoder_layer, num_layers)
+ else:
+ self.layers = []
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+ assert return_intermediate, "support return_intermediate only"
+ self.query_dim = query_dim
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
+ self.num_feature_levels = num_feature_levels
+
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
+ self.query_pos_sine_scale = None
+
+ self.query_scale = None
+ self.bbox_embed = None
+ self.class_embed = None
+
+ self.d_model = d_model
+
+ self.ref_anchor_head = None
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
+ # for memory
+ level_start_index: Optional[Tensor] = None, # num_levels
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ valid_ratios: Optional[Tensor] = None,
+ # for text
+ memory_text: Optional[Tensor] = None,
+ text_attention_mask: Optional[Tensor] = None,
+ ):
+ """
+ Input:
+ - tgt: nq, bs, d_model
+ - memory: hw, bs, d_model
+ - pos: hw, bs, d_model
+ - refpoints_unsigmoid: nq, bs, 2/4
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
+ """
+ output = tgt
+
+ intermediate = []
+ reference_points = refpoints_unsigmoid.sigmoid()
+ ref_points = [reference_points]
+
+ for layer_id, layer in enumerate(self.layers):
+
+ if reference_points.shape[-1] == 4:
+ reference_points_input = (
+ reference_points[:, :, None]
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
+ ) # nq, bs, nlevel, 4
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
+ query_sine_embed = gen_sineembed_for_position(
+ reference_points_input[:, :, 0, :]
+ ) # nq, bs, 256*2
+
+ # conditional query
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
+ query_pos = pos_scale * raw_query_pos
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
+ # if query_pos.isnan().any() | query_pos.isinf().any():
+ # import ipdb; ipdb.set_trace()
+
+ # main process
+ output = layer(
+ tgt=output,
+ tgt_query_pos=query_pos,
+ tgt_query_sine_embed=query_sine_embed,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ tgt_reference_points=reference_points_input,
+ memory_text=memory_text,
+ text_attention_mask=text_attention_mask,
+ memory=memory,
+ memory_key_padding_mask=memory_key_padding_mask,
+ memory_level_start_index=level_start_index,
+ memory_spatial_shapes=spatial_shapes,
+ memory_pos=pos,
+ self_attn_mask=tgt_mask,
+ cross_attn_mask=memory_mask,
+ )
+ if output.isnan().any() | output.isinf().any():
+ print(f"output layer_id {layer_id} is nan")
+ try:
+ num_nan = output.isnan().sum().item()
+ num_inf = output.isinf().sum().item()
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
+ except Exception as e:
+ print(e)
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
+ # import ipdb; ipdb.set_trace()
+
+ # iter update
+ if self.bbox_embed is not None:
+ # box_holder = self.bbox_embed(output)
+ # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
+ # new_reference_points = box_holder[..., :self.query_dim].sigmoid()
+
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
+ delta_unsig = self.bbox_embed[layer_id](output)
+ outputs_unsig = delta_unsig + reference_before_sigmoid
+ new_reference_points = outputs_unsig.sigmoid()
+
+ reference_points = new_reference_points.detach()
+ # if layer_id != self.num_layers - 1:
+ ref_points.append(new_reference_points)
+
+ intermediate.append(self.norm(output))
+
+ return [
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
+ ]
+
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ ):
+ super().__init__()
+
+ # self attention
+ self.self_attn = MSDeformAttn(
+ embed_dim=d_model,
+ num_levels=n_levels,
+ num_heads=n_heads,
+ num_points=n_points,
+ batch_first=True,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(
+ self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
+ ):
+ # self attention
+ # import ipdb; ipdb.set_trace()
+ src2 = self.self_attn(
+ query=self.with_pos_embed(src, pos),
+ reference_points=reference_points,
+ value=src,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ key_padding_mask=key_padding_mask,
+ )
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ return src
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ use_text_feat_guide=False,
+ use_text_cross_attention=False,
+ ):
+ super().__init__()
+
+ # cross attention
+ self.cross_attn = MSDeformAttn(
+ embed_dim=d_model,
+ num_levels=n_levels,
+ num_heads=n_heads,
+ num_points=n_points,
+ batch_first=True,
+ )
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # cross attention text
+ if use_text_cross_attention:
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.catext_norm = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
+ self.norm3 = nn.LayerNorm(d_model)
+
+ self.key_aware_proj = None
+ self.use_text_feat_guide = use_text_feat_guide
+ assert not use_text_feat_guide
+ self.use_text_cross_attention = use_text_cross_attention
+
+ def rm_self_attn_modules(self):
+ self.self_attn = None
+ self.dropout2 = None
+ self.norm2 = None
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ with torch.cuda.amp.autocast(enabled=False):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(
+ self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+ memory_text: Optional[Tensor] = None, # bs, num_token, d_model
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+ """
+ Input:
+ - tgt/tgt_query_pos: nq, bs, d_model
+ -
+ """
+ assert cross_attn_mask is None
+
+ # self attention
+ if self.self_attn is not None:
+ # import ipdb; ipdb.set_trace()
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ if self.use_text_cross_attention:
+ tgt2 = self.ca_text(
+ self.with_pos_embed(tgt, tgt_query_pos),
+ memory_text.transpose(0, 1),
+ memory_text.transpose(0, 1),
+ key_padding_mask=text_attention_mask,
+ )[0]
+ tgt = tgt + self.catext_dropout(tgt2)
+ tgt = self.catext_norm(tgt)
+
+ tgt2 = self.cross_attn(
+ query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
+ value=memory.transpose(0, 1),
+ spatial_shapes=memory_spatial_shapes,
+ level_start_index=memory_level_start_index,
+ key_padding_mask=memory_key_padding_mask,
+ ).transpose(0, 1)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+
+ return tgt
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ num_queries=args.num_queries,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ query_dim=args.query_dim,
+ activation=args.transformer_activation,
+ num_patterns=args.num_patterns,
+ num_feature_levels=args.num_feature_levels,
+ enc_n_points=args.enc_n_points,
+ dec_n_points=args.dec_n_points,
+ learnable_tgt_init=True,
+ # two stage
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
+ embed_init_tgt=args.embed_init_tgt,
+ use_text_enhancer=args.use_text_enhancer,
+ use_fusion_layer=args.use_fusion_layer,
+ use_checkpoint=args.use_checkpoint,
+ use_transformer_ckpt=args.use_transformer_ckpt,
+ use_text_cross_attention=args.use_text_cross_attention,
+ text_dropout=args.text_dropout,
+ fusion_dropout=args.fusion_dropout,
+ fusion_droppath=args.fusion_droppath,
+ )
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py
new file mode 100644
index 0000000000000000000000000000000000000000..10c0920c1a217af5bb3e1b13077568035ab3b7b5
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py
@@ -0,0 +1,123 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from .utils import (
+ MLP,
+ _get_activation_fn,
+ _get_clones,
+ gen_encoder_output_proposals,
+ gen_sineembed_for_position,
+ sigmoid_focal_loss,
+)
+
+
+class TextTransformer(nn.Module):
+ def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
+ super().__init__()
+ self.num_layers = num_layers
+ self.d_model = d_model
+ self.nheads = nheads
+ self.dim_feedforward = dim_feedforward
+ self.norm = None
+
+ single_encoder_layer = TransformerEncoderLayer(
+ d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
+ )
+ self.layers = _get_clones(single_encoder_layer, num_layers)
+
+ def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
+ """
+
+ Args:
+ text_attention_mask: bs, num_token
+ memory_text: bs, num_token, d_model
+
+ Raises:
+ RuntimeError: _description_
+
+ Returns:
+ output: bs, num_token, d_model
+ """
+
+ output = memory_text.transpose(0, 1)
+
+ for layer in self.layers:
+ output = layer(output, src_key_padding_mask=text_attention_mask)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output.transpose(0, 1)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self.nhead = nhead
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ # repeat attn mask
+ if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
+ # bs, num_q, num_k
+ src_mask = src_mask.repeat(self.nhead, 1, 1)
+
+ q = k = self.with_pos_embed(src, pos)
+
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
+
+ # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/utils.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd18f70225e12b2e27fdb4eabcde91d959f8e31
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/GroundingDINO/utils.py
@@ -0,0 +1,268 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+import copy
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+def _get_clones(module, N, layer_share=False):
+ # import ipdb; ipdb.set_trace()
+ if layer_share:
+ return nn.ModuleList([module for i in range(N)])
+ else:
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def get_sine_pos_embed(
+ pos_tensor: torch.Tensor,
+ num_pos_feats: int = 128,
+ temperature: int = 10000,
+ exchange_xy: bool = True,
+):
+ """generate sine position embedding from a position tensor
+ Args:
+ pos_tensor (torch.Tensor): shape: [..., n].
+ num_pos_feats (int): projected shape for each float in the tensor.
+ temperature (int): temperature in the sine/cosine function.
+ exchange_xy (bool, optional): exchange pos x and pos y. \
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
+ Returns:
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
+ """
+ scale = 2 * math.pi
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
+
+ def sine_func(x: torch.Tensor):
+ sin_x = x * scale / dim_t
+ sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
+ return sin_x
+
+ pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
+ if exchange_xy:
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
+ pos_res = torch.cat(pos_res, dim=-1)
+ return pos_res
+
+
+def gen_encoder_output_proposals(
+ memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
+):
+ """
+ Input:
+ - memory: bs, \sum{hw}, d_model
+ - memory_padding_mask: bs, \sum{hw}
+ - spatial_shapes: nlevel, 2
+ - learnedwh: 2
+ Output:
+ - output_memory: bs, \sum{hw}, d_model
+ - output_proposals: bs, \sum{hw}, 4
+ """
+ N_, S_, C_ = memory.shape
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ # import ipdb; ipdb.set_trace()
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+
+ if learnedwh is not None:
+ # import ipdb; ipdb.set_trace()
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
+ else:
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+
+ # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
+ # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ # wh = torch.ones_like(grid) / scale
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += H_ * W_
+ # import ipdb; ipdb.set_trace()
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
+ -1, keepdim=True
+ )
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+
+ # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
+ # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
+
+ return output_memory, output_proposals
+
+
+class RandomBoxPerturber:
+ def __init__(
+ self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
+ ) -> None:
+ self.noise_scale = torch.Tensor(
+ [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
+ )
+
+ def __call__(self, refanchors: Tensor) -> Tensor:
+ nq, bs, query_dim = refanchors.shape
+ device = refanchors.device
+
+ noise_raw = torch.rand_like(refanchors)
+ noise_scale = self.noise_scale.to(device)[:query_dim]
+
+ new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
+ return new_refanchors.clamp_(0, 1)
+
+
+def sigmoid_focal_loss(
+ inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
+):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ if no_reduction:
+ return loss
+
+ return loss.mean(1).sum() / num_boxes
+
+
+class MLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def _get_activation_fn(activation, d_model=256, batch_dim=0):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ if activation == "prelu":
+ return nn.PReLU()
+ if activation == "selu":
+ return F.selu
+
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def gen_sineembed_for_position(pos_tensor):
+ # n_query, bs, _ = pos_tensor.size()
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
+ scale = 2 * math.pi
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ if pos_tensor.size(-1) == 2:
+ pos = torch.cat((pos_y, pos_x), dim=2)
+ elif pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
+ return pos
+
+
+class ContrastiveEmbed(nn.Module):
+ def __init__(self, max_text_len=256):
+ """
+ Args:
+ max_text_len: max length of text.
+ """
+ super().__init__()
+ self.max_text_len = max_text_len
+
+ def forward(self, x, text_dict):
+ """_summary_
+
+ Args:
+ x (_type_): _description_
+ text_dict (_type_): _description_
+ {
+ 'encoded_text': encoded_text, # bs, 195, d_model
+ 'text_token_mask': text_token_mask, # bs, 195
+ # True for used tokens. False for padding tokens
+ }
+ Returns:
+ _type_: _description_
+ """
+ assert isinstance(text_dict, dict)
+
+ y = text_dict["encoded_text"]
+ text_token_mask = text_dict["text_token_mask"]
+
+ res = x @ y.transpose(-1, -2)
+ res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
+
+ # padding to max_text_len
+ new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
+ new_res[..., : res.shape[-1]] = res
+
+ return new_res
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3413961d1d184b99835eb1e919b052d70298bc6
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/__init__.py
@@ -0,0 +1,18 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from .GroundingDINO import build_groundingdino
+
+
+def build_model(args):
+ # we use register to maintain models from catdet6 on.
+ from .registry import MODULE_BUILD_FUNCS
+
+ assert args.modelname in MODULE_BUILD_FUNCS._module_dict
+ build_func = MODULE_BUILD_FUNCS.get(args.modelname)
+ model = build_func(args)
+ return model
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/registry.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d22a59eec79a2a19b83fa1779f2adaf5753aec6
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/models/registry.py
@@ -0,0 +1,66 @@
+# ------------------------------------------------------------------------
+# Grounding DINO
+# url: https://github.com/IDEA-Research/GroundingDINO
+# Copyright (c) 2023 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# -*- coding: utf-8 -*-
+# @Author: Yihao Chen
+# @Date: 2021-08-16 16:03:17
+# @Last Modified by: Shilong Liu
+# @Last Modified time: 2022-01-23 15:26
+# modified from mmcv
+
+import inspect
+from functools import partial
+
+
+class Registry(object):
+ def __init__(self, name):
+ self._name = name
+ self._module_dict = dict()
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + "(name={}, items={})".format(
+ self._name, list(self._module_dict.keys())
+ )
+ return format_str
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ def get(self, key):
+ return self._module_dict.get(key, None)
+
+ def registe_with_name(self, module_name=None, force=False):
+ return partial(self.register, module_name=module_name, force=force)
+
+ def register(self, module_build_function, module_name=None, force=False):
+ """Register a module build function.
+ Args:
+ module (:obj:`nn.Module`): Module to be registered.
+ """
+ if not inspect.isfunction(module_build_function):
+ raise TypeError(
+ "module_build_function must be a function, but got {}".format(
+ type(module_build_function)
+ )
+ )
+ if module_name is None:
+ module_name = module_build_function.__name__
+ if not force and module_name in self._module_dict:
+ raise KeyError("{} is already registered in {}".format(module_name, self.name))
+ self._module_dict[module_name] = module_build_function
+
+ return module_build_function
+
+
+MODULE_BUILD_FUNCS = Registry("model build functions")
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/box_ops.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/box_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..781068d294e576954edb4bd07b6e0f30e4e1bcd9
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/box_ops.py
@@ -0,0 +1,140 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ # import ipdb; ipdb.set_trace()
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / (union + 1e-6)
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ # except:
+ # import ipdb; ipdb.set_trace()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / (area + 1e-6)
+
+
+# modified from torchvision to also return the union
+def box_iou_pairwise(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,2]
+ inter = wh[:, 0] * wh[:, 1] # [N]
+
+ union = area1 + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def generalized_box_iou_pairwise(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ Input:
+ - boxes1, boxes2: N,4
+ Output:
+ - giou: N, 4
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ assert boxes1.shape == boxes2.shape
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
+
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,2]
+ area = wh[:, 0] * wh[:, 1]
+
+ return iou - (area - union) / area
+
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = masks * x.unsqueeze(0)
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = masks * y.unsqueeze(0)
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
+
+
+if __name__ == "__main__":
+ x = torch.rand(5, 4)
+ y = torch.rand(3, 4)
+ iou, union = box_iou(x, y)
+ import ipdb
+
+ ipdb.set_trace()
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/get_tokenlizer.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/get_tokenlizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7dcf7e95f03f95b20546b26442a94225924618b
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/get_tokenlizer.py
@@ -0,0 +1,26 @@
+from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
+
+
+def get_tokenlizer(text_encoder_type):
+ if not isinstance(text_encoder_type, str):
+ # print("text_encoder_type is not a str")
+ if hasattr(text_encoder_type, "text_encoder_type"):
+ text_encoder_type = text_encoder_type.text_encoder_type
+ elif text_encoder_type.get("text_encoder_type", False):
+ text_encoder_type = text_encoder_type.get("text_encoder_type")
+ else:
+ raise ValueError(
+ "Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
+ )
+ print("final text_encoder_type: {}".format(text_encoder_type))
+
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
+ return tokenizer
+
+
+def get_pretrained_language_model(text_encoder_type):
+ if text_encoder_type == "bert-base-uncased":
+ return BertModel.from_pretrained(text_encoder_type)
+ if text_encoder_type == "roberta-base":
+ return RobertaModel.from_pretrained(text_encoder_type)
+ raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/inference.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..8168b96ca51e6e494c7c675c2f4a610e21b095d6
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/inference.py
@@ -0,0 +1,98 @@
+from typing import Tuple, List
+
+import cv2
+import numpy as np
+import supervision as sv
+import torch
+from PIL import Image
+from torchvision.ops import box_convert
+
+import groundingdino.datasets.transforms as T
+from groundingdino.models import build_model
+from groundingdino.util.misc import clean_state_dict
+from groundingdino.util.slconfig import SLConfig
+from groundingdino.util.utils import get_phrases_from_posmap
+
+
+def preprocess_caption(caption: str) -> str:
+ result = caption.lower().strip()
+ if result.endswith("."):
+ return result
+ return result + "."
+
+
+def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
+ args = SLConfig.fromfile(model_config_path)
+ args.device = device
+ model = build_model(args)
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ model.eval()
+ return model
+
+
+def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image_source = Image.open(image_path).convert("RGB")
+ image = np.asarray(image_source)
+ image_transformed, _ = transform(image_source, None)
+ return image, image_transformed
+
+
+def predict(
+ model,
+ image: torch.Tensor,
+ caption: str,
+ box_threshold: float,
+ text_threshold: float,
+ device: str = "cuda"
+) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
+ caption = preprocess_caption(caption=caption)
+
+ model = model.to(device)
+ image = image.to(device)
+
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
+ prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
+
+ mask = prediction_logits.max(dim=1)[0] > box_threshold
+ logits = prediction_logits[mask] # logits.shape = (n, 256)
+ boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
+
+ tokenizer = model.tokenizer
+ tokenized = tokenizer(caption)
+
+ phrases = [
+ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
+ for logit
+ in logits
+ ]
+
+ return boxes, logits.max(dim=1)[0], phrases
+
+
+def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
+ h, w, _ = image_source.shape
+ boxes = boxes * torch.Tensor([w, h, w, h])
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
+ detections = sv.Detections(xyxy=xyxy)
+
+ labels = [
+ f"{phrase} {logit:.2f}"
+ for phrase, logit
+ in zip(phrases, logits)
+ ]
+
+ box_annotator = sv.BoxAnnotator()
+ annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
+ return annotated_frame
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/logger.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..18145f54c927abd59b95f3fa6e6da8002bc2ce97
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/logger.py
@@ -0,0 +1,93 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import functools
+import logging
+import os
+import sys
+
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+# so that calling setup_logger multiple times won't add many handlers
+@functools.lru_cache()
+def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
+
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logger
+
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = name
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+ return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/misc.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64b84ef24bea0c98e76824feb1903f6bfebe7a5
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/misc.py
@@ -0,0 +1,717 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import colorsys
+import datetime
+import functools
+import io
+import json
+import os
+import pickle
+import subprocess
+import time
+from collections import OrderedDict, defaultdict, deque
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+from torch import Tensor
+
+__torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
+if __torchvision_need_compat_flag:
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ if d.shape[0] == 0:
+ return 0
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ if os.environ.get("SHILONG_AMP", None) == "1":
+ eps = 1e-4
+ else:
+ eps = 1e-6
+ return self.total / (self.count + eps)
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+
+ return dist.group.WORLD
+
+
+def all_gather_cpu(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ cpu_group = _get_global_gloo_group()
+
+ buffer = io.BytesIO()
+ torch.save(data, buffer)
+ data_view = buffer.getbuffer()
+ device = "cuda" if cpu_group is None else "cpu"
+ tensor = torch.ByteTensor(data_view).to(device)
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
+ size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
+ if cpu_group is None:
+ dist.all_gather(size_list, local_size)
+ else:
+ print("gathering on cpu")
+ dist.all_gather(size_list, local_size, group=cpu_group)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+ assert isinstance(local_size.item(), int)
+ local_size = int(local_size.item())
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ if cpu_group is None:
+ dist.all_gather(tensor_list, tensor)
+ else:
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
+ buffer = io.BytesIO(tensor.cpu().numpy())
+ obj = torch.load(buffer)
+ data_list.append(obj)
+
+ return data_list
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+
+ if os.getenv("CPU_REDUCE") == "1":
+ return all_gather_cpu(data)
+
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ # print(name, str(meter))
+ # import ipdb;ipdb.set_trace()
+ if meter.count > 0:
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, logger=None):
+ if logger is None:
+ print_func = print
+ else:
+ print_func = logger.info
+
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ "max mem: {memory:.0f}",
+ ]
+ )
+ else:
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ )
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ # import ipdb; ipdb.set_trace()
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print_func(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print_func(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print_func(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ # import ipdb; ipdb.set_trace()
+ batch = list(zip(*batch))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+ if mask == "auto":
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
+ if self.mask.dim() == 3:
+ self.mask = self.mask.sum(0).to(bool)
+ elif self.mask.dim() == 4:
+ self.mask = self.mask.sum(1).to(bool)
+ else:
+ raise ValueError(
+ "tensors dim must be 3 or 4 but {}({})".format(
+ self.tensors.dim(), self.tensors.shape
+ )
+ )
+
+ def imgsize(self):
+ res = []
+ for i in range(self.tensors.shape[0]):
+ mask = self.mask[i]
+ maxH = (~mask).sum(0).max()
+ maxW = (~mask).sum(1).max()
+ res.append(torch.Tensor([maxH, maxW]))
+ return res
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def to_img_list_single(self, tensor, mask):
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
+ maxH = (~mask).sum(0).max()
+ maxW = (~mask).sum(1).max()
+ img = tensor[:, :maxH, :maxW]
+ return img
+
+ def to_img_list(self):
+ """remove the padding and convert to img list
+
+ Returns:
+ [type]: [description]
+ """
+ if self.tensors.dim() == 3:
+ return self.to_img_list_single(self.tensors, self.mask)
+ else:
+ res = []
+ for i in range(self.tensors.shape[0]):
+ tensor_i = self.tensors[i]
+ mask_i = self.mask[i]
+ res.append(self.to_img_list_single(tensor_i, mask_i))
+ return res
+
+ @property
+ def device(self):
+ return self.tensors.device
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+ @property
+ def shape(self):
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+ ).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
+
+ # launch by torch.distributed.launch
+ # Single node
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
+ # Multi nodes
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
+ # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
+ # local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
+ # args.world_size = args.world_size * local_world_size
+ # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
+ # args.rank = args.rank * local_world_size + args.local_rank
+ print(
+ "world size: {}, rank: {}, local rank: {}".format(
+ args.world_size, args.rank, args.local_rank
+ )
+ )
+ print(json.dumps(dict(os.environ), indent=2))
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
+ args.world_size = int(os.environ["SLURM_NPROCS"])
+
+ print(
+ "world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
+ args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
+ )
+ )
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ args.world_size = 1
+ args.rank = 0
+ args.local_rank = 0
+ return
+
+ print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
+ args.distributed = True
+ torch.cuda.set_device(args.local_rank)
+ args.dist_backend = "nccl"
+ print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
+
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ world_size=args.world_size,
+ rank=args.rank,
+ init_method=args.dist_url,
+ )
+
+ print("Before torch.distributed.barrier()")
+ torch.distributed.barrier()
+ print("End torch.distributed.barrier()")
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+@torch.no_grad()
+def accuracy_onehot(pred, gt):
+ """_summary_
+
+ Args:
+ pred (_type_): n, c
+ gt (_type_): n, c
+ """
+ tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
+ acc = tp / gt.shape[0] * 100
+ return acc
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if __torchvision_need_compat_flag < 0.7:
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class color_sys:
+ def __init__(self, num_colors) -> None:
+ self.num_colors = num_colors
+ colors = []
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
+ hue = i / 360.0
+ lightness = (50 + np.random.rand() * 10) / 100.0
+ saturation = (90 + np.random.rand() * 10) / 100.0
+ colors.append(
+ tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
+ )
+ self.colors = colors
+
+ def __call__(self, idx):
+ return self.colors[idx]
+
+
+def inverse_sigmoid(x, eps=1e-3):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+def clean_state_dict(state_dict):
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k[:7] == "module.":
+ k = k[7:] # remove `module.`
+ new_state_dict[k] = v
+ return new_state_dict
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slconfig.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f293e3aff215a3c7c2f7d21d27853493b6ebfbc
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slconfig.py
@@ -0,0 +1,427 @@
+# ==========================================================
+# Modified from mmcv
+# ==========================================================
+import ast
+import os.path as osp
+import shutil
+import sys
+import tempfile
+from argparse import Action
+from importlib import import_module
+import platform
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+BASE_KEY = "_base_"
+DELETE_KEY = "_delete_"
+RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"]
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+class ConfigDict(Dict):
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+class SLConfig(object):
+ """
+ config files.
+ only support .py file as config now.
+
+ ref: mmcv.utils.config
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename) as f:
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError:
+ raise SyntaxError("There are syntax errors in config " f"file {filename}")
+
+ @staticmethod
+ def _file2dict(filename):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ if filename.lower().endswith(".py"):
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
+ temp_config_name = osp.basename(temp_config_file.name)
+ if platform.system() == 'Windows':
+ temp_config_file.close()
+ shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ SLConfig._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value for name, value in mod.__dict__.items() if not name.startswith("__")
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ # close temp file
+ temp_config_file.close()
+ elif filename.lower().endswith((".yml", ".yaml", ".json")):
+ from .slio import slload
+
+ cfg_dict = slload(filename)
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+
+ cfg_text = filename + "\n"
+ with open(filename, "r") as f:
+ cfg_text += f.read()
+
+ # parse the base file
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError("Duplicate key is not allowed among bases")
+ # TODO Allow the duplicate key while warnning user
+ base_cfg_dict.update(c)
+
+ base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = "\n".join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b):
+ """merge dict `a` into dict `b` (non-inplace).
+ values in `a` will overwrite `b`.
+ copy first to avoid inplace modification
+
+ Args:
+ a ([type]): [description]
+ b ([type]): [description]
+
+ Returns:
+ [dict]: [description]
+ """
+ # import ipdb; ipdb.set_trace()
+ if not isinstance(a, dict):
+ return a
+
+ b = b.copy()
+ for k, v in a.items():
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
+
+ if not isinstance(b[k], dict) and not isinstance(b[k], list):
+ # if :
+ # import ipdb; ipdb.set_trace()
+ raise TypeError(
+ f"{k}={v} in child config cannot inherit from base "
+ f"because {k} is a dict in the child config but is of "
+ f"type {type(b[k])} in base config. You may set "
+ f"`{DELETE_KEY}=True` to ignore the base config"
+ )
+ b[k] = SLConfig._merge_a_into_b(v, b[k])
+ elif isinstance(b, list):
+ try:
+ _ = int(k)
+ except:
+ raise TypeError(
+ f"b is a list, " f"index {k} should be an int when input but {type(k)}"
+ )
+ b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
+ else:
+ b[k] = v
+
+ return b
+
+ @staticmethod
+ def fromfile(filename):
+ cfg_dict, cfg_text = SLConfig._file2dict(filename)
+ return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f"{key} is reserved for config file")
+
+ super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
+ super(SLConfig, self).__setattr__("_filename", filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, "r") as f:
+ text = f.read()
+ else:
+ text = ""
+ super(SLConfig, self).__setattr__("_text", text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split("\n")
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = "[\n"
+ v_str += "\n".join(
+ f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
+ ).rstrip(",")
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: {v_str}"
+ else:
+ attr_str = f"{str(k)}={v_str}"
+ attr_str = _indent(attr_str, indent) + "]"
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= not str(key_name).isidentifier()
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ""
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += "{"
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = "" if outest_level or is_last else ","
+ if isinstance(v, dict):
+ v_str = "\n" + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f"{k_str}: dict({v_str}"
+ else:
+ attr_str = f"{str(k)}=dict({v_str}"
+ attr_str = _indent(attr_str, indent) + ")" + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += "\n".join(s)
+ if use_mapping:
+ r += "}"
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style="pep8",
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True,
+ )
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ # # debug
+ # print('+'*15)
+ # print('name=%s' % name)
+ # print("addr:", id(self))
+ # # print('type(self):', type(self))
+ # print(self.__dict__)
+ # print('+'*15)
+ # if self.__dict__ == {}:
+ # raise ValueError
+
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def dump(self, file=None):
+ # import ipdb; ipdb.set_trace()
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, "w") as f:
+ f.write(self.pretty_text)
+
+ def merge_from_dict(self, options):
+ """Merge list into cfg_dict
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+
+ Args:
+ options (dict): dict of configs to merge from.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split(".")
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict")
+ super(SLConfig, self).__setattr__(
+ "_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)
+ )
+
+ # for multiprocess
+ def __setstate__(self, state):
+ self.__init__(state)
+
+ def copy(self):
+ return SLConfig(self._cfg_dict.copy())
+
+ def deepcopy(self):
+ return SLConfig(self._cfg_dict.deepcopy())
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options should
+ be passed as comma separated values, i.e KEY=V1,V2,V3
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ["true", "false"]:
+ return True if val.lower() == "true" else False
+ if val.lower() in ["none", "null"]:
+ return None
+ return val
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split("=", maxsplit=1)
+ val = [self._parse_int_float_bool(v) for v in val.split(",")]
+ if len(val) == 1:
+ val = val[0]
+ options[key] = val
+ setattr(namespace, self.dest, options)
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slio.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slio.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c1f0f7b82cdc931d381feef64fe15815ba657e
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/slio.py
@@ -0,0 +1,177 @@
+# ==========================================================
+# Modified from mmcv
+# ==========================================================
+
+import json
+import pickle
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+
+import yaml
+
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+
+# ===========================
+# Rigister handler
+# ===========================
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode="r", **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
+
+
+class JsonHandler(BaseFileHandler):
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ return json.dumps(obj, **kwargs)
+
+
+class PickleHandler(BaseFileHandler):
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault("protocol", 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault("protocol", 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
+
+
+class YamlHandler(BaseFileHandler):
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault("Loader", Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault("Dumper", Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault("Dumper", Dumper)
+ return yaml.dump(obj, **kwargs)
+
+
+file_handlers = {
+ "json": JsonHandler(),
+ "yaml": YamlHandler(),
+ "yml": YamlHandler(),
+ "pickle": PickleHandler(),
+ "pkl": PickleHandler(),
+}
+
+# ===========================
+# load and dump
+# ===========================
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def slload(file, file_format=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split(".")[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f"Unsupported format: {file_format}")
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ obj = handler.load_from_path(file, **kwargs)
+ elif hasattr(file, "read"):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def sldump(obj, file=None, file_format=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dump to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split(".")[-1]
+ elif file is None:
+ raise ValueError("file_format must be specified since file is None")
+ if file_format not in file_handlers:
+ raise TypeError(f"Unsupported format: {file_format}")
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ handler.dump_to_path(obj, file, **kwargs)
+ elif hasattr(file, "write"):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/time_counter.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/time_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aedb2e4d61bfbe7571dca9d50053f0fedaa1359
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/time_counter.py
@@ -0,0 +1,62 @@
+import json
+import time
+
+
+class TimeCounter:
+ def __init__(self) -> None:
+ pass
+
+ def clear(self):
+ self.timedict = {}
+ self.basetime = time.perf_counter()
+
+ def timeit(self, name):
+ nowtime = time.perf_counter() - self.basetime
+ self.timedict[name] = nowtime
+ self.basetime = time.perf_counter()
+
+
+class TimeHolder:
+ def __init__(self) -> None:
+ self.timedict = {}
+
+ def update(self, _timedict: dict):
+ for k, v in _timedict.items():
+ if k not in self.timedict:
+ self.timedict[k] = AverageMeter(name=k, val_only=True)
+ self.timedict[k].update(val=v)
+
+ def final_res(self):
+ return {k: v.avg for k, v in self.timedict.items()}
+
+ def __str__(self):
+ return json.dumps(self.final_res(), indent=2)
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=":f", val_only=False):
+ self.name = name
+ self.fmt = fmt
+ self.val_only = val_only
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ if self.val_only:
+ fmtstr = "{name} {val" + self.fmt + "}"
+ else:
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/utils.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9f0318e306fa04bff0ada70486b41aaa69b07c8
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/utils.py
@@ -0,0 +1,608 @@
+import argparse
+import json
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+from typing import Any, Dict, List
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+
+from groundingdino.util.slconfig import SLConfig
+
+
+def slprint(x, name="x"):
+ if isinstance(x, (torch.Tensor, np.ndarray)):
+ print(f"{name}.shape:", x.shape)
+ elif isinstance(x, (tuple, list)):
+ print("type x:", type(x))
+ for i in range(min(10, len(x))):
+ slprint(x[i], f"{name}[{i}]")
+ elif isinstance(x, dict):
+ for k, v in x.items():
+ slprint(v, f"{name}[{k}]")
+ else:
+ print(f"{name}.type:", type(x))
+
+
+def clean_state_dict(state_dict):
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k[:7] == "module.":
+ k = k[7:] # remove `module.`
+ new_state_dict[k] = v
+ return new_state_dict
+
+
+def renorm(
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+) -> torch.FloatTensor:
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
+ # return: same as img
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
+ if img.dim() == 3:
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
+ img.size(0),
+ str(img.size()),
+ )
+ img_perm = img.permute(1, 2, 0)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(2, 0, 1)
+ else: # img.dim() == 4
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
+ img.size(1),
+ str(img.size()),
+ )
+ img_perm = img.permute(0, 2, 3, 1)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(0, 3, 1, 2)
+
+
+class CocoClassMapper:
+ def __init__(self) -> None:
+ self.category_map_str = {
+ "1": 1,
+ "2": 2,
+ "3": 3,
+ "4": 4,
+ "5": 5,
+ "6": 6,
+ "7": 7,
+ "8": 8,
+ "9": 9,
+ "10": 10,
+ "11": 11,
+ "13": 12,
+ "14": 13,
+ "15": 14,
+ "16": 15,
+ "17": 16,
+ "18": 17,
+ "19": 18,
+ "20": 19,
+ "21": 20,
+ "22": 21,
+ "23": 22,
+ "24": 23,
+ "25": 24,
+ "27": 25,
+ "28": 26,
+ "31": 27,
+ "32": 28,
+ "33": 29,
+ "34": 30,
+ "35": 31,
+ "36": 32,
+ "37": 33,
+ "38": 34,
+ "39": 35,
+ "40": 36,
+ "41": 37,
+ "42": 38,
+ "43": 39,
+ "44": 40,
+ "46": 41,
+ "47": 42,
+ "48": 43,
+ "49": 44,
+ "50": 45,
+ "51": 46,
+ "52": 47,
+ "53": 48,
+ "54": 49,
+ "55": 50,
+ "56": 51,
+ "57": 52,
+ "58": 53,
+ "59": 54,
+ "60": 55,
+ "61": 56,
+ "62": 57,
+ "63": 58,
+ "64": 59,
+ "65": 60,
+ "67": 61,
+ "70": 62,
+ "72": 63,
+ "73": 64,
+ "74": 65,
+ "75": 66,
+ "76": 67,
+ "77": 68,
+ "78": 69,
+ "79": 70,
+ "80": 71,
+ "81": 72,
+ "82": 73,
+ "84": 74,
+ "85": 75,
+ "86": 76,
+ "87": 77,
+ "88": 78,
+ "89": 79,
+ "90": 80,
+ }
+ self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
+ self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
+
+ def origin2compact(self, idx):
+ return self.origin2compact_mapper[int(idx)]
+
+ def compact2origin(self, idx):
+ return self.compact2origin_mapper[int(idx)]
+
+
+def to_device(item, device):
+ if isinstance(item, torch.Tensor):
+ return item.to(device)
+ elif isinstance(item, list):
+ return [to_device(i, device) for i in item]
+ elif isinstance(item, dict):
+ return {k: to_device(v, device) for k, v in item.items()}
+ else:
+ raise NotImplementedError(
+ "Call Shilong if you use other containers! type: {}".format(type(item))
+ )
+
+
+#
+def get_gaussian_mean(x, axis, other_axis, softmax=True):
+ """
+
+ Args:
+ x (float): Input images(BxCxHxW)
+ axis (int): The index for weighted mean
+ other_axis (int): The other index
+
+ Returns: weighted index for axis, BxC
+
+ """
+ mat2line = torch.sum(x, axis=other_axis)
+ # mat2line = mat2line / mat2line.mean() * 10
+ if softmax:
+ u = torch.softmax(mat2line, axis=2)
+ else:
+ u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
+ size = x.shape[axis]
+ ind = torch.linspace(0, 1, size).to(x.device)
+ batch = x.shape[0]
+ channel = x.shape[1]
+ index = ind.repeat([batch, channel, 1])
+ mean_position = torch.sum(index * u, dim=2)
+ return mean_position
+
+
+def get_expected_points_from_map(hm, softmax=True):
+ """get_gaussian_map_from_points
+ B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
+ softargmax function
+
+ Args:
+ hm (float): Input images(BxCxHxW)
+
+ Returns:
+ weighted index for axis, BxCx2. float between 0 and 1.
+
+ """
+ # hm = 10*hm
+ B, C, H, W = hm.shape
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
+ # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
+ return torch.stack([x_mean, y_mean], dim=2)
+
+
+# Positional encoding (section 5.1)
+# borrow from nerf
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs["input_dims"]
+ out_dim = 0
+ if self.kwargs["include_input"]:
+ embed_fns.append(lambda x: x)
+ out_dim += d
+
+ max_freq = self.kwargs["max_freq_log2"]
+ N_freqs = self.kwargs["num_freqs"]
+
+ if self.kwargs["log_sampling"]:
+ freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
+ else:
+ freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs["periodic_fns"]:
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+
+def get_embedder(multires, i=0):
+ import torch.nn as nn
+
+ if i == -1:
+ return nn.Identity(), 3
+
+ embed_kwargs = {
+ "include_input": True,
+ "input_dims": 3,
+ "max_freq_log2": multires - 1,
+ "num_freqs": multires,
+ "log_sampling": True,
+ "periodic_fns": [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ embed = lambda x, eo=embedder_obj: eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+
+class APOPMeter:
+ def __init__(self) -> None:
+ self.tp = 0
+ self.fp = 0
+ self.tn = 0
+ self.fn = 0
+
+ def update(self, pred, gt):
+ """
+ Input:
+ pred, gt: Tensor()
+ """
+ assert pred.shape == gt.shape
+ self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
+ self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
+ self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
+ self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
+
+ def update_cm(self, tp, fp, tn, fn):
+ self.tp += tp
+ self.fp += fp
+ self.tn += tn
+ self.tn += fn
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+def get_raw_dict(args):
+ """
+ return the dicf contained in args.
+
+ e.g:
+ >>> with open(path, 'w') as f:
+ json.dump(get_raw_dict(args), f, indent=2)
+ """
+ if isinstance(args, argparse.Namespace):
+ return vars(args)
+ elif isinstance(args, dict):
+ return args
+ elif isinstance(args, SLConfig):
+ return args._cfg_dict
+ else:
+ raise NotImplementedError("Unknown type {}".format(type(args)))
+
+
+def stat_tensors(tensor):
+ assert tensor.dim() == 1
+ tensor_sm = tensor.softmax(0)
+ entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
+
+ return {
+ "max": tensor.max(),
+ "min": tensor.min(),
+ "mean": tensor.mean(),
+ "var": tensor.var(),
+ "std": tensor.var() ** 0.5,
+ "entropy": entropy,
+ }
+
+
+class NiceRepr:
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
+ objects.
+
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+ If the inheriting class has a ``__len__``, method then the default
+ ``__nice__`` method will return its length.
+
+ Example:
+ >>> class Foo(NiceRepr):
+ ... def __nice__(self):
+ ... return 'info'
+ >>> foo = Foo()
+ >>> assert str(foo) == ''
+ >>> assert repr(foo).startswith('>> class Bar(NiceRepr):
+ ... pass
+ >>> bar = Bar()
+ >>> import pytest
+ >>> with pytest.warns(None) as record:
+ >>> assert 'object at' in str(bar)
+ >>> assert 'object at' in repr(bar)
+
+ Example:
+ >>> class Baz(NiceRepr):
+ ... def __len__(self):
+ ... return 5
+ >>> baz = Baz()
+ >>> assert str(baz) == ''
+ """
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this module"""
+ if hasattr(self, "__len__"):
+ # It is a common pattern for objects to use __len__ in __nice__
+ # As a convenience we define a default __nice__ for these objects
+ return str(len(self))
+ else:
+ # In all other cases force the subclass to overload __nice__
+ raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
+
+ def __repr__(self):
+ """str: the string of the module"""
+ try:
+ nice = self.__nice__()
+ classname = self.__class__.__name__
+ return f"<{classname}({nice}) at {hex(id(self))}>"
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+ def __str__(self):
+ """str: the string of the module"""
+ try:
+ classname = self.__class__.__name__
+ nice = self.__nice__()
+ return f"<{classname}({nice})>"
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+
+def ensure_rng(rng=None):
+ """Coerces input into a random number generator.
+
+ If the input is None, then a global random state is returned.
+
+ If the input is a numeric value, then that is used as a seed to construct a
+ random state. Otherwise the input is returned as-is.
+
+ Adapted from [1]_.
+
+ Args:
+ rng (int | numpy.random.RandomState | None):
+ if None, then defaults to the global rng. Otherwise this can be an
+ integer or a RandomState class
+ Returns:
+ (numpy.random.RandomState) : rng -
+ a numpy random number generator
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
+ """
+
+ if rng is None:
+ rng = np.random.mtrand._rand
+ elif isinstance(rng, int):
+ rng = np.random.RandomState(rng)
+ else:
+ rng = rng
+ return rng
+
+
+def random_boxes(num=1, scale=1, rng=None):
+ """Simple version of ``kwimage.Boxes.random``
+
+ Returns:
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
+
+ References:
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
+
+ Example:
+ >>> num = 3
+ >>> scale = 512
+ >>> rng = 0
+ >>> boxes = random_boxes(num, scale, rng)
+ >>> print(boxes)
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
+ [216.9113, 330.6978, 224.0446, 456.5878],
+ [405.3632, 196.3221, 493.3953, 270.7942]])
+ """
+ rng = ensure_rng(rng)
+
+ tlbr = rng.rand(num, 4).astype(np.float32)
+
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
+
+ tlbr[:, 0] = tl_x * scale
+ tlbr[:, 1] = tl_y * scale
+ tlbr[:, 2] = br_x * scale
+ tlbr[:, 3] = br_y * scale
+
+ boxes = torch.from_numpy(tlbr)
+ return boxes
+
+
+class ModelEma(torch.nn.Module):
+ def __init__(self, model, decay=0.9997, device=None):
+ super(ModelEma, self).__init__()
+ # make a copy of the model for accumulating moving average of weights
+ self.module = deepcopy(model)
+ self.module.eval()
+
+ # import ipdb; ipdb.set_trace()
+
+ self.decay = decay
+ self.device = device # perform ema on different device from model if set
+ if self.device is not None:
+ self.module.to(device=device)
+
+ def _update(self, model, update_fn):
+ with torch.no_grad():
+ for ema_v, model_v in zip(
+ self.module.state_dict().values(), model.state_dict().values()
+ ):
+ if self.device is not None:
+ model_v = model_v.to(device=self.device)
+ ema_v.copy_(update_fn(ema_v, model_v))
+
+ def update(self, model):
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
+
+ def set(self, model):
+ self._update(model, update_fn=lambda e, m: m)
+
+
+class BestMetricSingle:
+ def __init__(self, init_res=0.0, better="large") -> None:
+ self.init_res = init_res
+ self.best_res = init_res
+ self.best_ep = -1
+
+ self.better = better
+ assert better in ["large", "small"]
+
+ def isbetter(self, new_res, old_res):
+ if self.better == "large":
+ return new_res > old_res
+ if self.better == "small":
+ return new_res < old_res
+
+ def update(self, new_res, ep):
+ if self.isbetter(new_res, self.best_res):
+ self.best_res = new_res
+ self.best_ep = ep
+ return True
+ return False
+
+ def __str__(self) -> str:
+ return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def summary(self) -> dict:
+ return {
+ "best_res": self.best_res,
+ "best_ep": self.best_ep,
+ }
+
+
+class BestMetricHolder:
+ def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
+ self.best_all = BestMetricSingle(init_res, better)
+ self.use_ema = use_ema
+ if use_ema:
+ self.best_ema = BestMetricSingle(init_res, better)
+ self.best_regular = BestMetricSingle(init_res, better)
+
+ def update(self, new_res, epoch, is_ema=False):
+ """
+ return if the results is the best.
+ """
+ if not self.use_ema:
+ return self.best_all.update(new_res, epoch)
+ else:
+ if is_ema:
+ self.best_ema.update(new_res, epoch)
+ return self.best_all.update(new_res, epoch)
+ else:
+ self.best_regular.update(new_res, epoch)
+ return self.best_all.update(new_res, epoch)
+
+ def summary(self):
+ if not self.use_ema:
+ return self.best_all.summary()
+
+ res = {}
+ res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
+ res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
+ res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
+ return res
+
+ def __repr__(self) -> str:
+ return json.dumps(self.summary(), indent=2)
+
+ def __str__(self) -> str:
+ return self.__repr__()
+
+
+def targets_to(targets: List[Dict[str, Any]], device):
+ """Moves the target dicts to the given device."""
+ excluded_keys = [
+ "questionId",
+ "tokens_positive",
+ "strings_positive",
+ "tokens",
+ "dataset_name",
+ "sentence_id",
+ "original_img_id",
+ "nb_eval",
+ "task_id",
+ "original_id",
+ "token_span",
+ "caption",
+ "dataset_type",
+ ]
+ return [
+ {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
+ ]
+
+
+def get_phrases_from_posmap(
+ posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
+):
+ assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
+ if posmap.dim() == 1:
+ non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
+ token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
+ return tokenizer.decode(token_ids)
+ else:
+ raise NotImplementedError("posmap must be 1-dim")
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/visualizer.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1b7b101e9b73f75f9136bc67f2063c7c1cf1c1
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/visualizer.py
@@ -0,0 +1,318 @@
+# -*- coding: utf-8 -*-
+"""
+@File : visualizer.py
+@Time : 2022/04/05 11:39:33
+@Author : Shilong Liu
+@Contact : slongliu86@gmail.com
+"""
+
+import datetime
+import os
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from matplotlib import transforms
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+from pycocotools import mask as maskUtils
+
+
+def renorm(
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+) -> torch.FloatTensor:
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
+ # return: same as img
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
+ if img.dim() == 3:
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
+ img.size(0),
+ str(img.size()),
+ )
+ img_perm = img.permute(1, 2, 0)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(2, 0, 1)
+ else: # img.dim() == 4
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
+ img.size(1),
+ str(img.size()),
+ )
+ img_perm = img.permute(0, 2, 3, 1)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(0, 3, 1, 2)
+
+
+class ColorMap:
+ def __init__(self, basergb=[255, 255, 0]):
+ self.basergb = np.array(basergb)
+
+ def __call__(self, attnmap):
+ # attnmap: h, w. np.uint8.
+ # return: h, w, 4. np.uint8.
+ assert attnmap.dtype == np.uint8
+ h, w = attnmap.shape
+ res = self.basergb.copy()
+ res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
+ attn1 = attnmap.copy()[..., None] # h, w, 1
+ res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
+ return res
+
+
+def rainbow_text(x, y, ls, lc, **kw):
+ """
+ Take a list of strings ``ls`` and colors ``lc`` and place them next to each
+ other, with text ls[i] being shown in color lc[i].
+
+ This example shows how to do both vertical and horizontal text, and will
+ pass all keyword arguments to plt.text, so you can set the font size,
+ family, etc.
+ """
+ t = plt.gca().transData
+ fig = plt.gcf()
+ plt.show()
+
+ # horizontal version
+ for s, c in zip(ls, lc):
+ text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
+ text.draw(fig.canvas.get_renderer())
+ ex = text.get_window_extent()
+ t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
+
+ # #vertical version
+ # for s,c in zip(ls,lc):
+ # text = plt.text(x,y," "+s+" ",color=c, transform=t,
+ # rotation=90,va='bottom',ha='center',**kw)
+ # text.draw(fig.canvas.get_renderer())
+ # ex = text.get_window_extent()
+ # t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
+
+
+class COCOVisualizer:
+ def __init__(self, coco=None, tokenlizer=None) -> None:
+ self.coco = coco
+
+ def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
+ """
+ img: tensor(3, H, W)
+ tgt: make sure they are all on cpu.
+ must have items: 'image_id', 'boxes', 'size'
+ """
+ plt.figure(dpi=dpi)
+ plt.rcParams["font.size"] = "5"
+ ax = plt.gca()
+ img = renorm(img).permute(1, 2, 0)
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+ ax.imshow(img)
+
+ self.addtgt(tgt)
+
+ if tgt is None:
+ image_id = 0
+ elif "image_id" not in tgt:
+ image_id = 0
+ else:
+ image_id = tgt["image_id"]
+
+ if caption is None:
+ savename = "{}/{}-{}.png".format(
+ savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
+ )
+ else:
+ savename = "{}/{}-{}-{}.png".format(
+ savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
+ )
+ print("savename: {}".format(savename))
+ os.makedirs(os.path.dirname(savename), exist_ok=True)
+ plt.savefig(savename)
+ plt.close()
+
+ def addtgt(self, tgt):
+ """ """
+ if tgt is None or not "boxes" in tgt:
+ ax = plt.gca()
+
+ if "caption" in tgt:
+ ax.set_title(tgt["caption"], wrap=True)
+
+ ax.set_axis_off()
+ return
+
+ ax = plt.gca()
+ H, W = tgt["size"]
+ numbox = tgt["boxes"].shape[0]
+
+ color = []
+ polygons = []
+ boxes = []
+ for box in tgt["boxes"].cpu():
+ unnormbbox = box * torch.Tensor([W, H, W, H])
+ unnormbbox[:2] -= unnormbbox[2:] / 2
+ [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
+ boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
+ poly = [
+ [bbox_x, bbox_y],
+ [bbox_x, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y],
+ ]
+ np_poly = np.array(poly).reshape((4, 2))
+ polygons.append(Polygon(np_poly))
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
+ color.append(c)
+
+ p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
+ ax.add_collection(p)
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
+ ax.add_collection(p)
+
+ if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
+ assert (
+ len(tgt["strings_positive"]) == numbox
+ ), f"{len(tgt['strings_positive'])} = {numbox}, "
+ for idx, strlist in enumerate(tgt["strings_positive"]):
+ cate_id = int(tgt["labels"][idx])
+ _string = str(cate_id) + ":" + " ".join(strlist)
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
+ ax.text(
+ bbox_x,
+ bbox_y,
+ _string,
+ color="black",
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
+ )
+
+ if "box_label" in tgt:
+ assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
+ for idx, bl in enumerate(tgt["box_label"]):
+ _string = str(bl)
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
+ ax.text(
+ bbox_x,
+ bbox_y,
+ _string,
+ color="black",
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
+ )
+
+ if "caption" in tgt:
+ ax.set_title(tgt["caption"], wrap=True)
+ # plt.figure()
+ # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
+ # ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
+
+ if "attn" in tgt:
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+ if isinstance(tgt["attn"], tuple):
+ tgt["attn"] = [tgt["attn"]]
+ for item in tgt["attn"]:
+ attn_map, basergb = item
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
+ attn_map = (attn_map * 255).astype(np.uint8)
+ cm = ColorMap(basergb)
+ heatmap = cm(attn_map)
+ ax.imshow(heatmap)
+ ax.set_axis_off()
+
+ def showAnns(self, anns, draw_bbox=False):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ if "segmentation" in anns[0] or "keypoints" in anns[0]:
+ datasetType = "instances"
+ elif "caption" in anns[0]:
+ datasetType = "captions"
+ else:
+ raise Exception("datasetType not supported")
+ if datasetType == "instances":
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+ polygons = []
+ color = []
+ for ann in anns:
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
+ if "segmentation" in ann:
+ if type(ann["segmentation"]) == list:
+ # polygon
+ for seg in ann["segmentation"]:
+ poly = np.array(seg).reshape((int(len(seg) / 2), 2))
+ polygons.append(Polygon(poly))
+ color.append(c)
+ else:
+ # mask
+ t = self.imgs[ann["image_id"]]
+ if type(ann["segmentation"]["counts"]) == list:
+ rle = maskUtils.frPyObjects(
+ [ann["segmentation"]], t["height"], t["width"]
+ )
+ else:
+ rle = [ann["segmentation"]]
+ m = maskUtils.decode(rle)
+ img = np.ones((m.shape[0], m.shape[1], 3))
+ if ann["iscrowd"] == 1:
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
+ if ann["iscrowd"] == 0:
+ color_mask = np.random.random((1, 3)).tolist()[0]
+ for i in range(3):
+ img[:, :, i] = color_mask[i]
+ ax.imshow(np.dstack((img, m * 0.5)))
+ if "keypoints" in ann and type(ann["keypoints"]) == list:
+ # turn skeleton into zero-based index
+ sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
+ kp = np.array(ann["keypoints"])
+ x = kp[0::3]
+ y = kp[1::3]
+ v = kp[2::3]
+ for sk in sks:
+ if np.all(v[sk] > 0):
+ plt.plot(x[sk], y[sk], linewidth=3, color=c)
+ plt.plot(
+ x[v > 0],
+ y[v > 0],
+ "o",
+ markersize=8,
+ markerfacecolor=c,
+ markeredgecolor="k",
+ markeredgewidth=2,
+ )
+ plt.plot(
+ x[v > 1],
+ y[v > 1],
+ "o",
+ markersize=8,
+ markerfacecolor=c,
+ markeredgecolor=c,
+ markeredgewidth=2,
+ )
+
+ if draw_bbox:
+ [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
+ poly = [
+ [bbox_x, bbox_y],
+ [bbox_x, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y],
+ ]
+ np_poly = np.array(poly).reshape((4, 2))
+ polygons.append(Polygon(np_poly))
+ color.append(c)
+
+ # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
+ # ax.add_collection(p)
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
+ ax.add_collection(p)
+ elif datasetType == "captions":
+ for ann in anns:
+ print(ann["caption"])
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/vl_utils.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/vl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91bb02f584398f08a28e6b7719e2b99f6e28616
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/util/vl_utils.py
@@ -0,0 +1,100 @@
+import os
+import random
+from typing import List
+
+import torch
+
+
+def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
+ """construct a map such that positive_map[i,j] = True iff box i is associated to token j
+ Input:
+ - tokenized:
+ - input_ids: Tensor[1, ntokens]
+ - attention_mask: Tensor[1, ntokens]
+ - token_span: list with length num_boxes.
+ - each item: [start_idx, end_idx]
+ """
+ positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
+ for j, tok_list in enumerate(token_span):
+ for (beg, end) in tok_list:
+ beg_pos = tokenized.char_to_token(beg)
+ end_pos = tokenized.char_to_token(end - 1)
+ if beg_pos is None:
+ try:
+ beg_pos = tokenized.char_to_token(beg + 1)
+ if beg_pos is None:
+ beg_pos = tokenized.char_to_token(beg + 2)
+ except:
+ beg_pos = None
+ if end_pos is None:
+ try:
+ end_pos = tokenized.char_to_token(end - 2)
+ if end_pos is None:
+ end_pos = tokenized.char_to_token(end - 3)
+ except:
+ end_pos = None
+ if beg_pos is None or end_pos is None:
+ continue
+
+ assert beg_pos is not None and end_pos is not None
+ if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
+ positive_map[j, beg_pos] = 1
+ break
+ else:
+ positive_map[j, beg_pos : end_pos + 1].fill_(1)
+
+ return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
+
+
+def build_captions_and_token_span(cat_list, force_lowercase):
+ """
+ Return:
+ captions: str
+ cat2tokenspan: dict
+ {
+ 'dog': [[0, 2]],
+ ...
+ }
+ """
+
+ cat2tokenspan = {}
+ captions = ""
+ for catname in cat_list:
+ class_name = catname
+ if force_lowercase:
+ class_name = class_name.lower()
+ if "/" in class_name:
+ class_name_list: List = class_name.strip().split("/")
+ class_name_list.append(class_name)
+ class_name: str = random.choice(class_name_list)
+
+ tokens_positive_i = []
+ subnamelist = [i.strip() for i in class_name.strip().split(" ")]
+ for subname in subnamelist:
+ if len(subname) == 0:
+ continue
+ if len(captions) > 0:
+ captions = captions + " "
+ strat_idx = len(captions)
+ end_idx = strat_idx + len(subname)
+ tokens_positive_i.append([strat_idx, end_idx])
+ captions = captions + subname
+
+ if len(tokens_positive_i) > 0:
+ captions = captions + " ."
+ cat2tokenspan[class_name] = tokens_positive_i
+
+ return captions, cat2tokenspan
+
+
+def build_id2posspan_and_caption(category_dict: dict):
+ """Build id2pos_span and caption from category_dict
+
+ Args:
+ category_dict (dict): category_dict
+ """
+ cat_list = [item["name"].lower() for item in category_dict]
+ id2catname = {item["id"]: item["name"].lower() for item in category_dict}
+ caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
+ id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
+ return id2posspan, caption
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/version.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..b794fd409a5e3b3b65ad76a43d6a01a318877640
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/version.py
@@ -0,0 +1 @@
+__version__ = '0.1.0'
diff --git a/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/setup.py b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a045b763fb4a4f61bac23b735544a18ffc68d20a
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/setup.py
@@ -0,0 +1,208 @@
+# coding=utf-8
+# Copyright 2022 The IDEA Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------------------------
+# Modified from
+# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
+# https://github.com/facebookresearch/detectron2/blob/main/setup.py
+# https://github.com/open-mmlab/mmdetection/blob/master/setup.py
+# https://github.com/Oneflow-Inc/libai/blob/main/setup.py
+# ------------------------------------------------------------------------------------------------
+
+import glob
+import os
+import subprocess
+
+import torch
+from setuptools import find_packages, setup
+from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
+
+# groundingdino version info
+version = "0.1.0"
+package_name = "groundingdino"
+cwd = os.path.dirname(os.path.abspath(__file__))
+
+
+sha = "Unknown"
+try:
+ sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
+except Exception:
+ pass
+
+
+def write_version_file():
+ version_path = os.path.join(cwd, "groundingdino", "version.py")
+ with open(version_path, "w") as f:
+ f.write(f"__version__ = '{version}'\n")
+ # f.write(f"git_version = {repr(sha)}\n")
+
+
+requirements = ["torch", "torchvision"]
+
+torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
+
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
+
+ main_source = os.path.join(extensions_dir, "vision.cpp")
+ sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
+ os.path.join(extensions_dir, "*.cu")
+ )
+
+ sources = [main_source] + sources
+
+ extension = CppExtension
+
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ print("Compiling with CUDA")
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ print("Compiling without CUDA")
+ define_macros += [("WITH_HIP", None)]
+ extra_compile_args["nvcc"] = []
+ return None
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+
+ ext_modules = [
+ extension(
+ "groundingdino._C",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+
+ return ext_modules
+
+
+def parse_requirements(fname="requirements.txt", with_version=True):
+ """Parse the package dependencies listed in a requirements file but strips
+ specific versioning information.
+
+ Args:
+ fname (str): path to requirements file
+ with_version (bool, default=False): if True include version specs
+
+ Returns:
+ List[str]: list of requirements items
+
+ CommandLine:
+ python -c "import setup; print(setup.parse_requirements())"
+ """
+ import re
+ import sys
+ from os.path import exists
+
+ require_fpath = fname
+
+ def parse_line(line):
+ """Parse information from a line in a requirements text file."""
+ if line.startswith("-r "):
+ # Allow specifying requirements in other files
+ target = line.split(" ")[1]
+ for info in parse_require_file(target):
+ yield info
+ else:
+ info = {"line": line}
+ if line.startswith("-e "):
+ info["package"] = line.split("#egg=")[1]
+ elif "@git+" in line:
+ info["package"] = line
+ else:
+ # Remove versioning from the package
+ pat = "(" + "|".join([">=", "==", ">"]) + ")"
+ parts = re.split(pat, line, maxsplit=1)
+ parts = [p.strip() for p in parts]
+
+ info["package"] = parts[0]
+ if len(parts) > 1:
+ op, rest = parts[1:]
+ if ";" in rest:
+ # Handle platform specific dependencies
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
+ version, platform_deps = map(str.strip, rest.split(";"))
+ info["platform_deps"] = platform_deps
+ else:
+ version = rest # NOQA
+ info["version"] = (op, version)
+ yield info
+
+ def parse_require_file(fpath):
+ with open(fpath, "r") as f:
+ for line in f.readlines():
+ line = line.strip()
+ if line and not line.startswith("#"):
+ for info in parse_line(line):
+ yield info
+
+ def gen_packages_items():
+ if exists(require_fpath):
+ for info in parse_require_file(require_fpath):
+ parts = [info["package"]]
+ if with_version and "version" in info:
+ parts.extend(info["version"])
+ if not sys.version.startswith("3.4"):
+ # apparently package_deps are broken in 3.4
+ platform_deps = info.get("platform_deps")
+ if platform_deps is not None:
+ parts.append(";" + platform_deps)
+ item = "".join(parts)
+ yield item
+
+ packages = list(gen_packages_items())
+ return packages
+
+
+if __name__ == "__main__":
+ print(f"Building wheel {package_name}-{version}")
+
+ with open("LICENSE", "r", encoding="utf-8") as f:
+ license = f.read()
+
+ write_version_file()
+
+ setup(
+ name="groundingdino",
+ version="0.1.0",
+ author="International Digital Economy Academy, Shilong Liu",
+ url="https://github.com/IDEA-Research/GroundingDINO",
+ description="open-set object detector",
+ license=license,
+ install_requires=parse_requirements("requirements.txt"),
+ packages=find_packages(
+ exclude=(
+ "configs",
+ "tests",
+ )
+ ),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+ )
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/.flake8 b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/.flake8
new file mode 100644
index 0000000000000000000000000000000000000000..6b0759587aa5756e66a13ef034c6bcdd76a885f5
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/.flake8
@@ -0,0 +1,7 @@
+[flake8]
+ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
+max-line-length = 100
+max-complexity = 18
+select = B,C,E,F,W,T4,B9
+per-file-ignores =
+ **/__init__.py:F401,F403,E402
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/CODE_OF_CONDUCT.md b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..08b500a221857ec3f451338e80b4a9ab1173a1af
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/CODE_OF_CONDUCT.md
@@ -0,0 +1,80 @@
+# Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to make participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies within all project spaces, and it also applies when
+an individual is representing the project or its community in public spaces.
+Examples of representing a project or community include using an official
+project e-mail address, posting via an official social media account, or acting
+as an appointed representative at an online or offline event. Representation of
+a project may be further defined and clarified by project maintainers.
+
+This Code of Conduct also applies outside the project spaces when there is a
+reasonable belief that an individual's behavior may have a negative impact on
+the project or its community.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at . All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/LICENSE b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/linter.sh b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/linter.sh
new file mode 100755
index 0000000000000000000000000000000000000000..df2e17436d30e89ff1728109301599f425f1ad6b
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/linter.sh
@@ -0,0 +1,32 @@
+#!/bin/bash -e
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+{
+ black --version | grep -E "23\." > /dev/null
+} || {
+ echo "Linter requires 'black==23.*' !"
+ exit 1
+}
+
+ISORT_VERSION=$(isort --version-number)
+if [[ "$ISORT_VERSION" != 5.12* ]]; then
+ echo "Linter requires isort==5.12.0 !"
+ exit 1
+fi
+
+echo "Running isort ..."
+isort . --atomic
+
+echo "Running black ..."
+black -l 100 .
+
+echo "Running flake8 ..."
+if [ -x "$(command -v flake8)" ]; then
+ flake8 .
+else
+ python3 -m flake8 .
+fi
+
+echo "Running mypy..."
+
+mypy --exclude 'setup.py|notebooks' .
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34383d83f5e76bc801f31b20e5651e383be348b6
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .build_sam import (
+ build_sam,
+ build_sam_vit_h,
+ build_sam_vit_l,
+ build_sam_vit_b,
+ sam_model_registry,
+)
+from .predictor import SamPredictor
+from .automatic_mask_generator import SamAutomaticMaskGenerator
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/automatic_mask_generator.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..23264971b7ff5aa0b4f499ade7773b68dce984b6
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/automatic_mask_generator.py
@@ -0,0 +1,372 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling import Sam
+from .predictor import SamPredictor
+from .utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SamAutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: Sam,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ ) -> None:
+ """
+ Using a SAM model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM with a ViT-H backbone.
+
+ Arguments:
+ model (Sam): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crops_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crops_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = SamPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros(len(data["boxes"])), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros(len(data["boxes"])), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros(len(boxes)), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/build_sam.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..07abfca24e96eced7f13bdefd3212ce1b77b8999
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/build_sam.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+
+
+def build_sam_vit_h(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1280,
+ encoder_depth=32,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[7, 15, 23, 31],
+ checkpoint=checkpoint,
+ )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1024,
+ encoder_depth=24,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[5, 11, 17, 23],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam_vit_b(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=768,
+ encoder_depth=12,
+ encoder_num_heads=12,
+ encoder_global_attn_indexes=[2, 5, 8, 11],
+ checkpoint=checkpoint,
+ )
+
+
+sam_model_registry = {
+ "default": build_sam,
+ "vit_h": build_sam,
+ "vit_l": build_sam_vit_l,
+ "vit_b": build_sam_vit_b,
+}
+
+
+def _build_sam(
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+):
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+ sam = Sam(
+ image_encoder=ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ ),
+ prompt_encoder=PromptEncoder(
+ embed_dim=prompt_embed_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(image_size, image_size),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ ),
+ pixel_mean=[123.675, 116.28, 103.53],
+ pixel_std=[58.395, 57.12, 57.375],
+ )
+ sam.eval()
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ sam.load_state_dict(state_dict)
+ return sam
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e906243d898d7fc071c0fe218338c5cace3ea1
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/common.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/image_encoder.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ad9ad2938842308e482a05c9d35ab08db9b2c3
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/image_encoder.py
@@ -0,0 +1,395 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ """
+ super().__init__()
+ self.img_size = img_size
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/mask_decoder.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e86f7cc9ad95582a08ef2531c68d03fa4af8d99
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/mask_decoder.py
@@ -0,0 +1,176 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ tranformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ )
+
+ # Select the correct mask or masks for outptu
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/prompt_encoder.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/sam.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..303bc2f40c3dbc84f5d4286bb73336e075a86589
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/sam.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ @torch.no_grad()
+ def forward(
+ self,
+ batched_input: List[Dict[str, Any]],
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if it is not present.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input promts,
+ C is determiend by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ if "point_coords" in image_record:
+ points = (image_record["point_coords"], image_record["point_labels"])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get("boxes", None),
+ masks=image_record.get("mask_inputs", None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record["image"].shape[-2:],
+ original_size=image_record["original_size"],
+ )
+ masks = masks > self.mask_threshold
+ outputs.append(
+ {
+ "masks": masks,
+ "iou_predictions": iou_predictions,
+ "low_res_logits": low_res_masks,
+ }
+ )
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_encoder.img_size - h
+ padw = self.image_encoder.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/transformer.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a2812f613cc55b1d0b3e3e1d0c84a760d1fb87
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/modeling/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attenion layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/predictor.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..57c089d1fc4a6bbf5786e1ef62c59e22d582f5aa
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/predictor.py
@@ -0,0 +1,269 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+from segment_anything.modeling import Sam
+
+from typing import Optional, Tuple
+
+from .utils.transforms import ResizeLongestSide
+
+
+class SamPredictor:
+ def __init__(
+ self,
+ sam_model: Sam,
+ ) -> None:
+ """
+ Uses SAM to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam): The model to use for mask prediction.
+ """
+ super().__init__()
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+ self.reset_image()
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = "RGB",
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (
+ len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features = self.model.image_encoder(input_image)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks = masks[0].detach().cpu().numpy()
+ iou_predictions = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks = low_res_masks[0].detach().cpu().numpy()
+ return masks, iou_predictions, low_res_masks
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert self.features is not None, "Features must exist if an image has been set."
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/__init__.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/amg.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a137778e45c464c079658ecb87ec53270e789f7
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/amg.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecesary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/onnx.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..4297b31291e036700d6ad0b818afb7dd72da3054
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/onnx.py
@@ -0,0 +1,144 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+ """
+ This model should not be called directly, but is used in ONNX export.
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+ with some functions modified to enable model tracing. Also supports extra
+ options controlling what information. See the ONNX export script for details.
+ """
+
+ def __init__(
+ self,
+ model: Sam,
+ return_single_mask: bool,
+ use_stability_score: bool = False,
+ return_extra_metrics: bool = False,
+ ) -> None:
+ super().__init__()
+ self.mask_decoder = model.mask_decoder
+ self.model = model
+ self.img_size = model.image_encoder.img_size
+ self.return_single_mask = return_single_mask
+ self.use_stability_score = use_stability_score
+ self.stability_score_offset = 1.0
+ self.return_extra_metrics = return_extra_metrics
+
+ @staticmethod
+ def resize_longest_image_size(
+ input_image_size: torch.Tensor, longest_side: int
+ ) -> torch.Tensor:
+ input_image_size = input_image_size.to(torch.float32)
+ scale = longest_side / torch.max(input_image_size)
+ transformed_size = scale * input_image_size
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+ return transformed_size
+
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+ point_coords = point_coords + 0.5
+ point_coords = point_coords / self.img_size
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+ point_embedding = point_embedding * (point_labels != -1)
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+ point_labels == -1
+ )
+
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+ i
+ ].weight * (point_labels == i)
+
+ return point_embedding
+
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+ mask_embedding = mask_embedding + (
+ 1 - has_mask_input
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+ return mask_embedding
+
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+ masks = F.interpolate(
+ masks,
+ size=(self.img_size, self.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
+ masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
+
+ orig_im_size = orig_im_size.to(torch.int64)
+ h, w = orig_im_size[0], orig_im_size[1]
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+ return masks
+
+ def select_masks(
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Determine if we should return the multiclick mask or not from the number of points.
+ # The reweighting is used to avoid control flow.
+ score_reweight = torch.tensor(
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
+ ).to(iou_preds.device)
+ score = iou_preds + (num_points - 2.5) * score_reweight
+ best_idx = torch.argmax(score, dim=1)
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
+
+ return masks, iou_preds
+
+ @torch.no_grad()
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mask_input: torch.Tensor,
+ has_mask_input: torch.Tensor,
+ orig_im_size: torch.Tensor,
+ ):
+ sparse_embedding = self._embed_points(point_coords, point_labels)
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+ masks, scores = self.model.mask_decoder.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embedding,
+ dense_prompt_embeddings=dense_embedding,
+ )
+
+ if self.use_stability_score:
+ scores = calculate_stability_score(
+ masks, self.model.mask_threshold, self.stability_score_offset
+ )
+
+ if self.return_single_mask:
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+ if self.return_extra_metrics:
+ stability_scores = calculate_stability_score(
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+ )
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+ return upscaled_masks, scores, stability_scores, areas, masks
+
+ return upscaled_masks, scores, masks
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/transforms.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad346661f84b0647026e130a552c4b38b83e2ac
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/segment_anything/utils/transforms.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image # type: ignore
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+ """
+ Resizes images to longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes. Provides methods for
+ transforming both numpy array and batched torch tensors.
+ """
+
+ def __init__(self, target_length: int) -> None:
+ self.target_length = target_length
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array shape Bx4. Requires the original image size
+ in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Expects batched images with shape BxCxHxW and float format. This
+ transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return F.interpolate(
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
+ )
+
+ def apply_coords_torch(
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with shape Bx4. Requires the original image
+ size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.cfg b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..0eee130ba71d14ec260d33a8ebd96a6491079a54
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.cfg
@@ -0,0 +1,11 @@
+[isort]
+line_length=100
+multi_line_output=3
+include_trailing_comma=True
+known_standard_library=numpy,setuptools
+skip_glob=*/__init__.py
+known_myself=segment_anything
+known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort
+no_lines_before=STDLIB,THIRDPARTY
+sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
+default_section=FIRSTPARTY
diff --git a/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.py b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c0986317eb576a14ec774205c88fdee3cc6c0b3
--- /dev/null
+++ b/Make-A-Protagonist/experts/GroundedSAM/segment_anything/setup.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from setuptools import find_packages, setup
+
+setup(
+ name="segment_anything",
+ version="1.0",
+ install_requires=[],
+ packages=find_packages(exclude="notebooks"),
+ extras_require={
+ "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"],
+ "dev": ["flake8", "isort", "black", "mypy"],
+ },
+)
diff --git a/Make-A-Protagonist/experts/XMem/__init__.py b/Make-A-Protagonist/experts/XMem/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/dataset/__init__.py b/Make-A-Protagonist/experts/XMem/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/dataset/range_transform.py b/Make-A-Protagonist/experts/XMem/dataset/range_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/range_transform.py
@@ -0,0 +1,12 @@
+import torchvision.transforms as transforms
+
+im_mean = (124, 116, 104)
+
+im_normalization = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+
+inv_im_trans = transforms.Normalize(
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
+ std=[1/0.229, 1/0.224, 1/0.225])
diff --git a/Make-A-Protagonist/experts/XMem/dataset/reseed.py b/Make-A-Protagonist/experts/XMem/dataset/reseed.py
new file mode 100644
index 0000000000000000000000000000000000000000..600c998fa33485c073af7f9e13e885350a5c6940
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/reseed.py
@@ -0,0 +1,6 @@
+import torch
+import random
+
+def reseed(seed):
+ random.seed(seed)
+ torch.manual_seed(seed)
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/dataset/static_dataset.py b/Make-A-Protagonist/experts/XMem/dataset/static_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5800f5f3471de261f0bad168556b16fd71ce1dff
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/static_dataset.py
@@ -0,0 +1,179 @@
+import os
+from os import path
+
+import torch
+from torch.utils.data.dataset import Dataset
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+from PIL import Image
+import numpy as np
+
+from dataset.range_transform import im_normalization, im_mean
+from dataset.tps import random_tps_warp
+from dataset.reseed import reseed
+
+
+class StaticTransformDataset(Dataset):
+ """
+ Generate pseudo VOS data by applying random transforms on static images.
+ Single-object only.
+
+ Method 0 - FSS style (class/1.jpg class/1.png)
+ Method 1 - Others style (XXX.jpg XXX.png)
+ """
+ def __init__(self, parameters, num_frames=3, max_num_obj=1):
+ self.num_frames = num_frames
+ self.max_num_obj = max_num_obj
+
+ self.im_list = []
+ for parameter in parameters:
+ root, method, multiplier = parameter
+ if method == 0:
+ # Get images
+ classes = os.listdir(root)
+ for c in classes:
+ imgs = os.listdir(path.join(root, c))
+ jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
+
+ joint_list = [path.join(root, c, im) for im in jpg_list]
+ self.im_list.extend(joint_list * multiplier)
+
+ elif method == 1:
+ self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
+
+ print(f'{len(self.im_list)} images found.')
+
+ # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
+ self.pair_im_lone_transform = transforms.Compose([
+ transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
+ ])
+
+ self.pair_im_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
+ transforms.Resize(384, InterpolationMode.BICUBIC),
+ transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
+ ])
+
+ self.pair_gt_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
+ transforms.Resize(384, InterpolationMode.NEAREST),
+ transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
+ ])
+
+
+ # These transform are the same for all pairs in the sampled sequence
+ self.all_im_lone_transform = transforms.Compose([
+ transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
+ transforms.RandomGrayscale(0.05),
+ ])
+
+ self.all_im_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
+ transforms.RandomHorizontalFlip(),
+ ])
+
+ self.all_gt_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
+ transforms.RandomHorizontalFlip(),
+ ])
+
+ # Final transform without randomness
+ self.final_im_transform = transforms.Compose([
+ transforms.ToTensor(),
+ im_normalization,
+ ])
+
+ self.final_gt_transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ def _get_sample(self, idx):
+ im = Image.open(self.im_list[idx]).convert('RGB')
+ gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
+
+ sequence_seed = np.random.randint(2147483647)
+
+ images = []
+ masks = []
+ for _ in range(self.num_frames):
+ reseed(sequence_seed)
+ this_im = self.all_im_dual_transform(im)
+ this_im = self.all_im_lone_transform(this_im)
+ reseed(sequence_seed)
+ this_gt = self.all_gt_dual_transform(gt)
+
+ pairwise_seed = np.random.randint(2147483647)
+ reseed(pairwise_seed)
+ this_im = self.pair_im_dual_transform(this_im)
+ this_im = self.pair_im_lone_transform(this_im)
+ reseed(pairwise_seed)
+ this_gt = self.pair_gt_dual_transform(this_gt)
+
+ # Use TPS only some of the times
+ # Not because TPS is bad -- just that it is too slow and I need to speed up data loading
+ if np.random.rand() < 0.33:
+ this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
+
+ this_im = self.final_im_transform(this_im)
+ this_gt = self.final_gt_transform(this_gt)
+
+ images.append(this_im)
+ masks.append(this_gt)
+
+ images = torch.stack(images, 0)
+ masks = torch.stack(masks, 0)
+
+ return images, masks.numpy()
+
+ def __getitem__(self, idx):
+ additional_objects = np.random.randint(self.max_num_obj)
+ indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
+
+ merged_images = None
+ merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
+
+ for i, list_id in enumerate(indices):
+ images, masks = self._get_sample(list_id)
+ if merged_images is None:
+ merged_images = images
+ else:
+ merged_images = merged_images*(1-masks) + images*masks
+ merged_masks[masks[:,0]>0.5] = (i+1)
+
+ masks = merged_masks
+
+ labels = np.unique(masks[0])
+ # Remove background
+ labels = labels[labels!=0]
+ target_objects = labels.tolist()
+
+ # Generate one-hot ground-truth
+ cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
+ first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
+ for i, l in enumerate(target_objects):
+ this_mask = (masks==l)
+ cls_gt[this_mask] = i+1
+ first_frame_gt[0,i] = (this_mask[0])
+ cls_gt = np.expand_dims(cls_gt, 1)
+
+ info = {}
+ info['name'] = self.im_list[idx]
+ info['num_objects'] = max(1, len(target_objects))
+
+ # 1 if object exist, 0 otherwise
+ selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
+ selector = torch.FloatTensor(selector)
+
+ data = {
+ 'rgb': merged_images,
+ 'first_frame_gt': first_frame_gt,
+ 'cls_gt': cls_gt,
+ 'selector': selector,
+ 'info': info
+ }
+
+ return data
+
+
+ def __len__(self):
+ return len(self.im_list)
diff --git a/Make-A-Protagonist/experts/XMem/dataset/tps.py b/Make-A-Protagonist/experts/XMem/dataset/tps.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee3747c110a8ca03169e3ece5654ba4e8abd7fe
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/tps.py
@@ -0,0 +1,37 @@
+import numpy as np
+from PIL import Image
+import cv2
+import thinplate as tps
+
+cv2.setNumThreads(0)
+
+def pick_random_points(h, w, n_samples):
+ y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
+ x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
+ return y_idx/h, x_idx/w
+
+
+def warp_dual_cv(img, mask, c_src, c_dst):
+ dshape = img.shape
+ theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
+ grid = tps.tps_grid(theta, c_dst, dshape)
+ mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
+ return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
+
+
+def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
+ """
+ Apply a random TPS warp of the input image and mask
+ Uses randomness from numpy
+ """
+ img = np.asarray(img)
+ mask = np.asarray(mask)
+
+ h, w = mask.shape
+ points = pick_random_points(h, w, n_ctrl_pts)
+ c_src = np.stack(points, 1)
+ c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
+ warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
+
+ return Image.fromarray(warp_im), Image.fromarray(warp_gt)
+
diff --git a/Make-A-Protagonist/experts/XMem/dataset/util.py b/Make-A-Protagonist/experts/XMem/dataset/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8e5523c4d2cea4e9010b3c28db0b1f03624e5af
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/util.py
@@ -0,0 +1,13 @@
+import numpy as np
+
+
+def all_to_onehot(masks, labels):
+ if len(masks.shape) == 3:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
+ else:
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
+
+ for ni, l in enumerate(labels):
+ Ms[ni] = (masks == l).astype(np.uint8)
+
+ return Ms
diff --git a/Make-A-Protagonist/experts/XMem/dataset/vos_dataset.py b/Make-A-Protagonist/experts/XMem/dataset/vos_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..be0f8a15a4c31f47a8a59aa115c5b4f937a033cf
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/dataset/vos_dataset.py
@@ -0,0 +1,216 @@
+import os
+from os import path, replace
+
+import torch
+from torch.utils.data.dataset import Dataset
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+from PIL import Image
+import numpy as np
+
+from dataset.range_transform import im_normalization, im_mean
+from dataset.reseed import reseed
+
+
+class VOSDataset(Dataset):
+ """
+ Works for DAVIS/YouTubeVOS/BL30K training
+ For each sequence:
+ - Pick three frames
+ - Pick two objects
+ - Apply some random transforms that are the same for all frames
+ - Apply random transform to each of the frame
+ - The distance between frames is controlled
+ """
+ def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
+ self.im_root = im_root
+ self.gt_root = gt_root
+ self.max_jump = max_jump
+ self.is_bl = is_bl
+ self.num_frames = num_frames
+ self.max_num_obj = max_num_obj
+
+ self.videos = []
+ self.frames = {}
+
+ vid_list = sorted(os.listdir(self.im_root))
+ # Pre-filtering
+ for vid in vid_list:
+ if subset is not None:
+ if vid not in subset:
+ continue
+ frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
+ if len(frames) < num_frames:
+ continue
+ self.frames[vid] = frames
+ self.videos.append(vid)
+
+ print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
+
+ # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
+ self.pair_im_lone_transform = transforms.Compose([
+ transforms.ColorJitter(0.01, 0.01, 0.01, 0),
+ ])
+
+ self.pair_im_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
+ ])
+
+ self.pair_gt_dual_transform = transforms.Compose([
+ transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
+ ])
+
+ # These transform are the same for all pairs in the sampled sequence
+ self.all_im_lone_transform = transforms.Compose([
+ transforms.ColorJitter(0.1, 0.03, 0.03, 0),
+ transforms.RandomGrayscale(0.05),
+ ])
+
+ if self.is_bl:
+ # Use a different cropping scheme for the blender dataset because the image size is different
+ self.all_im_dual_transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
+ ])
+
+ self.all_gt_dual_transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
+ ])
+ else:
+ self.all_im_dual_transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
+ ])
+
+ self.all_gt_dual_transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
+ ])
+
+ # Final transform without randomness
+ self.final_im_transform = transforms.Compose([
+ transforms.ToTensor(),
+ im_normalization,
+ ])
+
+ def __getitem__(self, idx):
+ video = self.videos[idx]
+ info = {}
+ info['name'] = video
+
+ vid_im_path = path.join(self.im_root, video)
+ vid_gt_path = path.join(self.gt_root, video)
+ frames = self.frames[video]
+
+ trials = 0
+ while trials < 5:
+ info['frames'] = [] # Appended with actual frames
+
+ num_frames = self.num_frames
+ length = len(frames)
+ this_max_jump = min(len(frames), self.max_jump)
+
+ # iterative sampling
+ frames_idx = [np.random.randint(length)]
+ acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
+ while(len(frames_idx) < num_frames):
+ idx = np.random.choice(list(acceptable_set))
+ frames_idx.append(idx)
+ new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
+ acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
+
+ frames_idx = sorted(frames_idx)
+ if np.random.rand() < 0.5:
+ # Reverse time
+ frames_idx = frames_idx[::-1]
+
+ sequence_seed = np.random.randint(2147483647)
+ images = []
+ masks = []
+ target_objects = []
+ for f_idx in frames_idx:
+ jpg_name = frames[f_idx][:-4] + '.jpg'
+ png_name = frames[f_idx][:-4] + '.png'
+ info['frames'].append(jpg_name)
+
+ reseed(sequence_seed)
+ this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
+ this_im = self.all_im_dual_transform(this_im)
+ this_im = self.all_im_lone_transform(this_im)
+ reseed(sequence_seed)
+ this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
+ this_gt = self.all_gt_dual_transform(this_gt)
+
+ pairwise_seed = np.random.randint(2147483647)
+ reseed(pairwise_seed)
+ this_im = self.pair_im_dual_transform(this_im)
+ this_im = self.pair_im_lone_transform(this_im)
+ reseed(pairwise_seed)
+ this_gt = self.pair_gt_dual_transform(this_gt)
+
+ this_im = self.final_im_transform(this_im)
+ this_gt = np.array(this_gt)
+
+ images.append(this_im)
+ masks.append(this_gt)
+
+ images = torch.stack(images, 0)
+
+ labels = np.unique(masks[0])
+ # Remove background
+ labels = labels[labels!=0]
+
+ if self.is_bl:
+ # Find large enough labels
+ good_lables = []
+ for l in labels:
+ pixel_sum = (masks[0]==l).sum()
+ if pixel_sum > 10*10:
+ # OK if the object is always this small
+ # Not OK if it is actually much bigger
+ if pixel_sum > 30*30:
+ good_lables.append(l)
+ elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
+ good_lables.append(l)
+ labels = np.array(good_lables, dtype=np.uint8)
+
+ if len(labels) == 0:
+ target_objects = []
+ trials += 1
+ else:
+ target_objects = labels.tolist()
+ break
+
+ if len(target_objects) > self.max_num_obj:
+ target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
+
+ info['num_objects'] = max(1, len(target_objects))
+
+ masks = np.stack(masks, 0)
+
+ # Generate one-hot ground-truth
+ cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
+ first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
+ for i, l in enumerate(target_objects):
+ this_mask = (masks==l)
+ cls_gt[this_mask] = i+1
+ first_frame_gt[0,i] = (this_mask[0])
+ cls_gt = np.expand_dims(cls_gt, 1)
+
+ # 1 if object exist, 0 otherwise
+ selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
+ selector = torch.FloatTensor(selector)
+
+ data = {
+ 'rgb': images,
+ 'first_frame_gt': first_frame_gt,
+ 'cls_gt': cls_gt,
+ 'selector': selector,
+ 'info': info,
+ }
+
+ return data
+
+ def __len__(self):
+ return len(self.videos)
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/__init__.py b/Make-A-Protagonist/experts/XMem/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/data/__init__.py b/Make-A-Protagonist/experts/XMem/inference/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/data/mask_mapper.py b/Make-A-Protagonist/experts/XMem/inference/data/mask_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..29290c16c3043310aa5ede043f3096f0edc4eb09
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/data/mask_mapper.py
@@ -0,0 +1,64 @@
+import numpy as np
+import torch
+
+from XMem.dataset.util import all_to_onehot
+
+
+class MaskMapper:
+ """
+ This class is used to convert a indexed-mask to a one-hot representation.
+ It also takes care of remapping non-continuous indices
+ It has two modes:
+ 1. Default. Only masks with new indices are supposed to go into the remapper.
+ This is also the case for YouTubeVOS.
+ i.e., regions with index 0 are not "background", but "don't care".
+
+ 2. Exhaustive. Regions with index 0 are considered "background".
+ Every single pixel is considered to be "labeled".
+ """
+ def __init__(self):
+ self.labels = []
+ self.remappings = {}
+
+ # if coherent, no mapping is required
+ self.coherent = True
+
+ def convert_mask(self, mask, exhaustive=False):
+ # mask is in index representation, H*W numpy array
+ labels = np.unique(mask).astype(np.uint8)
+ labels = labels[labels!=0].tolist()
+
+ new_labels = list(set(labels) - set(self.labels))
+ if not exhaustive:
+ assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
+
+ # add new remappings
+ for i, l in enumerate(new_labels):
+ self.remappings[l] = i+len(self.labels)+1
+ if self.coherent and i+len(self.labels)+1 != l:
+ self.coherent = False
+
+ if exhaustive:
+ new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
+ else:
+ if self.coherent:
+ new_mapped_labels = new_labels
+ else:
+ new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
+
+ self.labels.extend(new_labels)
+ mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
+
+ # mask num_objects*H*W
+ return mask, new_mapped_labels
+
+
+ def remap_index_mask(self, mask):
+ # mask is in index representation, H*W numpy array
+ if self.coherent:
+ return mask
+
+ new_mask = np.zeros_like(mask)
+ for l, i in self.remappings.items():
+ new_mask[mask==i] = l
+ return new_mask
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/data/test_datasets.py b/Make-A-Protagonist/experts/XMem/inference/data/test_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..16484580a66089736752a7b3bc949e8d721ef109
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/data/test_datasets.py
@@ -0,0 +1,120 @@
+import os
+from os import path
+import json
+
+# from inference.data.video_reader import VideoReader
+from .video_reader import VideoReader
+
+class CustomDataset:
+ def __init__(self, data_root, mask_dir, size=-1):
+ self.image_dir = data_root # data/dir/images
+ self.mask_dir = data_root.replace('images', mask_dir)
+ self.size = size
+ self.video_name = data_root.split('/')[-2] #
+ self.vid_list = [self.video_name]
+
+ def get_datasets(self):
+ yield VideoReader(self.video_name,
+ self.image_dir,
+ self.mask_dir,
+ to_save = [
+ name[:-4] for name in os.listdir(path.join(self.mask_dir))
+ ],
+ size=self.size,
+ )
+
+ def __len__(self):
+ return len(self.vid_list)
+
+
+class LongTestDataset:
+ def __init__(self, data_root, size=-1):
+ self.image_dir = path.join(data_root, 'JPEGImages')
+ self.mask_dir = path.join(data_root, 'Annotations')
+ self.size = size
+
+ self.vid_list = sorted(os.listdir(self.image_dir))
+
+ def get_datasets(self):
+ for video in self.vid_list:
+ yield VideoReader(video,
+ path.join(self.image_dir, video),
+ path.join(self.mask_dir, video),
+ to_save = [
+ name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
+ ],
+ size=self.size,
+ )
+
+ def __len__(self):
+ return len(self.vid_list)
+
+
+
+
+class DAVISTestDataset:
+ def __init__(self, data_root, imset='2017/val.txt', size=-1):
+ if size != 480:
+ self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
+ self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
+ if not path.exists(self.image_dir):
+ print(f'{self.image_dir} not found. Look at other options.')
+ self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
+ self.mask_dir = path.join(data_root, 'Annotations', '1080p')
+ assert path.exists(self.image_dir), 'path not found'
+ else:
+ self.image_dir = path.join(data_root, 'JPEGImages', '480p')
+ self.mask_dir = path.join(data_root, 'Annotations', '480p')
+ self.size_dir = path.join(data_root, 'JPEGImages', '480p')
+ self.size = size
+
+ with open(path.join(data_root, 'ImageSets', imset)) as f:
+ self.vid_list = sorted([line.strip() for line in f])
+
+ def get_datasets(self):
+ for video in self.vid_list:
+ yield VideoReader(video,
+ path.join(self.image_dir, video),
+ path.join(self.mask_dir, video),
+ size=self.size,
+ size_dir=path.join(self.size_dir, video),
+ )
+
+ def __len__(self):
+ return len(self.vid_list)
+
+
+class YouTubeVOSTestDataset:
+ def __init__(self, data_root, split, size=480):
+ self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
+ self.mask_dir = path.join(data_root, split, 'Annotations')
+ self.size = size
+
+ self.vid_list = sorted(os.listdir(self.image_dir))
+ self.req_frame_list = {}
+
+ with open(path.join(data_root, split, 'meta.json')) as f:
+ # read meta.json to know which frame is required for evaluation
+ meta = json.load(f)['videos']
+
+ for vid in self.vid_list:
+ req_frames = []
+ objects = meta[vid]['objects']
+ for value in objects.values():
+ req_frames.extend(value['frames'])
+
+ req_frames = list(set(req_frames))
+ self.req_frame_list[vid] = req_frames
+
+ def get_datasets(self):
+ for video in self.vid_list:
+ yield VideoReader(video,
+ path.join(self.image_dir, video),
+ path.join(self.mask_dir, video),
+ size=self.size,
+ to_save=self.req_frame_list[video],
+ use_all_mask=True
+ )
+
+ def __len__(self):
+ return len(self.vid_list)
diff --git a/Make-A-Protagonist/experts/XMem/inference/data/video_reader.py b/Make-A-Protagonist/experts/XMem/inference/data/video_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b0fdfaf0c80a7975b7d5387e2bce78e36810349
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/data/video_reader.py
@@ -0,0 +1,100 @@
+import os
+from os import path
+
+from torch.utils.data.dataset import Dataset
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+import numpy as np
+
+from XMem.dataset.range_transform import im_normalization
+
+
+class VideoReader(Dataset):
+ """
+ This class is used to read a video, one frame at a time
+ """
+ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
+ """
+ image_dir - points to a directory of jpg images
+ mask_dir - points to a directory of png masks
+ size - resize min. side to size. Does nothing if <0.
+ to_save - optionally contains a list of file names without extensions
+ where the segmentation mask is required
+ use_all_mask - when true, read all available mask in mask_dir.
+ Default false. Set to true for YouTubeVOS validation.
+ """
+ self.vid_name = vid_name
+ self.image_dir = image_dir
+ self.mask_dir = mask_dir
+ self.to_save = to_save
+ self.use_all_mask = use_all_mask
+ if size_dir is None:
+ self.size_dir = self.image_dir
+ else:
+ self.size_dir = size_dir
+
+ self.frames = sorted(os.listdir(self.image_dir))
+ self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
+ self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
+
+ if size < 0:
+ self.im_transform = transforms.Compose([
+ transforms.ToTensor(),
+ im_normalization,
+ ])
+ else:
+ self.im_transform = transforms.Compose([
+ transforms.ToTensor(),
+ im_normalization,
+ transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
+ ])
+ self.size = size
+
+
+ def __getitem__(self, idx):
+ frame = self.frames[idx]
+ info = {}
+ data = {}
+ info['frame'] = frame
+ info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
+
+ im_path = path.join(self.image_dir, frame)
+ img = Image.open(im_path).convert('RGB')
+
+ if self.image_dir == self.size_dir:
+ shape = np.array(img).shape[:2]
+ else:
+ size_path = path.join(self.size_dir, frame)
+ size_im = Image.open(size_path).convert('RGB')
+ shape = np.array(size_im).shape[:2]
+
+ gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
+ img = self.im_transform(img)
+
+ load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
+ if load_mask and path.exists(gt_path):
+ mask = Image.open(gt_path).convert('P')
+ mask = np.array(mask, dtype=np.uint8)
+ data['mask'] = mask
+
+ info['shape'] = shape
+ info['need_resize'] = not (self.size < 0)
+ data['rgb'] = img
+ data['info'] = info
+
+ return data
+
+ def resize_mask(self, mask):
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
+ h, w = mask.shape[-2:]
+ min_hw = min(h, w)
+ return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
+ mode='nearest')
+
+ def get_palette(self):
+ return self.palette
+
+ def __len__(self):
+ return len(self.frames)
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/inference_core.py b/Make-A-Protagonist/experts/XMem/inference/inference_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..b696cbf8884ac79e992d9e7c1da0be7fb5f3c74b
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/inference_core.py
@@ -0,0 +1,107 @@
+from XMem.inference.memory_manager import MemoryManager
+from XMem.model.network import XMem
+from XMem.model.aggregate import aggregate
+
+from XMem.util.tensor_util import pad_divide_by, unpad
+
+
+class InferenceCore:
+ def __init__(self, network:XMem, config):
+ self.config = config
+ self.network = network
+ self.mem_every = config['mem_every']
+ self.deep_update_every = config['deep_update_every']
+ self.enable_long_term = config['enable_long_term']
+
+ # if deep_update_every < 0, synchronize deep update with memory frame
+ self.deep_update_sync = (self.deep_update_every < 0)
+
+ self.clear_memory()
+ self.all_labels = None
+
+ def clear_memory(self):
+ self.curr_ti = -1
+ self.last_mem_ti = 0
+ if not self.deep_update_sync:
+ self.last_deep_update_ti = -self.deep_update_every
+ self.memory = MemoryManager(config=self.config)
+
+ def update_config(self, config):
+ self.mem_every = config['mem_every']
+ self.deep_update_every = config['deep_update_every']
+ self.enable_long_term = config['enable_long_term']
+
+ # if deep_update_every < 0, synchronize deep update with memory frame
+ self.deep_update_sync = (self.deep_update_every < 0)
+ self.memory.update_config(config)
+
+ def set_all_labels(self, all_labels):
+ # self.all_labels = [l.item() for l in all_labels]
+ self.all_labels = all_labels
+
+ def step(self, image, mask=None, valid_labels=None, end=False):
+ # image: 3*H*W
+ # mask: num_objects*H*W or None
+ self.curr_ti += 1
+ image, self.pad = pad_divide_by(image, 16)
+ image = image.unsqueeze(0) # add the batch dimension
+
+ is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
+ need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
+ is_deep_update = (
+ (self.deep_update_sync and is_mem_frame) or # synchronized
+ (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
+ ) and (not end)
+ is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
+
+ key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
+ need_ek=(self.enable_long_term or need_segment),
+ need_sk=is_mem_frame)
+ multi_scale_features = (f16, f8, f4)
+
+ # segment the current frame is needed
+ if need_segment:
+ memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
+ hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
+ self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
+ # remove batch dim
+ pred_prob_with_bg = pred_prob_with_bg[0]
+ pred_prob_no_bg = pred_prob_with_bg[1:]
+ if is_normal_update:
+ self.memory.set_hidden(hidden)
+ else:
+ pred_prob_no_bg = pred_prob_with_bg = None
+
+ # use the input mask if any
+ if mask is not None:
+ mask, _ = pad_divide_by(mask, 16)
+
+ if pred_prob_no_bg is not None:
+ # if we have a predicted mask, we work on it
+ # make pred_prob_no_bg consistent with the input mask
+ mask_regions = (mask.sum(0) > 0.5)
+ pred_prob_no_bg[:, mask_regions] = 0
+ # shift by 1 because mask/pred_prob_no_bg do not contain background
+ mask = mask.type_as(pred_prob_no_bg)
+ if valid_labels is not None:
+ shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
+ # non-labelled objects are copied from the predicted mask
+ mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
+ pred_prob_with_bg = aggregate(mask, dim=0)
+
+ # also create new hidden states
+ self.memory.create_hidden_state(len(self.all_labels), key)
+
+ # save as memory if needed
+ if is_mem_frame:
+ value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
+ pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
+ self.memory.add_memory(key, shrinkage, value, self.all_labels,
+ selection=selection if self.enable_long_term else None)
+ self.last_mem_ti = self.curr_ti
+
+ if is_deep_update:
+ self.memory.set_hidden(hidden)
+ self.last_deep_update_ti = self.curr_ti
+
+ return unpad(pred_prob_with_bg, self.pad)
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/LICENSE b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..fa0086a952236971ab37901954d596efae9f4af6
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/LICENSE
@@ -0,0 +1,373 @@
+Mozilla Public License Version 2.0
+==================================
+
+1. Definitions
+--------------
+
+1.1. "Contributor"
+ means each individual or legal entity that creates, contributes to
+ the creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+ means the combination of the Contributions of others (if any) used
+ by a Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+ means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+ means Source Code Form to which the initial Contributor has attached
+ the notice in Exhibit A, the Executable Form of such Source Code
+ Form, and Modifications of such Source Code Form, in each case
+ including portions thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+ means
+
+ (a) that the initial Contributor has attached the notice described
+ in Exhibit B to the Covered Software; or
+
+ (b) that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the
+ terms of a Secondary License.
+
+1.6. "Executable Form"
+ means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+ means a work that combines Covered Software with other material, in
+ a separate file or files, that is not Covered Software.
+
+1.8. "License"
+ means this document.
+
+1.9. "Licensable"
+ means having the right to grant, to the maximum extent possible,
+ whether at the time of the initial grant or subsequently, any and
+ all of the rights conveyed by this License.
+
+1.10. "Modifications"
+ means any of the following:
+
+ (a) any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered
+ Software; or
+
+ (b) any new file in Source Code Form that contains any Covered
+ Software.
+
+1.11. "Patent Claims" of a Contributor
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the
+ License, by the making, using, selling, offering for sale, having
+ made, import, or transfer of either its Contributions or its
+ Contributor Version.
+
+1.12. "Secondary License"
+ means either the GNU General Public License, Version 2.0, the GNU
+ Lesser General Public License, Version 2.1, the GNU Affero General
+ Public License, Version 3.0, or any later versions of those
+ licenses.
+
+1.13. "Source Code Form"
+ means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that
+ controls, is controlled by, or is under common control with You. For
+ purposes of this definition, "control" means (a) the power, direct
+ or indirect, to cause the direction or management of such entity,
+ whether by contract or otherwise, or (b) ownership of more than
+ fifty percent (50%) of the outstanding shares or beneficial
+ ownership of such entity.
+
+2. License Grants and Conditions
+--------------------------------
+
+2.1. Grants
+
+Each Contributor hereby grants You a world-wide, royalty-free,
+non-exclusive license:
+
+(a) under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+(b) under Patent Claims of such Contributor to make, use, sell, offer
+ for sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+The licenses granted in Section 2.1 with respect to any Contribution
+become effective for each Contribution on the date the Contributor first
+distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+The licenses granted in this Section 2 are the only rights granted under
+this License. No additional rights or licenses will be implied from the
+distribution or licensing of Covered Software under this License.
+Notwithstanding Section 2.1(b) above, no patent license is granted by a
+Contributor:
+
+(a) for any code that a Contributor has removed from Covered Software;
+ or
+
+(b) for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+(c) under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+This License does not grant any rights in the trademarks, service marks,
+or logos of any Contributor (except as may be necessary to comply with
+the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+No Contributor makes additional grants as a result of Your choice to
+distribute the Covered Software under a subsequent version of this
+License (see Section 10.2) or under the terms of a Secondary License (if
+permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+Each Contributor represents that the Contributor believes its
+Contributions are its original creation(s) or it has sufficient rights
+to grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+This License is not intended to limit any rights You have under
+applicable copyright doctrines of fair use, fair dealing, or other
+equivalents.
+
+2.7. Conditions
+
+Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
+in Section 2.1.
+
+3. Responsibilities
+-------------------
+
+3.1. Distribution of Source Form
+
+All distribution of Covered Software in Source Code Form, including any
+Modifications that You create or to which You contribute, must be under
+the terms of this License. You must inform recipients that the Source
+Code Form of the Covered Software is governed by the terms of this
+License, and how they can obtain a copy of this License. You may not
+attempt to alter or restrict the recipients' rights in the Source Code
+Form.
+
+3.2. Distribution of Executable Form
+
+If You distribute Covered Software in Executable Form then:
+
+(a) such Covered Software must also be made available in Source Code
+ Form, as described in Section 3.1, and You must inform recipients of
+ the Executable Form how they can obtain a copy of such Source Code
+ Form by reasonable means in a timely manner, at a charge no more
+ than the cost of distribution to the recipient; and
+
+(b) You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter
+ the recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+You may create and distribute a Larger Work under terms of Your choice,
+provided that You also comply with the requirements of this License for
+the Covered Software. If the Larger Work is a combination of Covered
+Software with a work governed by one or more Secondary Licenses, and the
+Covered Software is not Incompatible With Secondary Licenses, this
+License permits You to additionally distribute such Covered Software
+under the terms of such Secondary License(s), so that the recipient of
+the Larger Work may, at their option, further distribute the Covered
+Software under the terms of either this License or such Secondary
+License(s).
+
+3.4. Notices
+
+You may not remove or alter the substance of any license notices
+(including copyright notices, patent notices, disclaimers of warranty,
+or limitations of liability) contained within the Source Code Form of
+the Covered Software, except that You may alter any license notices to
+the extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+You may choose to offer, and to charge a fee for, warranty, support,
+indemnity or liability obligations to one or more recipients of Covered
+Software. However, You may do so only on Your own behalf, and not on
+behalf of any Contributor. You must make it absolutely clear that any
+such warranty, support, indemnity, or liability obligation is offered by
+You alone, and You hereby agree to indemnify every Contributor for any
+liability incurred by such Contributor as a result of warranty, support,
+indemnity or liability terms You offer. You may include additional
+disclaimers of warranty and limitations of liability specific to any
+jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+---------------------------------------------------
+
+If it is impossible for You to comply with any of the terms of this
+License with respect to some or all of the Covered Software due to
+statute, judicial order, or regulation then You must: (a) comply with
+the terms of this License to the maximum extent possible; and (b)
+describe the limitations and the code they affect. Such description must
+be placed in a text file included with all distributions of the Covered
+Software under this License. Except to the extent prohibited by statute
+or regulation, such description must be sufficiently detailed for a
+recipient of ordinary skill to be able to understand it.
+
+5. Termination
+--------------
+
+5.1. The rights granted under this License will terminate automatically
+if You fail to comply with any of its terms. However, if You become
+compliant, then the rights granted under this License from a particular
+Contributor are reinstated (a) provisionally, unless and until such
+Contributor explicitly and finally terminates Your grants, and (b) on an
+ongoing basis, if such Contributor fails to notify You of the
+non-compliance by some reasonable means prior to 60 days after You have
+come back into compliance. Moreover, Your grants from a particular
+Contributor are reinstated on an ongoing basis if such Contributor
+notifies You of the non-compliance by some reasonable means, this is the
+first time You have received notice of non-compliance with this License
+from such Contributor, and You become compliant prior to 30 days after
+Your receipt of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+infringement claim (excluding declaratory judgment actions,
+counter-claims, and cross-claims) alleging that a Contributor Version
+directly or indirectly infringes any patent, then the rights granted to
+You by any and all Contributors for the Covered Software under Section
+2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all
+end user license agreements (excluding distributors and resellers) which
+have been validly granted by You or Your distributors under this License
+prior to termination shall survive termination.
+
+************************************************************************
+* *
+* 6. Disclaimer of Warranty *
+* ------------------------- *
+* *
+* Covered Software is provided under this License on an "as is" *
+* basis, without warranty of any kind, either expressed, implied, or *
+* statutory, including, without limitation, warranties that the *
+* Covered Software is free of defects, merchantable, fit for a *
+* particular purpose or non-infringing. The entire risk as to the *
+* quality and performance of the Covered Software is with You. *
+* Should any Covered Software prove defective in any respect, You *
+* (not any Contributor) assume the cost of any necessary servicing, *
+* repair, or correction. This disclaimer of warranty constitutes an *
+* essential part of this License. No use of any Covered Software is *
+* authorized under this License except under this disclaimer. *
+* *
+************************************************************************
+
+************************************************************************
+* *
+* 7. Limitation of Liability *
+* -------------------------- *
+* *
+* Under no circumstances and under no legal theory, whether tort *
+* (including negligence), contract, or otherwise, shall any *
+* Contributor, or anyone who distributes Covered Software as *
+* permitted above, be liable to You for any direct, indirect, *
+* special, incidental, or consequential damages of any character *
+* including, without limitation, damages for lost profits, loss of *
+* goodwill, work stoppage, computer failure or malfunction, or any *
+* and all other commercial damages or losses, even if such party *
+* shall have been informed of the possibility of such damages. This *
+* limitation of liability shall not apply to liability for death or *
+* personal injury resulting from such party's negligence to the *
+* extent applicable law prohibits such limitation. Some *
+* jurisdictions do not allow the exclusion or limitation of *
+* incidental or consequential damages, so this exclusion and *
+* limitation may not apply to You. *
+* *
+************************************************************************
+
+8. Litigation
+-------------
+
+Any litigation relating to this License may be brought only in the
+courts of a jurisdiction where the defendant maintains its principal
+place of business and such litigation shall be governed by laws of that
+jurisdiction, without reference to its conflict-of-law provisions.
+Nothing in this Section shall prevent a party's ability to bring
+cross-claims or counter-claims.
+
+9. Miscellaneous
+----------------
+
+This License represents the complete agreement concerning the subject
+matter hereof. If any provision of this License is held to be
+unenforceable, such provision shall be reformed only to the extent
+necessary to make it enforceable. Any law or regulation which provides
+that the language of a contract shall be construed against the drafter
+shall not be used to construe this License against a Contributor.
+
+10. Versions of the License
+---------------------------
+
+10.1. New Versions
+
+Mozilla Foundation is the license steward. Except as provided in Section
+10.3, no one other than the license steward has the right to modify or
+publish new versions of this License. Each version will be given a
+distinguishing version number.
+
+10.2. Effect of New Versions
+
+You may distribute the Covered Software under the terms of the version
+of the License under which You originally received the Covered Software,
+or under the terms of any subsequent version published by the license
+steward.
+
+10.3. Modified Versions
+
+If you create software not governed by this License, and you want to
+create a new license for such software, you may create and use a
+modified version of this License if you rename the license and remove
+any references to the name of the license steward (except to note that
+such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+Licenses
+
+If You choose to distribute Source Code Form that is Incompatible With
+Secondary Licenses under the terms of this version of the License, the
+notice described in Exhibit B of this License must be attached.
+
+Exhibit A - Source Code Form License Notice
+-------------------------------------------
+
+ This Source Code Form is subject to the terms of the Mozilla Public
+ License, v. 2.0. If a copy of the MPL was not distributed with this
+ file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular
+file, then You may include the notice in a location (such as a LICENSE
+file in a relevant directory) where a recipient would be likely to look
+for such a notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+---------------------------------------------------------
+
+ This Source Code Form is "Incompatible With Secondary Licenses", as
+ defined by the Mozilla Public License, v. 2.0.
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/controller.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..57a0a9b7fec9a7bc9d0b6bc605b268b662fef77b
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/controller.py
@@ -0,0 +1,103 @@
+import torch
+
+from ..fbrs.inference import clicker
+from ..fbrs.inference.predictors import get_predictor
+
+
+class InteractiveController:
+ def __init__(self, net, device, predictor_params, prob_thresh=0.5):
+ self.net = net.to(device)
+ self.prob_thresh = prob_thresh
+ self.clicker = clicker.Clicker()
+ self.states = []
+ self.probs_history = []
+ self.object_count = 0
+ self._result_mask = None
+
+ self.image = None
+ self.predictor = None
+ self.device = device
+ self.predictor_params = predictor_params
+ self.reset_predictor()
+
+ def set_image(self, image):
+ self.image = image
+ self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8)
+ self.object_count = 0
+ self.reset_last_object()
+
+ def add_click(self, x, y, is_positive):
+ self.states.append({
+ 'clicker': self.clicker.get_state(),
+ 'predictor': self.predictor.get_states()
+ })
+
+ click = clicker.Click(is_positive=is_positive, coords=(y, x))
+ self.clicker.add_click(click)
+ pred = self.predictor.get_prediction(self.clicker)
+ torch.cuda.empty_cache()
+
+ if self.probs_history:
+ self.probs_history.append((self.probs_history[-1][0], pred))
+ else:
+ self.probs_history.append((torch.zeros_like(pred), pred))
+
+ def undo_click(self):
+ if not self.states:
+ return
+
+ prev_state = self.states.pop()
+ self.clicker.set_state(prev_state['clicker'])
+ self.predictor.set_states(prev_state['predictor'])
+ self.probs_history.pop()
+
+ def partially_finish_object(self):
+ object_prob = self.current_object_prob
+ if object_prob is None:
+ return
+
+ self.probs_history.append((object_prob, torch.zeros_like(object_prob)))
+ self.states.append(self.states[-1])
+
+ self.clicker.reset_clicks()
+ self.reset_predictor()
+
+ def finish_object(self):
+ object_prob = self.current_object_prob
+ if object_prob is None:
+ return
+
+ self.object_count += 1
+ object_mask = object_prob > self.prob_thresh
+ self._result_mask[object_mask] = self.object_count
+ self.reset_last_object()
+
+ def reset_last_object(self):
+ self.states = []
+ self.probs_history = []
+ self.clicker.reset_clicks()
+ self.reset_predictor()
+
+ def reset_predictor(self, predictor_params=None):
+ if predictor_params is not None:
+ self.predictor_params = predictor_params
+ self.predictor = get_predictor(self.net, device=self.device,
+ **self.predictor_params)
+ if self.image is not None:
+ self.predictor.set_input_image(self.image)
+
+ @property
+ def current_object_prob(self):
+ if self.probs_history:
+ current_prob_total, current_prob_additive = self.probs_history[-1]
+ return torch.maximum(current_prob_total, current_prob_additive)
+ else:
+ return None
+
+ @property
+ def is_incomplete_mask(self):
+ return len(self.probs_history) > 0
+
+ @property
+ def result_mask(self):
+ return self._result_mask.clone()
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/clicker.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/clicker.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1ea9cf319f88639fa0af45088cdf79c8954f83a
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/clicker.py
@@ -0,0 +1,103 @@
+from collections import namedtuple
+
+import numpy as np
+from copy import deepcopy
+from scipy.ndimage import distance_transform_edt
+
+Click = namedtuple('Click', ['is_positive', 'coords'])
+
+
+class Clicker(object):
+ def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1):
+ if gt_mask is not None:
+ self.gt_mask = gt_mask == 1
+ self.not_ignore_mask = gt_mask != ignore_label
+ else:
+ self.gt_mask = None
+
+ self.reset_clicks()
+
+ if init_clicks is not None:
+ for click in init_clicks:
+ self.add_click(click)
+
+ def make_next_click(self, pred_mask):
+ assert self.gt_mask is not None
+ click = self._get_click(pred_mask)
+ self.add_click(click)
+
+ def get_clicks(self, clicks_limit=None):
+ return self.clicks_list[:clicks_limit]
+
+ def _get_click(self, pred_mask, padding=True):
+ fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
+ fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
+
+ if padding:
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
+
+ fn_mask_dt = distance_transform_edt(fn_mask)
+ fp_mask_dt = distance_transform_edt(fp_mask)
+
+ if padding:
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
+
+ fn_mask_dt = fn_mask_dt * self.not_clicked_map
+ fp_mask_dt = fp_mask_dt * self.not_clicked_map
+
+ fn_max_dist = np.max(fn_mask_dt)
+ fp_max_dist = np.max(fp_mask_dt)
+
+ is_positive = fn_max_dist > fp_max_dist
+ if is_positive:
+ coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
+ else:
+ coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
+
+ return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
+
+ def add_click(self, click):
+ coords = click.coords
+
+ if click.is_positive:
+ self.num_pos_clicks += 1
+ else:
+ self.num_neg_clicks += 1
+
+ self.clicks_list.append(click)
+ if self.gt_mask is not None:
+ self.not_clicked_map[coords[0], coords[1]] = False
+
+ def _remove_last_click(self):
+ click = self.clicks_list.pop()
+ coords = click.coords
+
+ if click.is_positive:
+ self.num_pos_clicks -= 1
+ else:
+ self.num_neg_clicks -= 1
+
+ if self.gt_mask is not None:
+ self.not_clicked_map[coords[0], coords[1]] = True
+
+ def reset_clicks(self):
+ if self.gt_mask is not None:
+ self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
+
+ self.num_pos_clicks = 0
+ self.num_neg_clicks = 0
+
+ self.clicks_list = []
+
+ def get_state(self):
+ return deepcopy(self.clicks_list)
+
+ def set_state(self, state):
+ self.reset_clicks()
+ for click in state:
+ self.add_click(click)
+
+ def __len__(self):
+ return len(self.clicks_list)
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/evaluation.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6be3ed813eb257309f433ece0035e0890a82207e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/evaluation.py
@@ -0,0 +1,56 @@
+from time import time
+
+import numpy as np
+import torch
+
+from ..inference import utils
+from ..inference.clicker import Clicker
+
+try:
+ get_ipython()
+ from tqdm import tqdm_notebook as tqdm
+except NameError:
+ from tqdm import tqdm
+
+
+def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs):
+ all_ious = []
+
+ start_time = time()
+ for index in tqdm(range(len(dataset)), leave=False):
+ sample = dataset.get_sample(index)
+ item = dataset[index]
+
+ if oracle_eval:
+ gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32)
+ gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
+ predictor.opt_functor.mask_loss.set_gt_mask(gt_mask)
+ _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs)
+ all_ious.append(sample_ious)
+ end_time = time()
+ elapsed_time = end_time - start_time
+
+ return all_ious, elapsed_time
+
+
+def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr,
+ pred_thr=0.49, max_clicks=20):
+ clicker = Clicker(gt_mask=instances_mask)
+ pred_mask = np.zeros_like(instances_mask)
+ ious_list = []
+
+ with torch.no_grad():
+ predictor.set_input_image(image_nd)
+
+ for click_number in range(max_clicks):
+ clicker.make_next_click(pred_mask)
+ pred_probs = predictor.get_prediction(clicker)
+ pred_mask = pred_probs > pred_thr
+
+ iou = utils.get_iou(instances_mask, pred_mask)
+ ious_list.append(iou)
+
+ if iou >= max_iou_thr:
+ break
+
+ return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04b8b8618cd33efabdaec69328de2f5a8a58d2f9
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/__init__.py
@@ -0,0 +1,95 @@
+from .base import BasePredictor
+from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
+from .brs_functors import InputOptimizer, ScaleBiasOptimizer
+from ..transforms import ZoomIn
+from ...model.is_hrnet_model import DistMapsHRNetModel
+
+
+def get_predictor(net, brs_mode, device,
+ prob_thresh=0.49,
+ with_flip=True,
+ zoom_in_params=dict(),
+ predictor_params=None,
+ brs_opt_func_params=None,
+ lbfgs_params=None):
+ lbfgs_params_ = {
+ 'm': 20,
+ 'factr': 0,
+ 'pgtol': 1e-8,
+ 'maxfun': 20,
+ }
+
+ predictor_params_ = {
+ 'optimize_after_n_clicks': 1
+ }
+
+ if zoom_in_params is not None:
+ zoom_in = ZoomIn(**zoom_in_params)
+ else:
+ zoom_in = None
+
+ if lbfgs_params is not None:
+ lbfgs_params_.update(lbfgs_params)
+ lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
+
+ if brs_opt_func_params is None:
+ brs_opt_func_params = dict()
+
+ if brs_mode == 'NoBRS':
+ if predictor_params is not None:
+ predictor_params_.update(predictor_params)
+ predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
+ elif brs_mode.startswith('f-BRS'):
+ predictor_params_.update({
+ 'net_clicks_limit': 8,
+ })
+ if predictor_params is not None:
+ predictor_params_.update(predictor_params)
+
+ insertion_mode = {
+ 'f-BRS-A': 'after_c4',
+ 'f-BRS-B': 'after_aspp',
+ 'f-BRS-C': 'after_deeplab'
+ }[brs_mode]
+
+ opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
+ with_flip=with_flip,
+ optimizer_params=lbfgs_params_,
+ **brs_opt_func_params)
+
+ if isinstance(net, DistMapsHRNetModel):
+ FeaturePredictor = HRNetFeatureBRSPredictor
+ insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
+ else:
+ FeaturePredictor = FeatureBRSPredictor
+
+ predictor = FeaturePredictor(net, device,
+ opt_functor=opt_functor,
+ with_flip=with_flip,
+ insertion_mode=insertion_mode,
+ zoom_in=zoom_in,
+ **predictor_params_)
+ elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
+ use_dmaps = brs_mode == 'DistMap-BRS'
+
+ predictor_params_.update({
+ 'net_clicks_limit': 5,
+ })
+ if predictor_params is not None:
+ predictor_params_.update(predictor_params)
+
+ opt_functor = InputOptimizer(prob_thresh=prob_thresh,
+ with_flip=with_flip,
+ optimizer_params=lbfgs_params_,
+ **brs_opt_func_params)
+
+ predictor = InputBRSPredictor(net, device,
+ optimize_target='dmaps' if use_dmaps else 'rgb',
+ opt_functor=opt_functor,
+ with_flip=with_flip,
+ zoom_in=zoom_in,
+ **predictor_params_)
+ else:
+ raise NotImplementedError
+
+ return predictor
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/base.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3776506328ef9457afdad047fb4219c5e25c3ab6
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/base.py
@@ -0,0 +1,100 @@
+import torch
+import torch.nn.functional as F
+
+from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
+
+
+class BasePredictor(object):
+ def __init__(self, net, device,
+ net_clicks_limit=None,
+ with_flip=False,
+ zoom_in=None,
+ max_size=None,
+ **kwargs):
+ self.net = net
+ self.with_flip = with_flip
+ self.net_clicks_limit = net_clicks_limit
+ self.original_image = None
+ self.device = device
+ self.zoom_in = zoom_in
+
+ self.transforms = [zoom_in] if zoom_in is not None else []
+ if max_size is not None:
+ self.transforms.append(LimitLongestSide(max_size=max_size))
+ self.transforms.append(SigmoidForPred())
+ if with_flip:
+ self.transforms.append(AddHorizontalFlip())
+
+ def set_input_image(self, image_nd):
+ for transform in self.transforms:
+ transform.reset()
+ self.original_image = image_nd.to(self.device)
+ if len(self.original_image.shape) == 3:
+ self.original_image = self.original_image.unsqueeze(0)
+
+ def get_prediction(self, clicker):
+ clicks_list = clicker.get_clicks()
+
+ image_nd, clicks_lists, is_image_changed = self.apply_transforms(
+ self.original_image, [clicks_list]
+ )
+
+ pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
+ prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
+ size=image_nd.size()[2:])
+
+ for t in reversed(self.transforms):
+ prediction = t.inv_transform(prediction)
+
+ if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
+ print('zooming')
+ return self.get_prediction(clicker)
+
+ # return prediction.cpu().numpy()[0, 0]
+ return prediction
+
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
+ points_nd = self.get_points_nd(clicks_lists)
+ return self.net(image_nd, points_nd)['instances']
+
+ def _get_transform_states(self):
+ return [x.get_state() for x in self.transforms]
+
+ def _set_transform_states(self, states):
+ assert len(states) == len(self.transforms)
+ for state, transform in zip(states, self.transforms):
+ transform.set_state(state)
+
+ def apply_transforms(self, image_nd, clicks_lists):
+ is_image_changed = False
+ for t in self.transforms:
+ image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
+ is_image_changed |= t.image_changed
+
+ return image_nd, clicks_lists, is_image_changed
+
+ def get_points_nd(self, clicks_lists):
+ total_clicks = []
+ num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
+ num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
+ num_max_points = max(num_pos_clicks + num_neg_clicks)
+ if self.net_clicks_limit is not None:
+ num_max_points = min(self.net_clicks_limit, num_max_points)
+ num_max_points = max(1, num_max_points)
+
+ for clicks_list in clicks_lists:
+ clicks_list = clicks_list[:self.net_clicks_limit]
+ pos_clicks = [click.coords for click in clicks_list if click.is_positive]
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)]
+
+ neg_clicks = [click.coords for click in clicks_list if not click.is_positive]
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)]
+ total_clicks.append(pos_clicks + neg_clicks)
+
+ return torch.tensor(total_clicks, device=self.device)
+
+ def get_states(self):
+ return {'transform_states': self._get_transform_states()}
+
+ def set_states(self, states):
+ self._set_transform_states(states['transform_states'])
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc7296e52d5e575956eec8a614682a35cff9cd7
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs.py
@@ -0,0 +1,280 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.optimize import fmin_l_bfgs_b
+
+from .base import BasePredictor
+from ...model.is_hrnet_model import DistMapsHRNetModel
+
+
+class BRSBasePredictor(BasePredictor):
+ def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
+ super().__init__(model, device, **kwargs)
+ self.optimize_after_n_clicks = optimize_after_n_clicks
+ self.opt_functor = opt_functor
+
+ self.opt_data = None
+ self.input_data = None
+
+ def set_input_image(self, image_nd):
+ super().set_input_image(image_nd)
+ self.opt_data = None
+ self.input_data = None
+
+ def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
+ pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
+ neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
+
+ for list_indx, clicks_list in enumerate(clicks_lists):
+ for click in clicks_list:
+ y, x = click.coords
+ y, x = int(round(y)), int(round(x))
+ y1, x1 = y - radius, x - radius
+ y2, x2 = y + radius + 1, x + radius + 1
+
+ if click.is_positive:
+ pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
+ else:
+ neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
+
+ with torch.no_grad():
+ pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
+ neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
+
+ return pos_clicks_map, neg_clicks_map
+
+ def get_states(self):
+ return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
+
+ def set_states(self, states):
+ self._set_transform_states(states['transform_states'])
+ self.opt_data = states['opt_data']
+
+
+class FeatureBRSPredictor(BRSBasePredictor):
+ def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
+ self.insertion_mode = insertion_mode
+ self._c1_features = None
+
+ if self.insertion_mode == 'after_deeplab':
+ self.num_channels = model.feature_extractor.ch
+ elif self.insertion_mode == 'after_c4':
+ self.num_channels = model.feature_extractor.aspp_in_channels
+ elif self.insertion_mode == 'after_aspp':
+ self.num_channels = model.feature_extractor.ch + 32
+ else:
+ raise NotImplementedError
+
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
+ points_nd = self.get_points_nd(clicks_lists)
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
+
+ num_clicks = len(clicks_lists[0])
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
+
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
+
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
+ self.input_data = self._get_head_input(image_nd, points_nd)
+
+ def get_prediction_logits(scale, bias):
+ scale = scale.view(bs, -1, 1, 1)
+ bias = bias.view(bs, -1, 1, 1)
+ if self.with_flip:
+ scale = scale.repeat(2, 1, 1, 1)
+ bias = bias.repeat(2, 1, 1, 1)
+
+ scaled_backbone_features = self.input_data * scale
+ scaled_backbone_features = scaled_backbone_features + bias
+ if self.insertion_mode == 'after_c4':
+ x = self.net.feature_extractor.aspp(scaled_backbone_features)
+ x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
+ align_corners=True)
+ x = torch.cat((x, self._c1_features), dim=1)
+ scaled_backbone_features = self.net.feature_extractor.head(x)
+ elif self.insertion_mode == 'after_aspp':
+ scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
+
+ pred_logits = self.net.head(scaled_backbone_features)
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
+ align_corners=True)
+ return pred_logits
+
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
+ if num_clicks > self.optimize_after_n_clicks:
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
+ **self.opt_functor.optimizer_params)
+ self.opt_data = opt_result[0]
+
+ with torch.no_grad():
+ if self.opt_functor.best_prediction is not None:
+ opt_pred_logits = self.opt_functor.best_prediction
+ else:
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
+ opt_pred_logits = get_prediction_logits(*opt_vars)
+
+ return opt_pred_logits
+
+ def _get_head_input(self, image_nd, points):
+ with torch.no_grad():
+ coord_features = self.net.dist_maps(image_nd, points)
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
+ if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
+ c1, _, c3, c4 = self.net.feature_extractor.backbone(x)
+ c1 = self.net.feature_extractor.skip_project(c1)
+
+ if self.insertion_mode == 'after_aspp':
+ x = self.net.feature_extractor.aspp(c4)
+ x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
+ x = torch.cat((x, c1), dim=1)
+ backbone_features = x
+ else:
+ backbone_features = c4
+ self._c1_features = c1
+ else:
+ backbone_features = self.net.feature_extractor(x)[0]
+
+ return backbone_features
+
+
+class HRNetFeatureBRSPredictor(BRSBasePredictor):
+ def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
+ self.insertion_mode = insertion_mode
+ self._c1_features = None
+
+ if self.insertion_mode == 'A':
+ self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
+ elif self.insertion_mode == 'C':
+ self.num_channels = 2 * model.feature_extractor.ocr_width
+ else:
+ raise NotImplementedError
+
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
+ points_nd = self.get_points_nd(clicks_lists)
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
+ num_clicks = len(clicks_lists[0])
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
+
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
+
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
+ self.input_data = self._get_head_input(image_nd, points_nd)
+
+ def get_prediction_logits(scale, bias):
+ scale = scale.view(bs, -1, 1, 1)
+ bias = bias.view(bs, -1, 1, 1)
+ if self.with_flip:
+ scale = scale.repeat(2, 1, 1, 1)
+ bias = bias.repeat(2, 1, 1, 1)
+
+ scaled_backbone_features = self.input_data * scale
+ scaled_backbone_features = scaled_backbone_features + bias
+ if self.insertion_mode == 'A':
+ out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
+ feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
+
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
+ feats = self.net.feature_extractor.ocr_distri_head(feats, context)
+ pred_logits = self.net.feature_extractor.cls_head(feats)
+ elif self.insertion_mode == 'C':
+ pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
+ else:
+ raise NotImplementedError
+
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
+ align_corners=True)
+ return pred_logits
+
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
+ if num_clicks > self.optimize_after_n_clicks:
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
+ **self.opt_functor.optimizer_params)
+ self.opt_data = opt_result[0]
+
+ with torch.no_grad():
+ if self.opt_functor.best_prediction is not None:
+ opt_pred_logits = self.opt_functor.best_prediction
+ else:
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
+ opt_pred_logits = get_prediction_logits(*opt_vars)
+
+ return opt_pred_logits
+
+ def _get_head_input(self, image_nd, points):
+ with torch.no_grad():
+ coord_features = self.net.dist_maps(image_nd, points)
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
+ feats = self.net.feature_extractor.compute_hrnet_feats(x)
+ if self.insertion_mode == 'A':
+ backbone_features = feats
+ elif self.insertion_mode == 'C':
+ out_aux = self.net.feature_extractor.aux_head(feats)
+ feats = self.net.feature_extractor.conv3x3_ocr(feats)
+
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
+ backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
+ else:
+ raise NotImplementedError
+
+ return backbone_features
+
+
+class InputBRSPredictor(BRSBasePredictor):
+ def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
+ self.optimize_target = optimize_target
+
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
+ points_nd = self.get_points_nd(clicks_lists)
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
+ num_clicks = len(clicks_lists[0])
+
+ if self.opt_data is None or is_image_changed:
+ opt_channels = 2 if self.optimize_target == 'dmaps' else 3
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
+ self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
+ device=self.device, dtype=torch.float32)
+
+ def get_prediction_logits(opt_bias):
+ input_image = image_nd
+ if self.optimize_target == 'rgb':
+ input_image = input_image + opt_bias
+ dmaps = self.net.dist_maps(input_image, points_nd)
+ if self.optimize_target == 'dmaps':
+ dmaps = dmaps + opt_bias
+
+ x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
+ if self.optimize_target == 'all':
+ x = x + opt_bias
+
+ if isinstance(self.net, DistMapsHRNetModel):
+ pred_logits = self.net.feature_extractor(x)[0]
+ else:
+ backbone_features = self.net.feature_extractor(x)
+ pred_logits = self.net.head(backbone_features[0])
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
+
+ return pred_logits
+
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
+ shape=self.opt_data.shape)
+ if num_clicks > self.optimize_after_n_clicks:
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
+ **self.opt_functor.optimizer_params)
+
+ self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
+
+ with torch.no_grad():
+ if self.opt_functor.best_prediction is not None:
+ opt_pred_logits = self.opt_functor.best_prediction
+ else:
+ opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
+ opt_pred_logits = get_prediction_logits(*opt_vars)
+
+ return opt_pred_logits
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_functors.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_functors.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e6eb9037a4a3dc0f7671d134eea4113529455f5
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_functors.py
@@ -0,0 +1,109 @@
+import torch
+import numpy as np
+
+from ...model.metrics import _compute_iou
+from .brs_losses import BRSMaskLoss
+
+
+class BaseOptimizer:
+ def __init__(self, optimizer_params,
+ prob_thresh=0.49,
+ reg_weight=1e-3,
+ min_iou_diff=0.01,
+ brs_loss=BRSMaskLoss(),
+ with_flip=False,
+ flip_average=False,
+ **kwargs):
+ self.brs_loss = brs_loss
+ self.optimizer_params = optimizer_params
+ self.prob_thresh = prob_thresh
+ self.reg_weight = reg_weight
+ self.min_iou_diff = min_iou_diff
+ self.with_flip = with_flip
+ self.flip_average = flip_average
+
+ self.best_prediction = None
+ self._get_prediction_logits = None
+ self._opt_shape = None
+ self._best_loss = None
+ self._click_masks = None
+ self._last_mask = None
+ self.device = None
+
+ def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
+ self.best_prediction = None
+ self._get_prediction_logits = get_prediction_logits
+ self._click_masks = (pos_mask, neg_mask)
+ self._opt_shape = shape
+ self._last_mask = None
+ self.device = device
+
+ def __call__(self, x):
+ opt_params = torch.from_numpy(x).float().to(self.device)
+ opt_params.requires_grad_(True)
+
+ with torch.enable_grad():
+ opt_vars, reg_loss = self.unpack_opt_params(opt_params)
+ result_before_sigmoid = self._get_prediction_logits(*opt_vars)
+ result = torch.sigmoid(result_before_sigmoid)
+
+ pos_mask, neg_mask = self._click_masks
+ if self.with_flip and self.flip_average:
+ result, result_flipped = torch.chunk(result, 2, dim=0)
+ result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
+ pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
+
+ loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
+ loss = loss + reg_loss
+
+ f_val = loss.detach().cpu().numpy()
+ if self.best_prediction is None or f_val < self._best_loss:
+ self.best_prediction = result_before_sigmoid.detach()
+ self._best_loss = f_val
+
+ if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
+ return [f_val, np.zeros_like(x)]
+
+ current_mask = result > self.prob_thresh
+ if self._last_mask is not None and self.min_iou_diff > 0:
+ diff_iou = _compute_iou(current_mask, self._last_mask)
+ if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
+ return [f_val, np.zeros_like(x)]
+ self._last_mask = current_mask
+
+ loss.backward()
+ f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32)
+
+ return [f_val, f_grad]
+
+ def unpack_opt_params(self, opt_params):
+ raise NotImplementedError
+
+
+class InputOptimizer(BaseOptimizer):
+ def unpack_opt_params(self, opt_params):
+ opt_params = opt_params.view(self._opt_shape)
+ if self.with_flip:
+ opt_params_flipped = torch.flip(opt_params, dims=[3])
+ opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
+ reg_loss = self.reg_weight * torch.sum(opt_params**2)
+
+ return (opt_params,), reg_loss
+
+
+class ScaleBiasOptimizer(BaseOptimizer):
+ def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.scale_act = scale_act
+ self.reg_bias_weight = reg_bias_weight
+
+ def unpack_opt_params(self, opt_params):
+ scale, bias = torch.chunk(opt_params, 2, dim=0)
+ reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
+
+ if self.scale_act == 'tanh':
+ scale = torch.tanh(scale)
+ elif self.scale_act == 'sin':
+ scale = torch.sin(scale)
+
+ return (1 + scale, bias), reg_loss
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_losses.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d9998ab120b9987e79509d0ee594e8b6c431a9f
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/predictors/brs_losses.py
@@ -0,0 +1,58 @@
+import torch
+
+from ...model.losses import SigmoidBinaryCrossEntropyLoss
+
+
+class BRSMaskLoss(torch.nn.Module):
+ def __init__(self, eps=1e-5):
+ super().__init__()
+ self._eps = eps
+
+ def forward(self, result, pos_mask, neg_mask):
+ pos_diff = (1 - result) * pos_mask
+ pos_target = torch.sum(pos_diff ** 2)
+ pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
+
+ neg_diff = result * neg_mask
+ neg_target = torch.sum(neg_diff ** 2)
+ neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
+
+ loss = pos_target + neg_target
+
+ with torch.no_grad():
+ f_max_pos = torch.max(torch.abs(pos_diff)).item()
+ f_max_neg = torch.max(torch.abs(neg_diff)).item()
+
+ return loss, f_max_pos, f_max_neg
+
+
+class OracleMaskLoss(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gt_mask = None
+ self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
+ self.predictor = None
+ self.history = []
+
+ def set_gt_mask(self, gt_mask):
+ self.gt_mask = gt_mask
+ self.history = []
+
+ def forward(self, result, pos_mask, neg_mask):
+ gt_mask = self.gt_mask.to(result.device)
+ if self.predictor.object_roi is not None:
+ r1, r2, c1, c2 = self.predictor.object_roi[:4]
+ gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
+ gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
+
+ if result.shape[0] == 2:
+ gt_mask_flipped = torch.flip(gt_mask, dims=[3])
+ gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
+
+ loss = self.loss(result, gt_mask)
+ self.history.append(loss.detach().cpu().numpy()[0])
+
+ if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
+ return 0, 0, 0
+
+ return loss, 1.0, 1.0
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd54e38a2f84b3fef481672a7ceab070eb01b82
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/__init__.py
@@ -0,0 +1,5 @@
+from .base import SigmoidForPred
+from .flip import AddHorizontalFlip
+from .zoom_in import ZoomIn
+from .limit_longest_side import LimitLongestSide
+from .crops import Crops
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/base.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb5a2deb3c44f5aed7530fd1e299fff1273737b8
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/base.py
@@ -0,0 +1,38 @@
+import torch
+
+
+class BaseTransform(object):
+ def __init__(self):
+ self.image_changed = False
+
+ def transform(self, image_nd, clicks_lists):
+ raise NotImplementedError
+
+ def inv_transform(self, prob_map):
+ raise NotImplementedError
+
+ def reset(self):
+ raise NotImplementedError
+
+ def get_state(self):
+ raise NotImplementedError
+
+ def set_state(self, state):
+ raise NotImplementedError
+
+
+class SigmoidForPred(BaseTransform):
+ def transform(self, image_nd, clicks_lists):
+ return image_nd, clicks_lists
+
+ def inv_transform(self, prob_map):
+ return torch.sigmoid(prob_map)
+
+ def reset(self):
+ pass
+
+ def get_state(self):
+ return None
+
+ def set_state(self, state):
+ pass
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/crops.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/crops.py
new file mode 100644
index 0000000000000000000000000000000000000000..0910a2825608cf3fa761212d182dc1e8e5c242c4
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/crops.py
@@ -0,0 +1,97 @@
+import math
+
+import torch
+import numpy as np
+
+from ...inference.clicker import Click
+from .base import BaseTransform
+
+
+class Crops(BaseTransform):
+ def __init__(self, crop_size=(320, 480), min_overlap=0.2):
+ super().__init__()
+ self.crop_height, self.crop_width = crop_size
+ self.min_overlap = min_overlap
+
+ self.x_offsets = None
+ self.y_offsets = None
+ self._counts = None
+
+ def transform(self, image_nd, clicks_lists):
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
+ image_height, image_width = image_nd.shape[2:4]
+ self._counts = None
+
+ if image_height < self.crop_height or image_width < self.crop_width:
+ return image_nd, clicks_lists
+
+ self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
+ self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
+ self._counts = np.zeros((image_height, image_width))
+
+ image_crops = []
+ for dy in self.y_offsets:
+ for dx in self.x_offsets:
+ self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
+ image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
+ image_crops.append(image_crop)
+ image_crops = torch.cat(image_crops, dim=0)
+ self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
+
+ clicks_list = clicks_lists[0]
+ clicks_lists = []
+ for dy in self.y_offsets:
+ for dx in self.x_offsets:
+ crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx))
+ for x in clicks_list]
+ clicks_lists.append(crop_clicks)
+
+ return image_crops, clicks_lists
+
+ def inv_transform(self, prob_map):
+ if self._counts is None:
+ return prob_map
+
+ new_prob_map = torch.zeros((1, 1, *self._counts.shape),
+ dtype=prob_map.dtype, device=prob_map.device)
+
+ crop_indx = 0
+ for dy in self.y_offsets:
+ for dx in self.x_offsets:
+ new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
+ crop_indx += 1
+ new_prob_map = torch.div(new_prob_map, self._counts)
+
+ return new_prob_map
+
+ def get_state(self):
+ return self.x_offsets, self.y_offsets, self._counts
+
+ def set_state(self, state):
+ self.x_offsets, self.y_offsets, self._counts = state
+
+ def reset(self):
+ self.x_offsets = None
+ self.y_offsets = None
+ self._counts = None
+
+
+def get_offsets(length, crop_size, min_overlap_ratio=0.2):
+ if length == crop_size:
+ return [0]
+
+ N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
+ N = math.ceil(N)
+
+ overlap_ratio = (N - length / crop_size) / (N - 1)
+ overlap_width = int(crop_size * overlap_ratio)
+
+ offsets = [0]
+ for i in range(1, N):
+ new_offset = offsets[-1] + crop_size - overlap_width
+ if new_offset + crop_size > length:
+ new_offset = length - crop_size
+
+ offsets.append(new_offset)
+
+ return offsets
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/flip.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/flip.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1543cb65f8d3892054dc96f39a8196987fb6bfd
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/flip.py
@@ -0,0 +1,37 @@
+import torch
+
+from ..clicker import Click
+from .base import BaseTransform
+
+
+class AddHorizontalFlip(BaseTransform):
+ def transform(self, image_nd, clicks_lists):
+ assert len(image_nd.shape) == 4
+ image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
+
+ image_width = image_nd.shape[3]
+ clicks_lists_flipped = []
+ for clicks_list in clicks_lists:
+ clicks_list_flipped = [Click(is_positive=click.is_positive,
+ coords=(click.coords[0], image_width - click.coords[1] - 1))
+ for click in clicks_list]
+ clicks_lists_flipped.append(clicks_list_flipped)
+ clicks_lists = clicks_lists + clicks_lists_flipped
+
+ return image_nd, clicks_lists
+
+ def inv_transform(self, prob_map):
+ assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
+ num_maps = prob_map.shape[0] // 2
+ prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
+
+ return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
+
+ def get_state(self):
+ return None
+
+ def set_state(self, state):
+ pass
+
+ def reset(self):
+ pass
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c5a53d2670df52285621dc0d33e86df520d77c
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py
@@ -0,0 +1,22 @@
+from .zoom_in import ZoomIn, get_roi_image_nd
+
+
+class LimitLongestSide(ZoomIn):
+ def __init__(self, max_size=800):
+ super().__init__(target_size=max_size, skip_clicks=0)
+
+ def transform(self, image_nd, clicks_lists):
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
+ image_max_size = max(image_nd.shape[2:4])
+ self.image_changed = False
+
+ if image_max_size <= self.target_size:
+ return image_nd, clicks_lists
+ self._input_image = image_nd
+
+ self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
+ self.image_changed = True
+
+ tclicks_lists = [self._transform_clicks(clicks_lists[0])]
+ return self._roi_image, tclicks_lists
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/zoom_in.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/zoom_in.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c11ecc241570fe2429e85bdccbb713a70d9ffd6
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/transforms/zoom_in.py
@@ -0,0 +1,171 @@
+import torch
+
+from ..clicker import Click
+from ...utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
+from .base import BaseTransform
+
+
+class ZoomIn(BaseTransform):
+ def __init__(self,
+ target_size=400,
+ skip_clicks=1,
+ expansion_ratio=1.4,
+ min_crop_size=200,
+ recompute_thresh_iou=0.5,
+ prob_thresh=0.50):
+ super().__init__()
+ self.target_size = target_size
+ self.min_crop_size = min_crop_size
+ self.skip_clicks = skip_clicks
+ self.expansion_ratio = expansion_ratio
+ self.recompute_thresh_iou = recompute_thresh_iou
+ self.prob_thresh = prob_thresh
+
+ self._input_image_shape = None
+ self._prev_probs = None
+ self._object_roi = None
+ self._roi_image = None
+
+ def transform(self, image_nd, clicks_lists):
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
+ self.image_changed = False
+
+ clicks_list = clicks_lists[0]
+ if len(clicks_list) <= self.skip_clicks:
+ return image_nd, clicks_lists
+
+ self._input_image_shape = image_nd.shape
+
+ current_object_roi = None
+ if self._prev_probs is not None:
+ current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
+ if current_pred_mask.sum() > 0:
+ current_object_roi = get_object_roi(current_pred_mask, clicks_list,
+ self.expansion_ratio, self.min_crop_size)
+
+ if current_object_roi is None:
+ return image_nd, clicks_lists
+
+ update_object_roi = False
+ if self._object_roi is None:
+ update_object_roi = True
+ elif not check_object_roi(self._object_roi, clicks_list):
+ update_object_roi = True
+ elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
+ update_object_roi = True
+
+ if update_object_roi:
+ self._object_roi = current_object_roi
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
+ self.image_changed = True
+
+ tclicks_lists = [self._transform_clicks(clicks_list)]
+ return self._roi_image.to(image_nd.device), tclicks_lists
+
+ def inv_transform(self, prob_map):
+ if self._object_roi is None:
+ self._prev_probs = prob_map.cpu().numpy()
+ return prob_map
+
+ assert prob_map.shape[0] == 1
+ rmin, rmax, cmin, cmax = self._object_roi
+ prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
+ mode='bilinear', align_corners=True)
+
+ if self._prev_probs is not None:
+ new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
+ new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
+ else:
+ new_prob_map = prob_map
+
+ self._prev_probs = new_prob_map.cpu().numpy()
+
+ return new_prob_map
+
+ def check_possible_recalculation(self):
+ if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
+ return False
+
+ pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
+ if pred_mask.sum() > 0:
+ possible_object_roi = get_object_roi(pred_mask, [],
+ self.expansion_ratio, self.min_crop_size)
+ image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
+ if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
+ return True
+ return False
+
+ def get_state(self):
+ roi_image = self._roi_image.cpu() if self._roi_image is not None else None
+ return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
+
+ def set_state(self, state):
+ self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
+
+ def reset(self):
+ self._input_image_shape = None
+ self._object_roi = None
+ self._prev_probs = None
+ self._roi_image = None
+ self.image_changed = False
+
+ def _transform_clicks(self, clicks_list):
+ if self._object_roi is None:
+ return clicks_list
+
+ rmin, rmax, cmin, cmax = self._object_roi
+ crop_height, crop_width = self._roi_image.shape[2:]
+
+ transformed_clicks = []
+ for click in clicks_list:
+ new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
+ new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
+ transformed_clicks.append(Click(is_positive=click.is_positive, coords=(new_r, new_c)))
+ return transformed_clicks
+
+
+def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
+ pred_mask = pred_mask.copy()
+
+ for click in clicks_list:
+ if click.is_positive:
+ pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
+
+ bbox = get_bbox_from_mask(pred_mask)
+ bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
+ h, w = pred_mask.shape[0], pred_mask.shape[1]
+ bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
+
+ return bbox
+
+
+def get_roi_image_nd(image_nd, object_roi, target_size):
+ rmin, rmax, cmin, cmax = object_roi
+
+ height = rmax - rmin + 1
+ width = cmax - cmin + 1
+
+ if isinstance(target_size, tuple):
+ new_height, new_width = target_size
+ else:
+ scale = target_size / max(height, width)
+ new_height = int(round(height * scale))
+ new_width = int(round(width * scale))
+
+ with torch.no_grad():
+ roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
+ roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
+ mode='bilinear', align_corners=True)
+
+ return roi_image_nd
+
+
+def check_object_roi(object_roi, clicks_list):
+ for click in clicks_list:
+ if click.is_positive:
+ if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
+ return False
+ if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
+ return False
+
+ return True
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/utils.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1bec96ae744d68a4d471fbce68717f56296542c
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/inference/utils.py
@@ -0,0 +1,177 @@
+from datetime import timedelta
+from pathlib import Path
+
+import torch
+import numpy as np
+
+from ..model.is_deeplab_model import get_deeplab_model
+from ..model.is_hrnet_model import get_hrnet_model
+
+
+def get_time_metrics(all_ious, elapsed_time):
+ n_images = len(all_ious)
+ n_clicks = sum(map(len, all_ious))
+
+ mean_spc = elapsed_time / n_clicks
+ mean_spi = elapsed_time / n_images
+
+ return mean_spc, mean_spi
+
+
+def load_is_model(checkpoint, device, backbone='auto', **kwargs):
+ if isinstance(checkpoint, (str, Path)):
+ state_dict = torch.load(checkpoint, map_location='cpu')
+ else:
+ state_dict = checkpoint
+
+ if backbone == 'auto':
+ for k in state_dict.keys():
+ if 'feature_extractor.stage2.0.branches' in k:
+ return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
+ return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
+ elif 'resnet' in backbone:
+ return load_deeplab_is_model(state_dict, device, backbone, **kwargs)
+ elif 'hrnet' in backbone:
+ return load_hrnet_is_model(state_dict, device, backbone, **kwargs)
+ else:
+ raise NotImplementedError('Unknown backbone')
+
+
+def load_hrnet_is_model(state_dict, device, backbone='auto', width=48, ocr_width=256,
+ small=False, cpu_dist_maps=False, norm_radius=260):
+ if backbone == 'auto':
+ num_fe_weights = len([x for x in state_dict.keys() if 'feature_extractor.' in x])
+ small = num_fe_weights < 1800
+
+ ocr_f_down = [v for k, v in state_dict.items() if 'object_context_block.f_down.1.0.bias' in k]
+ assert len(ocr_f_down) == 1
+ ocr_width = ocr_f_down[0].shape[0]
+
+ s2_conv1_w = [v for k, v in state_dict.items() if 'stage2.0.branches.0.0.conv1.weight' in k]
+ assert len(s2_conv1_w) == 1
+ width = s2_conv1_w[0].shape[0]
+
+ model = get_hrnet_model(width=width, ocr_width=ocr_width, small=small,
+ with_aux_output=False, cpu_dist_maps=cpu_dist_maps,
+ norm_radius=norm_radius)
+
+ model.load_state_dict(state_dict, strict=False)
+ for param in model.parameters():
+ param.requires_grad = False
+ model.to(device)
+ model.eval()
+
+ return model
+
+
+def load_deeplab_is_model(state_dict, device, backbone='auto', deeplab_ch=128, aspp_dropout=0.2,
+ cpu_dist_maps=False, norm_radius=260):
+ if backbone == 'auto':
+ num_backbone_params = len([x for x in state_dict.keys()
+ if 'feature_extractor.backbone' in x and not('num_batches_tracked' in x)])
+
+ if num_backbone_params <= 181:
+ backbone = 'resnet34'
+ elif num_backbone_params <= 276:
+ backbone = 'resnet50'
+ elif num_backbone_params <= 531:
+ backbone = 'resnet101'
+ else:
+ raise NotImplementedError('Unknown backbone')
+
+ if 'aspp_dropout' in state_dict:
+ aspp_dropout = float(state_dict['aspp_dropout'].cpu().numpy())
+ else:
+ aspp_project_weight = [v for k, v in state_dict.items() if 'aspp.project.0.weight' in k][0]
+ deeplab_ch = aspp_project_weight.size(0)
+ if deeplab_ch == 256:
+ aspp_dropout = 0.5
+
+ model = get_deeplab_model(backbone=backbone, deeplab_ch=deeplab_ch,
+ aspp_dropout=aspp_dropout, cpu_dist_maps=cpu_dist_maps,
+ norm_radius=norm_radius)
+
+ model.load_state_dict(state_dict, strict=False)
+ for param in model.parameters():
+ param.requires_grad = False
+ model.to(device)
+ model.eval()
+
+ return model
+
+
+def get_iou(gt_mask, pred_mask, ignore_label=-1):
+ ignore_gt_mask_inv = gt_mask != ignore_label
+ obj_gt_mask = gt_mask == 1
+
+ intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
+ union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
+
+ return intersection / union
+
+
+def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
+ def _get_noc(iou_arr, iou_thr):
+ vals = iou_arr >= iou_thr
+ return np.argmax(vals) + 1 if np.any(vals) else max_clicks
+
+ noc_list = []
+ over_max_list = []
+ for iou_thr in iou_thrs:
+ scores_arr = np.array([_get_noc(iou_arr, iou_thr)
+ for iou_arr in all_ious], dtype=np.int32)
+
+ score = scores_arr.mean()
+ over_max = (scores_arr == max_clicks).sum()
+
+ noc_list.append(score)
+ over_max_list.append(over_max)
+
+ return noc_list, over_max_list
+
+
+def find_checkpoint(weights_folder, checkpoint_name):
+ weights_folder = Path(weights_folder)
+ if ':' in checkpoint_name:
+ model_name, checkpoint_name = checkpoint_name.split(':')
+ models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
+ assert len(models_candidates) == 1
+ model_folder = models_candidates[0]
+ else:
+ model_folder = weights_folder
+
+ if checkpoint_name.endswith('.pth'):
+ if Path(checkpoint_name).exists():
+ checkpoint_path = checkpoint_name
+ else:
+ checkpoint_path = weights_folder / checkpoint_name
+ else:
+ model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
+ assert len(model_checkpoints) == 1
+ checkpoint_path = model_checkpoints[0]
+
+ return str(checkpoint_path)
+
+
+def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
+ n_clicks=20, model_name=None):
+ table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
+ f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
+ f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
+ f'{"SPC,s":^7}|{"Time":^9}|')
+ row_width = len(table_header)
+
+ header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
+ header += '-' * row_width + '\n'
+ header += table_header + '\n' + '-' * row_width
+
+ eval_time = str(timedelta(seconds=int(elapsed_time)))
+ table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
+ table_row += f'{noc_list[0]:^9.2f}|'
+ table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
+ table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
+ table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
+ table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
+ table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
+
+ return header, table_row
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/initializer.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/initializer.py
new file mode 100644
index 0000000000000000000000000000000000000000..470c7df4659bc1e80ceec80a170b3b2e0302fb84
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/initializer.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class Initializer(object):
+ def __init__(self, local_init=True, gamma=None):
+ self.local_init = local_init
+ self.gamma = gamma
+
+ def __call__(self, m):
+ if getattr(m, '__initialized', False):
+ return
+
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
+ nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
+ nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
+ if m.weight is not None:
+ self._init_gamma(m.weight.data)
+ if m.bias is not None:
+ self._init_beta(m.bias.data)
+ else:
+ if getattr(m, 'weight', None) is not None:
+ self._init_weight(m.weight.data)
+ if getattr(m, 'bias', None) is not None:
+ self._init_bias(m.bias.data)
+
+ if self.local_init:
+ object.__setattr__(m, '__initialized', True)
+
+ def _init_weight(self, data):
+ nn.init.uniform_(data, -0.07, 0.07)
+
+ def _init_bias(self, data):
+ nn.init.constant_(data, 0)
+
+ def _init_gamma(self, data):
+ if self.gamma is None:
+ nn.init.constant_(data, 1.0)
+ else:
+ nn.init.normal_(data, 1.0, self.gamma)
+
+ def _init_beta(self, data):
+ nn.init.constant_(data, 0)
+
+
+class Bilinear(Initializer):
+ def __init__(self, scale, groups, in_channels, **kwargs):
+ super().__init__(**kwargs)
+ self.scale = scale
+ self.groups = groups
+ self.in_channels = in_channels
+
+ def _init_weight(self, data):
+ """Reset the weight and bias."""
+ bilinear_kernel = self.get_bilinear_kernel(self.scale)
+ weight = torch.zeros_like(data)
+ for i in range(self.in_channels):
+ if self.groups == 1:
+ j = i
+ else:
+ j = 0
+ weight[i, j] = bilinear_kernel
+ data[:] = weight
+
+ @staticmethod
+ def get_bilinear_kernel(scale):
+ """Generate a bilinear upsampling kernel."""
+ kernel_size = 2 * scale - scale % 2
+ scale = (kernel_size + 1) // 2
+ center = scale - 0.5 * (1 + kernel_size % 2)
+
+ og = np.ogrid[:kernel_size, :kernel_size]
+ kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
+
+ return torch.tensor(kernel, dtype=torch.float32)
+
+
+class XavierGluon(Initializer):
+ def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
+ super().__init__(**kwargs)
+
+ self.rnd_type = rnd_type
+ self.factor_type = factor_type
+ self.magnitude = float(magnitude)
+
+ def _init_weight(self, arr):
+ fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
+
+ if self.factor_type == 'avg':
+ factor = (fan_in + fan_out) / 2.0
+ elif self.factor_type == 'in':
+ factor = fan_in
+ elif self.factor_type == 'out':
+ factor = fan_out
+ else:
+ raise ValueError('Incorrect factor type')
+ scale = np.sqrt(self.magnitude / factor)
+
+ if self.rnd_type == 'uniform':
+ nn.init.uniform_(arr, -scale, scale)
+ elif self.rnd_type == 'gaussian':
+ nn.init.normal_(arr, 0, scale)
+ else:
+ raise ValueError('Unknown random type')
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_deeplab_model.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_deeplab_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9a75cc0f56c1a068dc742f65a42a6ec85e9ad83
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_deeplab_model.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn as nn
+
+from .ops import DistMaps
+from .modeling.deeplab_v3 import DeepLabV3Plus
+from .modeling.basic_blocks import SepConvHead
+
+
+def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
+ norm_layer=nn.BatchNorm2d, backbone_norm_layer=None,
+ use_rgb_conv=True, cpu_dist_maps=False,
+ norm_radius=260):
+ model = DistMapsModel(
+ feature_extractor=DeepLabV3Plus(backbone=backbone,
+ ch=deeplab_ch,
+ project_dropout=aspp_dropout,
+ norm_layer=norm_layer,
+ backbone_norm_layer=backbone_norm_layer),
+ head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
+ num_layers=2, norm_layer=norm_layer),
+ use_rgb_conv=use_rgb_conv,
+ norm_layer=norm_layer,
+ norm_radius=norm_radius,
+ cpu_dist_maps=cpu_dist_maps
+ )
+
+ return model
+
+
+class DistMapsModel(nn.Module):
+ def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True,
+ cpu_dist_maps=False, norm_radius=260):
+ super(DistMapsModel, self).__init__()
+
+ if use_rgb_conv:
+ self.rgb_conv = nn.Sequential(
+ nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
+ nn.LeakyReLU(negative_slope=0.2),
+ norm_layer(8),
+ nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
+ )
+ else:
+ self.rgb_conv = None
+
+ self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
+ cpu_mode=cpu_dist_maps)
+ self.feature_extractor = feature_extractor
+ self.head = head
+
+ def forward(self, image, points):
+ coord_features = self.dist_maps(image, points)
+
+ if self.rgb_conv is not None:
+ x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
+ else:
+ c1, c2 = torch.chunk(coord_features, 2, dim=1)
+ c3 = torch.ones_like(c1)
+ coord_features = torch.cat((c1, c2, c3), dim=1)
+ x = 0.8 * image * coord_features + 0.2 * image
+
+ backbone_features = self.feature_extractor(x)
+ instance_out = self.head(backbone_features[0])
+ instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
+ mode='bilinear', align_corners=True)
+
+ return {'instances': instance_out}
+
+ def load_weights(self, path_to_weights):
+ current_state_dict = self.state_dict()
+ new_state_dict = torch.load(path_to_weights, map_location='cpu')
+ current_state_dict.update(new_state_dict)
+ self.load_state_dict(current_state_dict)
+
+ def get_trainable_params(self):
+ backbone_params = nn.ParameterList()
+ other_params = nn.ParameterList()
+
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ if 'backbone' in name:
+ backbone_params.append(param)
+ else:
+ other_params.append(param)
+ return backbone_params, other_params
+
+
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_hrnet_model.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_hrnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced540a782c7b6e5b498d2e345faa95cb4015f4c
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/is_hrnet_model.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+
+from .ops import DistMaps
+from .modeling.hrnet_ocr import HighResolutionNet
+
+
+def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260,
+ use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False,
+ norm_layer=nn.BatchNorm2d):
+ model = DistMapsHRNetModel(
+ feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
+ num_classes=1, norm_layer=norm_layer),
+ use_rgb_conv=use_rgb_conv,
+ with_aux_output=with_aux_output,
+ norm_layer=norm_layer,
+ norm_radius=norm_radius,
+ cpu_dist_maps=cpu_dist_maps
+ )
+
+ return model
+
+
+class DistMapsHRNetModel(nn.Module):
+ def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False,
+ norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False):
+ super(DistMapsHRNetModel, self).__init__()
+ self.with_aux_output = with_aux_output
+
+ if use_rgb_conv:
+ self.rgb_conv = nn.Sequential(
+ nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
+ nn.LeakyReLU(negative_slope=0.2),
+ norm_layer(8),
+ nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
+ )
+ else:
+ self.rgb_conv = None
+
+ self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps)
+ self.feature_extractor = feature_extractor
+
+ def forward(self, image, points):
+ coord_features = self.dist_maps(image, points)
+
+ if self.rgb_conv is not None:
+ x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
+ else:
+ c1, c2 = torch.chunk(coord_features, 2, dim=1)
+ c3 = torch.ones_like(c1)
+ coord_features = torch.cat((c1, c2, c3), dim=1)
+ x = 0.8 * image * coord_features + 0.2 * image
+
+ feature_extractor_out = self.feature_extractor(x)
+ instance_out = feature_extractor_out[0]
+ instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
+ mode='bilinear', align_corners=True)
+ outputs = {'instances': instance_out}
+ if self.with_aux_output:
+ instance_aux_out = feature_extractor_out[1]
+ instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:],
+ mode='bilinear', align_corners=True)
+ outputs['instances_aux'] = instance_aux_out
+
+ return outputs
+
+ def load_weights(self, path_to_weights):
+ current_state_dict = self.state_dict()
+ new_state_dict = torch.load(path_to_weights)
+ current_state_dict.update(new_state_dict)
+ self.load_state_dict(current_state_dict)
+
+ def get_trainable_params(self):
+ backbone_params = nn.ParameterList()
+ other_params = nn.ParameterList()
+ other_params_keys = []
+ nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head']
+
+ for name, param in self.named_parameters():
+ if param.requires_grad:
+ if any(x in name for x in nonbackbone_keywords):
+ other_params.append(param)
+ other_params_keys.append(name)
+ else:
+ backbone_params.append(param)
+ print('Nonbackbone params:', sorted(other_params_keys))
+ return backbone_params, other_params
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/losses.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd89bf02b108533bc8c5639f233549d7387d3dbc
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/losses.py
@@ -0,0 +1,134 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import misc
+
+
+class NormalizedFocalLossSigmoid(nn.Module):
+ def __init__(self, axis=-1, alpha=0.25, gamma=2,
+ from_logits=False, batch_axis=0,
+ weight=None, size_average=True, detach_delimeter=True,
+ eps=1e-12, scale=1.0,
+ ignore_label=-1):
+ super(NormalizedFocalLossSigmoid, self).__init__()
+ self._axis = axis
+ self._alpha = alpha
+ self._gamma = gamma
+ self._ignore_label = ignore_label
+ self._weight = weight if weight is not None else 1.0
+ self._batch_axis = batch_axis
+
+ self._scale = scale
+ self._from_logits = from_logits
+ self._eps = eps
+ self._size_average = size_average
+ self._detach_delimeter = detach_delimeter
+ self._k_sum = 0
+
+ def forward(self, pred, label, sample_weight=None):
+ one_hot = label > 0
+ sample_weight = label != self._ignore_label
+
+ if not self._from_logits:
+ pred = torch.sigmoid(pred)
+
+ alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
+ pt = torch.where(one_hot, pred, 1 - pred)
+ pt = torch.where(sample_weight, pt, torch.ones_like(pt))
+
+ beta = (1 - pt) ** self._gamma
+
+ sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
+ beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
+ mult = sw_sum / (beta_sum + self._eps)
+ if self._detach_delimeter:
+ mult = mult.detach()
+ beta = beta * mult
+
+ ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
+ sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
+ if np.any(ignore_area == 0):
+ self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
+
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
+ loss = self._weight * (loss * sample_weight)
+
+ if self._size_average:
+ bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
+ else:
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
+
+ return self._scale * loss
+
+ def log_states(self, sw, name, global_step):
+ sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
+
+
+class FocalLoss(nn.Module):
+ def __init__(self, axis=-1, alpha=0.25, gamma=2,
+ from_logits=False, batch_axis=0,
+ weight=None, num_class=None,
+ eps=1e-9, size_average=True, scale=1.0):
+ super(FocalLoss, self).__init__()
+ self._axis = axis
+ self._alpha = alpha
+ self._gamma = gamma
+ self._weight = weight if weight is not None else 1.0
+ self._batch_axis = batch_axis
+
+ self._scale = scale
+ self._num_class = num_class
+ self._from_logits = from_logits
+ self._eps = eps
+ self._size_average = size_average
+
+ def forward(self, pred, label, sample_weight=None):
+ if not self._from_logits:
+ pred = F.sigmoid(pred)
+
+ one_hot = label > 0
+ pt = torch.where(one_hot, pred, 1 - pred)
+
+ t = label != -1
+ alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
+ beta = (1 - pt) ** self._gamma
+
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
+ sample_weight = label != -1
+
+ loss = self._weight * (loss * sample_weight)
+
+ if self._size_average:
+ tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
+ else:
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
+
+ return self._scale * loss
+
+
+class SigmoidBinaryCrossEntropyLoss(nn.Module):
+ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
+ super(SigmoidBinaryCrossEntropyLoss, self).__init__()
+ self._from_sigmoid = from_sigmoid
+ self._ignore_label = ignore_label
+ self._weight = weight if weight is not None else 1.0
+ self._batch_axis = batch_axis
+
+ def forward(self, pred, label):
+ label = label.view(pred.size())
+ sample_weight = label != self._ignore_label
+ label = torch.where(sample_weight, label, torch.zeros_like(label))
+
+ if not self._from_sigmoid:
+ loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
+ else:
+ eps = 1e-12
+ loss = -(torch.log(pred + eps) * label
+ + torch.log(1. - pred + eps) * (1. - label))
+
+ loss = self._weight * (loss * sample_weight)
+ return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/metrics.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..9944feb1cf76cfb8707122c7a6ea7a830c02070a
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/metrics.py
@@ -0,0 +1,101 @@
+import torch
+import numpy as np
+
+from ..utils import misc
+
+
+class TrainMetric(object):
+ def __init__(self, pred_outputs, gt_outputs):
+ self.pred_outputs = pred_outputs
+ self.gt_outputs = gt_outputs
+
+ def update(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def get_epoch_value(self):
+ raise NotImplementedError
+
+ def reset_epoch_stats(self):
+ raise NotImplementedError
+
+ def log_states(self, sw, tag_prefix, global_step):
+ pass
+
+ @property
+ def name(self):
+ return type(self).__name__
+
+
+class AdaptiveIoU(TrainMetric):
+ def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
+ ignore_label=-1, from_logits=True,
+ pred_output='instances', gt_output='instances'):
+ super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
+ self._ignore_label = ignore_label
+ self._from_logits = from_logits
+ self._iou_thresh = init_thresh
+ self._thresh_step = thresh_step
+ self._thresh_beta = thresh_beta
+ self._iou_beta = iou_beta
+ self._ema_iou = 0.0
+ self._epoch_iou_sum = 0.0
+ self._epoch_batch_count = 0
+
+ def update(self, pred, gt):
+ gt_mask = gt > 0
+ if self._from_logits:
+ pred = torch.sigmoid(pred)
+
+ gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
+ if np.all(gt_mask_area == 0):
+ return
+
+ ignore_mask = gt == self._ignore_label
+ max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
+ best_thresh = self._iou_thresh
+ for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
+ temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
+ if temp_iou > max_iou:
+ max_iou = temp_iou
+ best_thresh = t
+
+ self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
+ self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
+ self._epoch_iou_sum += max_iou
+ self._epoch_batch_count += 1
+
+ def get_epoch_value(self):
+ if self._epoch_batch_count > 0:
+ return self._epoch_iou_sum / self._epoch_batch_count
+ else:
+ return 0.0
+
+ def reset_epoch_stats(self):
+ self._epoch_iou_sum = 0.0
+ self._epoch_batch_count = 0
+
+ def log_states(self, sw, tag_prefix, global_step):
+ sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
+ sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
+
+ @property
+ def iou_thresh(self):
+ return self._iou_thresh
+
+
+def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
+ if ignore_mask is not None:
+ pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
+
+ reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
+ union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
+ intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
+ nonzero = union > 0
+
+ iou = intersection[nonzero] / union[nonzero]
+ if not keep_ignore:
+ return iou
+ else:
+ result = np.full_like(intersection, -1)
+ result[nonzero] = iou
+ return result
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/basic_blocks.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/basic_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..35946e8b6639460d5822b46a3e82a85bc4f1060e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/basic_blocks.py
@@ -0,0 +1,71 @@
+import torch.nn as nn
+
+from ...model import ops
+
+
+class ConvHead(nn.Module):
+ def __init__(self, out_channels, in_channels=32, num_layers=1,
+ kernel_size=3, padding=1,
+ norm_layer=nn.BatchNorm2d):
+ super(ConvHead, self).__init__()
+ convhead = []
+
+ for i in range(num_layers):
+ convhead.extend([
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
+ nn.ReLU(),
+ norm_layer(in_channels) if norm_layer is not None else nn.Identity()
+ ])
+ convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
+
+ self.convhead = nn.Sequential(*convhead)
+
+ def forward(self, *inputs):
+ return self.convhead(inputs[0])
+
+
+class SepConvHead(nn.Module):
+ def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
+ kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
+ norm_layer=nn.BatchNorm2d):
+ super(SepConvHead, self).__init__()
+
+ sepconvhead = []
+
+ for i in range(num_layers):
+ sepconvhead.append(
+ SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
+ out_channels=mid_channels,
+ dw_kernel=kernel_size, dw_padding=padding,
+ norm_layer=norm_layer, activation='relu')
+ )
+ if dropout_ratio > 0 and dropout_indx == i:
+ sepconvhead.append(nn.Dropout(dropout_ratio))
+
+ sepconvhead.append(
+ nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
+ )
+
+ self.layers = nn.Sequential(*sepconvhead)
+
+ def forward(self, *inputs):
+ x = inputs[0]
+
+ return self.layers(x)
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
+ activation=None, use_bias=False, norm_layer=None):
+ super(SeparableConv2d, self).__init__()
+ _activation = ops.select_activation_function(activation)
+ self.body = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
+ padding=dw_padding, bias=use_bias, groups=in_channels),
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
+ _activation()
+ )
+
+ def forward(self, x):
+ return self.body(x)
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e863862c48a75a2ba9d9aa8a8025ee4333308d5
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/deeplab_v3.py
@@ -0,0 +1,176 @@
+from contextlib import ExitStack
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .basic_blocks import SeparableConv2d
+from .resnet import ResNetBackbone
+from ...model import ops
+
+
+class DeepLabV3Plus(nn.Module):
+ def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
+ backbone_norm_layer=None,
+ ch=256,
+ project_dropout=0.5,
+ inference_mode=False,
+ **kwargs):
+ super(DeepLabV3Plus, self).__init__()
+ if backbone_norm_layer is None:
+ backbone_norm_layer = norm_layer
+
+ self.backbone_name = backbone
+ self.norm_layer = norm_layer
+ self.backbone_norm_layer = backbone_norm_layer
+ self.inference_mode = False
+ self.ch = ch
+ self.aspp_in_channels = 2048
+ self.skip_project_in_channels = 256 # layer 1 out_channels
+
+ self._kwargs = kwargs
+ if backbone == 'resnet34':
+ self.aspp_in_channels = 512
+ self.skip_project_in_channels = 64
+
+ self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
+ norm_layer=self.backbone_norm_layer, **kwargs)
+
+ self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
+ norm_layer=self.norm_layer)
+ self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
+ self.aspp = _ASPP(in_channels=self.aspp_in_channels,
+ atrous_rates=[12, 24, 36],
+ out_channels=ch,
+ project_dropout=project_dropout,
+ norm_layer=self.norm_layer)
+
+ if inference_mode:
+ self.set_prediction_mode()
+
+ def load_pretrained_weights(self):
+ pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
+ norm_layer=self.backbone_norm_layer, **self._kwargs)
+ backbone_state_dict = self.backbone.state_dict()
+ pretrained_state_dict = pretrained.state_dict()
+
+ backbone_state_dict.update(pretrained_state_dict)
+ self.backbone.load_state_dict(backbone_state_dict)
+
+ if self.inference_mode:
+ for param in self.backbone.parameters():
+ param.requires_grad = False
+
+ def set_prediction_mode(self):
+ self.inference_mode = True
+ self.eval()
+
+ def forward(self, x):
+ with ExitStack() as stack:
+ if self.inference_mode:
+ stack.enter_context(torch.no_grad())
+
+ c1, _, c3, c4 = self.backbone(x)
+ c1 = self.skip_project(c1)
+
+ x = self.aspp(c4)
+ x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
+ x = torch.cat((x, c1), dim=1)
+ x = self.head(x)
+
+ return x,
+
+
+class _SkipProject(nn.Module):
+ def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
+ super(_SkipProject, self).__init__()
+ _activation = ops.select_activation_function("relu")
+
+ self.skip_project = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+ norm_layer(out_channels),
+ _activation()
+ )
+
+ def forward(self, x):
+ return self.skip_project(x)
+
+
+class _DeepLabHead(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
+ super(_DeepLabHead, self).__init__()
+
+ self.block = nn.Sequential(
+ SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
+ SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
+ nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class _ASPP(nn.Module):
+ def __init__(self, in_channels, atrous_rates, out_channels=256,
+ project_dropout=0.5, norm_layer=nn.BatchNorm2d):
+ super(_ASPP, self).__init__()
+
+ b0 = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
+ norm_layer(out_channels),
+ nn.ReLU()
+ )
+
+ rate1, rate2, rate3 = tuple(atrous_rates)
+ b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
+ b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
+ b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
+ b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
+
+ self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
+
+ project = [
+ nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
+ kernel_size=1, bias=False),
+ norm_layer(out_channels),
+ nn.ReLU()
+ ]
+ if project_dropout > 0:
+ project.append(nn.Dropout(project_dropout))
+ self.project = nn.Sequential(*project)
+
+ def forward(self, x):
+ x = torch.cat([block(x) for block in self.concurent], dim=1)
+
+ return self.project(x)
+
+
+class _AsppPooling(nn.Module):
+ def __init__(self, in_channels, out_channels, norm_layer):
+ super(_AsppPooling, self).__init__()
+
+ self.gap = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=1, bias=False),
+ norm_layer(out_channels),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ pool = self.gap(x)
+ return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
+
+
+def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
+ block = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=3, padding=atrous_rate,
+ dilation=atrous_rate, bias=False),
+ norm_layer(out_channels),
+ nn.ReLU()
+ )
+
+ return block
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f8eff39c5a7e10ed712f96929644b325a90660
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/hrnet_ocr.py
@@ -0,0 +1,399 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch._utils
+import torch.nn.functional as F
+from .ocr import SpatialOCR_Module, SpatialGather_Module
+from .resnetv1b import BasicBlockV1b, BottleneckV1b
+
+relu_inplace = True
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method,multi_scale_output=True,
+ norm_layer=nn.BatchNorm2d, align_corners=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+ self.norm_layer = norm_layer
+ self.align_corners = align_corners
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=relu_inplace)
+
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ self.norm_layer(num_channels[branch_index] * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index], stride,
+ downsample=downsample, norm_layer=self.norm_layer))
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index],
+ norm_layer=self.norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(in_channels=num_inchannels[j],
+ out_channels=num_inchannels[i],
+ kernel_size=1,
+ bias=False),
+ self.norm_layer(num_inchannels[i])))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ kernel_size=3, stride=2, padding=1, bias=False),
+ self.norm_layer(num_outchannels_conv3x3)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ kernel_size=3, stride=2, padding=1, bias=False),
+ self.norm_layer(num_outchannels_conv3x3),
+ nn.ReLU(inplace=relu_inplace)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ elif j > i:
+ width_output = x[i].shape[-1]
+ height_output = x[i].shape[-2]
+ y = y + F.interpolate(
+ self.fuse_layers[i][j](x[j]),
+ size=[height_output, width_output],
+ mode='bilinear', align_corners=self.align_corners)
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+class HighResolutionNet(nn.Module):
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
+ norm_layer=nn.BatchNorm2d, align_corners=True):
+ super(HighResolutionNet, self).__init__()
+ self.norm_layer = norm_layer
+ self.width = width
+ self.ocr_width = ocr_width
+ self.align_corners = align_corners
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = norm_layer(64)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn2 = norm_layer(64)
+ self.relu = nn.ReLU(inplace=relu_inplace)
+
+ num_blocks = 2 if small else 4
+
+ stage1_num_channels = 64
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
+
+ self.stage2_num_branches = 2
+ num_channels = [width, 2 * width]
+ num_inchannels = [
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer(
+ [stage1_out_channel], num_inchannels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
+
+ self.stage3_num_branches = 3
+ num_channels = [width, 2 * width, 4 * width]
+ num_inchannels = [
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(
+ pre_stage_channels, num_inchannels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ BasicBlockV1b, num_inchannels=num_inchannels,
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
+
+ self.stage4_num_branches = 4
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
+ num_inchannels = [
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(
+ pre_stage_channels, num_inchannels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
+ num_branches=self.stage4_num_branches,
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
+
+ last_inp_channels = np.int32(np.sum(pre_stage_channels))
+ ocr_mid_channels = 2 * ocr_width
+ ocr_key_channels = ocr_width
+
+ self.conv3x3_ocr = nn.Sequential(
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
+ kernel_size=3, stride=1, padding=1),
+ norm_layer(ocr_mid_channels),
+ nn.ReLU(inplace=relu_inplace),
+ )
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
+
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
+ key_channels=ocr_key_channels,
+ out_channels=ocr_mid_channels,
+ scale=1,
+ dropout=0.05,
+ norm_layer=norm_layer,
+ align_corners=align_corners)
+ self.cls_head = nn.Conv2d(
+ ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
+
+ self.aux_head = nn.Sequential(
+ nn.Conv2d(last_inp_channels, last_inp_channels,
+ kernel_size=1, stride=1, padding=0),
+ norm_layer(last_inp_channels),
+ nn.ReLU(inplace=relu_inplace),
+ nn.Conv2d(last_inp_channels, num_classes,
+ kernel_size=1, stride=1, padding=0, bias=True)
+ )
+
+ def _make_transition_layer(
+ self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ self.norm_layer(num_channels_cur_layer[i]),
+ nn.ReLU(inplace=relu_inplace)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv3x3s = []
+ for j in range(i + 1 - num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(inchannels, outchannels,
+ kernel_size=3, stride=2, padding=1, bias=False),
+ self.norm_layer(outchannels),
+ nn.ReLU(inplace=relu_inplace)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ self.norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride,
+ downsample=downsample, norm_layer=self.norm_layer))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, block, num_inchannels,
+ num_modules, num_branches, num_blocks, num_channels,
+ fuse_method='SUM',
+ multi_scale_output=True):
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+ modules.append(
+ HighResolutionModule(num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output,
+ norm_layer=self.norm_layer,
+ align_corners=self.align_corners)
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, x):
+ feats = self.compute_hrnet_feats(x)
+ out_aux = self.aux_head(feats)
+ feats = self.conv3x3_ocr(feats)
+
+ context = self.ocr_gather_head(feats, out_aux)
+ feats = self.ocr_distri_head(feats, context)
+ out = self.cls_head(feats)
+
+ return [out, out_aux]
+
+ def compute_hrnet_feats(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_num_branches):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_num_branches):
+ if self.transition2[i] is not None:
+ if i < self.stage2_num_branches:
+ x_list.append(self.transition2[i](y_list[i]))
+ else:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_num_branches):
+ if self.transition3[i] is not None:
+ if i < self.stage3_num_branches:
+ x_list.append(self.transition3[i](y_list[i]))
+ else:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ x = self.stage4(x_list)
+
+ # Upsampling
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
+ mode='bilinear', align_corners=self.align_corners)
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
+ mode='bilinear', align_corners=self.align_corners)
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
+ mode='bilinear', align_corners=self.align_corners)
+
+ return torch.cat([x[0], x1, x2, x3], 1)
+
+ def load_pretrained_weights(self, pretrained_path=''):
+ model_dict = self.state_dict()
+
+ if not os.path.exists(pretrained_path):
+ print(f'\nFile "{pretrained_path}" does not exist.')
+ print('You need to specify the correct path to the pre-trained weights.\n'
+ 'You can download the weights for HRNet from the repository:\n'
+ 'https://github.com/HRNet/HRNet-Image-Classification')
+ exit(1)
+ pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
+ pretrained_dict.items()}
+
+ print('model_dict-pretrained_dict:', sorted(list(set(model_dict) - set(pretrained_dict))))
+ print('pretrained_dict-model_dict:', sorted(list(set(pretrained_dict) - set(model_dict))))
+
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()}
+
+ model_dict.update(pretrained_dict)
+ self.load_state_dict(model_dict)
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/ocr.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..df3b4f67959fc6a088b93ee7a34b15c1e07402df
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/ocr.py
@@ -0,0 +1,141 @@
+import torch
+import torch.nn as nn
+import torch._utils
+import torch.nn.functional as F
+
+
+class SpatialGather_Module(nn.Module):
+ """
+ Aggregate the context features according to the initial
+ predicted probability distribution.
+ Employ the soft-weighted method to aggregate the context.
+ """
+
+ def __init__(self, cls_num=0, scale=1):
+ super(SpatialGather_Module, self).__init__()
+ self.cls_num = cls_num
+ self.scale = scale
+
+ def forward(self, feats, probs):
+ batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
+ probs = probs.view(batch_size, c, -1)
+ feats = feats.view(batch_size, feats.size(1), -1)
+ feats = feats.permute(0, 2, 1) # batch x hw x c
+ probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
+ ocr_context = torch.matmul(probs, feats) \
+ .permute(0, 2, 1).unsqueeze(3) # batch x k x c
+ return ocr_context
+
+
+class SpatialOCR_Module(nn.Module):
+ """
+ Implementation of the OCR module:
+ We aggregate the global object representation to update the representation for each pixel.
+ """
+
+ def __init__(self,
+ in_channels,
+ key_channels,
+ out_channels,
+ scale=1,
+ dropout=0.1,
+ norm_layer=nn.BatchNorm2d,
+ align_corners=True):
+ super(SpatialOCR_Module, self).__init__()
+ self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
+ norm_layer, align_corners)
+ _in_channels = 2 * in_channels
+
+ self.conv_bn_dropout = nn.Sequential(
+ nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
+ nn.Dropout2d(dropout)
+ )
+
+ def forward(self, feats, proxy_feats):
+ context = self.object_context_block(feats, proxy_feats)
+
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
+
+ return output
+
+
+class ObjectAttentionBlock2D(nn.Module):
+ '''
+ The basic implementation for object context block
+ Input:
+ N X C X H X W
+ Parameters:
+ in_channels : the dimension of the input feature map
+ key_channels : the dimension after the key/query transform
+ scale : choose the scale to downsample the input feature maps (save memory cost)
+ bn_type : specify the bn type
+ Return:
+ N X C X H X W
+ '''
+
+ def __init__(self,
+ in_channels,
+ key_channels,
+ scale=1,
+ norm_layer=nn.BatchNorm2d,
+ align_corners=True):
+ super(ObjectAttentionBlock2D, self).__init__()
+ self.scale = scale
+ self.in_channels = in_channels
+ self.key_channels = key_channels
+ self.align_corners = align_corners
+
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
+ self.f_pixel = nn.Sequential(
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
+ )
+ self.f_object = nn.Sequential(
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
+ )
+ self.f_down = nn.Sequential(
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
+ )
+ self.f_up = nn.Sequential(
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
+ kernel_size=1, stride=1, padding=0, bias=False),
+ nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
+ )
+
+ def forward(self, x, proxy):
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
+ if self.scale > 1:
+ x = self.pool(x)
+
+ query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
+ query = query.permute(0, 2, 1)
+ key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
+ value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
+ value = value.permute(0, 2, 1)
+
+ sim_map = torch.matmul(query, key)
+ sim_map = (self.key_channels ** -.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+
+ # add bg context ...
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.view(batch_size, self.key_channels, *x.size()[2:])
+ context = self.f_up(context)
+ if self.scale > 1:
+ context = F.interpolate(input=context, size=(h, w),
+ mode='bilinear', align_corners=self.align_corners)
+
+ return context
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnet.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..349ea1cbd882a9b0daa1d6146b634e9baf3726e0
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnet.py
@@ -0,0 +1,39 @@
+import torch
+from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
+
+
+class ResNetBackbone(torch.nn.Module):
+ def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
+ super(ResNetBackbone, self).__init__()
+
+ if backbone == 'resnet34':
+ pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
+ elif backbone == 'resnet50':
+ pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
+ elif backbone == 'resnet101':
+ pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
+ elif backbone == 'resnet152':
+ pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
+ else:
+ raise RuntimeError(f'unknown backbone: {backbone}')
+
+ self.conv1 = pretrained.conv1
+ self.bn1 = pretrained.bn1
+ self.relu = pretrained.relu
+ self.maxpool = pretrained.maxpool
+ self.layer1 = pretrained.layer1
+ self.layer2 = pretrained.layer2
+ self.layer3 = pretrained.layer3
+ self.layer4 = pretrained.layer4
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ c1 = self.layer1(x)
+ c2 = self.layer2(c1)
+ c3 = self.layer3(c2)
+ c4 = self.layer4(c3)
+
+ return c1, c2, c3, c4
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnetv1b.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnetv1b.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/modeling/resnetv1b.py
@@ -0,0 +1,276 @@
+import torch
+import torch.nn as nn
+GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
+
+
+class BasicBlockV1b(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
+ super(BasicBlockV1b, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, bias=False)
+ self.bn1 = norm_layer(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
+ padding=previous_dilation, dilation=previous_dilation, bias=False)
+ self.bn2 = norm_layer(planes)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out = out + residual
+ out = self.relu(out)
+
+ return out
+
+
+class BottleneckV1b(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
+ super(BottleneckV1b, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, bias=False)
+ self.bn2 = norm_layer(planes)
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out = out + residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetV1b(nn.Module):
+ """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
+
+ Parameters
+ ----------
+ block : Block
+ Class for the residual block. Options are BasicBlockV1, BottleneckV1.
+ layers : list of int
+ Numbers of layers in each block
+ classes : int, default 1000
+ Number of classification classes.
+ dilated : bool, default False
+ Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
+ typically used in Semantic Segmentation.
+ norm_layer : object
+ Normalization layer used (default: :class:`nn.BatchNorm2d`)
+ deep_stem : bool, default False
+ Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
+ avg_down : bool, default False
+ Whether to use average pooling for projection skip connection between stages/downsample.
+ final_drop : float, default 0.0
+ Dropout ratio before the final classification layer.
+
+ Reference:
+ - He, Kaiming, et al. "Deep residual learning for image recognition."
+ Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
+
+ - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
+ """
+ def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
+ avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
+ self.inplanes = stem_width*2 if deep_stem else 64
+ super(ResNetV1b, self).__init__()
+ if not deep_stem:
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
+ norm_layer(stem_width),
+ nn.ReLU(True),
+ nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
+ norm_layer(stem_width),
+ nn.ReLU(True),
+ nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
+ )
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(True)
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
+ norm_layer=norm_layer)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
+ norm_layer=norm_layer)
+ if dilated:
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
+ avg_down=avg_down, norm_layer=norm_layer)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
+ avg_down=avg_down, norm_layer=norm_layer)
+ else:
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ avg_down=avg_down, norm_layer=norm_layer)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ avg_down=avg_down, norm_layer=norm_layer)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.drop = None
+ if final_drop > 0.0:
+ self.drop = nn.Dropout(final_drop)
+ self.fc = nn.Linear(512 * block.expansion, classes)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
+ avg_down=False, norm_layer=nn.BatchNorm2d):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = []
+ if avg_down:
+ if dilation == 1:
+ downsample.append(
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
+ )
+ else:
+ downsample.append(
+ nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
+ )
+ downsample.extend([
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
+ kernel_size=1, stride=1, bias=False),
+ norm_layer(planes * block.expansion)
+ ])
+ downsample = nn.Sequential(*downsample)
+ else:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ norm_layer(planes * block.expansion)
+ )
+
+ layers = []
+ if dilation in (1, 2):
+ layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
+ previous_dilation=dilation, norm_layer=norm_layer))
+ elif dilation == 4:
+ layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
+ previous_dilation=dilation, norm_layer=norm_layer))
+ else:
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
+
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation,
+ previous_dilation=dilation, norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ if self.drop is not None:
+ x = self.drop(x)
+ x = self.fc(x)
+
+ return x
+
+
+def _safe_state_dict_filtering(orig_dict, model_dict_keys):
+ filtered_orig_dict = {}
+ for k, v in orig_dict.items():
+ if k in model_dict_keys:
+ filtered_orig_dict[k] = v
+ else:
+ print(f"[ERROR] Failed to load <{k}> in backbone")
+ return filtered_orig_dict
+
+
+def resnet34_v1b(pretrained=False, **kwargs):
+ model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model_dict = model.state_dict()
+ filtered_orig_dict = _safe_state_dict_filtering(
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
+ model_dict.keys()
+ )
+ model_dict.update(filtered_orig_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+
+def resnet50_v1s(pretrained=False, **kwargs):
+ model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
+ if pretrained:
+ model_dict = model.state_dict()
+ filtered_orig_dict = _safe_state_dict_filtering(
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
+ model_dict.keys()
+ )
+ model_dict.update(filtered_orig_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+
+def resnet101_v1s(pretrained=False, **kwargs):
+ model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
+ if pretrained:
+ model_dict = model.state_dict()
+ filtered_orig_dict = _safe_state_dict_filtering(
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
+ model_dict.keys()
+ )
+ model_dict.update(filtered_orig_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+
+def resnet152_v1s(pretrained=False, **kwargs):
+ model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
+ if pretrained:
+ model_dict = model.state_dict()
+ filtered_orig_dict = _safe_state_dict_filtering(
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
+ model_dict.keys()
+ )
+ model_dict.update(filtered_orig_dict)
+ model.load_state_dict(model_dict)
+ return model
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/ops.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..f46ae39aeb14cdb0ca6d9922b67f4562c40be8df
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/ops.py
@@ -0,0 +1,83 @@
+import torch
+from torch import nn as nn
+import numpy as np
+
+from . import initializer as initializer
+from ..utils.cython import get_dist_maps
+
+
+def select_activation_function(activation):
+ if isinstance(activation, str):
+ if activation.lower() == 'relu':
+ return nn.ReLU
+ elif activation.lower() == 'softplus':
+ return nn.Softplus
+ else:
+ raise ValueError(f"Unknown activation type {activation}")
+ elif isinstance(activation, nn.Module):
+ return activation
+ else:
+ raise ValueError(f"Unknown activation type {activation}")
+
+
+class BilinearConvTranspose2d(nn.ConvTranspose2d):
+ def __init__(self, in_channels, out_channels, scale, groups=1):
+ kernel_size = 2 * scale - scale % 2
+ self.scale = scale
+
+ super().__init__(
+ in_channels, out_channels,
+ kernel_size=kernel_size,
+ stride=scale,
+ padding=1,
+ groups=groups,
+ bias=False)
+
+ self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
+
+
+class DistMaps(nn.Module):
+ def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
+ super(DistMaps, self).__init__()
+ self.spatial_scale = spatial_scale
+ self.norm_radius = norm_radius
+ self.cpu_mode = cpu_mode
+
+ def get_coord_features(self, points, batchsize, rows, cols):
+ if self.cpu_mode:
+ coords = []
+ for i in range(batchsize):
+ norm_delimeter = self.spatial_scale * self.norm_radius
+ coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
+ norm_delimeter))
+ coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
+ else:
+ num_points = points.shape[1] // 2
+ points = points.view(-1, 2)
+ invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
+ row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
+ col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
+
+ coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
+ coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
+
+ add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
+ coords.add_(-add_xy)
+ coords.div_(self.norm_radius * self.spatial_scale)
+ coords.mul_(coords)
+
+ coords[:, 0] += coords[:, 1]
+ coords = coords[:, :1]
+
+ coords[invalid_points, :, :, :] = 1e6
+
+ coords = coords.view(-1, num_points, 1, rows, cols)
+ coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
+ coords = coords.view(-1, 2, rows, cols)
+
+ coords.sqrt_().mul_(2).tanh_()
+
+ return coords
+
+ def forward(self, x, coords):
+ return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/LICENSE b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..fec54698d35926513ca1ddb7b6cee791daca834e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 Tamaki Kojima
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/README.md b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..d9a9ea21ca73d08dbac027aea3a4909d6b67ace3
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/README.md
@@ -0,0 +1,127 @@
+# pytorch-syncbn
+
+Tamaki Kojima(tamakoji@gmail.com)
+
+## Announcement
+
+**Pytorch 1.0 support**
+
+## Overview
+This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
+
+The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn)
+
+## Remarks
+- Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel`
+- Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation
+- You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
+- Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d`
+
+## Requirements
+For PyTorch, please refer to https://pytorch.org/
+
+NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
+
+It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
+
+```
+sudo apt-get install ninja-build
+```
+
+Also install all dependencies for python. For pip, run:
+
+
+```
+pip install -U -r requirements.txt
+```
+
+## Build
+
+There is no need to build. just run and JIT will take care.
+JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
+
+## Usage
+
+Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d`
+
+```
+import torch
+from modules import nn as NN
+num_gpu = torch.cuda.device_count()
+model = nn.Sequential(
+ nn.Conv2d(3, 3, 1, 1, bias=False),
+ NN.BatchNorm2d(3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(3, 3, 1, 1, bias=False),
+ NN.BatchNorm2d(3),
+).cuda()
+model = nn.DataParallel(model, device_ids=range(num_gpu))
+x = torch.rand(num_gpu, 3, 2, 2).cuda()
+z = model(x)
+```
+
+## Math
+
+### Forward
+1. compute in each gpu
+2. gather all from workers to master and compute where
+
+
+
+ and
+
+
+
+ and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
+
+3. forward batchnorm using global stats by
+
+
+
+ and then
+
+
+
+ where is weight parameter and is bias parameter.
+
+4. save for backward
+
+### Backward
+
+1. Restore saved
+
+2. Compute below sums on each gpu
+
+
+
+ and
+
+
+
+ where
+
+ then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
+
+3. compute gradients using global stats
+
+
+
+ where
+
+
+
+ and
+
+
+
+ and finally,
+
+
+
+
+
+
+
+ Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
+
+ You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/)
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8eb83a9d88b25cb8f1faebc9236da929a7722c7
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py
@@ -0,0 +1 @@
+from .syncbn import batchnorm2d_sync
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c14098f0cfa422920f01fe4985dbeb7fedc2d1
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py
@@ -0,0 +1,54 @@
+"""
+/*****************************************************************************/
+
+Extension module loader
+
+code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark
+
+/*****************************************************************************/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import os.path
+
+import torch
+
+try:
+ from torch.utils.cpp_extension import load
+ from torch.utils.cpp_extension import CUDA_HOME
+except ImportError:
+ raise ImportError(
+ "The cpp layer extensions requires PyTorch 0.4 or higher")
+
+
+def _load_C_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ this_dir = os.path.join(this_dir, "csrc")
+
+ main_file = glob.glob(os.path.join(this_dir, "*.cpp"))
+ sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp"))
+ sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu"))
+
+ sources = main_file + sources_cpu
+
+ extra_cflags = []
+ extra_cuda_cflags = []
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ sources.extend(sources_cuda)
+ extra_cflags = ["-O3", "-DWITH_CUDA"]
+ extra_cuda_cflags = ["--expt-extended-lambda"]
+ sources = [os.path.join(this_dir, s) for s in sources]
+ extra_include_paths = [this_dir]
+ return load(
+ name="ext_lib",
+ sources=sources,
+ extra_cflags=extra_cflags,
+ extra_include_paths=extra_include_paths,
+ extra_cuda_cflags=extra_cuda_cflags,
+ )
+
+
+_backend = _load_C_extensions()
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h
new file mode 100644
index 0000000000000000000000000000000000000000..52567a478633aa043ad86624253763e594121bd1
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h
@@ -0,0 +1,70 @@
+/*****************************************************************************
+
+SyncBN
+
+*****************************************************************************/
+#pragma once
+
+#ifdef WITH_CUDA
+#include "cuda/ext_lib.h"
+#endif
+
+/// SyncBN
+
+std::vector syncbn_sum_sqsum(const at::Tensor& x) {
+ if (x.is_cuda()) {
+#ifdef WITH_CUDA
+ return syncbn_sum_sqsum_cuda(x);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("CPU implementation not supported");
+ }
+}
+
+at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight,
+ const at::Tensor& bias, const at::Tensor& mean,
+ const at::Tensor& var, bool affine, float eps) {
+ if (x.is_cuda()) {
+#ifdef WITH_CUDA
+ return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("CPU implementation not supported");
+ }
+}
+
+std::vector syncbn_backward_xhat(const at::Tensor& dz,
+ const at::Tensor& x,
+ const at::Tensor& mean,
+ const at::Tensor& var, float eps) {
+ if (dz.is_cuda()) {
+#ifdef WITH_CUDA
+ return syncbn_backward_xhat_cuda(dz, x, mean, var, eps);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("CPU implementation not supported");
+ }
+}
+
+std::vector syncbn_backward(
+ const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
+ const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
+ const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
+ float eps) {
+ if (dz.is_cuda()) {
+#ifdef WITH_CUDA
+ return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz,
+ sum_dz_xhat, affine, eps);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ } else {
+ AT_ERROR("CPU implementation not supported");
+ }
+}
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9458eba4f4715673ba480fae2c318f4745e8fe78
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu
@@ -0,0 +1,280 @@
+/*****************************************************************************
+
+CUDA SyncBN code
+
+code referenced from : https://github.com/mapillary/inplace_abn
+
+*****************************************************************************/
+#include
+#include
+#include
+#include
+#include "cuda/common.h"
+
+// Utilities
+void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) {
+ num = x.size(0);
+ chn = x.size(1);
+ sp = 1;
+ for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i);
+}
+
+/// SyncBN
+
+template
+struct SqSumOp {
+ __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {}
+ __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
+ T x = tensor[(batch * chn + plane) * sp + n];
+ return Pair(x, x * x); // x, x^2
+ }
+ const T *tensor;
+ const int chn;
+ const int sp;
+};
+
+template
+__global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum,
+ int num, int chn, int sp) {
+ int plane = blockIdx.x;
+ Pair res =
+ reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp);
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ sum[plane] = res.v1;
+ sqsum[plane] = res.v2;
+ }
+}
+
+std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) {
+ CHECK_INPUT(x);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Prepare output tensors
+ auto sum = at::empty({chn}, x.options());
+ auto sqsum = at::empty({chn}, x.options());
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ AT_DISPATCH_FLOATING_TYPES(
+ x.type(), "syncbn_sum_sqsum_cuda", ([&] {
+ syncbn_sum_sqsum_kernel<<>>(
+ x.data(), sum.data(),
+ sqsum.data(), num, chn, sp);
+ }));
+ return {sum, sqsum};
+}
+
+template
+__global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight,
+ const T *bias, const T *mean,
+ const T *var, bool affine, float eps,
+ int num, int chn, int sp) {
+ int plane = blockIdx.x;
+ T _mean = mean[plane];
+ T _var = var[plane];
+ T _weight = affine ? weight[plane] : T(1);
+ T _bias = affine ? bias[plane] : T(0);
+ float _invstd = T(0);
+ if (_var || eps) {
+ _invstd = rsqrt(_var + eps);
+ }
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ T _x = x[(batch * chn + plane) * sp + n];
+ T _xhat = (_x - _mean) * _invstd;
+ T _z = _xhat * _weight + _bias;
+ z[(batch * chn + plane) * sp + n] = _z;
+ }
+ }
+}
+
+at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight,
+ const at::Tensor &bias, const at::Tensor &mean,
+ const at::Tensor &var, bool affine, float eps) {
+ CHECK_INPUT(x);
+ CHECK_INPUT(weight);
+ CHECK_INPUT(bias);
+ CHECK_INPUT(mean);
+ CHECK_INPUT(var);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ auto z = at::zeros_like(x);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ AT_DISPATCH_FLOATING_TYPES(
+ x.type(), "syncbn_forward_cuda", ([&] {
+ syncbn_forward_kernel<<>>(
+ z.data(), x.data(),
+ weight.data(), bias.data(),
+ mean.data(), var.data(),
+ affine, eps, num, chn, sp);
+ }));
+ return z;
+}
+
+template
+struct XHatOp {
+ __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s)
+ : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {}
+ __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
+ // xhat = (x - bias) * weight
+ T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight;
+ // dxhat * x_hat
+ T _dz = dz[(batch * chn + plane) * sp + n];
+ return Pair(_dz, _dz * _xhat);
+ }
+ const T weight;
+ const T bias;
+ const T *dz;
+ const T *x;
+ const int chn;
+ const int sp;
+};
+
+template
+__global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x,
+ const T *mean, const T *var,
+ T *sum_dz, T *sum_dz_xhat,
+ float eps, int num, int chn,
+ int sp) {
+ int plane = blockIdx.x;
+ T _mean = mean[plane];
+ T _var = var[plane];
+ T _invstd = T(0);
+ if (_var || eps) {
+ _invstd = rsqrt(_var + eps);
+ }
+ Pair res = reduce, XHatOp>(
+ XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp);
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ // \sum(\frac{dJ}{dy_i})
+ sum_dz[plane] = res.v1;
+ // \sum(\frac{dJ}{dy_i}*\hat{x_i})
+ sum_dz_xhat[plane] = res.v2;
+ }
+}
+
+std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz,
+ const at::Tensor &x,
+ const at::Tensor &mean,
+ const at::Tensor &var,
+ float eps) {
+ CHECK_INPUT(dz);
+ CHECK_INPUT(x);
+ CHECK_INPUT(mean);
+ CHECK_INPUT(var);
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+ // Prepare output tensors
+ auto sum_dz = at::empty({chn}, x.options());
+ auto sum_dz_xhat = at::empty({chn}, x.options());
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ AT_DISPATCH_FLOATING_TYPES(
+ x.type(), "syncbn_backward_xhat_cuda", ([&] {
+ syncbn_backward_xhat_kernel<<>>(
+ dz.data(), x.data(), mean.data(),
+ var.data(), sum_dz.data(),
+ sum_dz_xhat.data(), eps, num, chn, sp);
+ }));
+ return {sum_dz, sum_dz_xhat};
+}
+
+template
+__global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight,
+ const T *bias, const T *mean,
+ const T *var, const T *sum_dz,
+ const T *sum_dz_xhat, T *dx, T *dweight,
+ T *dbias, bool affine, float eps,
+ int num, int chn, int sp) {
+ int plane = blockIdx.x;
+ T _mean = mean[plane];
+ T _var = var[plane];
+ T _weight = affine ? weight[plane] : T(1);
+ T _sum_dz = sum_dz[plane];
+ T _sum_dz_xhat = sum_dz_xhat[plane];
+ T _invstd = T(0);
+ if (_var || eps) {
+ _invstd = rsqrt(_var + eps);
+ }
+ /*
+ \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} (
+ N\frac{dJ}{d\hat{x_i}} -
+ \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) -
+ \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j})
+ )
+ Note : N is omitted here since it will be accumulated and
+ _sum_dz and _sum_dz_xhat expected to be already normalized
+ before the call.
+ */
+ if (dx) {
+ T _mul = _weight * _invstd;
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ T _dz = dz[(batch * chn + plane) * sp + n];
+ T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd;
+ T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul;
+ dx[(batch * chn + plane) * sp + n] = _dx;
+ }
+ }
+ }
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ if (affine) {
+ T _norm = num * sp;
+ dweight[plane] += _sum_dz_xhat * _norm;
+ dbias[plane] += _sum_dz * _norm;
+ }
+ }
+}
+
+std::vector syncbn_backward_cuda(
+ const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight,
+ const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var,
+ const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine,
+ float eps) {
+ CHECK_INPUT(dz);
+ CHECK_INPUT(x);
+ CHECK_INPUT(weight);
+ CHECK_INPUT(bias);
+ CHECK_INPUT(mean);
+ CHECK_INPUT(var);
+ CHECK_INPUT(sum_dz);
+ CHECK_INPUT(sum_dz_xhat);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Prepare output tensors
+ auto dx = at::zeros_like(dz);
+ auto dweight = at::zeros_like(weight);
+ auto dbias = at::zeros_like(bias);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ AT_DISPATCH_FLOATING_TYPES(
+ x.type(), "syncbn_backward_cuda", ([&] {
+ syncbn_backward_kernel<<>>(
+ dz.data(), x.data(), weight.data(),
+ bias.data(), mean.data(), var.data(),
+ sum_dz.data(), sum_dz_xhat.data(),
+ dx.data(), dweight.data(),
+ dbias.data(), affine, eps, num, chn, sp);
+ }));
+ return {dx, dweight, dbias};
+}
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h
new file mode 100644
index 0000000000000000000000000000000000000000..a6cb2debeea3b8caa0f7c640601a94dce4e629cb
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h
@@ -0,0 +1,124 @@
+/*****************************************************************************
+
+CUDA utility funcs
+
+code referenced from : https://github.com/mapillary/inplace_abn
+
+*****************************************************************************/
+#pragma once
+
+#include
+
+// Checks
+#ifndef AT_CHECK
+ #define AT_CHECK AT_ASSERT
+#endif
+#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+/*
+ * General settings
+ */
+const int WARP_SIZE = 32;
+const int MAX_BLOCK_SIZE = 512;
+
+template
+struct Pair {
+ T v1, v2;
+ __device__ Pair() {}
+ __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
+ __device__ Pair(T v) : v1(v), v2(v) {}
+ __device__ Pair(int v) : v1(v), v2(v) {}
+ __device__ Pair &operator+=(const Pair &a) {
+ v1 += a.v1;
+ v2 += a.v2;
+ return *this;
+ }
+};
+
+/*
+ * Utility functions
+ */
+template
+__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask,
+ int width = warpSize,
+ unsigned int mask = 0xffffffff) {
+#if CUDART_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+ return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
+
+static int getNumThreads(int nElem) {
+ int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
+ for (int i = 0; i != 5; ++i) {
+ if (nElem <= threadSizes[i]) {
+ return threadSizes[i];
+ }
+ }
+ return MAX_BLOCK_SIZE;
+}
+
+template
+static __device__ __forceinline__ T warpSum(T val) {
+#if __CUDA_ARCH__ >= 300
+ for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
+ }
+#else
+ __shared__ T values[MAX_BLOCK_SIZE];
+ values[threadIdx.x] = val;
+ __threadfence_block();
+ const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
+ for (int i = 1; i < WARP_SIZE; i++) {
+ val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
+ }
+#endif
+ return val;
+}
+
+template
+static __device__ __forceinline__ Pair warpSum(Pair value) {
+ value.v1 = warpSum(value.v1);
+ value.v2 = warpSum(value.v2);
+ return value;
+}
+
+template
+__device__ T reduce(Op op, int plane, int N, int C, int S) {
+ T sum = (T)0;
+ for (int batch = 0; batch < N; ++batch) {
+ for (int x = threadIdx.x; x < S; x += blockDim.x) {
+ sum += op(batch, plane, x);
+ }
+ }
+
+ // sum over NumThreads within a warp
+ sum = warpSum(sum);
+
+ // 'transpose', and reduce within warp again
+ __shared__ T shared[32];
+ __syncthreads();
+ if (threadIdx.x % WARP_SIZE == 0) {
+ shared[threadIdx.x / WARP_SIZE] = sum;
+ }
+ if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
+ // zero out the other entries in shared
+ shared[threadIdx.x] = (T)0;
+ }
+ __syncthreads();
+ if (threadIdx.x / WARP_SIZE == 0) {
+ sum = warpSum(shared[threadIdx.x]);
+ if (threadIdx.x == 0) {
+ shared[0] = sum;
+ }
+ }
+ __syncthreads();
+
+ // Everyone picks it up, should be broadcast into the whole gradInput
+ return shared[0];
+}
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h
new file mode 100644
index 0000000000000000000000000000000000000000..1d707615ffcf5ad7dcabc60de8c9a0cfe035bf14
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h
@@ -0,0 +1,24 @@
+/*****************************************************************************
+
+CUDA SyncBN code
+
+*****************************************************************************/
+#pragma once
+#include
+#include
+
+/// Sync-BN
+std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x);
+at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight,
+ const at::Tensor& bias, const at::Tensor& mean,
+ const at::Tensor& var, bool affine, float eps);
+std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz,
+ const at::Tensor& x,
+ const at::Tensor& mean,
+ const at::Tensor& var,
+ float eps);
+std::vector syncbn_backward_cuda(
+ const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
+ const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
+ const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
+ float eps);
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9c2ecf142dd70de8a3bdaf9b04470c4cacee3086
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp
@@ -0,0 +1,10 @@
+#include "bn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation");
+ m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation");
+ m.def("syncbn_backward_xhat", &syncbn_backward_xhat,
+ "First part of SyncBN backward computation");
+ m.def("syncbn_backward", &syncbn_backward,
+ "Second part of SyncBN backward computation");
+}
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py
new file mode 100644
index 0000000000000000000000000000000000000000..867a432d14f4f28c25075caa85b22726424293ae
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/functional/syncbn.py
@@ -0,0 +1,137 @@
+"""
+/*****************************************************************************/
+
+BatchNorm2dSync with multi-gpu
+
+code referenced from : https://github.com/mapillary/inplace_abn
+
+/*****************************************************************************/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import torch.cuda.comm as comm
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from ._csrc import _backend
+
+
+def _count_samples(x):
+ count = 1
+ for i, s in enumerate(x.size()):
+ if i != 1:
+ count *= s
+ return count
+
+
+class BatchNorm2dSyncFunc(Function):
+
+ @staticmethod
+ def forward(ctx, x, weight, bias, running_mean, running_var,
+ extra, compute_stats=True, momentum=0.1, eps=1e-05):
+ def _parse_extra(ctx, extra):
+ ctx.is_master = extra["is_master"]
+ if ctx.is_master:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queues = extra["worker_queues"]
+ ctx.worker_ids = extra["worker_ids"]
+ else:
+ ctx.master_queue = extra["master_queue"]
+ ctx.worker_queue = extra["worker_queue"]
+ # Save context
+ if extra is not None:
+ _parse_extra(ctx, extra)
+ ctx.compute_stats = compute_stats
+ ctx.momentum = momentum
+ ctx.eps = eps
+ ctx.affine = weight is not None and bias is not None
+ if ctx.compute_stats:
+ N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
+ assert N > 1
+ # 1. compute sum(x) and sum(x^2)
+ xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
+ if ctx.is_master:
+ xsums, xsqsums = [xsum], [xsqsum]
+ # master : gatther all sum(x) and sum(x^2) from slaves
+ for _ in range(ctx.master_queue.maxsize):
+ xsum_w, xsqsum_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ xsums.append(xsum_w)
+ xsqsums.append(xsqsum_w)
+ xsum = comm.reduce_add(xsums)
+ xsqsum = comm.reduce_add(xsqsums)
+ mean = xsum / N
+ sumvar = xsqsum - xsum * mean
+ var = sumvar / N
+ uvar = sumvar / (N - 1)
+ # master : broadcast global mean, variance to all slaves
+ tensors = comm.broadcast_coalesced(
+ (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ # slave : send sum(x) and sum(x^2) to master
+ ctx.master_queue.put((xsum, xsqsum))
+ # slave : get global mean and variance
+ mean, uvar, var = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
+
+ # Update running stats
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
+ ctx.N = N
+ ctx.save_for_backward(x, weight, bias, mean, var)
+ else:
+ mean, var = running_mean, running_var
+
+ # do batch norm forward
+ z = _backend.syncbn_forward(x, weight, bias, mean, var,
+ ctx.affine, ctx.eps)
+ return z
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, dz):
+ x, weight, bias, mean, var = ctx.saved_tensors
+ dz = dz.contiguous()
+
+ # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
+ sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
+ dz, x, mean, var, ctx.eps)
+ if ctx.is_master:
+ sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
+ # master : gatther from slaves
+ for _ in range(ctx.master_queue.maxsize):
+ sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
+ ctx.master_queue.task_done()
+ sum_dzs.append(sum_dz_w)
+ sum_dz_xhats.append(sum_dz_xhat_w)
+ # master : compute global stats
+ sum_dz = comm.reduce_add(sum_dzs)
+ sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
+ sum_dz /= ctx.N
+ sum_dz_xhat /= ctx.N
+ # master : broadcast global stats
+ tensors = comm.broadcast_coalesced(
+ (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
+ for ts, queue in zip(tensors[1:], ctx.worker_queues):
+ queue.put(ts)
+ else:
+ # slave : send to master
+ ctx.master_queue.put((sum_dz, sum_dz_xhat))
+ # slave : get global stats
+ sum_dz, sum_dz_xhat = ctx.worker_queue.get()
+ ctx.worker_queue.task_done()
+
+ # do batch norm backward
+ dx, dweight, dbias = _backend.syncbn_backward(
+ dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
+ ctx.affine, ctx.eps)
+
+ return dx, dweight, dbias, \
+ None, None, None, None, None, None
+
+batchnorm2d_sync = BatchNorm2dSyncFunc.apply
+
+__all__ = ["batchnorm2d_sync"]
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c5aca9879273811b681baddc5755e20e838a361
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py
@@ -0,0 +1 @@
+from .syncbn import *
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b118c9d4aac3ee86821797bc9f794cd9aa38b1b2
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/model/syncbn/modules/nn/syncbn.py
@@ -0,0 +1,148 @@
+"""
+/*****************************************************************************/
+
+BatchNorm2dSync with multi-gpu
+
+/*****************************************************************************/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+try:
+ # python 3
+ from queue import Queue
+except ImportError:
+ # python 2
+ from Queue import Queue
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.parameter import Parameter
+from isegm.model.syncbn.modules.functional import batchnorm2d_sync
+
+
+class _BatchNorm(nn.Module):
+ """
+ Customized BatchNorm from nn.BatchNorm
+ >> added freeze attribute to enable bn freeze.
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True):
+ super(_BatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ self.freezed = False
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ else:
+ self.register_parameter('running_mean', None)
+ self.register_parameter('running_var', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ if self.affine:
+ self.weight.data.uniform_()
+ self.bias.data.zero_()
+
+ def _check_input_dim(self, input):
+ return NotImplemented
+
+ def forward(self, input):
+ self._check_input_dim(input)
+
+ compute_stats = not self.freezed and \
+ self.training and self.track_running_stats
+
+ ret = F.batch_norm(input, self.running_mean, self.running_var,
+ self.weight, self.bias, compute_stats,
+ self.momentum, self.eps)
+ return ret
+
+ def extra_repr(self):
+ return '{num_features}, eps={eps}, momentum={momentum}, '\
+ 'affine={affine}, ' \
+ 'track_running_stats={track_running_stats}'.format(
+ **self.__dict__)
+
+
+class BatchNorm2dNoSync(_BatchNorm):
+ """
+ Equivalent to nn.BatchNorm2d
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+
+
+class BatchNorm2dSync(BatchNorm2dNoSync):
+ """
+ BatchNorm2d with automatic multi-GPU Sync
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True):
+ super(BatchNorm2dSync, self).__init__(
+ num_features, eps=eps, momentum=momentum, affine=affine,
+ track_running_stats=track_running_stats)
+ self.sync_enabled = True
+ self.devices = list(range(torch.cuda.device_count()))
+ if len(self.devices) > 1:
+ # Initialize queues
+ self.worker_ids = self.devices[1:]
+ self.master_queue = Queue(len(self.worker_ids))
+ self.worker_queues = [Queue(1) for _ in self.worker_ids]
+
+ def forward(self, x):
+ compute_stats = not self.freezed and \
+ self.training and self.track_running_stats
+ if self.sync_enabled and compute_stats and len(self.devices) > 1:
+ if x.get_device() == self.devices[0]:
+ # Master mode
+ extra = {
+ "is_master": True,
+ "master_queue": self.master_queue,
+ "worker_queues": self.worker_queues,
+ "worker_ids": self.worker_ids
+ }
+ else:
+ # Worker mode
+ extra = {
+ "is_master": False,
+ "master_queue": self.master_queue,
+ "worker_queue": self.worker_queues[
+ self.worker_ids.index(x.get_device())]
+ }
+ return batchnorm2d_sync(x, self.weight, self.bias,
+ self.running_mean, self.running_var,
+ extra, compute_stats, self.momentum,
+ self.eps)
+ return super(BatchNorm2dSync, self).forward(x)
+
+ def __repr__(self):
+ """repr"""
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
+ 'affine={affine}, ' \
+ 'track_running_stats={track_running_stats},' \
+ 'devices={devices})'
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
+
+#BatchNorm2d = BatchNorm2dNoSync
+BatchNorm2d = BatchNorm2dSync
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb66bdbba883b9477bbc1a52d8355131d32a04cb
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/__init__.py
@@ -0,0 +1,2 @@
+# noinspection PyUnresolvedReferences
+from .dist_maps import get_dist_maps
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..779a7f02ad7c2ba25e68302c6fc6683cd4ab54f7
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx
@@ -0,0 +1,63 @@
+import numpy as np
+cimport cython
+cimport numpy as np
+from libc.stdlib cimport malloc, free
+
+ctypedef struct qnode:
+ int row
+ int col
+ int layer
+ int orig_row
+ int orig_col
+
+@cython.infer_types(True)
+@cython.boundscheck(False)
+@cython.wraparound(False)
+@cython.nonecheck(False)
+def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points,
+ int height, int width, float norm_delimeter):
+ cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \
+ np.full((2, height, width), 1e6, dtype=np.float32, order="C")
+
+ cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0]
+ cdef int i, j, x, y, dx, dy
+ cdef qnode v
+ cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode))
+ cdef int qhead = 0, qtail = -1
+ cdef float ndist
+
+ for i in range(points.shape[0]):
+ x, y = round(points[i, 0]), round(points[i, 1])
+ if x >= 0:
+ qtail += 1
+ q[qtail].row = x
+ q[qtail].col = y
+ q[qtail].orig_row = x
+ q[qtail].orig_col = y
+ if i >= points.shape[0] / 2:
+ q[qtail].layer = 1
+ else:
+ q[qtail].layer = 0
+ dist_maps[q[qtail].layer, x, y] = 0
+
+ while qtail - qhead + 1 > 0:
+ v = q[qhead]
+ qhead += 1
+
+ for k in range(4):
+ x = v.row + dxy[2 * k]
+ y = v.col + dxy[2 * k + 1]
+
+ ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2
+ if (x >= 0 and y >= 0 and x < height and y < width and
+ dist_maps[v.layer, x, y] > ndist):
+ qtail += 1
+ q[qtail].orig_col = v.orig_col
+ q[qtail].orig_row = v.orig_row
+ q[qtail].layer = v.layer
+ q[qtail].row = x
+ q[qtail].col = y
+ dist_maps[v.layer, x, y] = ndist
+
+ free(q)
+ return dist_maps
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld
new file mode 100644
index 0000000000000000000000000000000000000000..bd4451729201b5ebc6bbbd8f392389ab6b530636
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld
@@ -0,0 +1,7 @@
+import numpy
+
+def make_ext(modname, pyxfilename):
+ from distutils.extension import Extension
+ return Extension(modname, [pyxfilename],
+ include_dirs=[numpy.get_include()],
+ extra_compile_args=['-O3'], language='c++')
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/dist_maps.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/dist_maps.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ffa1e3f25231cd7c48b66ef8ef5167235c3ea4e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/cython/dist_maps.py
@@ -0,0 +1,3 @@
+import pyximport; pyximport.install(pyximport=True, language_level=3)
+# noinspection PyUnresolvedReferences
+from ._get_dist_maps import get_dist_maps
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/misc.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..65ce96dc5667494446110fda75e29243338e2b88
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/misc.py
@@ -0,0 +1,62 @@
+from functools import partial
+
+import torch
+import numpy as np
+
+
+def get_dims_with_exclusion(dim, exclude=None):
+ dims = list(range(dim))
+ if exclude is not None:
+ dims.remove(exclude)
+
+ return dims
+
+
+def get_unique_labels(mask):
+ return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1
+
+
+def get_bbox_from_mask(mask):
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+ rmin, rmax = np.where(rows)[0][[0, -1]]
+ cmin, cmax = np.where(cols)[0][[0, -1]]
+
+ return rmin, rmax, cmin, cmax
+
+
+def expand_bbox(bbox, expand_ratio, min_crop_size=None):
+ rmin, rmax, cmin, cmax = bbox
+ rcenter = 0.5 * (rmin + rmax)
+ ccenter = 0.5 * (cmin + cmax)
+ height = expand_ratio * (rmax - rmin + 1)
+ width = expand_ratio * (cmax - cmin + 1)
+ if min_crop_size is not None:
+ height = max(height, min_crop_size)
+ width = max(width, min_crop_size)
+
+ rmin = int(round(rcenter - 0.5 * height))
+ rmax = int(round(rcenter + 0.5 * height))
+ cmin = int(round(ccenter - 0.5 * width))
+ cmax = int(round(ccenter + 0.5 * width))
+
+ return rmin, rmax, cmin, cmax
+
+
+def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
+ return (max(rmin, bbox[0]), min(rmax, bbox[1]),
+ max(cmin, bbox[2]), min(cmax, bbox[3]))
+
+
+def get_bbox_iou(b1, b2):
+ h_iou = get_segments_iou(b1[:2], b2[:2])
+ w_iou = get_segments_iou(b1[2:4], b2[2:4])
+ return h_iou * w_iou
+
+
+def get_segments_iou(s1, s2):
+ a, b = s1
+ c, d = s2
+ intersection = max(0, min(b, d) - max(a, c) + 1)
+ union = max(1e-6, max(b, d) - min(a, c) + 1)
+ return intersection / union
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/vis.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c1a291306453c15bdfe5117302beb62e0fe7248
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs/utils/vis.py
@@ -0,0 +1,129 @@
+from functools import lru_cache
+
+import cv2
+import numpy as np
+
+
+def visualize_instances(imask, bg_color=255,
+ boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8):
+ num_objects = imask.max() + 1
+ palette = get_palette(num_objects)
+ if bg_color is not None:
+ palette[0] = bg_color
+
+ result = palette[imask].astype(np.uint8)
+ if boundaries_color is not None:
+ boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width)
+ tresult = result.astype(np.float32)
+ tresult[boundaries_mask] = boundaries_color
+ tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result
+ result = tresult.astype(np.uint8)
+
+ return result
+
+
+@lru_cache(maxsize=16)
+def get_palette(num_cls):
+ palette = np.zeros(3 * num_cls, dtype=np.int32)
+
+ for j in range(0, num_cls):
+ lab = j
+ i = 0
+
+ while lab > 0:
+ palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i))
+ palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i))
+ palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i))
+ i = i + 1
+ lab >>= 3
+
+ return palette.reshape((-1, 3))
+
+
+def visualize_mask(mask, num_cls):
+ palette = get_palette(num_cls)
+ mask[mask == -1] = 0
+
+ return palette[mask].astype(np.uint8)
+
+
+def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1):
+ proposal_map, colors, candidates = proposals_info
+
+ proposal_map = draw_probmap(proposal_map)
+ for x, y in candidates:
+ proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1)
+
+ return proposal_map
+
+
+def draw_probmap(x):
+ return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT)
+
+
+def draw_points(image, points, color, radius=3):
+ image = image.copy()
+ for p in points:
+ image = cv2.circle(image, (int(p[1]), int(p[0])), radius, color, -1)
+
+ return image
+
+
+def draw_instance_map(x, palette=None):
+ num_colors = x.max() + 1
+ if palette is None:
+ palette = get_palette(num_colors)
+
+ return palette[x].astype(np.uint8)
+
+
+def blend_mask(image, mask, alpha=0.6):
+ if mask.min() == -1:
+ mask = mask.copy() + 1
+
+ imap = draw_instance_map(mask)
+ result = (image * (1 - alpha) + alpha * imap).astype(np.uint8)
+ return result
+
+
+def get_boundaries(instances_masks, boundaries_width=1):
+ boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool)
+
+ for obj_id in np.unique(instances_masks.flatten()):
+ if obj_id == 0:
+ continue
+
+ obj_mask = instances_masks == obj_id
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
+ inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool)
+
+ obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask))
+ boundaries = np.logical_or(boundaries, obj_boundary)
+ return boundaries
+
+
+def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0),
+ neg_color=(255, 0, 0), radius=4):
+ result = img.copy()
+
+ if mask is not None:
+ palette = get_palette(np.max(mask) + 1)
+ rgb_mask = palette[mask.astype(np.uint8)]
+
+ mask_region = (mask > 0).astype(np.uint8)
+ result = result * (1 - mask_region[:, :, np.newaxis]) + \
+ (1 - alpha) * mask_region[:, :, np.newaxis] * result + \
+ alpha * rgb_mask
+ result = result.astype(np.uint8)
+
+ # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8)
+
+ if clicks_list is not None and len(clicks_list) > 0:
+ pos_points = [click.coords for click in clicks_list if click.is_positive]
+ neg_points = [click.coords for click in clicks_list if not click.is_positive]
+
+ result = draw_points(result, pos_points, pos_color, radius=radius)
+ result = draw_points(result, neg_points, neg_color, radius=radius)
+
+ return result
+
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/fbrs_controller.py b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe9ca496193829990b3db7b0f141aabeb61fd35
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/fbrs_controller.py
@@ -0,0 +1,53 @@
+import torch
+from .fbrs.controller import InteractiveController
+from .fbrs.inference import utils
+
+
+class FBRSController:
+ def __init__(self, checkpoint_path, device='cuda:0', max_size=800):
+ model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260)
+
+ # Predictor params
+ zoomin_params = {
+ 'skip_clicks': 1,
+ 'target_size': 480,
+ 'expansion_ratio': 1.4,
+ }
+
+ predictor_params = {
+ 'brs_mode': 'f-BRS-B',
+ 'prob_thresh': 0.5,
+ 'zoom_in_params': zoomin_params,
+ 'predictor_params': {
+ 'net_clicks_limit': 8,
+ 'max_size': 800,
+ },
+ 'brs_opt_func_params': {'min_iou_diff': 1e-3},
+ 'lbfgs_params': {'maxfun': 20}
+ }
+
+ self.controller = InteractiveController(model, device, predictor_params)
+ self.anchored = False
+ self.device = device
+
+ def unanchor(self):
+ self.anchored = False
+
+ def interact(self, image, x, y, is_positive):
+ image = image.to(self.device, non_blocking=True)
+ if not self.anchored:
+ self.controller.set_image(image)
+ self.controller.reset_predictor()
+ self.anchored = True
+
+ self.controller.add_click(x, y, is_positive)
+ # return self.controller.result_mask
+ # return self.controller.probs_history[-1][1]
+ return (self.controller.probs_history[-1][1]>0.5).float()
+
+ def undo(self):
+ self.controller.undo_click()
+ if len(self.controller.probs_history) == 0:
+ return None
+ else:
+ return (self.controller.probs_history[-1][1]>0.5).float()
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/gui.py b/Make-A-Protagonist/experts/XMem/inference/interact/gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..039a382bda5b5a892723df894c4dffab356e99c4
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/gui.py
@@ -0,0 +1,933 @@
+"""
+Based on https://github.com/hkchengrex/MiVOS/tree/MiVOS-STCN
+(which is based on https://github.com/seoungwugoh/ivs-demo)
+
+This version is much simplified.
+In this repo, we don't have
+- local control
+- fusion module
+- undo
+- timers
+
+but with XMem as the backbone and is more memory (for both CPU and GPU) friendly
+"""
+
+import functools
+
+import os
+import cv2
+# fix conflicts between qt5 and cv2
+os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH")
+
+import numpy as np
+import torch
+
+from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox,
+ QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog,
+ QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton)
+
+from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon
+from PyQt5.QtCore import Qt, QTimer
+
+from model.network import XMem
+
+from inference.inference_core import InferenceCore
+from .s2m_controller import S2MController
+from .fbrs_controller import FBRSController
+
+from .interactive_utils import *
+from .interaction import *
+from .resource_manager import ResourceManager
+from .gui_utils import *
+
+
+class App(QWidget):
+ def __init__(self, net: XMem,
+ resource_manager: ResourceManager,
+ s2m_ctrl:S2MController,
+ fbrs_ctrl:FBRSController, config):
+ super().__init__()
+
+ self.initialized = False
+ self.num_objects = config['num_objects']
+ self.s2m_controller = s2m_ctrl
+ self.fbrs_controller = fbrs_ctrl
+ self.config = config
+ self.processor = InferenceCore(net, config)
+ self.processor.set_all_labels(list(range(1, self.num_objects+1)))
+ self.res_man = resource_manager
+
+ self.num_frames = len(self.res_man)
+ self.height, self.width = self.res_man.h, self.res_man.w
+
+ # set window
+ self.setWindowTitle('XMem Demo')
+ self.setGeometry(100, 100, self.width, self.height+100)
+ self.setWindowIcon(QIcon('docs/icon.png'))
+
+ # some buttons
+ self.play_button = QPushButton('Play Video')
+ self.play_button.clicked.connect(self.on_play_video)
+ self.commit_button = QPushButton('Commit')
+ self.commit_button.clicked.connect(self.on_commit)
+
+ self.forward_run_button = QPushButton('Forward Propagate')
+ self.forward_run_button.clicked.connect(self.on_forward_propagation)
+ self.forward_run_button.setMinimumWidth(200)
+
+ self.backward_run_button = QPushButton('Backward Propagate')
+ self.backward_run_button.clicked.connect(self.on_backward_propagation)
+ self.backward_run_button.setMinimumWidth(200)
+
+ self.reset_button = QPushButton('Reset Frame')
+ self.reset_button.clicked.connect(self.on_reset_mask)
+
+ # LCD
+ self.lcd = QTextEdit()
+ self.lcd.setReadOnly(True)
+ self.lcd.setMaximumHeight(28)
+ self.lcd.setMaximumWidth(120)
+ self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1))
+
+ # timeline slider
+ self.tl_slider = QSlider(Qt.Horizontal)
+ self.tl_slider.valueChanged.connect(self.tl_slide)
+ self.tl_slider.setMinimum(0)
+ self.tl_slider.setMaximum(self.num_frames-1)
+ self.tl_slider.setValue(0)
+ self.tl_slider.setTickPosition(QSlider.TicksBelow)
+ self.tl_slider.setTickInterval(1)
+
+ # brush size slider
+ self.brush_label = QLabel()
+ self.brush_label.setAlignment(Qt.AlignCenter)
+ self.brush_label.setMinimumWidth(100)
+
+ self.brush_slider = QSlider(Qt.Horizontal)
+ self.brush_slider.valueChanged.connect(self.brush_slide)
+ self.brush_slider.setMinimum(1)
+ self.brush_slider.setMaximum(100)
+ self.brush_slider.setValue(3)
+ self.brush_slider.setTickPosition(QSlider.TicksBelow)
+ self.brush_slider.setTickInterval(2)
+ self.brush_slider.setMinimumWidth(300)
+
+ # combobox
+ self.combo = QComboBox(self)
+ self.combo.addItem("davis")
+ self.combo.addItem("fade")
+ self.combo.addItem("light")
+ self.combo.addItem("popup")
+ self.combo.addItem("layered")
+ self.combo.currentTextChanged.connect(self.set_viz_mode)
+
+ self.save_visualization_checkbox = QCheckBox(self)
+ self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle)
+ self.save_visualization_checkbox.setChecked(False)
+ self.save_visualization = False
+
+ # Radio buttons for type of interactions
+ self.curr_interaction = 'Click'
+ self.interaction_group = QButtonGroup()
+ self.radio_fbrs = QRadioButton('Click')
+ self.radio_s2m = QRadioButton('Scribble')
+ self.radio_free = QRadioButton('Free')
+ self.interaction_group.addButton(self.radio_fbrs)
+ self.interaction_group.addButton(self.radio_s2m)
+ self.interaction_group.addButton(self.radio_free)
+ self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
+ self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
+ self.radio_free.toggled.connect(self.interaction_radio_clicked)
+ self.radio_fbrs.toggle()
+
+ # Main canvas -> QLabel
+ self.main_canvas = QLabel()
+ self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
+ QSizePolicy.Expanding)
+ self.main_canvas.setAlignment(Qt.AlignCenter)
+ self.main_canvas.setMinimumSize(100, 100)
+
+ self.main_canvas.mousePressEvent = self.on_mouse_press
+ self.main_canvas.mouseMoveEvent = self.on_mouse_motion
+ self.main_canvas.setMouseTracking(True) # Required for all-time tracking
+ self.main_canvas.mouseReleaseEvent = self.on_mouse_release
+
+ # Minimap -> Also a QLbal
+ self.minimap = QLabel()
+ self.minimap.setSizePolicy(QSizePolicy.Expanding,
+ QSizePolicy.Expanding)
+ self.minimap.setAlignment(Qt.AlignTop)
+ self.minimap.setMinimumSize(100, 100)
+
+ # Zoom-in buttons
+ self.zoom_p_button = QPushButton('Zoom +')
+ self.zoom_p_button.clicked.connect(self.on_zoom_plus)
+ self.zoom_m_button = QPushButton('Zoom -')
+ self.zoom_m_button.clicked.connect(self.on_zoom_minus)
+
+ # Parameters setting
+ self.clear_mem_button = QPushButton('Clear memory')
+ self.clear_mem_button.clicked.connect(self.on_clear_memory)
+
+ self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size')
+ self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size')
+ self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)')
+ self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)')
+
+ self.update_memory_size()
+ self.update_gpu_usage()
+
+ self.work_mem_min, self.work_mem_min_layout = create_parameter_box(1, 100, 'Min. working memory frames',
+ callback=self.on_work_min_change)
+ self.work_mem_max, self.work_mem_max_layout = create_parameter_box(2, 100, 'Max. working memory frames',
+ callback=self.on_work_max_change)
+ self.long_mem_max, self.long_mem_max_layout = create_parameter_box(1000, 100000,
+ 'Max. long-term memory size', step=1000, callback=self.update_config)
+ self.num_prototypes_box, self.num_prototypes_box_layout = create_parameter_box(32, 1280,
+ 'Number of prototypes', step=32, callback=self.update_config)
+ self.mem_every_box, self.mem_every_box_layout = create_parameter_box(1, 100, 'Memory frame every (r)',
+ callback=self.update_config)
+
+ self.work_mem_min.setValue(self.processor.memory.min_mt_frames)
+ self.work_mem_max.setValue(self.processor.memory.max_mt_frames)
+ self.long_mem_max.setValue(self.processor.memory.max_long_elements)
+ self.num_prototypes_box.setValue(self.processor.memory.num_prototypes)
+ self.mem_every_box.setValue(self.processor.mem_every)
+
+ # import mask/layer
+ self.import_mask_button = QPushButton('Import mask')
+ self.import_mask_button.clicked.connect(self.on_import_mask)
+ self.import_layer_button = QPushButton('Import layer')
+ self.import_layer_button.clicked.connect(self.on_import_layer)
+
+ # Console on the GUI
+ self.console = QPlainTextEdit()
+ self.console.setReadOnly(True)
+ self.console.setMinimumHeight(100)
+ self.console.setMaximumHeight(100)
+
+ # navigator
+ navi = QHBoxLayout()
+ navi.addWidget(self.lcd)
+ navi.addWidget(self.play_button)
+
+ interact_subbox = QVBoxLayout()
+ interact_topbox = QHBoxLayout()
+ interact_botbox = QHBoxLayout()
+ interact_topbox.setAlignment(Qt.AlignCenter)
+ interact_topbox.addWidget(self.radio_s2m)
+ interact_topbox.addWidget(self.radio_fbrs)
+ interact_topbox.addWidget(self.radio_free)
+ interact_topbox.addWidget(self.brush_label)
+ interact_botbox.addWidget(self.brush_slider)
+ interact_subbox.addLayout(interact_topbox)
+ interact_subbox.addLayout(interact_botbox)
+ navi.addLayout(interact_subbox)
+
+ navi.addStretch(1)
+ navi.addWidget(self.reset_button)
+
+ navi.addStretch(1)
+ navi.addWidget(QLabel('Overlay Mode'))
+ navi.addWidget(self.combo)
+ navi.addWidget(QLabel('Save overlay during propagation'))
+ navi.addWidget(self.save_visualization_checkbox)
+ navi.addStretch(1)
+ navi.addWidget(self.commit_button)
+ navi.addWidget(self.forward_run_button)
+ navi.addWidget(self.backward_run_button)
+
+ # Drawing area, main canvas and minimap
+ draw_area = QHBoxLayout()
+ draw_area.addWidget(self.main_canvas, 4)
+
+ # Minimap area
+ minimap_area = QVBoxLayout()
+ minimap_area.setAlignment(Qt.AlignTop)
+ mini_label = QLabel('Minimap')
+ mini_label.setAlignment(Qt.AlignTop)
+ minimap_area.addWidget(mini_label)
+
+ # Minimap zooming
+ minimap_ctrl = QHBoxLayout()
+ minimap_ctrl.setAlignment(Qt.AlignTop)
+ minimap_ctrl.addWidget(self.zoom_p_button)
+ minimap_ctrl.addWidget(self.zoom_m_button)
+ minimap_area.addLayout(minimap_ctrl)
+ minimap_area.addWidget(self.minimap)
+
+ # Parameters
+ minimap_area.addLayout(self.work_mem_gauge_layout)
+ minimap_area.addLayout(self.long_mem_gauge_layout)
+ minimap_area.addLayout(self.gpu_mem_gauge_layout)
+ minimap_area.addLayout(self.torch_mem_gauge_layout)
+ minimap_area.addWidget(self.clear_mem_button)
+ minimap_area.addLayout(self.work_mem_min_layout)
+ minimap_area.addLayout(self.work_mem_max_layout)
+ minimap_area.addLayout(self.long_mem_max_layout)
+ minimap_area.addLayout(self.num_prototypes_box_layout)
+ minimap_area.addLayout(self.mem_every_box_layout)
+
+ # import mask/layer
+ import_area = QHBoxLayout()
+ import_area.setAlignment(Qt.AlignTop)
+ import_area.addWidget(self.import_mask_button)
+ import_area.addWidget(self.import_layer_button)
+ minimap_area.addLayout(import_area)
+
+ # console
+ minimap_area.addWidget(self.console)
+
+ draw_area.addLayout(minimap_area, 1)
+
+ layout = QVBoxLayout()
+ layout.addLayout(draw_area)
+ layout.addWidget(self.tl_slider)
+ layout.addLayout(navi)
+ self.setLayout(layout)
+
+ # timer to play video
+ self.timer = QTimer()
+ self.timer.setSingleShot(False)
+
+ # timer to update GPU usage
+ self.gpu_timer = QTimer()
+ self.gpu_timer.setSingleShot(False)
+ self.gpu_timer.timeout.connect(self.on_gpu_timer)
+ self.gpu_timer.setInterval(2000)
+ self.gpu_timer.start()
+
+ # current frame info
+ self.curr_frame_dirty = False
+ self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
+ self.current_image_torch = None
+ self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8)
+ self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda()
+
+ # initialize visualization
+ self.viz_mode = 'davis'
+ self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
+ self.vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
+ self.brush_vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
+ self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32)
+ self.cursur = 0
+ self.on_showing = None
+
+ # Zoom parameters
+ self.zoom_pixels = 150
+
+ # initialize action
+ self.interaction = None
+ self.pressed = False
+ self.right_click = False
+ self.current_object = 1
+ self.last_ex = self.last_ey = 0
+
+ self.propagating = False
+
+ # Objects shortcuts
+ for i in range(1, self.num_objects+1):
+ QShortcut(QKeySequence(str(i)), self).activated.connect(functools.partial(self.hit_number_key, i))
+
+ # <- and -> shortcuts
+ QShortcut(QKeySequence(Qt.Key_Left), self).activated.connect(self.on_prev_frame)
+ QShortcut(QKeySequence(Qt.Key_Right), self).activated.connect(self.on_next_frame)
+
+ self.interacted_prob = None
+ self.overlay_layer = None
+ self.overlay_layer_torch = None
+
+ # the object id used for popup/layered overlay
+ self.vis_target_objects = [1]
+ # try to load the default overlay
+ self._try_load_layer('./docs/ECCV-logo.png')
+
+ self.load_current_image_mask()
+ self.show_current_frame()
+ self.show()
+
+ self.console_push_text('Initialized.')
+ self.initialized = True
+
+ def resizeEvent(self, event):
+ self.show_current_frame()
+
+ def console_push_text(self, text):
+ self.console.moveCursor(QTextCursor.End)
+ self.console.insertPlainText(text+'\n')
+
+ def interaction_radio_clicked(self, event):
+ self.last_interaction = self.curr_interaction
+ if self.radio_s2m.isChecked():
+ self.curr_interaction = 'Scribble'
+ self.brush_size = 3
+ self.brush_slider.setDisabled(True)
+ elif self.radio_fbrs.isChecked():
+ self.curr_interaction = 'Click'
+ self.brush_size = 3
+ self.brush_slider.setDisabled(True)
+ elif self.radio_free.isChecked():
+ self.brush_slider.setDisabled(False)
+ self.brush_slide()
+ self.curr_interaction = 'Free'
+ if self.curr_interaction == 'Scribble':
+ self.commit_button.setEnabled(True)
+ else:
+ self.commit_button.setEnabled(False)
+
+ def load_current_image_mask(self, no_mask=False):
+ self.current_image = self.res_man.get_image(self.cursur)
+ self.current_image_torch = None
+
+ if not no_mask:
+ loaded_mask = self.res_man.get_mask(self.cursur)
+ if loaded_mask is None:
+ self.current_mask.fill(0)
+ else:
+ self.current_mask = loaded_mask.copy()
+ self.current_prob = None
+
+ def load_current_torch_image_mask(self, no_mask=False):
+ if self.current_image_torch is None:
+ self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image)
+
+ if self.current_prob is None and not no_mask:
+ self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
+
+ def compose_current_im(self):
+ self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
+ self.overlay_layer, self.vis_target_objects)
+
+ def update_interact_vis(self):
+ # Update the interactions without re-computing the overlay
+ height, width, channel = self.viz.shape
+ bytesPerLine = 3 * width
+
+ vis_map = self.vis_map
+ vis_alpha = self.vis_alpha
+ brush_vis_map = self.brush_vis_map
+ brush_vis_alpha = self.brush_vis_alpha
+
+ self.viz_with_stroke = self.viz*(1-vis_alpha) + vis_map*vis_alpha
+ self.viz_with_stroke = self.viz_with_stroke*(1-brush_vis_alpha) + brush_vis_map*brush_vis_alpha
+ self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8)
+
+ qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine, QImage.Format_RGB888)
+ self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
+ Qt.KeepAspectRatio, Qt.FastTransformation)))
+
+ self.main_canvas_size = self.main_canvas.size()
+ self.image_size = qImg.size()
+
+ def update_minimap(self):
+ ex, ey = self.last_ex, self.last_ey
+ r = self.zoom_pixels//2
+ ex = int(round(max(r, min(self.width-r, ex))))
+ ey = int(round(max(r, min(self.height-r, ey))))
+
+ patch = self.viz_with_stroke[ey-r:ey+r, ex-r:ex+r, :].astype(np.uint8)
+
+ height, width, channel = patch.shape
+ bytesPerLine = 3 * width
+ qImg = QImage(patch.data, width, height, bytesPerLine, QImage.Format_RGB888)
+ self.minimap.setPixmap(QPixmap(qImg.scaled(self.minimap.size(),
+ Qt.KeepAspectRatio, Qt.FastTransformation)))
+
+ def update_current_image_fast(self):
+ # fast path, uses gpu. Changes the image in-place to avoid copying
+ self.viz = get_visualization_torch(self.viz_mode, self.current_image_torch_no_norm,
+ self.current_prob, self.overlay_layer_torch, self.vis_target_objects)
+ if self.save_visualization:
+ self.res_man.save_visualization(self.cursur, self.viz)
+
+ height, width, channel = self.viz.shape
+ bytesPerLine = 3 * width
+
+ qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888)
+ self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(),
+ Qt.KeepAspectRatio, Qt.FastTransformation)))
+
+ def show_current_frame(self, fast=False):
+ # Re-compute overlay and show the image
+ if fast:
+ self.update_current_image_fast()
+ else:
+ self.compose_current_im()
+ self.update_interact_vis()
+ self.update_minimap()
+
+ self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1))
+ self.tl_slider.setValue(self.cursur)
+
+ def pixel_pos_to_image_pos(self, x, y):
+ # Un-scale and un-pad the label coordinates into image coordinates
+ oh, ow = self.image_size.height(), self.image_size.width()
+ nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width()
+
+ h_ratio = nh/oh
+ w_ratio = nw/ow
+ dominate_ratio = min(h_ratio, w_ratio)
+
+ # Solve scale
+ x /= dominate_ratio
+ y /= dominate_ratio
+
+ # Solve padding
+ fh, fw = nh/dominate_ratio, nw/dominate_ratio
+ x -= (fw-ow)/2
+ y -= (fh-oh)/2
+
+ return x, y
+
+ def is_pos_out_of_bound(self, x, y):
+ x, y = self.pixel_pos_to_image_pos(x, y)
+
+ out_of_bound = (
+ (x < 0) or
+ (y < 0) or
+ (x > self.width-1) or
+ (y > self.height-1)
+ )
+
+ return out_of_bound
+
+ def get_scaled_pos(self, x, y):
+ x, y = self.pixel_pos_to_image_pos(x, y)
+
+ x = max(0, min(self.width-1, x))
+ y = max(0, min(self.height-1, y))
+
+ return x, y
+
+ def clear_visualization(self):
+ self.vis_map.fill(0)
+ self.vis_alpha.fill(0)
+
+ def reset_this_interaction(self):
+ self.complete_interaction()
+ self.clear_visualization()
+ self.interaction = None
+ if self.fbrs_controller is not None:
+ self.fbrs_controller.unanchor()
+
+ def set_viz_mode(self):
+ self.viz_mode = self.combo.currentText()
+ self.show_current_frame()
+
+ def save_current_mask(self):
+ # save mask to hard disk
+ self.res_man.save_mask(self.cursur, self.current_mask)
+
+ def tl_slide(self):
+ # if we are propagating, the on_run function will take care of everything
+ # don't do duplicate work here
+ if not self.propagating:
+ if self.curr_frame_dirty:
+ self.save_current_mask()
+ self.curr_frame_dirty = False
+
+ self.reset_this_interaction()
+ self.cursur = self.tl_slider.value()
+ self.load_current_image_mask()
+ self.show_current_frame()
+
+ def brush_slide(self):
+ self.brush_size = self.brush_slider.value()
+ self.brush_label.setText('Brush size: %d' % self.brush_size)
+ try:
+ if type(self.interaction) == FreeInteraction:
+ self.interaction.set_size(self.brush_size)
+ except AttributeError:
+ # Initialization, forget about it
+ pass
+
+ def on_forward_propagation(self):
+ if self.propagating:
+ # acts as a pause button
+ self.propagating = False
+ else:
+ self.propagate_fn = self.on_next_frame
+ self.backward_run_button.setEnabled(False)
+ self.forward_run_button.setText('Pause Propagation')
+ self.on_propagation()
+
+ def on_backward_propagation(self):
+ if self.propagating:
+ # acts as a pause button
+ self.propagating = False
+ else:
+ self.propagate_fn = self.on_prev_frame
+ self.forward_run_button.setEnabled(False)
+ self.backward_run_button.setText('Pause Propagation')
+ self.on_propagation()
+
+ def on_pause(self):
+ self.propagating = False
+ self.forward_run_button.setEnabled(True)
+ self.backward_run_button.setEnabled(True)
+ self.clear_mem_button.setEnabled(True)
+ self.forward_run_button.setText('Forward Propagate')
+ self.backward_run_button.setText('Backward Propagate')
+ self.console_push_text('Propagation stopped.')
+
+ def on_propagation(self):
+ # start to propagate
+ self.load_current_torch_image_mask()
+ self.show_current_frame(fast=True)
+
+ self.console_push_text('Propagation started.')
+ self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:])
+ self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
+ # clear
+ self.interacted_prob = None
+ self.reset_this_interaction()
+
+ self.propagating = True
+ self.clear_mem_button.setEnabled(False)
+ # propagate till the end
+ while self.propagating:
+ self.propagate_fn()
+
+ self.load_current_image_mask(no_mask=True)
+ self.load_current_torch_image_mask(no_mask=True)
+
+ self.current_prob = self.processor.step(self.current_image_torch)
+ self.current_mask = torch_prob_to_numpy_mask(self.current_prob)
+
+ self.save_current_mask()
+ self.show_current_frame(fast=True)
+
+ self.update_memory_size()
+ QApplication.processEvents()
+
+ if self.cursur == 0 or self.cursur == self.num_frames-1:
+ break
+
+ self.propagating = False
+ self.curr_frame_dirty = False
+ self.on_pause()
+ self.tl_slide()
+ QApplication.processEvents()
+
+ def pause_propagation(self):
+ self.propagating = False
+
+ def on_commit(self):
+ self.complete_interaction()
+ self.update_interacted_mask()
+
+ def on_prev_frame(self):
+ # self.tl_slide will trigger on setValue
+ self.cursur = max(0, self.cursur-1)
+ self.tl_slider.setValue(self.cursur)
+
+ def on_next_frame(self):
+ # self.tl_slide will trigger on setValue
+ self.cursur = min(self.cursur+1, self.num_frames-1)
+ self.tl_slider.setValue(self.cursur)
+
+ def on_play_video_timer(self):
+ self.cursur += 1
+ if self.cursur > self.num_frames-1:
+ self.cursur = 0
+ self.tl_slider.setValue(self.cursur)
+
+ def on_play_video(self):
+ if self.timer.isActive():
+ self.timer.stop()
+ self.play_button.setText('Play Video')
+ else:
+ self.timer.start(1000 / 30)
+ self.play_button.setText('Stop Video')
+
+ def on_reset_mask(self):
+ self.current_mask.fill(0)
+ if self.current_prob is not None:
+ self.current_prob.fill_(0)
+ self.curr_frame_dirty = True
+ self.save_current_mask()
+ self.reset_this_interaction()
+ self.show_current_frame()
+
+ def on_zoom_plus(self):
+ self.zoom_pixels -= 25
+ self.zoom_pixels = max(50, self.zoom_pixels)
+ self.update_minimap()
+
+ def on_zoom_minus(self):
+ self.zoom_pixels += 25
+ self.zoom_pixels = min(self.zoom_pixels, 300)
+ self.update_minimap()
+
+ def set_navi_enable(self, boolean):
+ self.zoom_p_button.setEnabled(boolean)
+ self.zoom_m_button.setEnabled(boolean)
+ self.run_button.setEnabled(boolean)
+ self.tl_slider.setEnabled(boolean)
+ self.play_button.setEnabled(boolean)
+ self.lcd.setEnabled(boolean)
+
+ def hit_number_key(self, number):
+ if number == self.current_object:
+ return
+ self.current_object = number
+ if self.fbrs_controller is not None:
+ self.fbrs_controller.unanchor()
+ self.console_push_text(f'Current object changed to {number}.')
+ self.clear_brush()
+ self.vis_brush(self.last_ex, self.last_ey)
+ self.update_interact_vis()
+ self.show_current_frame()
+
+ def clear_brush(self):
+ self.brush_vis_map.fill(0)
+ self.brush_vis_alpha.fill(0)
+
+ def vis_brush(self, ex, ey):
+ self.brush_vis_map = cv2.circle(self.brush_vis_map,
+ (int(round(ex)), int(round(ey))), self.brush_size//2+1, color_map[self.current_object], thickness=-1)
+ self.brush_vis_alpha = cv2.circle(self.brush_vis_alpha,
+ (int(round(ex)), int(round(ey))), self.brush_size//2+1, 0.5, thickness=-1)
+
+ def on_mouse_press(self, event):
+ if self.is_pos_out_of_bound(event.x(), event.y()):
+ return
+
+ # mid-click
+ if (event.button() == Qt.MidButton):
+ ex, ey = self.get_scaled_pos(event.x(), event.y())
+ target_object = self.current_mask[int(ey),int(ex)]
+ if target_object in self.vis_target_objects:
+ self.vis_target_objects.remove(target_object)
+ else:
+ self.vis_target_objects.append(target_object)
+ self.console_push_text(f'Target objects for visualization changed to {self.vis_target_objects}')
+ self.show_current_frame()
+ return
+
+ self.right_click = (event.button() == Qt.RightButton)
+ self.pressed = True
+
+ h, w = self.height, self.width
+
+ self.load_current_torch_image_mask()
+ image = self.current_image_torch
+
+ last_interaction = self.interaction
+ new_interaction = None
+ if self.curr_interaction == 'Scribble':
+ if last_interaction is None or type(last_interaction) != ScribbleInteraction:
+ self.complete_interaction()
+ new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(),
+ (h, w), self.s2m_controller, self.num_objects)
+ elif self.curr_interaction == 'Free':
+ if last_interaction is None or type(last_interaction) != FreeInteraction:
+ self.complete_interaction()
+ new_interaction = FreeInteraction(image, self.current_mask, (h, w),
+ self.num_objects)
+ new_interaction.set_size(self.brush_size)
+ elif self.curr_interaction == 'Click':
+ if (last_interaction is None or type(last_interaction) != ClickInteraction
+ or last_interaction.tar_obj != self.current_object):
+ self.complete_interaction()
+ self.fbrs_controller.unanchor()
+ new_interaction = ClickInteraction(image, self.current_prob, (h, w),
+ self.fbrs_controller, self.current_object)
+
+ if new_interaction is not None:
+ self.interaction = new_interaction
+
+ # Just motion it as the first step
+ self.on_mouse_motion(event)
+
+ def on_mouse_motion(self, event):
+ ex, ey = self.get_scaled_pos(event.x(), event.y())
+ self.last_ex, self.last_ey = ex, ey
+ self.clear_brush()
+ # Visualize
+ self.vis_brush(ex, ey)
+ if self.pressed:
+ if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
+ obj = 0 if self.right_click else self.current_object
+ self.vis_map, self.vis_alpha = self.interaction.push_point(
+ ex, ey, obj, (self.vis_map, self.vis_alpha)
+ )
+ self.update_interact_vis()
+ self.update_minimap()
+
+ def update_interacted_mask(self):
+ self.current_prob = self.interacted_prob
+ self.current_mask = torch_prob_to_numpy_mask(self.interacted_prob)
+ self.show_current_frame()
+ self.save_current_mask()
+ self.curr_frame_dirty = False
+
+ def complete_interaction(self):
+ if self.interaction is not None:
+ self.clear_visualization()
+ self.interaction = None
+
+ def on_mouse_release(self, event):
+ if not self.pressed:
+ # this can happen when the initial press is out-of-bound
+ return
+
+ ex, ey = self.get_scaled_pos(event.x(), event.y())
+
+ self.console_push_text('%s interaction at frame %d.' % (self.curr_interaction, self.cursur))
+ interaction = self.interaction
+
+ if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
+ self.on_mouse_motion(event)
+ interaction.end_path()
+ if self.curr_interaction == 'Free':
+ self.clear_visualization()
+ elif self.curr_interaction == 'Click':
+ ex, ey = self.get_scaled_pos(event.x(), event.y())
+ self.vis_map, self.vis_alpha = interaction.push_point(ex, ey,
+ self.right_click, (self.vis_map, self.vis_alpha))
+
+ self.interacted_prob = interaction.predict()
+ self.update_interacted_mask()
+ self.update_gpu_usage()
+
+ self.pressed = self.right_click = False
+
+ def wheelEvent(self, event):
+ ex, ey = self.get_scaled_pos(event.x(), event.y())
+ if self.curr_interaction == 'Free':
+ self.brush_slider.setValue(self.brush_slider.value() + event.angleDelta().y()//30)
+ self.clear_brush()
+ self.vis_brush(ex, ey)
+ self.update_interact_vis()
+ self.update_minimap()
+
+ def update_gpu_usage(self):
+ info = torch.cuda.mem_get_info()
+ global_free, global_total = info
+ global_free /= (2**30)
+ global_total /= (2**30)
+ global_used = global_total - global_free
+
+ self.gpu_mem_gauge.setFormat(f'{global_used:.01f} GB / {global_total:.01f} GB')
+ self.gpu_mem_gauge.setValue(round(global_used/global_total*100))
+
+ used_by_torch = torch.cuda.max_memory_allocated() / (2**20)
+ self.torch_mem_gauge.setFormat(f'{used_by_torch:.0f} MB / {global_total:.01f} GB')
+ self.torch_mem_gauge.setValue(round(used_by_torch/global_total*100/1024))
+
+ def on_gpu_timer(self):
+ self.update_gpu_usage()
+
+ def update_memory_size(self):
+ try:
+ max_work_elements = self.processor.memory.max_work_elements
+ max_long_elements = self.processor.memory.max_long_elements
+
+ curr_work_elements = self.processor.memory.work_mem.size
+ curr_long_elements = self.processor.memory.long_mem.size
+
+ self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}')
+ self.work_mem_gauge.setValue(round(curr_work_elements/max_work_elements*100))
+
+ self.long_mem_gauge.setFormat(f'{curr_long_elements} / {max_long_elements}')
+ self.long_mem_gauge.setValue(round(curr_long_elements/max_long_elements*100))
+
+ except AttributeError:
+ self.work_mem_gauge.setFormat('Unknown')
+ self.long_mem_gauge.setFormat('Unknown')
+ self.work_mem_gauge.setValue(0)
+ self.long_mem_gauge.setValue(0)
+
+ def on_work_min_change(self):
+ if self.initialized:
+ self.work_mem_min.setValue(min(self.work_mem_min.value(), self.work_mem_max.value()-1))
+ self.update_config()
+
+ def on_work_max_change(self):
+ if self.initialized:
+ self.work_mem_max.setValue(max(self.work_mem_max.value(), self.work_mem_min.value()+1))
+ self.update_config()
+
+ def update_config(self):
+ if self.initialized:
+ self.config['min_mid_term_frames'] = self.work_mem_min.value()
+ self.config['max_mid_term_frames'] = self.work_mem_max.value()
+ self.config['max_long_term_elements'] = self.long_mem_max.value()
+ self.config['num_prototypes'] = self.num_prototypes_box.value()
+ self.config['mem_every'] = self.mem_every_box.value()
+
+ self.processor.update_config(self.config)
+
+ def on_clear_memory(self):
+ self.processor.clear_memory()
+ torch.cuda.empty_cache()
+ self.update_gpu_usage()
+ self.update_memory_size()
+
+ def _open_file(self, prompt):
+ options = QFileDialog.Options()
+ file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options)
+ return file_name
+
+ def on_import_mask(self):
+ file_name = self._open_file('Mask')
+ if len(file_name) == 0:
+ return
+
+ mask = self.res_man.read_external_image(file_name, size=(self.height, self.width))
+
+ shape_condition = (
+ (len(mask.shape) == 2) and
+ (mask.shape[-1] == self.width) and
+ (mask.shape[-2] == self.height)
+ )
+
+ object_condition = (
+ mask.max() <= self.num_objects
+ )
+
+ if not shape_condition:
+ self.console_push_text(f'Expected ({self.height}, {self.width}). Got {mask.shape} instead.')
+ elif not object_condition:
+ self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.')
+ else:
+ self.console_push_text(f'Mask file {file_name} loaded.')
+ self.current_image_torch = self.current_prob = None
+ self.current_mask = mask
+ self.show_current_frame()
+ self.save_current_mask()
+
+ def on_import_layer(self):
+ file_name = self._open_file('Layer')
+ if len(file_name) == 0:
+ return
+
+ self._try_load_layer(file_name)
+
+ def _try_load_layer(self, file_name):
+ try:
+ layer = self.res_man.read_external_image(file_name, size=(self.height, self.width))
+
+ if layer.shape[-1] == 3:
+ layer = np.concatenate([layer, np.ones_like(layer[:,:,0:1])*255], axis=-1)
+
+ condition = (
+ (len(layer.shape) == 3) and
+ (layer.shape[-1] == 4) and
+ (layer.shape[-2] == self.width) and
+ (layer.shape[-3] == self.height)
+ )
+
+ if not condition:
+ self.console_push_text(f'Expected ({self.height}, {self.width}, 4). Got {layer.shape}.')
+ else:
+ self.console_push_text(f'Layer file {file_name} loaded.')
+ self.overlay_layer = layer
+ self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255
+ self.show_current_frame()
+ except FileNotFoundError:
+ self.console_push_text(f'{file_name} not found.')
+
+ def on_save_visualization_toggle(self):
+ self.save_visualization = self.save_visualization_checkbox.isChecked()
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/gui_utils.py b/Make-A-Protagonist/experts/XMem/inference/interact/gui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..daf852b30a84893c836d7c3350b727aeed5d0a6b
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/gui_utils.py
@@ -0,0 +1,40 @@
+from PyQt5.QtCore import Qt
+from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)
+
+
+def create_parameter_box(min_val, max_val, text, step=1, callback=None):
+ layout = QHBoxLayout()
+
+ dial = QSpinBox()
+ dial.setMaximumHeight(28)
+ dial.setMaximumWidth(150)
+ dial.setMinimum(min_val)
+ dial.setMaximum(max_val)
+ dial.setAlignment(Qt.AlignRight)
+ dial.setSingleStep(step)
+ dial.valueChanged.connect(callback)
+
+ label = QLabel(text)
+ label.setAlignment(Qt.AlignRight)
+
+ layout.addWidget(label)
+ layout.addWidget(dial)
+
+ return dial, layout
+
+
+def create_gauge(text):
+ layout = QHBoxLayout()
+
+ gauge = QProgressBar()
+ gauge.setMaximumHeight(28)
+ gauge.setMaximumWidth(200)
+ gauge.setAlignment(Qt.AlignCenter)
+
+ label = QLabel(text)
+ label.setAlignment(Qt.AlignRight)
+
+ layout.addWidget(label)
+ layout.addWidget(gauge)
+
+ return gauge, layout
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/interaction.py b/Make-A-Protagonist/experts/XMem/inference/interact/interaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f83f9d58a00cac079a7ba5c239196378603b64
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/interaction.py
@@ -0,0 +1,252 @@
+"""
+Contains all the types of interaction related to the GUI
+Not related to automatic evaluation in the DAVIS dataset
+
+You can inherit the Interaction class to create new interaction types
+undo is (sometimes partially) supported
+"""
+
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import cv2
+import time
+from .interactive_utils import color_map, index_numpy_to_one_hot_torch
+
+
+def aggregate_sbg(prob, keep_bg=False, hard=False):
+ device = prob.device
+ k, h, w = prob.shape
+ ex_prob = torch.zeros((k+1, h, w), device=device)
+ ex_prob[0] = 0.5
+ ex_prob[1:] = prob
+ ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7)
+ logits = torch.log((ex_prob /(1-ex_prob)))
+
+ if hard:
+ # Very low temperature o((⊙﹏⊙))o 🥶
+ logits *= 1000
+
+ if keep_bg:
+ return F.softmax(logits, dim=0)
+ else:
+ return F.softmax(logits, dim=0)[1:]
+
+def aggregate_wbg(prob, keep_bg=False, hard=False):
+ k, h, w = prob.shape
+ new_prob = torch.cat([
+ torch.prod(1-prob, dim=0, keepdim=True),
+ prob
+ ], 0).clamp(1e-7, 1-1e-7)
+ logits = torch.log((new_prob /(1-new_prob)))
+
+ if hard:
+ # Very low temperature o((⊙﹏⊙))o 🥶
+ logits *= 1000
+
+ if keep_bg:
+ return F.softmax(logits, dim=0)
+ else:
+ return F.softmax(logits, dim=0)[1:]
+
+class Interaction:
+ def __init__(self, image, prev_mask, true_size, controller):
+ self.image = image
+ self.prev_mask = prev_mask
+ self.controller = controller
+ self.start_time = time.time()
+
+ self.h, self.w = true_size
+
+ self.out_prob = None
+ self.out_mask = None
+
+ def predict(self):
+ pass
+
+
+class FreeInteraction(Interaction):
+ def __init__(self, image, prev_mask, true_size, num_objects):
+ """
+ prev_mask should be index format numpy array
+ """
+ super().__init__(image, prev_mask, true_size, None)
+
+ self.K = num_objects
+
+ self.drawn_map = self.prev_mask.copy()
+ self.curr_path = [[] for _ in range(self.K + 1)]
+
+ self.size = None
+
+ def set_size(self, size):
+ self.size = size
+
+ """
+ k - object id
+ vis - a tuple (visualization map, pass through alpha). None if not needed.
+ """
+ def push_point(self, x, y, k, vis=None):
+ if vis is not None:
+ vis_map, vis_alpha = vis
+ selected = self.curr_path[k]
+ selected.append((x, y))
+ if len(selected) >= 2:
+ cv2.line(self.drawn_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ k, thickness=self.size)
+
+ # Plot visualization
+ if vis is not None:
+ # Visualization for drawing
+ if k == 0:
+ vis_map = cv2.line(vis_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ color_map[k], thickness=self.size)
+ else:
+ vis_map = cv2.line(vis_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ color_map[k], thickness=self.size)
+ # Visualization on/off boolean filter
+ vis_alpha = cv2.line(vis_alpha,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ 0.75, thickness=self.size)
+
+ if vis is not None:
+ return vis_map, vis_alpha
+
+ def end_path(self):
+ # Complete the drawing
+ self.curr_path = [[] for _ in range(self.K + 1)]
+
+ def predict(self):
+ self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
+ # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
+ # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
+ # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
+ return self.out_prob
+
+class ScribbleInteraction(Interaction):
+ def __init__(self, image, prev_mask, true_size, controller, num_objects):
+ """
+ prev_mask should be in an indexed form
+ """
+ super().__init__(image, prev_mask, true_size, controller)
+
+ self.K = num_objects
+
+ self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8)
+ self.drawn_map.fill(255)
+ # background + k
+ self.curr_path = [[] for _ in range(self.K + 1)]
+ self.size = 3
+
+ """
+ k - object id
+ vis - a tuple (visualization map, pass through alpha). None if not needed.
+ """
+ def push_point(self, x, y, k, vis=None):
+ if vis is not None:
+ vis_map, vis_alpha = vis
+ selected = self.curr_path[k]
+ selected.append((x, y))
+ if len(selected) >= 2:
+ self.drawn_map = cv2.line(self.drawn_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ k, thickness=self.size)
+
+ # Plot visualization
+ if vis is not None:
+ # Visualization for drawing
+ if k == 0:
+ vis_map = cv2.line(vis_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ color_map[k], thickness=self.size)
+ else:
+ vis_map = cv2.line(vis_map,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ color_map[k], thickness=self.size)
+ # Visualization on/off boolean filter
+ vis_alpha = cv2.line(vis_alpha,
+ (int(round(selected[-2][0])), int(round(selected[-2][1]))),
+ (int(round(selected[-1][0])), int(round(selected[-1][1]))),
+ 0.75, thickness=self.size)
+
+ # Optional vis return
+ if vis is not None:
+ return vis_map, vis_alpha
+
+ def end_path(self):
+ # Complete the drawing
+ self.curr_path = [[] for _ in range(self.K + 1)]
+
+ def predict(self):
+ self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map)
+ self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
+ return self.out_prob
+
+
+class ClickInteraction(Interaction):
+ def __init__(self, image, prev_mask, true_size, controller, tar_obj):
+ """
+ prev_mask in a prob. form
+ """
+ super().__init__(image, prev_mask, true_size, controller)
+ self.tar_obj = tar_obj
+
+ # negative/positive for each object
+ self.pos_clicks = []
+ self.neg_clicks = []
+
+ self.out_prob = self.prev_mask.clone()
+
+ """
+ neg - Negative interaction or not
+ vis - a tuple (visualization map, pass through alpha). None if not needed.
+ """
+ def push_point(self, x, y, neg, vis=None):
+ # Clicks
+ if neg:
+ self.neg_clicks.append((x, y))
+ else:
+ self.pos_clicks.append((x, y))
+
+ # Do the prediction
+ self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg)
+
+ # Plot visualization
+ if vis is not None:
+ vis_map, vis_alpha = vis
+ # Visualization for clicks
+ if neg:
+ vis_map = cv2.circle(vis_map,
+ (int(round(x)), int(round(y))),
+ 2, color_map[0], thickness=-1)
+ else:
+ vis_map = cv2.circle(vis_map,
+ (int(round(x)), int(round(y))),
+ 2, color_map[self.tar_obj], thickness=-1)
+
+ vis_alpha = cv2.circle(vis_alpha,
+ (int(round(x)), int(round(y))),
+ 2, 1, thickness=-1)
+
+ # Optional vis return
+ return vis_map, vis_alpha
+
+ def predict(self):
+ self.out_prob = self.prev_mask.clone()
+ # a small hack to allow the interacting object to overwrite existing masks
+ # without remembering all the object probabilities
+ self.out_prob = torch.clamp(self.out_prob, max=0.9)
+ self.out_prob[self.tar_obj] = self.obj_mask
+ self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
+ return self.out_prob
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/interactive_utils.py b/Make-A-Protagonist/experts/XMem/inference/interact/interactive_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9961f63aab4f59323454f34d76241a743190198f
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/interactive_utils.py
@@ -0,0 +1,175 @@
+# Modifed from https://github.com/seoungwugoh/ivs-demo
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from util.palette import davis_palette
+from dataset.range_transform import im_normalization
+
+def image_to_torch(frame: np.ndarray, device='cuda'):
+ # frame: H*W*3 numpy array
+ frame = frame.transpose(2, 0, 1)
+ frame = torch.from_numpy(frame).float().to(device)/255
+ frame_norm = im_normalization(frame)
+ return frame_norm, frame
+
+def torch_prob_to_numpy_mask(prob):
+ mask = torch.argmax(prob, dim=0)
+ mask = mask.cpu().numpy().astype(np.uint8)
+ return mask
+
+def index_numpy_to_one_hot_torch(mask, num_classes):
+ mask = torch.from_numpy(mask).long()
+ return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()
+
+"""
+Some constants fro visualization
+"""
+color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
+# scales for better visualization
+color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8)
+color_map = color_map_np.tolist()
+if torch.cuda.is_available():
+ color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
+
+grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32)
+if torch.cuda.is_available():
+ grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
+
+def get_visualization(mode, image, mask, layer, target_object):
+ if mode == 'fade':
+ return overlay_davis(image, mask, fade=True)
+ elif mode == 'davis':
+ return overlay_davis(image, mask)
+ elif mode == 'light':
+ return overlay_davis(image, mask, 0.9)
+ elif mode == 'popup':
+ return overlay_popup(image, mask, target_object)
+ elif mode == 'layered':
+ if layer is None:
+ print('Layer file not given. Defaulting to DAVIS.')
+ return overlay_davis(image, mask)
+ else:
+ return overlay_layer(image, mask, layer, target_object)
+ else:
+ raise NotImplementedError
+
+def get_visualization_torch(mode, image, prob, layer, target_object):
+ if mode == 'fade':
+ return overlay_davis_torch(image, prob, fade=True)
+ elif mode == 'davis':
+ return overlay_davis_torch(image, prob)
+ elif mode == 'light':
+ return overlay_davis_torch(image, prob, 0.9)
+ elif mode == 'popup':
+ return overlay_popup_torch(image, prob, target_object)
+ elif mode == 'layered':
+ if layer is None:
+ print('Layer file not given. Defaulting to DAVIS.')
+ return overlay_davis_torch(image, prob)
+ else:
+ return overlay_layer_torch(image, prob, layer, target_object)
+ else:
+ raise NotImplementedError
+
+def overlay_davis(image, mask, alpha=0.5, fade=False):
+ """ Overlay segmentation on top of RGB image. from davis official"""
+ im_overlay = image.copy()
+
+ colored_mask = color_map_np[mask]
+ foreground = image*alpha + (1-alpha)*colored_mask
+ binary_mask = (mask > 0)
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+ if fade:
+ im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
+ return im_overlay.astype(image.dtype)
+
+def overlay_popup(image, mask, target_object):
+ # Keep foreground colored. Convert background to grayscale.
+ im_overlay = image.copy()
+
+ binary_mask = ~(np.isin(mask, target_object))
+ colored_region = (im_overlay[binary_mask]*grayscale_weights).sum(-1, keepdims=-1)
+ im_overlay[binary_mask] = colored_region
+ return im_overlay.astype(image.dtype)
+
+def overlay_layer(image, mask, layer, target_object):
+ # insert a layer between foreground and background
+ # The CPU version is less accurate because we are using the hard mask
+ # The GPU version has softer edges as it uses soft probabilities
+ obj_mask = (np.isin(mask, target_object)).astype(np.float32)
+ layer_alpha = layer[:, :, 3].astype(np.float32) / 255
+ layer_rgb = layer[:, :, :3]
+ background_alpha = np.maximum(obj_mask, layer_alpha)[:,:,np.newaxis]
+ obj_mask = obj_mask[:,:,np.newaxis]
+ im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 255)
+ return im_overlay.astype(image.dtype)
+
+def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
+ """ Overlay segmentation on top of RGB image. from davis official"""
+ # Changes the image in-place to avoid copying
+ image = image.permute(1, 2, 0)
+ im_overlay = image
+ mask = torch.argmax(mask, dim=0)
+
+ colored_mask = color_map_torch[mask]
+ foreground = image*alpha + (1-alpha)*colored_mask
+ binary_mask = (mask > 0)
+ # Compose image
+ im_overlay[binary_mask] = foreground[binary_mask]
+ if fade:
+ im_overlay[~binary_mask] = im_overlay[~binary_mask] * 0.6
+
+ im_overlay = (im_overlay*255).cpu().numpy()
+ im_overlay = im_overlay.astype(np.uint8)
+
+ return im_overlay
+
+def overlay_popup_torch(image, mask, target_object):
+ # Keep foreground colored. Convert background to grayscale.
+ image = image.permute(1, 2, 0)
+
+ if len(target_object) == 0:
+ obj_mask = torch.zeros_like(mask[0]).unsqueeze(2)
+ else:
+ # I should not need to convert this to numpy.
+ # uUsing list works most of the time but consistently fails
+ # if I include first object -> exclude it -> include it again.
+ # I check everywhere and it makes absolutely no sense.
+ # I am blaming this on PyTorch and calling it a day
+ obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0).unsqueeze(2)
+ gray_image = (image*grayscale_weights_torch).sum(-1, keepdim=True)
+ im_overlay = obj_mask*image + (1-obj_mask)*gray_image
+
+ im_overlay = (im_overlay*255).cpu().numpy()
+ im_overlay = im_overlay.astype(np.uint8)
+
+ return im_overlay
+
+def overlay_layer_torch(image, mask, layer, target_object):
+ # insert a layer between foreground and background
+ # The CPU version is less accurate because we are using the hard mask
+ # The GPU version has softer edges as it uses soft probabilities
+ image = image.permute(1, 2, 0)
+
+ if len(target_object) == 0:
+ obj_mask = torch.zeros_like(mask[0])
+ else:
+ # I should not need to convert this to numpy.
+ # uUsing list works most of the time but consistently fails
+ # if I include first object -> exclude it -> include it again.
+ # I check everywhere and it makes absolutely no sense.
+ # I am blaming this on PyTorch and calling it a day
+ obj_mask = mask[np.array(target_object,dtype=np.int32)].sum(0)
+ layer_alpha = layer[:, :, 3]
+ layer_rgb = layer[:, :, :3]
+ background_alpha = torch.maximum(obj_mask, layer_alpha).unsqueeze(2)
+ obj_mask = obj_mask.unsqueeze(2)
+ im_overlay = (image*(1-background_alpha) + layer_rgb*(1-obj_mask) + image*obj_mask).clip(0, 1)
+
+ im_overlay = (im_overlay*255).cpu().numpy()
+ im_overlay = im_overlay.astype(np.uint8)
+
+ return im_overlay
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/resource_manager.py b/Make-A-Protagonist/experts/XMem/inference/interact/resource_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f28af2e35a3ea29958e5eee4e19b26f1fa010b
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/resource_manager.py
@@ -0,0 +1,206 @@
+import os
+from os import path
+import shutil
+import collections
+
+import cv2
+from PIL import Image
+if not hasattr(Image, 'Resampling'): # Pillow<9.0
+ Image.Resampling = Image
+import numpy as np
+
+from util.palette import davis_palette
+import progressbar
+
+
+# https://bugs.python.org/issue28178
+# ah python ah why
+class LRU:
+ def __init__(self, func, maxsize=128):
+ self.cache = collections.OrderedDict()
+ self.func = func
+ self.maxsize = maxsize
+
+ def __call__(self, *args):
+ cache = self.cache
+ if args in cache:
+ cache.move_to_end(args)
+ return cache[args]
+ result = self.func(*args)
+ cache[args] = result
+ if len(cache) > self.maxsize:
+ cache.popitem(last=False)
+ return result
+
+ def invalidate(self, key):
+ self.cache.pop(key, None)
+
+
+class ResourceManager:
+ def __init__(self, config):
+ # determine inputs
+ images = config['images']
+ video = config['video']
+ self.workspace = config['workspace']
+ self.size = config['size']
+ self.palette = davis_palette
+
+ # create temporary workspace if not specified
+ if self.workspace is None:
+ if images is not None:
+ basename = path.basename(images)
+ elif video is not None:
+ basename = path.basename(video)[:-4]
+ else:
+ raise NotImplementedError(
+ 'Either images, video, or workspace has to be specified')
+
+ self.workspace = path.join('./workspace', basename)
+
+ print(f'Workspace is in: {self.workspace}')
+
+ # determine the location of input images
+ need_decoding = False
+ need_resizing = False
+ if path.exists(path.join(self.workspace, 'images')):
+ pass
+ elif images is not None:
+ need_resizing = True
+ elif video is not None:
+ # will decode video into frames later
+ need_decoding = True
+
+ # create workspace subdirectories
+ self.image_dir = path.join(self.workspace, 'images')
+ self.mask_dir = path.join(self.workspace, 'masks')
+ os.makedirs(self.image_dir, exist_ok=True)
+ os.makedirs(self.mask_dir, exist_ok=True)
+
+ # convert read functions to be buffered
+ self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size'])
+ self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size'])
+
+ # extract frames from video
+ if need_decoding:
+ self._extract_frames(video)
+
+ # copy/resize existing images to the workspace
+ if need_resizing:
+ self._copy_resize_frames(images)
+
+ # read all frame names
+ self.names = sorted(os.listdir(self.image_dir))
+ self.names = [f[:-4] for f in self.names] # remove extensions
+ self.length = len(self.names)
+
+ assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.'
+
+ print(f'{self.length} images found.')
+
+ self.height, self.width = self.get_image(0).shape[:2]
+ self.visualization_init = False
+
+ def _extract_frames(self, video):
+ cap = cv2.VideoCapture(video)
+ frame_index = 0
+ print(f'Extracting frames from {video} into {self.image_dir}...')
+ bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength)
+ while(cap.isOpened()):
+ _, frame = cap.read()
+ if frame is None:
+ break
+ if self.size > 0:
+ h, w = frame.shape[:2]
+ new_w = (w*self.size//min(w, h))
+ new_h = (h*self.size//min(w, h))
+ if new_w != w or new_h != h:
+ frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
+ cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame)
+ frame_index += 1
+ bar.update(frame_index)
+ bar.finish()
+ print('Done!')
+
+ def _copy_resize_frames(self, images):
+ image_list = os.listdir(images)
+ print(f'Copying/resizing frames into {self.image_dir}...')
+ for image_name in progressbar.progressbar(image_list):
+ if self.size < 0:
+ # just copy
+ shutil.copy2(path.join(images, image_name), self.image_dir)
+ else:
+ frame = cv2.imread(path.join(images, image_name))
+ h, w = frame.shape[:2]
+ new_w = (w*self.size//min(w, h))
+ new_h = (h*self.size//min(w, h))
+ if new_w != w or new_h != h:
+ frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
+ cv2.imwrite(path.join(self.image_dir, image_name), frame)
+ print('Done!')
+
+ def save_mask(self, ti, mask):
+ # mask should be uint8 H*W without channels
+ assert 0 <= ti < self.length
+ assert isinstance(mask, np.ndarray)
+
+ mask = Image.fromarray(mask)
+ mask.putpalette(self.palette)
+ mask.save(path.join(self.mask_dir, self.names[ti]+'.png'))
+ self.invalidate(ti)
+
+ def save_visualization(self, ti, image):
+ # image should be uint8 3*H*W
+ assert 0 <= ti < self.length
+ assert isinstance(image, np.ndarray)
+ if not self.visualization_init:
+ self.visualization_dir = path.join(self.workspace, 'visualization')
+ os.makedirs(self.visualization_dir, exist_ok=True)
+ self.visualization_init = True
+
+ image = Image.fromarray(image)
+ image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg'))
+
+ def _get_image_unbuffered(self, ti):
+ # returns H*W*3 uint8 array
+ assert 0 <= ti < self.length
+
+ image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg'))
+ image = np.array(image)
+ return image
+
+ def _get_mask_unbuffered(self, ti):
+ # returns H*W uint8 array
+ assert 0 <= ti < self.length
+
+ mask_path = path.join(self.mask_dir, self.names[ti]+'.png')
+ if path.exists(mask_path):
+ mask = Image.open(mask_path)
+ mask = np.array(mask)
+ return mask
+ else:
+ return None
+
+ def read_external_image(self, file_name, size=None):
+ image = Image.open(file_name)
+ is_mask = image.mode in ['L', 'P']
+ if size is not None:
+ # PIL uses (width, height)
+ image = image.resize((size[1], size[0]),
+ resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC)
+ image = np.array(image)
+ return image
+
+ def invalidate(self, ti):
+ # the image buffer is never invalidated
+ self.get_mask.invalidate((ti,))
+
+ def __len__(self):
+ return self.length
+
+ @property
+ def h(self):
+ return self.height
+
+ @property
+ def w(self):
+ return self.width
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m/__init__.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m/_deeplab.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/_deeplab.py
new file mode 100644
index 0000000000000000000000000000000000000000..e663007dde9a56add1aa540be76cf2f5d81de82f
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/_deeplab.py
@@ -0,0 +1,180 @@
+# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .utils import _SimpleSegmentationModel
+
+
+__all__ = ["DeepLabV3"]
+
+
+class DeepLabV3(_SimpleSegmentationModel):
+ """
+ Implements DeepLabV3 model from
+ `"Rethinking Atrous Convolution for Semantic Image Segmentation"
+ `_.
+
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ The backbone should return an OrderedDict[Tensor], with the key being
+ "out" for the last feature map used, and "aux" if an auxiliary classifier
+ is used.
+ classifier (nn.Module): module that takes the "out" element returned from
+ the backbone and returns a dense prediction.
+ aux_classifier (nn.Module, optional): auxiliary classifier used during training
+ """
+ pass
+
+class DeepLabHeadV3Plus(nn.Module):
+ def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
+ super(DeepLabHeadV3Plus, self).__init__()
+ self.project = nn.Sequential(
+ nn.Conv2d(low_level_channels, 48, 1, bias=False),
+ nn.BatchNorm2d(48),
+ nn.ReLU(inplace=True),
+ )
+
+ self.aspp = ASPP(in_channels, aspp_dilate)
+
+ self.classifier = nn.Sequential(
+ nn.Conv2d(304, 256, 3, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, num_classes, 1)
+ )
+ self._init_weight()
+
+ def forward(self, feature):
+ low_level_feature = self.project( feature['low_level'] )
+ output_feature = self.aspp(feature['out'])
+ output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
+ return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
+
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+class DeepLabHead(nn.Module):
+ def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
+ super(DeepLabHead, self).__init__()
+
+ self.classifier = nn.Sequential(
+ ASPP(in_channels, aspp_dilate),
+ nn.Conv2d(256, 256, 3, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, num_classes, 1)
+ )
+ self._init_weight()
+
+ def forward(self, feature):
+ return self.classifier( feature['out'] )
+
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+class AtrousSeparableConvolution(nn.Module):
+ """ Atrous Separable Convolution
+ """
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, bias=True):
+ super(AtrousSeparableConvolution, self).__init__()
+ self.body = nn.Sequential(
+ # Separable Conv
+ nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
+ # PointWise Conv
+ nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
+ )
+
+ self._init_weight()
+
+ def forward(self, x):
+ return self.body(x)
+
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+class ASPPConv(nn.Sequential):
+ def __init__(self, in_channels, out_channels, dilation):
+ modules = [
+ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)
+ ]
+ super(ASPPConv, self).__init__(*modules)
+
+class ASPPPooling(nn.Sequential):
+ def __init__(self, in_channels, out_channels):
+ super(ASPPPooling, self).__init__(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ size = x.shape[-2:]
+ x = super(ASPPPooling, self).forward(x)
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
+
+class ASPP(nn.Module):
+ def __init__(self, in_channels, atrous_rates):
+ super(ASPP, self).__init__()
+ out_channels = 256
+ modules = []
+ modules.append(nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)))
+
+ rate1, rate2, rate3 = tuple(atrous_rates)
+ modules.append(ASPPConv(in_channels, out_channels, rate1))
+ modules.append(ASPPConv(in_channels, out_channels, rate2))
+ modules.append(ASPPConv(in_channels, out_channels, rate3))
+ modules.append(ASPPPooling(in_channels, out_channels))
+
+ self.convs = nn.ModuleList(modules)
+
+ self.project = nn.Sequential(
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),)
+
+ def forward(self, x):
+ res = []
+ for conv in self.convs:
+ res.append(conv(x))
+ res = torch.cat(res, dim=1)
+ return self.project(res)
+
+
+
+def convert_to_separable_conv(module):
+ new_module = module
+ if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
+ new_module = AtrousSeparableConvolution(module.in_channels,
+ module.out_channels,
+ module.kernel_size,
+ module.stride,
+ module.padding,
+ module.dilation,
+ module.bias)
+ for name, child in module.named_children():
+ new_module.add_module(name, convert_to_separable_conv(child))
+ return new_module
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_network.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4f9a3fc4fcc9cc4210485fe24e4d740464d3f8a
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_network.py
@@ -0,0 +1,65 @@
+# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
+
+from .utils import IntermediateLayerGetter
+from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
+from . import s2m_resnet
+
+def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
+
+ if output_stride==8:
+ replace_stride_with_dilation=[False, True, True]
+ aspp_dilate = [12, 24, 36]
+ else:
+ replace_stride_with_dilation=[False, False, True]
+ aspp_dilate = [6, 12, 18]
+
+ backbone = s2m_resnet.__dict__[backbone_name](
+ pretrained=pretrained_backbone,
+ replace_stride_with_dilation=replace_stride_with_dilation)
+
+ inplanes = 2048
+ low_level_planes = 256
+
+ if name=='deeplabv3plus':
+ return_layers = {'layer4': 'out', 'layer1': 'low_level'}
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
+ elif name=='deeplabv3':
+ return_layers = {'layer4': 'out'}
+ classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
+
+ model = DeepLabV3(backbone, classifier)
+ return model
+
+def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
+
+ if backbone.startswith('resnet'):
+ model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+ else:
+ raise NotImplementedError
+ return model
+
+
+# Deeplab v3
+def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
+
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+
+
+# Deeplab v3+
+def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False):
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
+
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_resnet.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f1ce042c69daa9b18172a0aadf9bc1de6f300e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/s2m_resnet.py
@@ -0,0 +1,182 @@
+import torch
+import torch.nn as nn
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except ModuleNotFoundError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+
+__all__ = ['ResNet', 'resnet50']
+
+
+model_urls = {
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(6, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m/utils.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2adecf63baa9c2db4cc70b04c25200f6bc0a6a6
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/s2m/utils.py
@@ -0,0 +1,78 @@
+# Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from collections import OrderedDict
+
+class _SimpleSegmentationModel(nn.Module):
+ def __init__(self, backbone, classifier):
+ super(_SimpleSegmentationModel, self).__init__()
+ self.backbone = backbone
+ self.classifier = classifier
+
+ def forward(self, x):
+ input_shape = x.shape[-2:]
+ features = self.backbone(x)
+ x = self.classifier(features)
+ x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
+ return x
+
+
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+
+ Examples::
+
+ >>> m = torchvision.models.resnet18(pretrained=True)
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
+ >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
+ >>> {'layer1': 'feat1', 'layer3': 'feat2'})
+ >>> out = new_m(torch.rand(1, 3, 224, 224))
+ >>> print([(k, v.shape) for k, v in out.items()])
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
+ """
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+
+ orig_return_layers = return_layers
+ return_layers = {k: v for k, v in return_layers.items()}
+ layers = OrderedDict()
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+
+ super(IntermediateLayerGetter, self).__init__(layers)
+ self.return_layers = orig_return_layers
+
+ def forward(self, x):
+ out = OrderedDict()
+ for name, module in self.named_children():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/s2m_controller.py b/Make-A-Protagonist/experts/XMem/inference/interact/s2m_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..e222259eebdf3938290c476f85ba8c8d79fb626d
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/s2m_controller.py
@@ -0,0 +1,39 @@
+import torch
+import numpy as np
+from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M
+
+from util.tensor_util import pad_divide_by, unpad
+
+
+class S2MController:
+ """
+ A controller for Scribble-to-Mask (for user interaction, not for DAVIS)
+ Takes the image, previous mask, and scribbles to produce a new mask
+ ignore_class is usually 255
+ 0 is NOT the ignore class -- it is the label for the background
+ """
+ def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'):
+ self.s2m_net = s2m_net
+ self.num_objects = num_objects
+ self.ignore_class = ignore_class
+ self.device = device
+
+ def interact(self, image, prev_mask, scr_mask):
+ image = image.to(self.device, non_blocking=True)
+ prev_mask = prev_mask.unsqueeze(0)
+
+ h, w = image.shape[-2:]
+ unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device)
+
+ for ki in range(1, self.num_objects+1):
+ p_srb = (scr_mask==ki).astype(np.uint8)
+ n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8)
+
+ Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device)
+
+ inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1)
+ inputs, pads = pad_divide_by(inputs, 16)
+
+ unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads)
+
+ return unaggre_mask
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/interact/timer.py b/Make-A-Protagonist/experts/XMem/inference/interact/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d134aa180275528c0d485e6d237cd6832f62d77e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/interact/timer.py
@@ -0,0 +1,33 @@
+import time
+
+class Timer:
+ def __init__(self):
+ self._acc_time = 0
+ self._paused = True
+
+ def start(self):
+ if self._paused:
+ self.last_time = time.time()
+ self._paused = False
+ return self
+
+ def pause(self):
+ self.count()
+ self._paused = True
+ return self
+
+ def count(self):
+ if self._paused:
+ return self._acc_time
+ t = time.time()
+ self._acc_time += t - self.last_time
+ self.last_time = t
+ return self._acc_time
+
+ def format(self):
+ # count = int(self.count()*100)
+ # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100)
+ return '%03.2f' % self.count()
+
+ def __str__(self):
+ return self.format()
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/inference/kv_memory_store.py b/Make-A-Protagonist/experts/XMem/inference/kv_memory_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..33a332625f03b39f38f4b7162dcaddc8bafa262e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/kv_memory_store.py
@@ -0,0 +1,215 @@
+import torch
+from typing import List
+
+class KeyValueMemoryStore:
+ """
+ Works for key/value pairs type storage
+ e.g., working and long-term memory
+ """
+
+ """
+ An object group is created when new objects enter the video
+ Objects in the same group share the same temporal extent
+ i.e., objects initialized in the same frame are in the same group
+ For DAVIS/interactive, there is only one object group
+ For YouTubeVOS, there can be multiple object groups
+ """
+
+ def __init__(self, count_usage: bool):
+ self.count_usage = count_usage
+
+ # keys are stored in a single tensor and are shared between groups/objects
+ # values are stored as a list indexed by object groups
+ self.k = None
+ self.v = []
+ self.obj_groups = []
+ # for debugging only
+ self.all_objects = []
+
+ # shrinkage and selection are also single tensors
+ self.s = self.e = None
+
+ # usage
+ if self.count_usage:
+ self.use_count = self.life_count = None
+
+ def add(self, key, value, shrinkage, selection, objects: List[int]):
+ new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
+ new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7
+
+ # add the key
+ if self.k is None:
+ self.k = key
+ self.s = shrinkage
+ self.e = selection
+ if self.count_usage:
+ self.use_count = new_count
+ self.life_count = new_life
+ else:
+ self.k = torch.cat([self.k, key], -1)
+ if shrinkage is not None:
+ self.s = torch.cat([self.s, shrinkage], -1)
+ if selection is not None:
+ self.e = torch.cat([self.e, selection], -1)
+ if self.count_usage:
+ self.use_count = torch.cat([self.use_count, new_count], -1)
+ self.life_count = torch.cat([self.life_count, new_life], -1)
+
+ # add the value
+ if objects is not None:
+ # When objects is given, v is a tensor; used in working memory
+ assert isinstance(value, torch.Tensor)
+ # First consume objects that are already in the memory bank
+ # cannot use set here because we need to preserve order
+ # shift by one as background is not part of value
+ remaining_objects = [obj-1 for obj in objects]
+ for gi, group in enumerate(self.obj_groups):
+ for obj in group:
+ # should properly raise an error if there are overlaps in obj_groups
+ remaining_objects.remove(obj)
+ self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
+
+ # If there are remaining objects, add them as a new group
+ if len(remaining_objects) > 0:
+ new_group = list(remaining_objects)
+ self.v.append(value[new_group])
+ self.obj_groups.append(new_group)
+ self.all_objects.extend(new_group)
+
+ assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
+ else:
+ # When objects is not given, v is a list that already has the object groups sorted
+ # used in long-term memory
+ assert isinstance(value, list)
+ for gi, gv in enumerate(value):
+ if gv is None:
+ continue
+ if gi < self.num_groups:
+ self.v[gi] = torch.cat([self.v[gi], gv], -1)
+ else:
+ self.v.append(gv)
+
+ def update_usage(self, usage):
+ # increase all life count by 1
+ # increase use of indexed elements
+ if not self.count_usage:
+ return
+
+ self.use_count += usage.view_as(self.use_count)
+ self.life_count += 1
+
+ def sieve_by_range(self, start: int, end: int, min_size: int):
+ # keep only the elements *outside* of this range (with some boundary conditions)
+ # i.e., concat (a[:start], a[end:])
+ # min_size is only used for values, we do not sieve values under this size
+ # (because they are not consolidated)
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ self.k = self.k[:,:,:start]
+ if self.count_usage:
+ self.use_count = self.use_count[:,:,:start]
+ self.life_count = self.life_count[:,:,:start]
+ if self.s is not None:
+ self.s = self.s[:,:,:start]
+ if self.e is not None:
+ self.e = self.e[:,:,:start]
+
+ for gi in range(self.num_groups):
+ if self.v[gi].shape[-1] >= min_size:
+ self.v[gi] = self.v[gi][:,:,:start]
+ else:
+ self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
+ if self.count_usage:
+ self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
+ self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
+ if self.s is not None:
+ self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
+ if self.e is not None:
+ self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
+
+ for gi in range(self.num_groups):
+ if self.v[gi].shape[-1] >= min_size:
+ self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
+
+ def remove_obsolete_features(self, max_size: int):
+ # normalize with life duration
+ usage = self.get_usage().flatten()
+
+ values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
+ survived = (usage > values[-1])
+
+ self.k = self.k[:, :, survived]
+ self.s = self.s[:, :, survived] if self.s is not None else None
+ # Long-term memory does not store ek so this should not be needed
+ self.e = self.e[:, :, survived] if self.e is not None else None
+ if self.num_groups > 1:
+ raise NotImplementedError("""The current data structure does not support feature removal with
+ multiple object groups (e.g., some objects start to appear later in the video)
+ The indices for "survived" is based on keys but not all values are present for every key
+ Basically we need to remap the indices for keys to values
+ """)
+ for gi in range(self.num_groups):
+ self.v[gi] = self.v[gi][:, :, survived]
+
+ self.use_count = self.use_count[:, :, survived]
+ self.life_count = self.life_count[:, :, survived]
+
+ def get_usage(self):
+ # return normalized usage
+ if not self.count_usage:
+ raise RuntimeError('I did not count usage!')
+ else:
+ usage = self.use_count / self.life_count
+ return usage
+
+ def get_all_sliced(self, start: int, end: int):
+ # return k, sk, ek, usage in order, sliced by start and end
+
+ if end == 0:
+ # negative 0 would not work as the end index!
+ k = self.k[:,:,start:]
+ sk = self.s[:,:,start:] if self.s is not None else None
+ ek = self.e[:,:,start:] if self.e is not None else None
+ usage = self.get_usage()[:,:,start:]
+ else:
+ k = self.k[:,:,start:end]
+ sk = self.s[:,:,start:end] if self.s is not None else None
+ ek = self.e[:,:,start:end] if self.e is not None else None
+ usage = self.get_usage()[:,:,start:end]
+
+ return k, sk, ek, usage
+
+ def get_v_size(self, ni: int):
+ return self.v[ni].shape[2]
+
+ def engaged(self):
+ return self.k is not None
+
+ @property
+ def size(self):
+ if self.k is None:
+ return 0
+ else:
+ return self.k.shape[-1]
+
+ @property
+ def num_groups(self):
+ return len(self.v)
+
+ @property
+ def key(self):
+ return self.k
+
+ @property
+ def value(self):
+ return self.v
+
+ @property
+ def shrinkage(self):
+ return self.s
+
+ @property
+ def selection(self):
+ return self.e
+
diff --git a/Make-A-Protagonist/experts/XMem/inference/memory_manager.py b/Make-A-Protagonist/experts/XMem/inference/memory_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..abae24349cb61b4e7e07588375a409254302ab08
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/inference/memory_manager.py
@@ -0,0 +1,284 @@
+import torch
+import warnings
+
+from XMem.inference.kv_memory_store import KeyValueMemoryStore
+from XMem.model.memory_util import *
+
+
+class MemoryManager:
+ """
+ Manages all three memory stores and the transition between working/long-term memory
+ """
+ def __init__(self, config):
+ self.hidden_dim = config['hidden_dim']
+ self.top_k = config['top_k']
+
+ self.enable_long_term = config['enable_long_term']
+ self.enable_long_term_usage = config['enable_long_term_count_usage']
+ if self.enable_long_term:
+ self.max_mt_frames = config['max_mid_term_frames']
+ self.min_mt_frames = config['min_mid_term_frames']
+ self.num_prototypes = config['num_prototypes']
+ self.max_long_elements = config['max_long_term_elements']
+
+ # dimensions will be inferred from input later
+ self.CK = self.CV = None
+ self.H = self.W = None
+
+ # The hidden state will be stored in a single tensor for all objects
+ # B x num_objects x CH x H x W
+ self.hidden = None
+
+ self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
+ if self.enable_long_term:
+ self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
+
+ self.reset_config = True
+
+ def update_config(self, config):
+ self.reset_config = True
+ self.hidden_dim = config['hidden_dim']
+ self.top_k = config['top_k']
+
+ assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
+ assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this'
+
+ self.enable_long_term_usage = config['enable_long_term_count_usage']
+ if self.enable_long_term:
+ self.max_mt_frames = config['max_mid_term_frames']
+ self.min_mt_frames = config['min_mid_term_frames']
+ self.num_prototypes = config['num_prototypes']
+ self.max_long_elements = config['max_long_term_elements']
+
+ def _readout(self, affinity, v):
+ # this function is for a single object group
+ return v @ affinity
+
+ def match_memory(self, query_key, selection):
+ # query_key: B x C^k x H x W
+ # selection: B x C^k x H x W
+ num_groups = self.work_mem.num_groups
+ h, w = query_key.shape[-2:]
+
+ query_key = query_key.flatten(start_dim=2)
+ selection = selection.flatten(start_dim=2) if selection is not None else None
+
+ """
+ Memory readout using keys
+ """
+
+ if self.enable_long_term and self.long_mem.engaged():
+ # Use long-term memory
+ long_mem_size = self.long_mem.size
+ memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
+ shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1)
+
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
+ work_mem_similarity = similarity[:, long_mem_size:]
+ long_mem_similarity = similarity[:, :long_mem_size]
+
+ # get the usage with the first group
+ # the first group always have all the keys valid
+ affinity, usage = do_softmax(
+ torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1),
+ top_k=self.top_k, inplace=True, return_usage=True)
+ affinity = [affinity]
+
+ # compute affinity group by group as later groups only have a subset of keys
+ for gi in range(1, num_groups):
+ if gi < self.long_mem.num_groups:
+ # merge working and lt similarities before softmax
+ affinity_one_group = do_softmax(
+ torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):],
+ work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1),
+ top_k=self.top_k, inplace=True)
+ else:
+ # no long-term memory for this group
+ affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):],
+ top_k=self.top_k, inplace=(gi==num_groups-1))
+ affinity.append(affinity_one_group)
+
+ all_memory_value = []
+ for gi, gv in enumerate(self.work_mem.value):
+ # merge the working and lt values before readout
+ if gi < self.long_mem.num_groups:
+ all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1))
+ else:
+ all_memory_value.append(gv)
+
+ """
+ Record memory usage for working and long-term memory
+ """
+ # ignore the index return for long-term memory
+ work_usage = usage[:, long_mem_size:]
+ self.work_mem.update_usage(work_usage.flatten())
+
+ if self.enable_long_term_usage:
+ # ignore the index return for working memory
+ long_usage = usage[:, :long_mem_size]
+ self.long_mem.update_usage(long_usage.flatten())
+ else:
+ # No long-term memory
+ similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection)
+
+ if self.enable_long_term:
+ affinity, usage = do_softmax(similarity, inplace=(num_groups==1),
+ top_k=self.top_k, return_usage=True)
+
+ # Record memory usage for working memory
+ self.work_mem.update_usage(usage.flatten())
+ else:
+ affinity = do_softmax(similarity, inplace=(num_groups==1),
+ top_k=self.top_k, return_usage=False)
+
+ affinity = [affinity]
+
+ # compute affinity group by group as later groups only have a subset of keys
+ for gi in range(1, num_groups):
+ affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
+ top_k=self.top_k, inplace=(gi==num_groups-1))
+ affinity.append(affinity_one_group)
+
+ all_memory_value = self.work_mem.value
+
+ # Shared affinity within each group
+ all_readout_mem = torch.cat([
+ self._readout(affinity[gi], gv)
+ for gi, gv in enumerate(all_memory_value)
+ ], 0)
+
+ return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
+
+ def add_memory(self, key, shrinkage, value, objects, selection=None):
+ # key: 1*C*H*W
+ # value: 1*num_objects*C*H*W
+ # objects contain a list of object indices
+ if self.H is None or self.reset_config:
+ self.reset_config = False
+ self.H, self.W = key.shape[-2:]
+ self.HW = self.H*self.W
+ if self.enable_long_term:
+ # convert from num. frames to num. nodes
+ self.min_work_elements = self.min_mt_frames*self.HW
+ self.max_work_elements = self.max_mt_frames*self.HW
+
+ # key: 1*C*N
+ # value: num_objects*C*N
+ key = key.flatten(start_dim=2)
+ shrinkage = shrinkage.flatten(start_dim=2)
+ value = value[0].flatten(start_dim=2)
+
+ self.CK = key.shape[1]
+ self.CV = value.shape[1]
+
+ if selection is not None:
+ if not self.enable_long_term:
+ warnings.warn('the selection factor is only needed in long-term mode', UserWarning)
+ selection = selection.flatten(start_dim=2)
+
+ self.work_mem.add(key, value, shrinkage, selection, objects)
+
+ # long-term memory cleanup
+ if self.enable_long_term:
+ # Do memory compressed if needed
+ if self.work_mem.size >= self.max_work_elements:
+ # Remove obsolete features if needed
+ if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
+ self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
+
+ self.compress_features()
+
+
+ def create_hidden_state(self, n, sample_key):
+ # n is the TOTAL number of objects
+ h, w = sample_key.shape[-2:]
+ if self.hidden is None:
+ self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device)
+ elif self.hidden.shape[1] != n:
+ self.hidden = torch.cat([
+ self.hidden,
+ torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device)
+ ], 1)
+
+ assert(self.hidden.shape[1] == n)
+
+ def set_hidden(self, hidden):
+ self.hidden = hidden
+
+ def get_hidden(self):
+ return self.hidden
+
+ def compress_features(self):
+ HW = self.HW
+ candidate_value = []
+ total_work_mem_size = self.work_mem.size
+ for gv in self.work_mem.value:
+ # Some object groups might be added later in the video
+ # So not all keys have values associated with all objects
+ # We need to keep track of the key->value validity
+ mem_size_in_this_group = gv.shape[-1]
+ if mem_size_in_this_group == total_work_mem_size:
+ # full LT
+ candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
+ else:
+ # mem_size is smaller than total_work_mem_size, but at least HW
+ assert HW <= mem_size_in_this_group < total_work_mem_size
+ if mem_size_in_this_group > self.min_work_elements+HW:
+ # part of this object group still goes into LT
+ candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
+ else:
+ # this object group cannot go to the LT at all
+ candidate_value.append(None)
+
+ # perform memory consolidation
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
+ *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
+
+ # remove consolidated working memory
+ self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
+
+ # add to long-term memory
+ self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
+
+ def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
+ # keys: 1*C*N
+ # values: num_objects*C*N
+ N = candidate_key.shape[-1]
+
+ # find the indices with max usage
+ _, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True)
+ prototype_indices = max_usage_indices.flatten()
+
+ # Prototypes are invalid for out-of-bound groups
+ validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value]
+
+ prototype_key = candidate_key[:, :, prototype_indices]
+ prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None
+
+ """
+ Potentiation step
+ """
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection)
+
+ # convert similarity to affinity
+ # need to do it group by group since the softmax normalization would be different
+ affinity = [
+ do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None
+ for gi, gv in enumerate(candidate_value)
+ ]
+
+ # some values can be have all False validity. Weed them out.
+ affinity = [
+ aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
+ ]
+
+ # readout the values
+ prototype_value = [
+ self._readout(affinity[gi], gv) if affinity[gi] is not None else None
+ for gi, gv in enumerate(candidate_value)
+ ]
+
+ # readout the shrinkage term
+ prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
+
+ return prototype_key, prototype_value, prototype_shrinkage
diff --git a/Make-A-Protagonist/experts/XMem/model/__init__.py b/Make-A-Protagonist/experts/XMem/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/model/aggregate.py b/Make-A-Protagonist/experts/XMem/model/aggregate.py
new file mode 100644
index 0000000000000000000000000000000000000000..7622391fb3ac9aa8b515df88cf3ea5297b367538
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/aggregate.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+
+# Soft aggregation from STM
+def aggregate(prob, dim, return_logits=False):
+ new_prob = torch.cat([
+ torch.prod(1-prob, dim=dim, keepdim=True),
+ prob
+ ], dim).clamp(1e-7, 1-1e-7)
+ logits = torch.log((new_prob /(1-new_prob)))
+ prob = F.softmax(logits, dim=dim)
+
+ if return_logits:
+ return logits, prob
+ else:
+ return prob
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/model/cbam.py b/Make-A-Protagonist/experts/XMem/model/cbam.py
new file mode 100644
index 0000000000000000000000000000000000000000..6423358429e2843b1f36ceb2bc1a485ea72b8eb4
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/cbam.py
@@ -0,0 +1,77 @@
+# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class BasicConv(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ super(BasicConv, self).__init__()
+ self.out_channels = out_planes
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+class ChannelGate(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
+ super(ChannelGate, self).__init__()
+ self.gate_channels = gate_channels
+ self.mlp = nn.Sequential(
+ Flatten(),
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
+ nn.ReLU(),
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
+ )
+ self.pool_types = pool_types
+ def forward(self, x):
+ channel_att_sum = None
+ for pool_type in self.pool_types:
+ if pool_type=='avg':
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+ channel_att_raw = self.mlp( avg_pool )
+ elif pool_type=='max':
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+ channel_att_raw = self.mlp( max_pool )
+
+ if channel_att_sum is None:
+ channel_att_sum = channel_att_raw
+ else:
+ channel_att_sum = channel_att_sum + channel_att_raw
+
+ scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
+ return x * scale
+
+class ChannelPool(nn.Module):
+ def forward(self, x):
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
+
+class SpatialGate(nn.Module):
+ def __init__(self):
+ super(SpatialGate, self).__init__()
+ kernel_size = 7
+ self.compress = ChannelPool()
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
+ def forward(self, x):
+ x_compress = self.compress(x)
+ x_out = self.spatial(x_compress)
+ scale = torch.sigmoid(x_out) # broadcasting
+ return x * scale
+
+class CBAM(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
+ super(CBAM, self).__init__()
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
+ self.no_spatial=no_spatial
+ if not no_spatial:
+ self.SpatialGate = SpatialGate()
+ def forward(self, x):
+ x_out = self.ChannelGate(x)
+ if not self.no_spatial:
+ x_out = self.SpatialGate(x_out)
+ return x_out
diff --git a/Make-A-Protagonist/experts/XMem/model/group_modules.py b/Make-A-Protagonist/experts/XMem/model/group_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..749ef2386a992a468b7cf631293ebd22036b2777
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/group_modules.py
@@ -0,0 +1,82 @@
+"""
+Group-specific modules
+They handle features that also depends on the mask.
+Features are typically of shape
+ batch_size * num_objects * num_channels * H * W
+
+All of them are permutation equivariant w.r.t. to the num_objects dimension
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def interpolate_groups(g, ratio, mode, align_corners):
+ batch_size, num_objects = g.shape[:2]
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
+ scale_factor=ratio, mode=mode, align_corners=align_corners)
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+ return g
+
+def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False):
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+def downsample_groups(g, ratio=1/2, mode='area', align_corners=None):
+ return interpolate_groups(g, ratio, mode, align_corners)
+
+
+class GConv2D(nn.Conv2d):
+ def forward(self, g):
+ batch_size, num_objects = g.shape[:2]
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
+ return g.view(batch_size, num_objects, *g.shape[1:])
+
+
+class GroupResBlock(nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ if in_dim == out_dim:
+ self.downsample = None
+ else:
+ self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
+
+ self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
+ self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
+
+ def forward(self, g):
+ out_g = self.conv1(F.relu(g))
+ out_g = self.conv2(F.relu(out_g))
+
+ if self.downsample is not None:
+ g = self.downsample(g)
+
+ return out_g + g
+
+
+class MainToGroupDistributor(nn.Module):
+ def __init__(self, x_transform=None, method='cat', reverse_order=False):
+ super().__init__()
+
+ self.x_transform = x_transform
+ self.method = method
+ self.reverse_order = reverse_order
+
+ def forward(self, x, g):
+ num_objects = g.shape[1]
+
+ if self.x_transform is not None:
+ x = self.x_transform(x)
+
+ if self.method == 'cat':
+ if self.reverse_order:
+ g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2)
+ else:
+ g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2)
+ elif self.method == 'add':
+ g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g
+ else:
+ raise NotImplementedError
+
+ return g
diff --git a/Make-A-Protagonist/experts/XMem/model/losses.py b/Make-A-Protagonist/experts/XMem/model/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..60a2894b6f5b330aa4baa56db226e8a59cb8c1ae
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/losses.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from collections import defaultdict
+
+
+def dice_loss(input_mask, cls_gt):
+ num_objects = input_mask.shape[1]
+ losses = []
+ for i in range(num_objects):
+ mask = input_mask[:,i].flatten(start_dim=1)
+ # background not in mask, so we add one to cls_gt
+ gt = (cls_gt==(i+1)).float().flatten(start_dim=1)
+ numerator = 2 * (mask * gt).sum(-1)
+ denominator = mask.sum(-1) + gt.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ losses.append(loss)
+ return torch.cat(losses).mean()
+
+
+# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
+class BootstrappedCE(nn.Module):
+ def __init__(self, start_warm, end_warm, top_p=0.15):
+ super().__init__()
+
+ self.start_warm = start_warm
+ self.end_warm = end_warm
+ self.top_p = top_p
+
+ def forward(self, input, target, it):
+ if it < self.start_warm:
+ return F.cross_entropy(input, target), 1.0
+
+ raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
+ num_pixels = raw_loss.numel()
+
+ if it > self.end_warm:
+ this_p = self.top_p
+ else:
+ this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
+ loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
+ return loss.mean(), this_p
+
+
+class LossComputer:
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.bce = BootstrappedCE(config['start_warm'], config['end_warm'])
+
+ def compute(self, data, num_objects, it):
+ losses = defaultdict(int)
+
+ b, t = data['rgb'].shape[:2]
+
+ losses['total_loss'] = 0
+ for ti in range(1, t):
+ for bi in range(b):
+ loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it)
+ losses['p'] += p / b / (t-1)
+ losses[f'ce_loss_{ti}'] += loss / b
+
+ losses['total_loss'] += losses['ce_loss_%d'%ti]
+ losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0])
+ losses['total_loss'] += losses[f'dice_loss_{ti}']
+
+ return losses
diff --git a/Make-A-Protagonist/experts/XMem/model/memory_util.py b/Make-A-Protagonist/experts/XMem/model/memory_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..faf6197b8c4ea990317476e2e3aeb8952a78aedf
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/memory_util.py
@@ -0,0 +1,80 @@
+import math
+import numpy as np
+import torch
+from typing import Optional
+
+
+def get_similarity(mk, ms, qk, qe):
+ # used for training/inference and memory reading/memory potentiation
+ # mk: B x CK x [N] - Memory keys
+ # ms: B x 1 x [N] - Memory shrinkage
+ # qk: B x CK x [HW/P] - Query keys
+ # qe: B x CK x [HW/P] - Query selection
+ # Dimensions in [] are flattened
+ CK = mk.shape[1]
+ mk = mk.flatten(start_dim=2)
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
+ qk = qk.flatten(start_dim=2)
+ qe = qe.flatten(start_dim=2) if qe is not None else None
+
+ if qe is not None:
+ # See appendix for derivation
+ # or you can just trust me ヽ(ー_ー )ノ
+ mk = mk.transpose(1, 2)
+ a_sq = (mk.pow(2) @ qe)
+ two_ab = 2 * (mk @ (qk * qe))
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
+ similarity = (-a_sq+two_ab-b_sq)
+ else:
+ # similar to STCN if we don't have the selection term
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
+ similarity = (-a_sq+two_ab)
+
+ if ms is not None:
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
+ else:
+ similarity = similarity / math.sqrt(CK) # B*N*HW
+
+ return similarity
+
+def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
+ # normalize similarity with top-k softmax
+ # similarity: B x N x [HW/P]
+ # use inplace with care
+ if top_k is not None:
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
+
+ x_exp = values.exp_()
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
+ if inplace:
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
+ affinity = similarity
+ else:
+ affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
+ else:
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
+ x_exp = torch.exp(similarity - maxes)
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
+ affinity = x_exp / x_exp_sum
+ indices = None
+
+ if return_usage:
+ return affinity, affinity.sum(dim=2)
+
+ return affinity
+
+def get_affinity(mk, ms, qk, qe):
+ # shorthand used in training with no top-k
+ similarity = get_similarity(mk, ms, qk, qe)
+ affinity = do_softmax(similarity)
+ return affinity
+
+def readout(affinity, mv):
+ B, CV, T, H, W = mv.shape
+
+ mo = mv.view(B, CV, T*H*W)
+ mem = torch.bmm(mo, affinity)
+ mem = mem.view(B, CV, H, W)
+
+ return mem
diff --git a/Make-A-Protagonist/experts/XMem/model/modules.py b/Make-A-Protagonist/experts/XMem/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..090fb9cefc40911a13d8b4730d98da22a2af92aa
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/modules.py
@@ -0,0 +1,250 @@
+"""
+modules.py - This file stores the rather boring network blocks.
+
+x - usually means features that only depends on the image
+g - usually means features that also depends on the mask.
+ They might have an extra "group" or "num_objects" dimension, hence
+ batch_size * num_objects * num_channels * H * W
+
+The trailing number of a variable usually denote the stride
+
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from XMem.model.group_modules import *
+from XMem.model import resnet
+from XMem.model.cbam import CBAM
+
+
+class FeatureFusionBlock(nn.Module):
+ def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
+ super().__init__()
+
+ self.distributor = MainToGroupDistributor()
+ self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim)
+ self.attention = CBAM(g_mid_dim)
+ self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
+
+ def forward(self, x, g):
+ batch_size, num_objects = g.shape[:2]
+
+ g = self.distributor(x, g)
+ g = self.block1(g)
+ r = self.attention(g.flatten(start_dim=0, end_dim=1))
+ r = r.view(batch_size, num_objects, *r.shape[1:])
+
+ g = self.block2(g+r)
+
+ return g
+
+
+class HiddenUpdater(nn.Module):
+ # Used in the decoder, multi-scale feature + GRU
+ def __init__(self, g_dims, mid_dim, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+
+ self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
+ self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
+ self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
+
+ self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g, h):
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
+
+ g = torch.cat([g, h], 2)
+
+ # defined slightly differently than standard GRU,
+ # namely the new value is generated before the forget gate.
+ # might provide better gradient but frankly it was initially just an
+ # implementation error that I never bothered fixing
+ values = self.transform(g)
+ forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
+ update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
+ new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
+ new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
+
+ return new_h
+
+
+class HiddenReinforcer(nn.Module):
+ # Used in the value encoder, a single GRU
+ def __init__(self, g_dim, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
+
+ nn.init.xavier_normal_(self.transform.weight)
+
+ def forward(self, g, h):
+ g = torch.cat([g, h], 2)
+
+ # defined slightly differently than standard GRU,
+ # namely the new value is generated before the forget gate.
+ # might provide better gradient but frankly it was initially just an
+ # implementation error that I never bothered fixing
+ values = self.transform(g)
+ forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
+ update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
+ new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
+ new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
+
+ return new_h
+
+
+class ValueEncoder(nn.Module):
+ def __init__(self, value_dim, hidden_dim, single_object=False):
+ super().__init__()
+
+ self.single_object = single_object
+ network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu # 1/2, 64
+ self.maxpool = network.maxpool
+
+ self.layer1 = network.layer1 # 1/4, 64
+ self.layer2 = network.layer2 # 1/8, 128
+ self.layer3 = network.layer3 # 1/16, 256
+
+ self.distributor = MainToGroupDistributor()
+ self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
+ if hidden_dim > 0:
+ self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
+ else:
+ self.hidden_reinforce = None
+
+ def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
+ # image_feat_f16 is the feature from the key encoder
+ if not self.single_object:
+ g = torch.stack([masks, others], 2)
+ else:
+ g = masks.unsqueeze(2)
+ g = self.distributor(image, g)
+
+ batch_size, num_objects = g.shape[:2]
+ g = g.flatten(start_dim=0, end_dim=1)
+
+ g = self.conv1(g)
+ g = self.bn1(g) # 1/2, 64
+ g = self.maxpool(g) # 1/4, 64
+ g = self.relu(g)
+
+ g = self.layer1(g) # 1/4
+ g = self.layer2(g) # 1/8
+ g = self.layer3(g) # 1/16
+
+ g = g.view(batch_size, num_objects, *g.shape[1:])
+ g = self.fuser(image_feat_f16, g)
+
+ if is_deep_update and self.hidden_reinforce is not None:
+ h = self.hidden_reinforce(g, h)
+
+ return g, h
+
+
+class KeyEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ network = resnet.resnet50(pretrained=True)
+ self.conv1 = network.conv1
+ self.bn1 = network.bn1
+ self.relu = network.relu # 1/2, 64
+ self.maxpool = network.maxpool
+
+ self.res2 = network.layer1 # 1/4, 256
+ self.layer2 = network.layer2 # 1/8, 512
+ self.layer3 = network.layer3 # 1/16, 1024
+
+ def forward(self, f):
+ x = self.conv1(f)
+ x = self.bn1(x)
+ x = self.relu(x) # 1/2, 64
+ x = self.maxpool(x) # 1/4, 64
+ f4 = self.res2(x) # 1/4, 256
+ f8 = self.layer2(f4) # 1/8, 512
+ f16 = self.layer3(f8) # 1/16, 1024
+
+ return f16, f8, f4
+
+
+class UpsampleBlock(nn.Module):
+ def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
+ super().__init__()
+ self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
+ self.distributor = MainToGroupDistributor(method='add')
+ self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
+ self.scale_factor = scale_factor
+
+ def forward(self, skip_f, up_g):
+ skip_f = self.skip_conv(skip_f)
+ g = upsample_groups(up_g, ratio=self.scale_factor)
+ g = self.distributor(skip_f, g)
+ g = self.out_conv(g)
+ return g
+
+
+class KeyProjection(nn.Module):
+ def __init__(self, in_dim, keydim):
+ super().__init__()
+
+ self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
+ # shrinkage
+ self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
+ # selection
+ self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
+
+ nn.init.orthogonal_(self.key_proj.weight.data)
+ nn.init.zeros_(self.key_proj.bias.data)
+
+ def forward(self, x, need_s, need_e):
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
+
+ return self.key_proj(x), shrinkage, selection
+
+
+class Decoder(nn.Module):
+ def __init__(self, val_dim, hidden_dim):
+ super().__init__()
+
+ self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512)
+ if hidden_dim > 0:
+ self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim)
+ else:
+ self.hidden_update = None
+
+ self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
+ self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
+
+ self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
+
+ def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
+ batch_size, num_objects = memory_readout.shape[:2]
+
+ if self.hidden_update is not None:
+ g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
+ else:
+ g16 = self.fuser(f16, memory_readout)
+
+ g8 = self.up_16_8(f8, g16)
+ g4 = self.up_8_4(f4, g8)
+ logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
+
+ if h_out and self.hidden_update is not None:
+ g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2)
+ hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
+ else:
+ hidden_state = None
+
+ logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False)
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
+
+ return hidden_state, logits
diff --git a/Make-A-Protagonist/experts/XMem/model/network.py b/Make-A-Protagonist/experts/XMem/model/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e130e2426871a080804553f9ae65f50ea82a88
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/network.py
@@ -0,0 +1,198 @@
+"""
+This file defines XMem, the highest level nn.Module interface
+During training, it is used by trainer.py
+During evaluation, it is used by inference_core.py
+
+It further depends on modules.py which gives more detailed implementations of sub-modules
+"""
+
+import torch
+import torch.nn as nn
+
+from XMem.model.aggregate import aggregate
+from XMem.model.modules import *
+from XMem.model.memory_util import *
+
+
+class XMem(nn.Module):
+ def __init__(self, config, model_path=None, map_location=None):
+ """
+ model_path/map_location are used in evaluation only
+ map_location is for converting models saved in cuda to cpu
+ """
+ super().__init__()
+ model_weights = self.init_hyperparameters(config, model_path, map_location)
+
+ self.single_object = config.get('single_object', False)
+ print(f'Single object mode: {self.single_object}')
+
+ self.key_encoder = KeyEncoder()
+ self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object)
+
+ # Projection from f16 feature space to key/value space
+ self.key_proj = KeyProjection(1024, self.key_dim)
+
+ self.decoder = Decoder(self.value_dim, self.hidden_dim)
+
+ if model_weights is not None:
+ self.load_weights(model_weights, init_as_zero_if_needed=True)
+
+ def encode_key(self, frame, need_sk=True, need_ek=True):
+ # Determine input shape
+ if len(frame.shape) == 5:
+ # shape is b*t*c*h*w
+ need_reshape = True
+ b, t = frame.shape[:2]
+ # flatten so that we can feed them into a 2D CNN
+ frame = frame.flatten(start_dim=0, end_dim=1)
+ elif len(frame.shape) == 4:
+ # shape is b*c*h*w
+ need_reshape = False
+ else:
+ raise NotImplementedError
+
+ f16, f8, f4 = self.key_encoder(frame)
+ key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
+
+ if need_reshape:
+ # B*C*T*H*W
+ key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
+ if shrinkage is not None:
+ shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous()
+ if selection is not None:
+ selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous()
+
+ # B*T*C*H*W
+ f16 = f16.view(b, t, *f16.shape[-3:])
+ f8 = f8.view(b, t, *f8.shape[-3:])
+ f4 = f4.view(b, t, *f4.shape[-3:])
+
+ return key, shrinkage, selection, f16, f8, f4
+
+ def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
+ num_objects = masks.shape[1]
+ if num_objects != 1:
+ others = torch.cat([
+ torch.sum(
+ masks[:, [j for j in range(num_objects) if i!=j]]
+ , dim=1, keepdim=True)
+ for i in range(num_objects)], 1)
+ else:
+ others = torch.zeros_like(masks)
+
+ g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update)
+
+ return g16, h16
+
+ # Used in training only.
+ # This step is replaced by MemoryManager in test time
+ def read_memory(self, query_key, query_selection, memory_key,
+ memory_shrinkage, memory_value):
+ """
+ query_key : B * CK * H * W
+ query_selection : B * CK * H * W
+ memory_key : B * CK * T * H * W
+ memory_shrinkage: B * 1 * T * H * W
+ memory_value : B * num_objects * CV * T * H * W
+ """
+ batch_size, num_objects = memory_value.shape[:2]
+ memory_value = memory_value.flatten(start_dim=1, end_dim=2)
+
+ affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection)
+ memory = readout(affinity, memory_value)
+ memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:])
+
+ return memory
+
+ def segment(self, multi_scale_features, memory_readout,
+ hidden_state, selector=None, h_out=True, strip_bg=True):
+
+ hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out)
+ prob = torch.sigmoid(logits)
+ if selector is not None:
+ prob = prob * selector
+
+ logits, prob = aggregate(prob, dim=1, return_logits=True)
+ if strip_bg:
+ # Strip away the background
+ prob = prob[:, 1:]
+
+ return hidden_state, logits, prob
+
+ def forward(self, mode, *args, **kwargs):
+ if mode == 'encode_key':
+ return self.encode_key(*args, **kwargs)
+ elif mode == 'encode_value':
+ return self.encode_value(*args, **kwargs)
+ elif mode == 'read_memory':
+ return self.read_memory(*args, **kwargs)
+ elif mode == 'segment':
+ return self.segment(*args, **kwargs)
+ else:
+ raise NotImplementedError
+
+ def init_hyperparameters(self, config, model_path=None, map_location=None):
+ """
+ Init three hyperparameters: key_dim, value_dim, and hidden_dim
+ If model_path is provided, we load these from the model weights
+ The actual parameters are then updated to the config in-place
+
+ Otherwise we load it either from the config or default
+ """
+ if model_path is not None:
+ # load the model and key/value/hidden dimensions with some hacks
+ # config is updated with the loaded parameters
+ model_weights = torch.load(model_path, map_location=map_location)
+ self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
+ self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
+ self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
+ if self.disable_hidden:
+ self.hidden_dim = 0
+ else:
+ self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3
+ print(f'Hyperparameters read from the model weights: '
+ f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}')
+ else:
+ model_weights = None
+ # load dimensions from config or default
+ if 'key_dim' not in config:
+ self.key_dim = 64
+ print(f'key_dim not found in config. Set to default {self.key_dim}')
+ else:
+ self.key_dim = config['key_dim']
+
+ if 'value_dim' not in config:
+ self.value_dim = 512
+ print(f'value_dim not found in config. Set to default {self.value_dim}')
+ else:
+ self.value_dim = config['value_dim']
+
+ if 'hidden_dim' not in config:
+ self.hidden_dim = 64
+ print(f'hidden_dim not found in config. Set to default {self.hidden_dim}')
+ else:
+ self.hidden_dim = config['hidden_dim']
+
+ self.disable_hidden = (self.hidden_dim <= 0)
+
+ config['key_dim'] = self.key_dim
+ config['value_dim'] = self.value_dim
+ config['hidden_dim'] = self.hidden_dim
+
+ return model_weights
+
+ def load_weights(self, src_dict, init_as_zero_if_needed=False):
+ # Maps SO weight (without other_mask) to MO weight (with other_mask)
+ for k in list(src_dict.keys()):
+ if k == 'value_encoder.conv1.weight':
+ if src_dict[k].shape[1] == 4:
+ print('Converting weights from single object to multiple objects.')
+ pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
+ if not init_as_zero_if_needed:
+ print('Randomly initialized padding.')
+ nn.init.orthogonal_(pads)
+ else:
+ print('Zero-initialized padding.')
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
+
+ self.load_state_dict(src_dict)
diff --git a/Make-A-Protagonist/experts/XMem/model/resnet.py b/Make-A-Protagonist/experts/XMem/model/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..574626efcfd8c8c9b21e3b5a6ed0999ea698ef6d
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/resnet.py
@@ -0,0 +1,165 @@
+"""
+resnet.py - A modified ResNet structure
+We append extra channels to the first conv by some network surgery
+"""
+
+from collections import OrderedDict
+import math
+
+import torch
+import torch.nn as nn
+from torch.utils import model_zoo
+
+
+def load_weights_add_extra_dim(target, source_state, extra_dim=1):
+ new_dict = OrderedDict()
+
+ for k1, v1 in target.state_dict().items():
+ if not 'num_batches_tracked' in k1:
+ if k1 in source_state:
+ tar_v = source_state[k1]
+
+ if v1.shape != tar_v.shape:
+ # Init the new segmentation channel with zeros
+ # print(v1.shape, tar_v.shape)
+ c, _, w, h = v1.shape
+ pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
+ nn.init.orthogonal_(pads)
+ tar_v = torch.cat([tar_v, pads], 1)
+
+ new_dict[k1] = tar_v
+
+ target.load_state_dict(new_dict)
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
+ padding=dilation, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+def resnet18(pretrained=True, extra_dim=0):
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
+ return model
+
+def resnet50(pretrained=True, extra_dim=0):
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
+ if pretrained:
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
+ return model
+
diff --git a/Make-A-Protagonist/experts/XMem/model/trainer.py b/Make-A-Protagonist/experts/XMem/model/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..97db8650e9a36fd0e140e1ce8d8ccb6b26bac1b3
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/model/trainer.py
@@ -0,0 +1,234 @@
+"""
+trainer.py - warpper and utility functions for network training
+Compute loss, back-prop, update parameters, logging, etc.
+"""
+
+
+import os
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from model.network import XMem
+from model.losses import LossComputer
+from util.log_integrator import Integrator
+from util.image_saver import pool_pairs
+
+
+class XMemTrainer:
+ def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
+ self.config = config
+ self.num_frames = config['num_frames']
+ self.num_ref_frames = config['num_ref_frames']
+ self.deep_update_prob = config['deep_update_prob']
+ self.local_rank = local_rank
+
+ self.XMem = nn.parallel.DistributedDataParallel(
+ XMem(config).cuda(),
+ device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
+
+ # Set up logger when local_rank=0
+ self.logger = logger
+ self.save_path = save_path
+ if logger is not None:
+ self.last_time = time.time()
+ self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
+ self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
+ self.loss_computer = LossComputer(config)
+
+ self.train()
+ self.optimizer = optim.AdamW(filter(
+ lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
+ if config['amp']:
+ self.scaler = torch.cuda.amp.GradScaler()
+
+ # Logging info
+ self.log_text_interval = config['log_text_interval']
+ self.log_image_interval = config['log_image_interval']
+ self.save_network_interval = config['save_network_interval']
+ self.save_checkpoint_interval = config['save_checkpoint_interval']
+ if config['debug']:
+ self.log_text_interval = self.log_image_interval = 1
+
+ def do_pass(self, data, it=0):
+ # No need to store the gradient outside training
+ torch.set_grad_enabled(self._is_train)
+
+ for k, v in data.items():
+ if type(v) != list and type(v) != dict and type(v) != int:
+ data[k] = v.cuda(non_blocking=True)
+
+ out = {}
+ frames = data['rgb']
+ first_frame_gt = data['first_frame_gt'].float()
+ b = frames.shape[0]
+ num_filled_objects = [o.item() for o in data['info']['num_objects']]
+ num_objects = first_frame_gt.shape[2]
+ selector = data['selector'].unsqueeze(2).unsqueeze(2)
+
+ with torch.cuda.amp.autocast(enabled=self.config['amp']):
+ # image features never change, compute once
+ key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
+
+ filler_one = torch.zeros(1, dtype=torch.int64)
+ hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
+ v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
+ values = v16.unsqueeze(3) # add the time dimension
+
+ for ti in range(1, self.num_frames):
+ if ti <= self.num_ref_frames:
+ ref_values = values
+ ref_keys = key[:,:,:ti]
+ ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
+ else:
+ # pick num_ref_frames random frames
+ # this is not very efficient but I think we would
+ # need broadcasting in gather which we don't have
+ indices = [
+ torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
+ for _ in range(b)]
+ ref_values = torch.stack([
+ values[bi, :, :, indices[bi]] for bi in range(b)
+ ], 0)
+ ref_keys = torch.stack([
+ key[bi, :, indices[bi]] for bi in range(b)
+ ], 0)
+ ref_shrinkage = torch.stack([
+ shrinkage[bi, :, indices[bi]] for bi in range(b)
+ ], 0) if shrinkage is not None else None
+
+ # Segment frame ti
+ memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
+ ref_keys, ref_shrinkage, ref_values)
+ hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
+ hidden, selector, h_out=(ti < (self.num_frames-1)))
+
+ # No need to encode the last frame
+ if ti < (self.num_frames-1):
+ is_deep_update = np.random.rand() < self.deep_update_prob
+ v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
+ values = torch.cat([values, v16.unsqueeze(3)], 3)
+
+ out[f'masks_{ti}'] = masks
+ out[f'logits_{ti}'] = logits
+
+ if self._do_log or self._is_train:
+ losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
+
+ # Logging
+ if self._do_log:
+ self.integrator.add_dict(losses)
+ if self._is_train:
+ if it % self.log_image_interval == 0 and it != 0:
+ if self.logger is not None:
+ images = {**data, **out}
+ size = (384, 384)
+ self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
+
+ if self._is_train:
+ if (it) % self.log_text_interval == 0 and it != 0:
+ if self.logger is not None:
+ self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
+ self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.log_text_interval, it)
+ self.last_time = time.time()
+ self.train_integrator.finalize('train', it)
+ self.train_integrator.reset_except_hooks()
+
+ if it % self.save_network_interval == 0 and it != 0:
+ if self.logger is not None:
+ self.save_network(it)
+
+ if it % self.save_checkpoint_interval == 0 and it != 0:
+ if self.logger is not None:
+ self.save_checkpoint(it)
+
+ # Backward pass
+ self.optimizer.zero_grad(set_to_none=True)
+ if self.config['amp']:
+ self.scaler.scale(losses['total_loss']).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ losses['total_loss'].backward()
+ self.optimizer.step()
+
+ self.scheduler.step()
+
+ def save_network(self, it):
+ if self.save_path is None:
+ print('Saving has been disabled.')
+ return
+
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
+ model_path = f'{self.save_path}_{it}.pth'
+ torch.save(self.XMem.module.state_dict(), model_path)
+ print(f'Network saved to {model_path}.')
+
+ def save_checkpoint(self, it):
+ if self.save_path is None:
+ print('Saving has been disabled.')
+ return
+
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
+ checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
+ checkpoint = {
+ 'it': it,
+ 'network': self.XMem.module.state_dict(),
+ 'optimizer': self.optimizer.state_dict(),
+ 'scheduler': self.scheduler.state_dict()}
+ torch.save(checkpoint, checkpoint_path)
+ print(f'Checkpoint saved to {checkpoint_path}.')
+
+ def load_checkpoint(self, path):
+ # This method loads everything and should be used to resume training
+ map_location = 'cuda:%d' % self.local_rank
+ checkpoint = torch.load(path, map_location={'cuda:0': map_location})
+
+ it = checkpoint['it']
+ network = checkpoint['network']
+ optimizer = checkpoint['optimizer']
+ scheduler = checkpoint['scheduler']
+
+ map_location = 'cuda:%d' % self.local_rank
+ self.XMem.module.load_state_dict(network)
+ self.optimizer.load_state_dict(optimizer)
+ self.scheduler.load_state_dict(scheduler)
+
+ print('Network weights, optimizer states, and scheduler states loaded.')
+
+ return it
+
+ def load_network_in_memory(self, src_dict):
+ self.XMem.module.load_weights(src_dict)
+ print('Network weight loaded from memory.')
+
+ def load_network(self, path):
+ # This method loads only the network weight and should be used to load a pretrained model
+ map_location = 'cuda:%d' % self.local_rank
+ src_dict = torch.load(path, map_location={'cuda:0': map_location})
+
+ self.load_network_in_memory(src_dict)
+ print(f'Network weight loaded from {path}')
+
+ def train(self):
+ self._is_train = True
+ self._do_log = True
+ self.integrator = self.train_integrator
+ self.XMem.eval()
+ return self
+
+ def val(self):
+ self._is_train = False
+ self._do_log = True
+ self.XMem.eval()
+ return self
+
+ def test(self):
+ self._is_train = False
+ self._do_log = False
+ self.XMem.eval()
+ return self
+
diff --git a/Make-A-Protagonist/experts/XMem/util/__init__.py b/Make-A-Protagonist/experts/XMem/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/Make-A-Protagonist/experts/XMem/util/configuration.py b/Make-A-Protagonist/experts/XMem/util/configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..890956b989ef9b2055b1faea7a850d335488a20d
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/configuration.py
@@ -0,0 +1,135 @@
+from argparse import ArgumentParser
+
+
+def none_or_default(x, default):
+ return x if x is not None else default
+
+class Configuration():
+ def parse(self, unknown_arg_ok=False):
+ parser = ArgumentParser()
+
+ # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
+ parser.add_argument('--benchmark', action='store_true')
+ parser.add_argument('--no_amp', action='store_true')
+
+ # Data parameters
+ parser.add_argument('--static_root', help='Static training data root', default='../static')
+ parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
+ parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube')
+ parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS')
+ parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16)
+
+ parser.add_argument('--key_dim', default=64, type=int)
+ parser.add_argument('--value_dim', default=512, type=int)
+ parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int)
+
+ parser.add_argument('--deep_update_prob', default=0.2, type=float)
+
+ parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02')
+
+ """
+ Stage-specific learning parameters
+ Batch sizes are effective -- you don't have to scale them when you scale the number processes
+ """
+ # Stage 0, static images
+ parser.add_argument('--s0_batch_size', default=16, type=int)
+ parser.add_argument('--s0_iterations', default=150000, type=int)
+ parser.add_argument('--s0_finetune', default=0, type=int)
+ parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
+ parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float)
+ parser.add_argument('--s0_num_ref_frames', default=2, type=int)
+ parser.add_argument('--s0_num_frames', default=3, type=int)
+ parser.add_argument('--s0_start_warm', default=20000, type=int)
+ parser.add_argument('--s0_end_warm', default=70000, type=int)
+
+ # Stage 1, BL30K
+ parser.add_argument('--s1_batch_size', default=8, type=int)
+ parser.add_argument('--s1_iterations', default=250000, type=int)
+ # fine-tune means fewer augmentations to train the sensory memory
+ parser.add_argument('--s1_finetune', default=0, type=int)
+ parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int)
+ parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float)
+ parser.add_argument('--s1_num_ref_frames', default=3, type=int)
+ parser.add_argument('--s1_num_frames', default=8, type=int)
+ parser.add_argument('--s1_start_warm', default=20000, type=int)
+ parser.add_argument('--s1_end_warm', default=70000, type=int)
+
+ # Stage 2, DAVIS+YoutubeVOS, longer
+ parser.add_argument('--s2_batch_size', default=8, type=int)
+ parser.add_argument('--s2_iterations', default=150000, type=int)
+ # fine-tune means fewer augmentations to train the sensory memory
+ parser.add_argument('--s2_finetune', default=10000, type=int)
+ parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int)
+ parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float)
+ parser.add_argument('--s2_num_ref_frames', default=3, type=int)
+ parser.add_argument('--s2_num_frames', default=8, type=int)
+ parser.add_argument('--s2_start_warm', default=20000, type=int)
+ parser.add_argument('--s2_end_warm', default=70000, type=int)
+
+ # Stage 3, DAVIS+YoutubeVOS, shorter
+ parser.add_argument('--s3_batch_size', default=8, type=int)
+ parser.add_argument('--s3_iterations', default=100000, type=int)
+ # fine-tune means fewer augmentations to train the sensory memory
+ parser.add_argument('--s3_finetune', default=10000, type=int)
+ parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int)
+ parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
+ parser.add_argument('--s3_num_ref_frames', default=3, type=int)
+ parser.add_argument('--s3_num_frames', default=8, type=int)
+ parser.add_argument('--s3_start_warm', default=20000, type=int)
+ parser.add_argument('--s3_end_warm', default=70000, type=int)
+
+ parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float)
+ parser.add_argument('--weight_decay', default=0.05, type=float)
+
+ # Loading
+ parser.add_argument('--load_network', help='Path to pretrained network weight only')
+ parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such')
+
+ # Logging information
+ parser.add_argument('--log_text_interval', default=100, type=int)
+ parser.add_argument('--log_image_interval', default=1000, type=int)
+ parser.add_argument('--save_network_interval', default=25000, type=int)
+ parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
+ parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL')
+ parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true')
+
+ # # Multiprocessing parameters, not set by users
+ # parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process')
+
+ if unknown_arg_ok:
+ args, _ = parser.parse_known_args()
+ self.args = vars(args)
+ else:
+ self.args = vars(parser.parse_args())
+
+ self.args['amp'] = not self.args['no_amp']
+
+ # check if the stages are valid
+ stage_to_perform = list(self.args['stages'])
+ for s in stage_to_perform:
+ if s not in ['0', '1', '2', '3']:
+ raise NotImplementedError
+
+ def get_stage_parameters(self, stage):
+ parameters = {
+ 'batch_size': self.args['s%s_batch_size'%stage],
+ 'iterations': self.args['s%s_iterations'%stage],
+ 'finetune': self.args['s%s_finetune'%stage],
+ 'steps': self.args['s%s_steps'%stage],
+ 'lr': self.args['s%s_lr'%stage],
+ 'num_ref_frames': self.args['s%s_num_ref_frames'%stage],
+ 'num_frames': self.args['s%s_num_frames'%stage],
+ 'start_warm': self.args['s%s_start_warm'%stage],
+ 'end_warm': self.args['s%s_end_warm'%stage],
+ }
+
+ return parameters
+
+ def __getitem__(self, key):
+ return self.args[key]
+
+ def __setitem__(self, key, value):
+ self.args[key] = value
+
+ def __str__(self):
+ return str(self.args)
diff --git a/Make-A-Protagonist/experts/XMem/util/davis_subset.txt b/Make-A-Protagonist/experts/XMem/util/davis_subset.txt
new file mode 100644
index 0000000000000000000000000000000000000000..875c2409d2cc4cfc4491ebf7703cb432b26678d8
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/davis_subset.txt
@@ -0,0 +1,60 @@
+bear
+bmx-bumps
+boat
+boxing-fisheye
+breakdance-flare
+bus
+car-turn
+cat-girl
+classic-car
+color-run
+crossing
+dance-jump
+dancing
+disc-jockey
+dog-agility
+dog-gooses
+dogs-scale
+drift-turn
+drone
+elephant
+flamingo
+hike
+hockey
+horsejump-low
+kid-football
+kite-walk
+koala
+lady-running
+lindy-hop
+longboard
+lucia
+mallard-fly
+mallard-water
+miami-surf
+motocross-bumps
+motorbike
+night-race
+paragliding
+planes-water
+rallye
+rhino
+rollerblade
+schoolgirls
+scooter-board
+scooter-gray
+sheep
+skate-park
+snowboard
+soccerball
+stroller
+stunt
+surf
+swing
+tennis
+tractor-sand
+train
+tuk-tuk
+upside-down
+varanus-cage
+walking
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/util/image_saver.py b/Make-A-Protagonist/experts/XMem/util/image_saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..c43d9de68b5c2fbe5b690c8aa59d073d3c217d19
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/image_saver.py
@@ -0,0 +1,136 @@
+import cv2
+import numpy as np
+
+import torch
+from dataset.range_transform import inv_im_trans
+from collections import defaultdict
+
+def tensor_to_numpy(image):
+ image_np = (image.numpy() * 255).astype('uint8')
+ return image_np
+
+def tensor_to_np_float(image):
+ image_np = image.numpy().astype('float32')
+ return image_np
+
+def detach_to_cpu(x):
+ return x.detach().cpu()
+
+def transpose_np(x):
+ return np.transpose(x, [1,2,0])
+
+def tensor_to_gray_im(x):
+ x = detach_to_cpu(x)
+ x = tensor_to_numpy(x)
+ x = transpose_np(x)
+ return x
+
+def tensor_to_im(x):
+ x = detach_to_cpu(x)
+ x = inv_im_trans(x).clamp(0, 1)
+ x = tensor_to_numpy(x)
+ x = transpose_np(x)
+ return x
+
+# Predefined key <-> caption dict
+key_captions = {
+ 'im': 'Image',
+ 'gt': 'GT',
+}
+
+"""
+Return an image array with captions
+keys in dictionary will be used as caption if not provided
+values should contain lists of cv2 images
+"""
+def get_image_array(images, grid_shape, captions={}):
+ h, w = grid_shape
+ cate_counts = len(images)
+ rows_counts = len(next(iter(images.values())))
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+
+ output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8)
+ col_cnt = 0
+ for k, v in images.items():
+
+ # Default as key value itself
+ caption = captions.get(k, k)
+
+ # Handles new line character
+ dy = 40
+ for i, line in enumerate(caption.split('\n')):
+ cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy),
+ font, 0.8, (255,255,255), 2, cv2.LINE_AA)
+
+ # Put images
+ for row_cnt, img in enumerate(v):
+ im_shape = img.shape
+ if len(im_shape) == 2:
+ img = img[..., np.newaxis]
+
+ img = (img * 255).astype('uint8')
+
+ output_image[(col_cnt+0)*w:(col_cnt+1)*w,
+ (row_cnt+1)*h:(row_cnt+2)*h, :] = img
+
+ col_cnt += 1
+
+ return output_image
+
+def base_transform(im, size):
+ im = tensor_to_np_float(im)
+ if len(im.shape) == 3:
+ im = im.transpose((1, 2, 0))
+ else:
+ im = im[:, :, None]
+
+ # Resize
+ if im.shape[1] != size:
+ im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
+
+ return im.clip(0, 1)
+
+def im_transform(im, size):
+ return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
+
+def mask_transform(mask, size):
+ return base_transform(detach_to_cpu(mask), size=size)
+
+def out_transform(mask, size):
+ return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
+
+def pool_pairs(images, size, num_objects):
+ req_images = defaultdict(list)
+
+ b, t = images['rgb'].shape[:2]
+
+ # limit the number of images saved
+ b = min(2, b)
+
+ # find max num objects
+ max_num_objects = max(num_objects[:b])
+
+ GT_suffix = ''
+ for bi in range(b):
+ GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
+
+ for bi in range(b):
+ for ti in range(t):
+ req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size))
+ for oi in range(max_num_objects):
+ if ti == 0 or oi >= num_objects[bi]:
+ req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
+ # req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
+ # req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
+ else:
+ req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size))
+ # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size))
+ # req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size))
+ # req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size))
+ req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size))
+ # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape)
+ # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape)
+
+
+ return get_image_array(req_images, size, key_captions)
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/util/load_subset.py b/Make-A-Protagonist/experts/XMem/util/load_subset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3191f4fef05cec04a11eafdfa42b34b98a35549e
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/load_subset.py
@@ -0,0 +1,16 @@
+"""
+load_subset.py - Presents a subset of data
+DAVIS - only the training set
+YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
+"""
+
+
+def load_sub_davis(path='util/davis_subset.txt'):
+ with open(path, mode='r') as f:
+ subset = set(f.read().splitlines())
+ return subset
+
+def load_sub_yv(path='util/yv_subset.txt'):
+ with open(path, mode='r') as f:
+ subset = set(f.read().splitlines())
+ return subset
diff --git a/Make-A-Protagonist/experts/XMem/util/log_integrator.py b/Make-A-Protagonist/experts/XMem/util/log_integrator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4b26d53de98b16e145090bcddf2041a3f2d1394
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/log_integrator.py
@@ -0,0 +1,80 @@
+"""
+Integrate numerical values for some iterations
+Typically used for loss computation / logging to tensorboard
+Call finalize and create a new Integrator when you want to display/log
+"""
+
+import torch
+
+
+class Integrator:
+ def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
+ self.values = {}
+ self.counts = {}
+ self.hooks = [] # List is used here to maintain insertion order
+
+ self.logger = logger
+
+ self.distributed = distributed
+ self.local_rank = local_rank
+ self.world_size = world_size
+
+ def add_tensor(self, key, tensor):
+ if key not in self.values:
+ self.counts[key] = 1
+ if type(tensor) == float or type(tensor) == int:
+ self.values[key] = tensor
+ else:
+ self.values[key] = tensor.mean().item()
+ else:
+ self.counts[key] += 1
+ if type(tensor) == float or type(tensor) == int:
+ self.values[key] += tensor
+ else:
+ self.values[key] += tensor.mean().item()
+
+ def add_dict(self, tensor_dict):
+ for k, v in tensor_dict.items():
+ self.add_tensor(k, v)
+
+ def add_hook(self, hook):
+ """
+ Adds a custom hook, i.e. compute new metrics using values in the dict
+ The hook takes the dict as argument, and returns a (k, v) tuple
+ e.g. for computing IoU
+ """
+ if type(hook) == list:
+ self.hooks.extend(hook)
+ else:
+ self.hooks.append(hook)
+
+ def reset_except_hooks(self):
+ self.values = {}
+ self.counts = {}
+
+ # Average and output the metrics
+ def finalize(self, prefix, it, f=None):
+
+ for hook in self.hooks:
+ k, v = hook(self.values)
+ self.add_tensor(k, v)
+
+ for k, v in self.values.items():
+
+ if k[:4] == 'hide':
+ continue
+
+ avg = v / self.counts[k]
+
+ if self.distributed:
+ # Inplace operation
+ avg = torch.tensor(avg).cuda()
+ torch.distributed.reduce(avg, dst=0)
+
+ if self.local_rank == 0:
+ avg = (avg/self.world_size).cpu().item()
+ self.logger.log_metrics(prefix, k, avg, it, f)
+ else:
+ # Simple does it
+ self.logger.log_metrics(prefix, k, avg, it, f)
+
diff --git a/Make-A-Protagonist/experts/XMem/util/logger.py b/Make-A-Protagonist/experts/XMem/util/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0493b934072f8d327a9c3807f8e169b4510bc8d
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/logger.py
@@ -0,0 +1,101 @@
+"""
+Dumps things to tensorboard and console
+"""
+
+import os
+import warnings
+
+import torchvision.transforms as transforms
+from torch.utils.tensorboard import SummaryWriter
+
+
+def tensor_to_numpy(image):
+ image_np = (image.numpy() * 255).astype('uint8')
+ return image_np
+
+def detach_to_cpu(x):
+ return x.detach().cpu()
+
+def fix_width_trunc(x):
+ return ('{:.9s}'.format('{:0.9f}'.format(x)))
+
+class TensorboardLogger:
+ def __init__(self, short_id, id, git_info):
+ self.short_id = short_id
+ if self.short_id == 'NULL':
+ self.short_id = 'DEBUG'
+
+ if id is None:
+ self.no_log = True
+ warnings.warn('Logging has been disbaled.')
+ else:
+ self.no_log = False
+
+ self.inv_im_trans = transforms.Normalize(
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
+ std=[1/0.229, 1/0.224, 1/0.225])
+
+ self.inv_seg_trans = transforms.Normalize(
+ mean=[-0.5/0.5],
+ std=[1/0.5])
+
+ log_path = os.path.join('.', 'saves', '%s' % id)
+ self.logger = SummaryWriter(log_path)
+
+ self.log_string('git', git_info)
+
+ def log_scalar(self, tag, x, step):
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ self.logger.add_scalar(tag, x, step)
+
+ def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
+ tag = l1_tag + '/' + l2_tag
+ text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
+ print(text)
+ if f is not None:
+ f.write(text + '\n')
+ f.flush()
+ self.log_scalar(tag, val, step)
+
+ def log_im(self, tag, x, step):
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ x = detach_to_cpu(x)
+ x = self.inv_im_trans(x)
+ x = tensor_to_numpy(x)
+ self.logger.add_image(tag, x, step)
+
+ def log_cv2(self, tag, x, step):
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ x = x.transpose((2, 0, 1))
+ self.logger.add_image(tag, x, step)
+
+ def log_seg(self, tag, x, step):
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ x = detach_to_cpu(x)
+ x = self.inv_seg_trans(x)
+ x = tensor_to_numpy(x)
+ self.logger.add_image(tag, x, step)
+
+ def log_gray(self, tag, x, step):
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ x = detach_to_cpu(x)
+ x = tensor_to_numpy(x)
+ self.logger.add_image(tag, x, step)
+
+ def log_string(self, tag, x):
+ print(tag, x)
+ if self.no_log:
+ warnings.warn('Logging has been disabled.')
+ return
+ self.logger.add_text(tag, x)
+
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/util/palette.py b/Make-A-Protagonist/experts/XMem/util/palette.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2541659563056b015b3d6e4c2b0accef3b4e831
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/palette.py
@@ -0,0 +1,3 @@
+davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
+
+youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
diff --git a/Make-A-Protagonist/experts/XMem/util/tensor_util.py b/Make-A-Protagonist/experts/XMem/util/tensor_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..05189d38e2b0b0d1d08bd7804b8e43418d6da637
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/tensor_util.py
@@ -0,0 +1,47 @@
+import torch.nn.functional as F
+
+
+def compute_tensor_iu(seg, gt):
+ intersection = (seg & gt).float().sum()
+ union = (seg | gt).float().sum()
+
+ return intersection, union
+
+def compute_tensor_iou(seg, gt):
+ intersection, union = compute_tensor_iu(seg, gt)
+ iou = (intersection + 1e-6) / (union + 1e-6)
+
+ return iou
+
+# STM
+def pad_divide_by(in_img, d):
+ h, w = in_img.shape[-2:]
+
+ if h % d > 0:
+ new_h = h + d - h % d
+ else:
+ new_h = h
+ if w % d > 0:
+ new_w = w + d - w % d
+ else:
+ new_w = w
+ lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
+ lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
+ out = F.pad(in_img, pad_array)
+ return out, pad_array
+
+def unpad(img, pad):
+ if len(img.shape) == 4:
+ if pad[2]+pad[3] > 0:
+ img = img[:,:,pad[2]:-pad[3],:]
+ if pad[0]+pad[1] > 0:
+ img = img[:,:,:,pad[0]:-pad[1]]
+ elif len(img.shape) == 3:
+ if pad[2]+pad[3] > 0:
+ img = img[:,pad[2]:-pad[3],:]
+ if pad[0]+pad[1] > 0:
+ img = img[:,:,pad[0]:-pad[1]]
+ else:
+ raise NotImplementedError
+ return img
\ No newline at end of file
diff --git a/Make-A-Protagonist/experts/XMem/util/yv_subset.txt b/Make-A-Protagonist/experts/XMem/util/yv_subset.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a26e50a7b8e6233bf17c542b540765cd8a1c5716
--- /dev/null
+++ b/Make-A-Protagonist/experts/XMem/util/yv_subset.txt
@@ -0,0 +1,3464 @@
+003234408d
+0043f083b5
+0044fa5fba
+005a527edd
+0065b171f9
+00917dcfc4
+00a23ccf53
+00ad5016a4
+01082ae388
+011ac0a06f
+013099c098
+0155498c85
+01694ad9c8
+017ac35701
+01b80e8e1a
+01baa5a4e1
+01c3111683
+01c4cb5ffe
+01c76f0a82
+01c783268c
+01ed275c6e
+01ff60d1fa
+020cd28cd2
+02264db755
+0248626d9a
+02668dbffa
+0274193026
+02d28375aa
+02f3a5c4df
+031ccc99b1
+0321b18c10
+0348a45bca
+0355e92655
+0358b938c1
+0368107cf1
+0379ddf557
+038b2cc71d
+038c15a5dd
+03a06cc98a
+03a63e187f
+03c95b4dae
+03e2b57b0e
+04194e1248
+0444918a5f
+04460a7a52
+04474174a4
+0450095513
+045f00aed2
+04667fabaa
+04735c5030
+04990d1915
+04d62d9d98
+04f21da964
+04fbad476e
+04fe256562
+0503bf89c9
+0536c9eed0
+054acb238f
+05579ca250
+056c200404
+05774f3a2c
+058a7592c8
+05a0a513df
+05a569d8aa
+05aa652648
+05d7715782
+05e0b0f28f
+05fdbbdd7a
+05ffcfed85
+0630391881
+06840b2bbe
+068f7dce6f
+0693719753
+06ce2b51fb
+06e224798e
+06ee361788
+06fbb3fa2c
+0700264286
+070c918ca7
+07129e14a4
+07177017e9
+07238ffc58
+07353b2a89
+0738493cbf
+075926c651
+075c701292
+0762ea9a30
+07652ee4af
+076f206928
+077d32af19
+079049275c
+07913cdda7
+07a11a35e8
+07ac33b6df
+07b6e8fda8
+07c62c3d11
+07cc1c7d74
+080196ef01
+081207976e
+081ae4fa44
+081d8250cb
+082900c5d4
+0860df21e2
+0866d4c5e3
+0891ac2eb6
+08931bc458
+08aa2705d5
+08c8450db7
+08d50b926c
+08e1e4de15
+08e48c1a48
+08f561c65e
+08feb87790
+09049f6fe3
+092e4ff450
+09338adea8
+093c335ccc
+0970d28339
+0974a213dc
+097b471ed8
+0990941758
+09a348f4fa
+09a6841288
+09c5bad17b
+09c9ce80c7
+09ff54fef4
+0a23765d15
+0a275e7f12
+0a2f2bd294
+0a7a2514aa
+0a7b27fde9
+0a8c467cc3
+0ac8c560ae
+0b1627e896
+0b285c47f6
+0b34ec1d55
+0b5b5e8e5a
+0b68535614
+0b6f9105fc
+0b7dbfa3cb
+0b9cea51ca
+0b9d012be8
+0bcfc4177d
+0bd37b23c1
+0bd864064c
+0c11c6bf7b
+0c26bc77ac
+0c3a04798c
+0c44a9d545
+0c817cc390
+0ca839ee9a
+0cd7ac0ac0
+0ce06e0121
+0cfe974a89
+0d2fcc0dcd
+0d3aad05d2
+0d40b015f4
+0d97fba242
+0d9cc80d7e
+0dab85b6d3
+0db5c427a5
+0dbaf284f1
+0de4923598
+0df28a9101
+0e04f636c4
+0e05f0e232
+0e0930474b
+0e27472bea
+0e30020549
+0e621feb6c
+0e803c7d73
+0e9ebe4e3c
+0e9f2785ec
+0ea68d418b
+0eb403a222
+0ee92053d6
+0eefca067f
+0f17fa6fcb
+0f1ac8e9a3
+0f202e9852
+0f2ab8b1ff
+0f51a78756
+0f5fbe16b0
+0f6072077b
+0f6b69b2f4
+0f6c2163de
+0f74ec5599
+0f9683715b
+0fa7b59356
+0fb173695b
+0fc958cde2
+0fe7b1a621
+0ffcdb491c
+101caff7d4
+1022fe8417
+1032e80b37
+103f501680
+104e64565f
+104f1ab997
+106242403f
+10b31f5431
+10eced835e
+110d26fa3a
+1122c1d16a
+1145b49a5f
+11485838c2
+114e7676ec
+1157472b95
+115ee1072c
+1171141012
+117757b4b8
+1178932d2f
+117cc76bda
+1180cbf814
+1187bbd0e3
+1197e44b26
+119cf20728
+119dd54871
+11a0c3b724
+11a6ba8c94
+11c722a456
+11cbcb0b4d
+11ccf5e99d
+11ce6f452e
+11e53de6f2
+11feabe596
+120cb9514d
+12156b25b3
+122896672d
+1232b2f1d4
+1233ac8596
+1239c87234
+1250423f7c
+1257a1bc67
+125d1b19dd
+126d203967
+1295e19071
+12ad198c54
+12bddb2bcb
+12ec9b93ee
+12eebedc35
+132852e094
+1329409f2a
+13325cfa14
+134d06dbf9
+135625b53d
+13870016f9
+13960b3c84
+13adaad9d9
+13ae097e20
+13e3070469
+13f6a8c20d
+1416925cf2
+142d2621f5
+145d5d7c03
+145fdc3ac5
+1471274fa7
+14a6b5a139
+14c21cea0d
+14dae0dc93
+14f9bd22b5
+14fd28ae99
+15097d5d4e
+150ea711f2
+1514e3563f
+152aaa3a9e
+152b7d3bd7
+15617297cc
+15abbe0c52
+15d1fb3de5
+15f67b0fab
+161eb59aad
+16288ea47f
+164410ce62
+165c3c8cd4
+165c42b41b
+165ec9e22b
+1669502269
+16763cccbb
+16adde065e
+16af445362
+16afd538ad
+16c3fa4d5d
+16d1d65c27
+16e8599e94
+16fe9fb444
+1705796b02
+1724db7671
+17418e81ea
+175169edbb
+17622326fd
+17656bae77
+17b0d94172
+17c220e4f6
+17c7bcd146
+17cb4afe89
+17cd79a434
+17d18604c3
+17d8ca1a37
+17e33f4330
+17f7a6d805
+180abc8378
+183ba3d652
+185bf64702
+18913cc690
+1892651815
+189ac8208a
+189b44e92c
+18ac264b76
+18b245ab49
+18b5cebc34
+18bad52083
+18bb5144d5
+18c6f205c5
+1903f9ea15
+1917b209f2
+191e74c01d
+19367bb94e
+193ffaa217
+19696b67d3
+197f3ab6f3
+1981e763cc
+198afe39ae
+19a6e62b9b
+19b60d5335
+19c00c11f9
+19e061eb88
+19e8bc6178
+19ee80dac6
+1a25a9170a
+1a359a6c1a
+1a3e87c566
+1a5fe06b00
+1a6c0fbd1e
+1a6f3b5a4b
+1a8afbad92
+1a8bdc5842
+1a95752aca
+1a9c131cb7
+1aa3da3ee3
+1ab27ec7ea
+1abf16d21d
+1acd0f993b
+1ad202e499
+1af8d2395d
+1afd39a1fa
+1b2d31306f
+1b3fa67f0e
+1b43fa74b4
+1b73ea9fc2
+1b7e8bb255
+1b8680f8cd
+1b883843c0
+1b8898785b
+1b88ba1aa4
+1b96a498e5
+1bbc4c274f
+1bd87fe9ab
+1c4090c75b
+1c41934f84
+1c72b04b56
+1c87955a3a
+1c9f9eb792
+1ca240fede
+1ca5673803
+1cada35274
+1cb44b920d
+1cd10e62be
+1d3087d5e5
+1d3685150a
+1d6ff083aa
+1d746352a6
+1da256d146
+1da4e956b1
+1daf812218
+1dba687bce
+1dce57d05d
+1de4a9e537
+1dec5446c8
+1dfbe6f586
+1e1a18c45a
+1e1e42529d
+1e4be70796
+1eb60959c8
+1ec8b2566b
+1ecdc2941c
+1ee0ac70ff
+1ef8e17def
+1f1a2a9fc0
+1f1beb8daa
+1f2609ee13
+1f3876f8d0
+1f4ec0563d
+1f64955634
+1f7d31b5b2
+1f8014b7fd
+1f9c7d10f1
+1fa350df76
+1fc9538993
+1fe2f0ec59
+2000c02f9d
+20142b2f05
+201a8d75e5
+2023b3ee4f
+202b767bbc
+203594a418
+2038987336
+2039c3aecb
+204a90d81f
+207bc6cf01
+208833d1d1
+20c6d8b362
+20e3e52e0a
+2117fa0c14
+211bc5d102
+2120d9c3c3
+2125235a49
+21386f5978
+2142af8795
+215dfc0f73
+217bae91e5
+217c0d44e4
+219057c87b
+21d0edbf81
+21df87ad76
+21f1d089f5
+21f4019116
+222597030f
+222904eb5b
+223a0e0657
+223bd973ab
+22472f7395
+224e7c833e
+225aba51d9
+2261d421ea
+2263a8782b
+2268cb1ffd
+2268e93b0a
+2293c99f3f
+22a1141970
+22b13084b2
+22d9f5ab0c
+22f02efe3a
+232c09b75b
+2350d71b4b
+2376440551
+2383d8aafd
+238b84e67f
+238d4b86f6
+238d947c6b
+23993ce90d
+23b0c8a9ab
+23b3beafcc
+23d80299fe
+23f404a9fc
+240118e58a
+2431dec2fd
+24440e0ac7
+2457274dbc
+2465bf515d
+246b142c4d
+247d729e36
+2481ceafeb
+24866b4e6a
+2489d78320
+24ab0b83e8
+24b0868d92
+24b5207cd9
+24ddf05c03
+250116161c
+256ad2e3fc
+256bd83d5e
+256dcc8ab8
+2589956baa
+258b3b33c6
+25ad437e29
+25ae395636
+25c750c6db
+25d2c3fe5d
+25dc80db7c
+25f97e926f
+26011bc28b
+260846ffbe
+260dd9ad33
+267964ee57
+2680861931
+268ac7d3fc
+26b895d91e
+26bc786d4f
+26ddd2ef12
+26de3d18ca
+26f7784762
+2703e52a6a
+270ed80c12
+2719b742ab
+272f4163d0
+27303333e1
+27659fa7d6
+279214115d
+27a5f92a9c
+27cf2af1f3
+27f0d5f8a2
+28075f33c1
+281629cb41
+282b0d51f5
+282fcab00b
+28449fa0dc
+28475208ca
+285580b7c4
+285b69e223
+288c117201
+28a8eb9623
+28bf9c3cf3
+28c6b8f86a
+28c972dacd
+28d9fa6016
+28e392de91
+28f4a45190
+298c844fc9
+29a0356a2b
+29d779f9e3
+29dde5f12b
+29de7b6579
+29e630bdd0
+29f2332d30
+2a18873352
+2a3824ff31
+2a559dd27f
+2a5c09acbd
+2a63eb1524
+2a6a30a4ea
+2a6d9099d1
+2a821394e3
+2a8c5b1342
+2abc8d66d2
+2ac9ef904a
+2b08f37364
+2b351bfd7d
+2b659a49d7
+2b69ee5c26
+2b6c30bbbd
+2b88561cf2
+2b8b14954e
+2ba621c750
+2bab50f9a7
+2bb00c2434
+2bbde474ef
+2bdd82fb86
+2be06fb855
+2bf545c2f5
+2bffe4cf9a
+2c04b887b7
+2c05209105
+2c0ad8cf39
+2c11fedca8
+2c1a94ebfb
+2c1e8c8e2f
+2c29fabcf1
+2c2c076c01
+2c3ea7ee7d
+2c41fa0648
+2c44bb6d1c
+2c54cfbb78
+2c5537eddf
+2c6e63b7de
+2cb10c6a7e
+2cbcd5ccd1
+2cc5d9c5f6
+2cd01cf915
+2cdbf5f0a7
+2ce660f123
+2cf114677e
+2d01eef98e
+2d03593bdc
+2d183ac8c4
+2d33ad3935
+2d3991d83e
+2d4333577b
+2d4d015c64
+2d8f5e5025
+2d900bdb8e
+2d9a1a1d49
+2db0576a5c
+2dc0838721
+2dcc417f82
+2df005b843
+2df356de14
+2e00393d96
+2e03b8127a
+2e0f886168
+2e2bf37e6d
+2e42410932
+2ea78f46e4
+2ebb017a26
+2ee2edba2a
+2efb07554a
+2f17e4fc1e
+2f2c65c2f3
+2f2d9b33be
+2f309c206b
+2f53822e88
+2f53998171
+2f5b0c89b1
+2f680909e6
+2f710f66bd
+2f724132b9
+2f7e3517ae
+2f96f5fc6f
+2f97d9fecb
+2fbfa431ec
+2fc9520b53
+2fcd9f4c62
+2feb30f208
+2ff7f5744f
+30085a2cc6
+30176e3615
+301f72ee11
+3026bb2f61
+30318465dc
+3054ca937d
+306121e726
+3064ad91e8
+307444a47f
+307bbb7409
+30a20194ab
+30c35c64a4
+30dbdb2cd6
+30fc77d72f
+310021b58b
+3113140ee8
+3150b2ee57
+31539918c4
+318dfe2ce2
+3193da4835
+319f725ad9
+31bbd0d793
+322505c47f
+322b237865
+322da43910
+3245e049fb
+324c4c38f6
+324e35111a
+3252398f09
+327dc4cabf
+328d918c7d
+3290c0de97
+3299ae3116
+32a7cd687b
+33098cedb4
+3332334ac4
+334cb835ac
+3355e056eb
+33639a2847
+3373891cdc
+337975816b
+33e29d7e91
+34046fe4f2
+3424f58959
+34370a710f
+343bc6a65a
+3450382ef7
+3454303a08
+346aacf439
+346e92ff37
+34a5ece7dd
+34b109755a
+34d1b37101
+34dd2c70a7
+34efa703df
+34fbee00a6
+3504df2fda
+35195a56a1
+351c822748
+351cfd6bc5
+3543d8334c
+35573455c7
+35637a827f
+357a710863
+358bf16f9e
+35ab34cc34
+35c6235b8d
+35d01a438a
+3605019d3b
+3609bc3f88
+360e25da17
+36299c687c
+362c5bc56e
+3649228783
+365b0501ea
+365f459863
+369893f3ad
+369c9977e1
+369dde050a
+36c7dac02f
+36d5b1493b
+36f5cc68fd
+3735480d18
+374b479880
+375a49d38f
+375a5c0e09
+376bda9651
+377db65f60
+37c19d1087
+37d4ae24fc
+37ddce7f8b
+37e10d33af
+37e45c6247
+37fa0001e8
+3802d458c0
+382caa3cb4
+383bb93111
+388843df90
+38924f4a7f
+38b00f93d7
+38c197c10e
+38c9c3d801
+38eb2bf67f
+38fe9b3ed1
+390352cced
+390c51b987
+390ca6f1d6
+392bc0f8a1
+392ecb43bd
+3935291688
+3935e63b41
+394454fa9c
+394638fc8b
+39545e20b7
+397abeae8f
+3988074b88
+398f5d5f19
+39bc49a28c
+39befd99fb
+39c3c7bf55
+39d584b09f
+39f6f6ffb1
+3a079fb484
+3a0d3a81b7
+3a1d55d22b
+3a20a7583e
+3a2c1f66e5
+3a33f4d225
+3a3bf84b13
+3a4565e5ec
+3a4e32ed5e
+3a7ad86ce0
+3a7bdde9b8
+3a98867cbe
+3aa3f1c9e8
+3aa7fce8b6
+3aa876887d
+3ab807ded6
+3ab9b1a85a
+3adac8d7da
+3ae1a4016f
+3ae2deaec2
+3ae81609d6
+3af847e62f
+3b23792b84
+3b3b0af2ee
+3b512dad74
+3b6c7988f6
+3b6e983b5b
+3b74a0fc20
+3b7a50b80d
+3b96d3492f
+3b9ad0c5a9
+3b9ba0894a
+3bb4e10ed7
+3bd9a9b515
+3beef45388
+3c019c0a24
+3c090704aa
+3c2784fc0d
+3c47ab95f8
+3c4db32d74
+3c5ff93faf
+3c700f073e
+3c713cbf2f
+3c8320669c
+3c90d225ee
+3cadbcc404
+3cb9be84a5
+3cc37fd487
+3cc6f90cb2
+3cd5e035ef
+3cdf03531b
+3cdf828f59
+3d254b0bca
+3d5aeac5ba
+3d690473e1
+3d69fed2fb
+3d8997aeb6
+3db0d6b07e
+3db1ddb8cf
+3db907ac77
+3dcbc0635b
+3dd48ed55f
+3de4ac4ec4
+3decd63d88
+3e04a6be11
+3e108fb65a
+3e1448b01c
+3e16c19634
+3e2845307e
+3e38336da5
+3e3a819865
+3e3e4be915
+3e680622d7
+3e7d2aeb07
+3e7d8f363d
+3e91f10205
+3ea4c49bbe
+3eb39d11ab
+3ec273c8d5
+3ed3f91271
+3ee062a2fd
+3eede9782c
+3ef2fa99cb
+3efc6e9892
+3f0b0dfddd
+3f0c860359
+3f18728586
+3f3b15f083
+3f45a470ad
+3f4f3bc803
+3fd96c5267
+3fea675fab
+3fee8cbc9f
+3fff16d112
+401888b36c
+4019231330
+402316532d
+402680df52
+404d02e0c0
+40709263a8
+4083cfbe15
+40a96c5cb1
+40b8e50f82
+40f4026bf5
+4100b57a3a
+41059fdd0b
+41124e36de
+4122aba5f9
+413bab0f0d
+4164faee0b
+418035eec9
+4182d51532
+418bb97e10
+41a34c20e7
+41dab05200
+41ff6d5e2a
+420caf0859
+42264230ba
+425a0c96e0
+42da96b87c
+42eb5a5b0f
+42f17cd14d
+42f5c61c49
+42ffdcdee9
+432f9884f9
+43326d9940
+4350f3ab60
+4399ffade3
+43a6c21f37
+43b5555faa
+43d63b752a
+4416bdd6ac
+4444753edd
+444aa274e7
+444d4e0596
+446b8b5f7a
+4478f694bb
+44b1da0d87
+44b4dad8c9
+44b5ece1b9
+44d239b24e
+44eaf8f51e
+44f4f57099
+44f7422af2
+450787ac97
+4523656564
+4536c882e5
+453b65daa4
+454f227427
+45636d806a
+456fb9362e
+457e717a14
+45a89f35e1
+45bf0e947d
+45c36a9eab
+45d9fc1357
+45f8128b97
+4607f6c03c
+46146dfd39
+4620e66b1e
+4625f3f2d3
+462b22f263
+4634736113
+463c0f4fdd
+46565a75f8
+46630b55ae
+466839cb37
+466ba4ae0c
+4680236c9d
+46bf4e8709
+46e18e42f1
+46f5093c59
+47269e0499
+472da1c484
+47354fab09
+4743bb84a7
+474a796272
+4783d2ab87
+479cad5da3
+479f5d7ef6
+47a05fbd1d
+4804ee2767
+4810c3fbca
+482fb439c2
+48375af288
+484ab44de4
+485f3944cd
+4867b84887
+486a8ac57e
+486e69c5bd
+48812cf33e
+4894b3b9ea
+48bd66517d
+48d83b48a4
+49058178b8
+4918d10ff0
+4932911f80
+49405b7900
+49972c2d14
+499bf07002
+49b16e9377
+49c104258e
+49c879f82d
+49e7326789
+49ec3e406a
+49fbf0c98a
+4a0255c865
+4a088fe99a
+4a341402d0
+4a3471bdf5
+4a4b50571c
+4a50f3d2e9
+4a6e3faaa1
+4a7191f08a
+4a86fcfc30
+4a885fa3ef
+4a8af115de
+4aa2e0f865
+4aa9d6527f
+4abb74bb52
+4ae13de1cd
+4af8cb323f
+4b02c272b3
+4b19c529fb
+4b2974eff4
+4b3154c159
+4b54d2587f
+4b556740ff
+4b67aa9ef6
+4b97cc7b8d
+4baa1ed4aa
+4bc8c676bb
+4beaea4dbe
+4bf5763d24
+4bffa92b67
+4c25dfa8ec
+4c397b6fd4
+4c51e75d66
+4c7710908f
+4c9b5017be
+4ca2ffc361
+4cad2e93bc
+4cd427b535
+4cd9a4b1ef
+4cdfe3c2b2
+4cef87b649
+4cf208e9b3
+4cf5bc3e60
+4cfdd73249
+4cff5c9e42
+4d26d41091
+4d5c23c554
+4d67c59727
+4d983cad9f
+4da0d00b55
+4daa179861
+4dadd57153
+4db117e6c5
+4de4ce4dea
+4dfaee19e5
+4dfdd7fab0
+4e3f346aa5
+4e49c2a9c7
+4e4e06a749
+4e70279712
+4e72856cc7
+4e752f8075
+4e7a28907f
+4e824b9247
+4e82b1df57
+4e87a639bc
+4ea77bfd15
+4eb6fc23a2
+4ec9da329e
+4efb9a0720
+4f062fbc63
+4f35be0e0b
+4f37e86797
+4f414dd6e7
+4f424abded
+4f470cc3ae
+4f601d255a
+4f7386a1ab
+4f824d3dcd
+4f827b0751
+4f8db33a13
+4fa160f8a3
+4fa9c30a45
+4facd8f0e8
+4fca07ad01
+4fded94004
+4fdfef4dea
+4feb3ac01f
+4fffec8479
+500c835a86
+50168342bf
+50243cffdc
+5031d5a036
+504dd9c0fd
+50568fbcfb
+5069c7c5b3
+508189ac91
+50b6b3d4b7
+50c6f4fe3e
+50cce40173
+50efbe152f
+50f290b95d
+5104aa1fea
+5110dc72c0
+511e8ecd7f
+513aada14e
+5158d6e985
+5161e1fa57
+51794ddd58
+517d276725
+51a597ee04
+51b37b6d97
+51b5dc30a0
+51e85b347b
+51eea1fdac
+51eef778af
+51f384721c
+521cfadcb4
+52355da42f
+5247d4b160
+524b470fd0
+524cee1534
+5252195e8a
+5255c9ca97
+525928f46f
+526df007a7
+529b12de78
+52c7a3d653
+52c8ec0373
+52d225ed52
+52ee406d9e
+52ff1ccd4a
+53143511e8
+5316d11eb7
+53253f2362
+534a560609
+5352c4a70e
+536096501f
+536b17bcea
+5380eaabff
+5390a43a54
+53af427bb2
+53bf5964ce
+53c30110b5
+53cad8e44a
+53d9c45013
+53e274f1b5
+53e32d21ea
+540850e1c7
+540cb31cfe
+541c4da30f
+541d7935d7
+545468262b
+5458647306
+54657855cd
+547b3fb23b
+5497dc3712
+549c56f1d4
+54a4260bb1
+54b98b8d5e
+54e1054b0f
+54e8867b83
+54ebe34f6e
+5519b4ad13
+551acbffd5
+55341f42da
+5566ab97e1
+556c79bbf2
+5589637cc4
+558aa072f0
+559824b6f6
+55c1764e90
+55eda6c77e
+562d173565
+5665c024cb
+566cef4959
+5675d78833
+5678a91bd8
+567a2b4bd0
+569c282890
+56cc449917
+56e71f3e07
+56f09b9d92
+56fc0e8cf9
+571ca79c71
+57243657cf
+57246af7d1
+57427393e9
+574b682c19
+578f211b86
+5790ac295d
+579393912d
+57a344ab1a
+57bd3bcda4
+57bfb7fa4c
+57c010175e
+57c457cc75
+57c7fc2183
+57d5289a01
+58045fde85
+58163c37cd
+582d463e5c
+5851739c15
+585dd0f208
+587250f3c3
+589e4cc1de
+589f65f5d5
+58a07c17d5
+58adc6d8b6
+58b9bcf656
+58c374917e
+58fc75fd42
+5914c30f05
+59323787d5
+5937b08d69
+594065ddd7
+595a0ceea6
+59623ec40b
+597ff7ef78
+598935ef05
+598c2ad3b2
+59a6459751
+59b175e138
+59bf0a149f
+59d53d1649
+59e3e6fae7
+59fe33e560
+5a13a73fe5
+5a25c22770
+5a4a785006
+5a50640995
+5a75f7a1cf
+5a841e59ad
+5a91c5ab6d
+5ab49d9de0
+5aba1057fe
+5abe46ba6d
+5ac7c88d0c
+5aeb95cc7d
+5af15e4fc3
+5afe381ae4
+5b07b4229d
+5b1001cc4f
+5b1df237d2
+5b263013bf
+5b27d19f0b
+5b48ae16c5
+5b5babc719
+5baaebdf00
+5bab55cdbe
+5bafef6e79
+5bd1f84545
+5bddc3ba25
+5bdf7c20d2
+5bf23bc9d3
+5c01f6171a
+5c021681b7
+5c185cff1d
+5c42aba280
+5c44bf8ab6
+5c4c574894
+5c52fa4662
+5c6ea7dac3
+5c74315dc2
+5c7668855e
+5c83e96778
+5ca36173e4
+5cac477371
+5cb0cb1b2f
+5cb0cfb98f
+5cb49a19cf
+5cbf7dc388
+5d0e07d126
+5d1e24b6e3
+5d663000ff
+5da6b2dc5d
+5de9b90f24
+5e08de0ed7
+5e1011df9a
+5e1ce354fd
+5e35512dd7
+5e418b25f9
+5e4849935a
+5e4ee19663
+5e886ef78f
+5e8d00b974
+5e8d59dc31
+5ed838bd5c
+5edda6ee5a
+5ede4d2f7a
+5ede9767da
+5eec4d9fe5
+5eecf07824
+5eef7ed4f4
+5ef5860ac6
+5ef6573a99
+5f1193e72b
+5f29ced797
+5f32cf521e
+5f51876986
+5f6ebe94a9
+5f6f14977c
+5f808d0d2d
+5fb8aded6a
+5fba90767d
+5fd1c7a3df
+5fd3da9f68
+5fee2570ae
+5ff66140d6
+5ff8b85b53
+600803c0f6
+600be7f53e
+6024888af8
+603189a03c
+6057307f6e
+6061ddbb65
+606c86c455
+60c61cc2e5
+60e51ff1ae
+610e38b751
+61344be2f6
+6135e27185
+614afe7975
+614e571886
+614e7078db
+619812a1a7
+61b481a78b
+61c7172650
+61cf7e40d2
+61d08ef5a1
+61da008958
+61ed178ecb
+61f5d1282c
+61fd977e49
+621584cffe
+625817a927
+625892cf0b
+625b89d28a
+629995af95
+62a0840bb5
+62ad6e121c
+62d6ece152
+62ede7b2da
+62f025e1bc
+6316faaebc
+63281534dc
+634058dda0
+6353f09384
+6363c87314
+636e4872e0
+637681cd6b
+6376d49f31
+6377809ec2
+63936d7de5
+639bddef11
+63d37e9fd3
+63d90c2bae
+63e544a5d6
+63ebbcf874
+63fff40b31
+6406c72e4d
+64148128be
+6419386729
+643092bc41
+644081b88d
+64453cf61d
+644bad9729
+6454f548fd
+645913b63a
+64750b825f
+64a43876b7
+64dd6c83e3
+64e05bf46e
+64f55f1478
+650b0165e4
+651066ed39
+652b67d960
+653821d680
+6538d00d73
+65866dce22
+6589565c8c
+659832db64
+65ab7e1d98
+65b7dda462
+65bd5eb4f5
+65dcf115ab
+65e9825801
+65f9afe51c
+65ff12bcb5
+666b660284
+6671643f31
+668364b372
+66852243cb
+6693a52081
+669b572898
+66e98e78f5
+670f12e88f
+674c12c92d
+675c27208a
+675ed3e1ca
+67741db50a
+678a2357eb
+67b0f4d562
+67cfbff9b1
+67e717d6bd
+67ea169a3b
+67ea809e0e
+681249baa3
+683de643d9
+6846ac20df
+6848e012ef
+684bcd8812
+684dc1c40c
+685a1fa9cf
+686dafaac9
+68807d8601
+6893778c77
+6899d2dabe
+68a2fad4ab
+68cb45fda3
+68cc4a1970
+68dcb40675
+68ea4a8c3d
+68f6e7fbf0
+68fa8300b4
+69023db81f
+6908ccf557
+691a111e7c
+6927723ba5
+692ca0e1a2
+692eb57b63
+69340faa52
+693cbf0c9d
+6942f684ad
+6944fc833b
+69491c0ebf
+695b61a2b0
+6979b4d83f
+697d4fdb02
+69910460a4
+6997636670
+69a436750b
+69aebf7669
+69b8c17047
+69c67f109f
+69e0e7b868
+69ea9c09d1
+69f0af42a6
+6a078cdcc7
+6a37a91708
+6a42176f2e
+6a48e4aea8
+6a5977be3a
+6a5de0535f
+6a80d2e2e5
+6a96c8815d
+6a986084e2
+6aa8e50445
+6ab9dce449
+6abf0ba6b2
+6acc6049d9
+6adb31756c
+6ade215eb0
+6afb7d50e4
+6afd692f1a
+6b0b1044fe
+6b17c67633
+6b1b6ef28b
+6b1e04d00d
+6b2261888d
+6b25d6528a
+6b3a24395c
+6b685eb75b
+6b79be238c
+6b928b7ba6
+6b9c43c25a
+6ba99cc41f
+6bdab62bcd
+6bf2e853b1
+6bf584200f
+6bf95df2b9
+6c0949c51c
+6c11a5f11f
+6c23d89189
+6c4387daf5
+6c4ce479a4
+6c5123e4bc
+6c54265f16
+6c56848429
+6c623fac5f
+6c81b014e9
+6c99ea7c31
+6c9d29d509
+6c9e3b7d1a
+6ca006e283
+6caeb928d6
+6cb2ee722a
+6cbfd32c5e
+6cc791250b
+6cccc985e0
+6d12e30c48
+6d4bf200ad
+6d6d2b8843
+6d6eea5682
+6d7a3d0c21
+6d7efa9b9e
+6da21f5c91
+6da6adabc0
+6dd2827fbb
+6dd36705b9
+6df3637557
+6dfe55e9e5
+6e1a21ba55
+6e2f834767
+6e36e4929a
+6e4f460caf
+6e618d26b6
+6ead4670f7
+6eaff19b9f
+6eb2e1cd9e
+6eb30b3b5a
+6eca26c202
+6ecad29e52
+6ef0b44654
+6efcfe9275
+6f4789045c
+6f49f522ef
+6f67d7c4c4
+6f96e91d81
+6fc6fce380
+6fc9b44c00
+6fce7f3226
+6fdf1ca888
+702fd8b729
+70405185d2
+7053e4f41e
+707bf4ce41
+7082544248
+708535b72a
+7094ac0f60
+70a6b875fa
+70c3e97e41
+7106b020ab
+711dce6fe2
+7136a4453f
+7143fb084f
+714d902095
+7151c53b32
+715357be94
+7163b8085f
+716df1aa59
+71caded286
+71d2665f35
+71d67b9e19
+71e06dda39
+720b398b9c
+720e3fa04c
+720e7a5f1e
+721bb6f2cb
+722803f4f2
+72552a07c9
+726243a205
+72690ef572
+728cda9b65
+728e81c319
+72a810a799
+72acb8cdf6
+72b01281f9
+72cac683e4
+72cadebbce
+72cae058a5
+72d8dba870
+72e8d1c1ff
+72edc08285
+72f04f1a38
+731b825695
+7320b49b13
+732626383b
+732df1eb05
+73329902ab
+733798921e
+733824d431
+734ea0d7fb
+735a7cf7b9
+7367a42892
+7368d5c053
+73c6ae7711
+73e1852735
+73e4e5cc74
+73eac9156b
+73f8441a88
+7419e2ab3f
+74267f68b9
+7435690c8c
+747c44785c
+747f1b1f2f
+748b2d5c01
+74d4cee0a4
+74ec2b3073
+74ef677020
+750be4c4d8
+75172d4ac8
+75285a7eb1
+75504539c3
+7550949b1d
+7551cbd537
+75595b453d
+7559b4b0ec
+755bd1fbeb
+756f76f74d
+7570ca7f3c
+757a69746e
+757cac96c6
+7584129dc3
+75a058dbcd
+75b09ce005
+75cae39a8f
+75cee6caf0
+75cf58fb2c
+75d5c2f32a
+75eaf5669d
+75f7937438
+75f99bd3b3
+75fa586876
+7613df1f84
+762e1b3487
+76379a3e69
+764271f0f3
+764503c499
+7660005554
+7666351b84
+76693db153
+767856368b
+768671f652
+768802b80d
+76962c7ed2
+76a75f4eee
+76b90809f7
+770a441457
+772a0fa402
+772f2ffc3e
+774f6c2175
+77610860e0
+777e58ff3d
+77920f1708
+7799df28e7
+779e847a9a
+77ba4edc72
+77c834dc43
+77d8aa8691
+77e7f38f4d
+77eea6845e
+7806308f33
+78254660ea
+7828af8bff
+784398620a
+784d201b12
+78613981ed
+78896c6baf
+78aff3ebc0
+78c7c03716
+78d3676361
+78e29dd4c3
+78f1a1a54f
+79208585cd
+792218456c
+7923bad550
+794e6fc49f
+796e6762ce
+797cd21f71
+79921b21c2
+79a5778027
+79bc006280
+79bf95e624
+79d9e00c55
+79e20fc008
+79e9db913e
+79f014085e
+79fcbb433a
+7a13a5dfaa
+7a14bc9a36
+7a3c535f70
+7a446a51e9
+7a56e759c5
+7a5f46198d
+7a626ec98d
+7a802264c4
+7a8b5456ca
+7abdff3086
+7aecf9f7ac
+7b0fd09c28
+7b18b3db87
+7b39fe7371
+7b49e03d4c
+7b5388c9f1
+7b5cf7837f
+7b733d31d8
+7b74fd7b98
+7b918ccb8a
+7ba3ce3485
+7bb0abc031
+7bb5bb25cd
+7bb7dac673
+7bc7761b8c
+7bf3820566
+7c03a18ec1
+7c078f211b
+7c37d7991a
+7c4ec17eff
+7c649c2aaf
+7c73340ab7
+7c78a2266d
+7c88ce3c5b
+7ca6843a72
+7cc9258dee
+7cec7296ae
+7d0ffa68a4
+7d11b4450f
+7d1333fcbe
+7d18074fef
+7d18c8c716
+7d508fb027
+7d55f791f0
+7d74e3c2f6
+7d783f67a9
+7d83a5d854
+7dd409947e
+7de45f75e5
+7e0cd25696
+7e1922575c
+7e1e3bbcc1
+7e24023274
+7e2f212fd3
+7e6d1cc1f4
+7e7cdcb284
+7e9b6bef69
+7ea5b49283
+7eb2605d96
+7eb26b8485
+7ecd1f0c69
+7f02b3cfe2
+7f1723f0d5
+7f21063c3a
+7f3658460e
+7f54132e48
+7f559f9d4a
+7f5faedf8b
+7f838baf2b
+7fa5f527e3
+7ff84d66dd
+802b45c8c4
+804382b1ad
+804c558adb
+804f6338a4
+8056117b89
+806b6223ab
+8088bda461
+80b790703b
+80c4a94706
+80ce2e351b
+80db581acd
+80e12193df
+80e41b608f
+80f16b016d
+81541b3725
+8175486e6a
+8179095000
+8193671178
+81a58d2c6b
+81aa1286fb
+81dffd30fb
+8200245704
+823e7a86e8
+824973babb
+824ca5538f
+827171a845
+8273a03530
+827cf4f886
+82b865c7dd
+82c1517708
+82d15514d6
+82e117b900
+82fec06574
+832b5ef379
+83424c9fbf
+8345358fb8
+834b50b31b
+835e3b67d7
+836ea92b15
+837c618777
+838eb3bd89
+839381063f
+839bc71489
+83a8151377
+83ae88d217
+83ca8bcad0
+83ce590d7f
+83d3130ba0
+83d40bcba5
+83daba503a
+83de906ec0
+84044f37f3
+84696b5a5e
+84752191a3
+847eeeb2e0
+848e7835a0
+84a4b29286
+84a4bf147d
+84be115c09
+84d95c4350
+84e0922cf7
+84f0cfc665
+8515f6db22
+851f2f32c1
+852a4d6067
+854c48b02a
+857a387c86
+859633d56a
+85a4f4a639
+85ab85510c
+85b1eda0d9
+85dc1041c6
+85e081f3c7
+85f75187ad
+8604bb2b75
+860745b042
+863b4049d7
+8643de22d0
+8647d06439
+864ffce4fe
+8662d9441a
+8666521b13
+868d6a0685
+869fa45998
+86a40b655d
+86a8ae4223
+86b2180703
+86c85d27df
+86d3755680
+86e61829a1
+871015806c
+871e409c5c
+8744b861ce
+8749369ba0
+878a299541
+8792c193a0
+8799ab0118
+87d1f7d741
+882b9e4500
+885673ea17
+8859dedf41
+8873ab2806
+887a93b198
+8883e991a9
+8891aa6dfa
+8899d8cbcd
+88b8274d67
+88d3b80af6
+88ede83da2
+88f345941b
+890976d6da
+8909bde9ab
+8929c7d5d9
+89363acf76
+89379487e0
+8939db6354
+893f658345
+8953138465
+895c96d671
+895cbf96f9
+895e8b29a7
+898fa256c8
+89986c60be
+89b874547b
+89bdb021d5
+89c802ff9c
+89d6336c2b
+89ebb27334
+8a27e2407c
+8a31f7bca5
+8a4a2fc105
+8a5d6c619c
+8a75ad7924
+8aa817e4ed
+8aad0591eb
+8aca214360
+8ae168c71b
+8b0cfbab97
+8b3645d826
+8b3805dbd4
+8b473f0f5d
+8b4f6d1186
+8b4fb018b7
+8b518ee936
+8b523bdfd6
+8b52fb5fba
+8b91036e5c
+8b99a77ac5
+8ba04b1e7b
+8ba782192f
+8bbeaad78b
+8bd1b45776
+8bd7a2dda6
+8bdb091ccf
+8be56f165d
+8be950d00f
+8bf84e7d45
+8bffc4374b
+8bfff50747
+8c09867481
+8c0a3251c3
+8c3015cccb
+8c469815cf
+8c9ccfedc7
+8ca1af9f3c
+8ca3f6e6c1
+8ca6a4f60f
+8cac6900fe
+8cba221a1e
+8cbbe62ccd
+8d064b29e2
+8d167e7c08
+8d4ab94e1c
+8d81f6f899
+8d87897d66
+8dcccd2bd2
+8dcfb878a8
+8dd3ab71b9
+8dda6bf10f
+8ddd51ca94
+8dea22c533
+8def5bd3bf
+8e1848197c
+8e3a83cf2d
+8e478e73f3
+8e98ae3c84
+8ea6687ab0
+8eb0d315c1
+8ec10891f9
+8ec3065ec2
+8ecf51a971
+8eddbab9f7
+8ee198467a
+8ee2368f40
+8ef595ce82
+8f0a653ad7
+8f1204a732
+8f1600f7f6
+8f16366707
+8f1ce0a411
+8f2e05e814
+8f320d0e09
+8f3b4a84ad
+8f3fdad3da
+8f5d3622d8
+8f62a2c633
+8f81c9405a
+8f8c974d53
+8f918598b6
+8ff61619f6
+9002761b41
+90107941f3
+90118a42ee
+902bc16b37
+903e87e0d6
+9041a0f489
+9047bf3222
+9057bfa502
+90617b0954
+9076f4b6db
+9077e69b08
+909655b4a6
+909c2eca88
+909dbd1b76
+90bc4a319a
+90c7a87887
+90cc785ddd
+90d300f09b
+9101ea9b1b
+9108130458
+911ac9979b
+9151cad9b5
+9153762797
+91634ee0c9
+916942666f
+9198cfb4ea
+919ac864d6
+91b67d58d4
+91bb8df281
+91be106477
+91c33b4290
+91ca7dd9f3
+91d095f869
+91f107082e
+920329dd5e
+920c959958
+92128fbf4b
+9223dacb40
+923137bb7f
+9268e1f88a
+927647fe08
+9276f5ba47
+92a28cd233
+92b5c1fc6d
+92c46be756
+92dabbe3a0
+92e3159361
+92ebab216a
+934bdc2893
+9359174efc
+935d97dd2f
+935feaba1b
+93901858ee
+939378f6d6
+939bdf742e
+93a22bee7e
+93da9aeddf
+93e2feacce
+93e6f1fdf9
+93e811e393
+93e85d8fd3
+93f623d716
+93ff35e801
+94031f12f2
+94091a4873
+94125907e3
+9418653742
+941c870569
+94209c86f0
+9437c715eb
+9445c3eca2
+9467c8617c
+946d71fb5d
+948f3ae6fb
+9498baa359
+94a33abeab
+94bf1af5e3
+94cf3a8025
+94db712ac8
+94e4b66cff
+94e76cbaf6
+950be91db1
+952058e2d0
+952633c37f
+952ec313fe
+9533fc037c
+9574b81269
+9579b73761
+957f7bc48b
+958073d2b0
+9582e0eb33
+9584092d0b
+95b58b8004
+95bd88da55
+95f74a9959
+962781c601
+962f045bf5
+964ad23b44
+967b90590e
+967bffe201
+96825c4714
+968492136a
+9684ef9d64
+968c41829e
+96a856ef9a
+96dfc49961
+96e1a5b4f8
+96e6ff0917
+96fb88e9d7
+96fbe5fc23
+96fc924050
+9715cc83dc
+9720eff40f
+972c187c0d
+97476eb38d
+97659ed431
+9773492949
+97756b264f
+977bff0d10
+97ab569ff3
+97ba838008
+97d9d008c7
+97e59f09fa
+97eb642e56
+98043e2d14
+981ff580cf
+983e66cbfc
+984f0f1c36
+98595f2bb4
+985c3be474
+9869a12362
+986b5a5e18
+9877af5063
+98911292da
+9893a3cf77
+9893d9202d
+98a8b06e7f
+98ac6f93d9
+98b6974d12
+98ba3c9417
+98c7c00a19
+98d044f206
+98e909f9d1
+98fe7f0410
+990f2742c7
+992bd0779a
+994b9b47ba
+9955b76bf5
+9966f3adac
+997117a654
+999d53d841
+99c04108d3
+99c4277aee
+99c6b1acf2
+99dc8bb20b
+99fcba71e5
+99fecd4efb
+9a02c70ba2
+9a08e7a6f8
+9a2f2c0f86
+9a3254a76e
+9a3570a020
+9a39112493
+9a4e9fd399
+9a50af4bfb
+9a68631d24
+9a72318dbf
+9a767493b7
+9a7fc1548b
+9a84ccf6a7
+9a9c0e15b7
+9adf06d89b
+9b22b54ee4
+9b473fc8fe
+9b4f081782
+9b997664ba
+9bc454e109
+9bccfd04de
+9bce4583a2
+9bebf1b87f
+9bfc50d261
+9c166c86ff
+9c293ef4d7
+9c29c047b0
+9c3bc2e2a7
+9c3ce23bd1
+9c404cac0c
+9c5180d23a
+9c7feca6e4
+9caa49d3ff
+9cb2f1b646
+9ce6f765c3
+9cfee34031
+9d01f08ec6
+9d04c280b8
+9d12ceaddc
+9d15f8cb3c
+9d2101e9bf
+9d407c3aeb
+9ddefc6165
+9df0b1e298
+9e16f115d8
+9e249b4982
+9e29b1982c
+9e493e4773
+9e4c752cd0
+9e4de40671
+9e6319faeb
+9e6ddbb52d
+9eadcea74f
+9ecec5f8ea
+9efb47b595
+9f30bfe61e
+9f3734c3a4
+9f5b858101
+9f66640cda
+9f913803e9
+9f97bc74c8
+9fbad86e20
+9fc2bad316
+9fc5c3af78
+9fcb310255
+9fcc256871
+9fd2fd4d47
+a0071ae316
+a023141022
+a046399a74
+a066e739c1
+a06722ba82
+a07a15dd64
+a07b47f694
+a09c39472e
+a0b208fe2e
+a0b61c959e
+a0bc6c611d
+a0e6da5ba2
+a1193d6490
+a14ef483ff
+a14f709908
+a15ccc5658
+a16062456f
+a174e8d989
+a177c2733c
+a17c62e764
+a18ad065fc
+a1aaf63216
+a1bb65fb91
+a1bd8e5349
+a1dfdd0cac
+a2052e4f6c
+a20fd34693
+a21ffe4d81
+a22349e647
+a235d01ec1
+a24f63e8a2
+a2554c9f6d
+a263ce8a87
+a29bfc29ec
+a2a80072d4
+a2a800ab63
+a2bcd10a33
+a2bdaff3b0
+a2c146ab0d
+a2c996e429
+a2dc51ebe8
+a2e6608bfa
+a2f2a55f01
+a301869dea
+a31fccd2cc
+a34f440f33
+a35e0206da
+a36bdc4cab
+a36e8c79d8
+a378053b20
+a37db3a2b3
+a38950ebc2
+a39a0eb433
+a39c9bca52
+a3a945dc8c
+a3b40a0c1e
+a3b8588550
+a3c502bec3
+a3f2878017
+a3f4d58010
+a3f51855c3
+a402dc0dfe
+a4065a7eda
+a412bb2fef
+a416b56b53
+a41ec95906
+a43299e362
+a4757bd7af
+a48c53c454
+a49dcf9ad5
+a4a506521f
+a4ba7753d9
+a4bac06849
+a4f05d681c
+a50c10060f
+a50eb5a0ea
+a5122c6ec6
+a522b1aa79
+a590915345
+a5b5b59139
+a5b77abe43
+a5c2b2c3e1
+a5cd17bb11
+a5da03aef1
+a5dd11de0d
+a5ea2b93b6
+a5eaeac80b
+a5ec5b0265
+a5f350a87e
+a5f472caf4
+a6027a53cf
+a61715bb1b
+a61cf4389d
+a61d9bbd9b
+a6470dbbf5
+a64a40f3eb
+a653d5c23b
+a65bd23cb5
+a66e0b7ad4
+a66fc5053c
+a68259572b
+a6a810a92c
+a6bc36937f
+a6c3a374e9
+a6d8a4228d
+a6f4e0817f
+a71e0481f5
+a7203deb2d
+a7392d4438
+a73d3c3902
+a7491f1578
+a74b9ca19c
+a77b7a91df
+a78195a5f5
+a78758d4ce
+a7e6d6c29a
+a800d85e88
+a832fa8790
+a83d06410d
+a8999af004
+a8f78125b9
+a907b18df1
+a919392446
+a965504e88
+a96b84b8d2
+a973f239cd
+a977126596
+a9804f2a08
+a984e56893
+a99738f24c
+a99bdd0079
+a9c9c1517e
+a9cbf9c41b
+a9e42e3c0c
+aa07b7c1c0
+aa175e5ec7
+aa1a338630
+aa27d7b868
+aa45f1caaf
+aa49e46432
+aa51934e1b
+aa6287bb6c
+aa6d999971
+aa85278334
+aab33f0e2a
+aaba004362
+aade4cf385
+aae78feda4
+aaed233bf3
+aaff16c2db
+ab199e8dfb
+ab23b78715
+ab2e1b5577
+ab33a18ded
+ab45078265
+ab56201494
+ab90f0d24b
+abab2e6c20
+abb50c8697
+abbe2d15a0
+abbe73cd21
+abe61a11bb
+abeae8ce21
+ac2b431d5f
+ac2cb1b9eb
+ac31fcd6d0
+ac3d3a126d
+ac46bd8087
+ac783ef388
+acb73e4297
+acbf581760
+accafc3531
+acf2c4b745
+acf44293a2
+acf736a27b
+acff336758
+ad1fe56886
+ad28f9b9d9
+ad2de9f80e
+ad397527b2
+ad3d1cfbcb
+ad3fada9d9
+ad4108ee8e
+ad54468654
+ad573f7d31
+ad6255bc29
+ad65ebaa07
+ad97cc064a
+adabbd1cc4
+adb0b5a270
+adc648f890
+add21ee467
+adfd15ceef
+adfdd52eac
+ae01cdab63
+ae0b50ff4f
+ae13ee3d70
+ae1bcbd423
+ae20d09dea
+ae2cecf5f6
+ae3bc4a0ef
+ae499c7514
+ae628f2cd4
+ae8545d581
+ae93214fe6
+ae9cd16dbf
+aeba9ac967
+aebb242b5c
+aed4e0b4c4
+aedd71f125
+aef3e2cb0e
+af0b54cee3
+af3de54c7a
+af5fd24a36
+af8826d084
+af8ad72057
+afb71e22c5
+afcb331e1f
+afe1a35c1e
+b01080b5d3
+b05ad0d345
+b0623a6232
+b064dbd4b7
+b06ed37831
+b06f5888e6
+b08dcc490e
+b0a68228dc
+b0aece727f
+b0b0731606
+b0c7f11f9f
+b0cca8b830
+b0dd580a89
+b0de66ca08
+b0df7c5c5c
+b0f5295608
+b11099eb09
+b132a53086
+b1399fac64
+b13abc0c69
+b1457e3b5e
+b15bf4453b
+b179c4a82d
+b17ee70e8c
+b190b1aa65
+b19b3e22c0
+b19c561fab
+b1d1cd2e6e
+b1d7c03927
+b1d7fe2753
+b1f540a4bd
+b1fc9c64e1
+b1fcbb3ced
+b220939e93
+b22099b419
+b241e95235
+b2432ae86d
+b2456267df
+b247940d01
+b24af1c35c
+b24f600420
+b24fe36b2a
+b258fb0b7d
+b26b219919
+b26d9904de
+b274456ce1
+b27b28d581
+b2a26bc912
+b2a9c51e1b
+b2b0baf470
+b2b2756fe7
+b2ce7699e3
+b2edc76bd2
+b2f6b52100
+b30bf47bcd
+b34105a4e9
+b372a82edf
+b3779a1962
+b379ab4ff5
+b37a1d69e3
+b37c01396e
+b382b09e25
+b3996e4ba5
+b3d9ca2aee
+b3dde1e1e9
+b3eb7f05eb
+b40b25055c
+b41e0f1f19
+b44e32a42b
+b4805ae9cd
+b4807569a5
+b48efceb3e
+b493c25c7f
+b4b565aba1
+b4b715a15b
+b4d0c90bf4
+b4d84bc371
+b4e5ad97aa
+b4eaea9e6b
+b50f4b90d5
+b53f675641
+b54278cd43
+b554843889
+b573c0677a
+b58d853734
+b5943b18ab
+b5a09a83f3
+b5aae1fe25
+b5b9da5364
+b5eb64d419
+b5ebb1d000
+b5f1c0c96a
+b5f7fece90
+b6070de1bb
+b60a76fe73
+b61f998772
+b62c943664
+b63094ba0c
+b64fca8100
+b673e7dcfb
+b678b7db00
+b68fc1b217
+b69926d9fa
+b6a1df3764
+b6a4859528
+b6b4738b78
+b6b4f847b7
+b6b8d502d4
+b6bb00e366
+b6d65a9eef
+b6d79a0845
+b6e9ec577f
+b6ec609f7b
+b6f92a308d
+b70a2c0ab1
+b70a5a0d50
+b70c052f2f
+b70d231781
+b72ac6e10b
+b7302d8226
+b73867d769
+b751e767f2
+b76df6e059
+b77e5eddef
+b7a2c2c83c
+b7bcbe6466
+b7c2a469c4
+b7d69da8f0
+b7f31b7c36
+b7f675fb98
+b7fb871660
+b82e5ad1c9
+b841cfb932
+b84b8ae665
+b85b78ac2b
+b86c17caa6
+b86e50d82d
+b871db031a
+b87d56925a
+b8aaa59b75
+b8c03d1091
+b8c3210036
+b8e16df00b
+b8f34cf72e
+b8fb75864e
+b9004db86c
+b9166cbae9
+b920b256a6
+b938d79dff
+b93963f214
+b941aef1a0
+b94d34d14e
+b964c57da4
+b96a95bc7a
+b96c57d2c7
+b9b6bdde0c
+b9bcb3e0f2
+b9d3b92169
+b9dd4b306c
+b9f43ef41e
+ba1f03c811
+ba3a775d7b
+ba3c7f2a31
+ba3fcd417d
+ba5e1f4faa
+ba795f3089
+ba8a291e6a
+ba98512f97
+bac9db04f5
+baedae3442
+baff40d29d
+bb04e28695
+bb1b0ee89f
+bb1c770fe7
+bb1fc34f99
+bb2d220506
+bb334e5cdb
+bb337f9830
+bb721eb9aa
+bb87ff58bd
+bb89a6b18a
+bbaa9a036a
+bbb4302dda
+bbd31510cf
+bbe0256a75
+bc141b9ad5
+bc17ab8a99
+bc318160de
+bc3b9ee033
+bc4240b43c
+bc4ce49105
+bc4f71372d
+bc6b8d6371
+bcaad44ad7
+bcc241b081
+bcc5d8095e
+bcd1d39afb
+bd0d849da4
+bd0e9ed437
+bd2c94730f
+bd321d2be6
+bd3ec46511
+bd5b2e2848
+bd7e02b139
+bd96f9943a
+bda224cb25
+bda4a82837
+bdb74e333f
+bdccd69dde
+bddcc15521
+be116aab29
+be15e18f1e
+be1a284edb
+be2a367a7b
+be376082d0
+be3e3cffbd
+be5d1d89a0
+be8b72fe37
+be9b29e08e
+bea1f6e62c
+bea83281b5
+beb921a4c9
+bec5e9edcd
+beeb8a3f92
+bf2232b58d
+bf28751739
+bf443804e8
+bf461df850
+bf5374f122
+bf551a6f60
+bf8d0f5ada
+bf961167a6
+bfab1ad8f9
+bfcb05d88d
+bfd8f6e6c9
+bfd91d0742
+bfe262322f
+c013f42ed7
+c01878083f
+c01faff1ed
+c046fd0edb
+c053e35f97
+c079a6482d
+c0847b521a
+c0a1e06710
+c0e8d4635c
+c0e973ad85
+c0f49c6579
+c0f5b222d7
+c10d07c90d
+c1268d998c
+c130c3fc0c
+c14826ad5e
+c15b922281
+c16f09cb63
+c18e19d922
+c1c830a735
+c1e8aeea45
+c20a5ccc99
+c20fd5e597
+c219d6f8dc
+c2406ae462
+c26f7b5824
+c279e641ee
+c27adaeac5
+c2a35c1cda
+c2a9903b8b
+c2b62567c1
+c2b974ec8c
+c2baaff7bf
+c2be6900f2
+c304dd44d5
+c307f33da2
+c30a7b62c9
+c3128733ee
+c31fa6c598
+c325c8201e
+c32d4aa5d1
+c33f28249a
+c34365e2d7
+c3457af795
+c34d120a88
+c3509e728d
+c35e4fa6c4
+c36240d96f
+c3641dfc5a
+c37b17a4a9
+c39559ddf6
+c3b0c6e180
+c3b3d82e6c
+c3be369fdb
+c3bf1e40c2
+c3c760b015
+c3dd38bf98
+c3e4274614
+c3edc48cbd
+c41e6587f5
+c4272227b0
+c42917fe82
+c438858117
+c44676563f
+c44beb7472
+c45411dacb
+c4571bedc8
+c46deb2956
+c479ee052e
+c47d551843
+c49f07d46d
+c4cc40c1fc
+c4f256f5d5
+c4f5b1ddcc
+c4ff9b4885
+c52bce43db
+c544da6854
+c55784c766
+c557b69fbf
+c593a3f7ab
+c598faa682
+c5ab1f09c8
+c5b6da8602
+c5b9128d94
+c5e845c6b7
+c5fba7b341
+c60897f093
+c61fe6ed7c
+c62188c536
+c64035b2e2
+c69689f177
+c6a12c131f
+c6bb6d2d5c
+c6c18e860f
+c6d9526e0d
+c6e55c33f0
+c7030b28bd
+c70682c7cc
+c70f9be8c5
+c71f30d7b6
+c73c8e747f
+c760eeb8b3
+c7637cab0a
+c7a1a17308
+c7bf937af5
+c7c2860db3
+c7cef4aee2
+c7ebfc5d57
+c813dcf13c
+c82235a49a
+c82a7619a1
+c82ecb90cb
+c844f03dc7
+c8557963f3
+c89147e6e8
+c8a46ff0c8
+c8ab107dd5
+c8b869a04a
+c8c7b306a6
+c8c8b28781
+c8d79e3163
+c8edab0415
+c8f494f416
+c8f6cba9fd
+c909ceea97
+c9188f4980
+c922365dd4
+c92c8c3c75
+c937eb0b83
+c94b31b5e5
+c95cd17749
+c96379c03c
+c96465ee65
+c965afa713
+c9734b451f
+c9862d82dc
+c98b6fe013
+c9999b7c48
+c99e92aaf0
+c9b3a8fbda
+c9bf64e965
+c9c3cb3797
+c9d1c60cd0
+c9de9c22c4
+ca1828fa54
+ca346f17eb
+ca3787d3d3
+ca4b99cbac
+ca91c69e3b
+ca91e99105
+caa8e97f81
+caac5807f8
+cabba242c2
+cad5a656a9
+cad673e375
+cad8a85930
+cae7b0a02b
+cae7ef3184
+caeb6b6cbb
+caecf0a5db
+cb15312003
+cb2e35d610
+cb35a87504
+cb3f22b0cf
+cbb410da64
+cc8728052e
+cc892997b8
+cce03c2a9b
+cd47a23e31
+cd4dc03dc0
+cd5ae611da
+cd603bb9d1
+cd8f49734c
+cdc6b1c032
+cdcfe008ad
+cdd57027c2
+ce1af99b4b
+ce1bc5743a
+ce25872021
+ce2776f78f
+ce49b1f474
+ce4f0a266f
+ce5641b195
+ce6866aa19
+ce712ed3c9
+ce7d1c8117
+ce7dbeaa88
+ce9b015a5e
+cea7697b25
+cebbd826cf
+cec3415361
+cec41ad4f4
+ced49d26df
+ced7705ab2
+cef824a1e1
+cf13f5c95a
+cf4376a52d
+cf85ab28b5
+cfc2e50b9d
+cfcd571fff
+cfd9d4ae47
+cfda2dcce5
+cff035928b
+cff8191891
+d01608c2a5
+d01a8f1f83
+d021d68bca
+d04258ca14
+d0483573dc
+d04a90aaff
+d05279c0bd
+d0696bd5fc
+d072fda75b
+d0a83bcd9f
+d0ab39112e
+d0acde820f
+d0b4442c71
+d0c65e9e95
+d0fb600c73
+d107a1457c
+d123d674c1
+d14d1e9289
+d154e3388e
+d177e9878a
+d1802f69f8
+d182c4483a
+d195d31128
+d200838929
+d205e3cff5
+d247420c4c
+d2484bff33
+d26f6ed9b0
+d280fcd1cb
+d2857f0faa
+d292a50c7f
+d295ea2dc7
+d2a58b4fa6
+d2b026739a
+d2ebe0890f
+d2ede5d862
+d301ca58cc
+d3069da8bb
+d343d4a77d
+d355e634ef
+d367fb5253
+d36d16358e
+d38bc77e2c
+d38d1679e2
+d3932ad4bd
+d3987b2930
+d39934abe3
+d3ae1c3f4c
+d3b088e593
+d3e6e05e16
+d3eefae7c5
+d3f55f5ab8
+d3f5c309cc
+d4034a7fdf
+d4193011f3
+d429c67630
+d42c0ff975
+d44a764409
+d44e6acd1d
+d45158c175
+d454e8444f
+d45f62717e
+d48ebdcf74
+d49ab52a25
+d4a607ad81
+d4b063c7db
+d4da13e9ba
+d4dd1a7d00
+d4f4f7c9c3
+d521aba02e
+d535bb1b97
+d53b955f78
+d55cb7a205
+d55f247a45
+d5695544d8
+d5853d9b8b
+d5b6c6d94a
+d5cae12834
+d5df027f0c
+d5ee40e5d0
+d600046f73
+d632fd3510
+d6476cad55
+d65a7bae86
+d664c89912
+d689658f06
+d6917db4be
+d69967143e
+d699d3d798
+d69f757a3f
+d6ac0e065c
+d6c02bfda5
+d6c1b5749e
+d6e12ef6cc
+d6eed152c4
+d6faaaf726
+d704766646
+d708e1350c
+d7135cf104
+d7157a9f44
+d719cf9316
+d724134cfd
+d73a60a244
+d7411662da
+d74875ea7c
+d756f5a694
+d7572b7d8a
+d763bd6d96
+d7697c8b13
+d7797196b4
+d79c834768
+d7b34e5d73
+d7bb6b37a7
+d7c7e064a6
+d7fbf545b3
+d82a0aa15b
+d847e24abd
+d8596701b7
+d86101499c
+d87069ba86
+d87160957b
+d874654b52
+d88a403092
+d8aee40f3f
+d8e77a222d
+d8eb07c381
+d9010348a1
+d90e3cf281
+d92532c7b2
+d927fae122
+d95707bca8
+d973b31c00
+d991cb471d
+d992c69d37
+d99d770820
+d9b63abc11
+d9db6f1983
+d9e52be2d2
+d9edc82650
+da01070697
+da070ea4b7
+da080507b9
+da0e944cc4
+da28d94ff4
+da5d78b9d1
+da6003fc72
+da690fee9f
+da6c68708f
+da7a816676
+dac361e828
+dac71659b8
+dad980385d
+daebc12b77
+db0968cdd3
+db231a7100
+db59282ace
+db7f267c3f
+dba35b87fd
+dbba735a50
+dbca076acd
+dbd66dc3ac
+dbdc3c292b
+dbf4a5b32b
+dbfc417d28
+dc1745e0a2
+dc32a44804
+dc34b35e30
+dc504a4f79
+dc704dd647
+dc71bc6918
+dc7771b3be
+dcf8c93617
+dd0f4c9fb9
+dd415df125
+dd601f9a3f
+dd61d903df
+dd77583736
+dd8636bd8b
+dd9fe6c6ac
+ddb2da4c14
+ddcd450d47
+dde8e67fb4
+ddfc3f04d3
+de2ab79dfa
+de2f35b2fd
+de30990a51
+de36b216da
+de37403340
+de46e4943b
+de4ddbccb1
+de5e480f05
+de6a9382ca
+de74a601d3
+de827c510d
+ded6069f7b
+defb71c741
+df01f277f1
+df05214b82
+df0638b0a0
+df11931ffe
+df1b0e4620
+df20a8650d
+df2bc56d7c
+df365282c6
+df39a0d9df
+df3c430c24
+df5536cfb9
+df59cfd91d
+df5e2152b3
+df741313c9
+df7626172f
+df8ad5deb9
+df96aa609a
+df9705605c
+df9c91c4da
+dfc0d3d27a
+dfdbf91a99
+e00baaae9b
+e0a938c6e7
+e0b2ceee6f
+e0bdb5dfae
+e0be1f6e17
+e0c478f775
+e0de82caa7
+e0f217dd59
+e0f7208874
+e0fb58395e
+e1194c2e9d
+e11adcd05d
+e128124b9d
+e1495354e4
+e1561d6d4b
+e158805399
+e16945b951
+e19edcd34b
+e1a1544285
+e1ab7957f4
+e1d26d35be
+e1e957085b
+e1f14510fa
+e214b160f4
+e2167379b8
+e21acb20ab
+e221105579
+e22ddf8a1b
+e22de45950
+e22ffc469b
+e23cca5244
+e252f46f0b
+e25fa6cf39
+e26e486026
+e275760245
+e27bbedbfe
+e29e9868a8
+e2b37ff8af
+e2b608d309
+e2bef4da9a
+e2c87a6421
+e2ea25542c
+e2fb1d6497
+e2fcc99117
+e33c18412a
+e348377191
+e352cb59c8
+e36ac982f0
+e391bc981e
+e39e3e0a06
+e3bf38265f
+e3d5b2cd21
+e3d60e82d5
+e3e3245492
+e3e4134877
+e3f4635e03
+e4004ee048
+e402d1afa5
+e415093d27
+e41ceb5d81
+e424653b78
+e42b6d3dbb
+e42d60f0d4
+e436d0ff1e
+e43d7ae2c5
+e4428801bc
+e44e0b4917
+e470345ede
+e48e8b4263
+e4922e3726
+e4936852bb
+e495f32c60
+e499228f26
+e4af66e163
+e4b2095f58
+e4d19c8283
+e4d4872dab
+e4e2983570
+e4eaa63aab
+e4ef0a3a34
+e4f8e5f46e
+e4ffb6d0dd
+e53e21aa02
+e57f4f668b
+e588433c1e
+e597442c99
+e5abc0e96b
+e5be628030
+e5ce96a55d
+e5d6b70a9f
+e5fde1574c
+e625e1d27b
+e6261d2348
+e6267d46bc
+e6295f223f
+e63463d8c6
+e6387bd1e0
+e653883384
+e65f134e0b
+e668ef5664
+e672ccd250
+e674510b20
+e676107765
+e699da0cdf
+e6be243065
+e6deab5e0b
+e6f065f2b9
+e71629e7b5
+e72a7d7b0b
+e72f6104e1
+e75a466eea
+e76c55933f
+e7784ec8ad
+e78922e5e6
+e78d450a9c
+e7c6354e77
+e7c8de1fce
+e7ea10db28
+e803918710
+e8073a140b
+e828dd02db
+e845994987
+e8485a2615
+e85c5118a7
+e88b6736e4
+e8962324e3
+e8b3018d36
+e8cee8bf0b
+e8d97ebece
+e8da49ea6a
+e8ed1a3ccf
+e8f7904326
+e8f8341dec
+e8fa21eb13
+e90c10fc4c
+e914b8cac8
+e92b6bfea4
+e92e1b7623
+e93f83e512
+e9422ad240
+e9460b55f9
+e9502628f6
+e950befd5f
+e9582bdd1b
+e95e5afe0f
+e97cfac475
+e98d57d99c
+e98eda8978
+e99706b555
+e9bc0760ba
+e9d3c78bf3
+e9ec1b7ea8
+ea065cc205
+ea138b6617
+ea16d3fd48
+ea2545d64b
+ea286a581c
+ea320da917
+ea345f3627
+ea3b94a591
+ea444a37eb
+ea4a01216b
+ea5672ffa8
+eaa99191cb
+eaab4d746c
+eac7a59bc1
+ead5d3835a
+eaec65cfa7
+eaed1a87be
+eb2f821c6f
+eb383cb82e
+eb6992fe02
+eb6ac20a01
+eb6d7ab39e
+eb7921facd
+eb8fce51a6
+ebbb90e9f9
+ebbf5c9ee1
+ebc4ec32e6
+ebe56e5ef8
+ec1299aee4
+ec139ff675
+ec193e1a01
+ec28252938
+ec387be051
+ec3d4fac00
+ec4186ce12
+ec579c2f96
+ecae59b782
+ecb33a0448
+ece6bc9e92
+ecfedd4035
+ecfff22fd6
+ed3291c3d6
+ed3cd5308d
+ed3e6fc1a5
+ed72ae8825
+ed7455da68
+ed844e879f
+ed8f814b2b
+ed911a1f63
+ed9ff4f649
+eda8ab984b
+edb8878849
+edbfdfe1b4
+edd22c46a2
+edd663afa3
+ede3552eae
+edeab61ee0
+ee07583fc0
+ee316eaed6
+ee3f509537
+ee40a1e491
+ee4bf100f1
+ee6f9b01f9
+ee947ed771
+ee9706ac7f
+ee9a7840ae
+eeb90cb569
+eebf45e5c5
+eeed0c7d73
+ef0061a309
+ef07f1a655
+ef0a8e8f35
+ef232a2aed
+ef308ad2e9
+ef44945428
+ef45ce3035
+ef5dde449d
+ef5e770988
+ef6359cea3
+ef65268834
+ef6cb5eae0
+ef78972bc2
+ef8cfcfc4f
+ef96501dd0
+ef9a2e976b
+efb24f950f
+efce0c1868
+efe5ac6901
+efe828affa
+efea4e0523
+f0268aa627
+f0483250c8
+f04cf99ee6
+f05b189097
+f08928c6d3
+f09d74856f
+f0a7607d63
+f0ad38da27
+f0c34e1213
+f0c7f86c29
+f0dfa18ba7
+f0eb3179f7
+f119bab27d
+f14409b6a3
+f1489baff4
+f14c18cf6a
+f15c607b92
+f1af214222
+f1b77bd309
+f1ba9e1a3e
+f1d99239eb
+f1dc710cf4
+f1ec5c08fa
+f22648fe12
+f22d21f1f1
+f233257395
+f23e95dbe5
+f2445b1572
+f253b3486d
+f277c7a6a4
+f2ab2b84d6
+f2b7c9b1f3
+f2b83d5ce5
+f2c276018f
+f2cfd94d64
+f2dd6e3add
+f2e7653f16
+f2f333ad06
+f2f55d6713
+f2fdb6abec
+f305a56d9f
+f3085d6570
+f3325c3338
+f3400f1204
+f34497c932
+f34a56525e
+f36483c824
+f3704d5663
+f3734c4913
+f38e5aa5b4
+f3986fba44
+f3a0ffc7d9
+f3b24a7d28
+f3e6c35ec3
+f3fc0ea80b
+f40a683fbe
+f4207ca554
+f4377499c2
+f46184f393
+f46c2d0a6d
+f46c364dca
+f46f7a0b63
+f46fe141b0
+f470b9aeb0
+f47eb7437f
+f48b535719
+f49e4866ac
+f4aa882cfd
+f4daa3dbd5
+f4dd51ac35
+f507a1b9dc
+f51c5ac84b
+f52104164b
+f54c67b9bb
+f5966cadd2
+f5bddf5598
+f5d85cfd17
+f5e2e7d6a0
+f5f051e9b4
+f5f8a93a76
+f6283e8af5
+f635e9568b
+f6474735be
+f659251be2
+f66981af4e
+f6708fa398
+f697fe8e8f
+f6adb12c42
+f6c7906ca4
+f6cd0a8016
+f6d6f15ae7
+f6e501892c
+f6f59d986f
+f6fe8c90a5
+f714160545
+f74c3888d7
+f7782c430e
+f7783ae5f2
+f77ab47923
+f788a98327
+f7961ac1f0
+f7a71e7574
+f7a8521432
+f7afbf4947
+f7b7cd5f44
+f7cf4b4a39
+f7d49799ad
+f7e0c9bb83
+f7e5b84928
+f7e6bd58be
+f7f2a38ac6
+f7f6cb2d6d
+f83f19e796
+f85796a921
+f8603c26b2
+f8819b42ec
+f891f8eaa1
+f89288d10c
+f895ae8cc1
+f8b4ac12f1
+f8c3fb2b01
+f8c8de2764
+f8db369b40
+f8fcb6a78c
+f94aafdeef
+f95d217b70
+f9681d5103
+f9750192a4
+f9823a32c2
+f991ddb4c2
+f99d535567
+f9ae3d98b7
+f9b6217959
+f9bd1fabf5
+f9c68eaa64
+f9d3e04c4f
+f9daf64494
+f9e4cc5a0a
+f9ea6b7f31
+f9f3852526
+fa04c615cf
+fa08e00a56
+fa4370d74d
+fa67744af3
+fa88d48a92
+fa8b904cc9
+fa9526bdf1
+fa9b9d2426
+fad633fbe1
+faf5222dc3
+faff0e15f1
+fb08c64e8c
+fb23455a7f
+fb2e19fa6e
+fb34dfbb77
+fb47fcea1e
+fb49738155
+fb4cbc514b
+fb4e6062f7
+fb5ba7ad6e
+fb63cd1236
+fb81157a07
+fb92abdaeb
+fba22a6848
+fbaca0c9df
+fbc645f602
+fbd77444cd
+fbe53dc8e8
+fbe541dd73
+fbe8488798
+fbfd25174f
+fc28cb305e
+fc33b1ffd6
+fc6186f0bb
+fc918e3a40
+fc96cda9d8
+fc9832eea4
+fcb10d0f81
+fcd20a2509
+fcf637e3ab
+fcfd81727f
+fd31890379
+fd33551c28
+fd542da05e
+fd6789b3fe
+fd77828200
+fd7af75f4d
+fdb28d0fbb
+fdb3d1fb1e
+fdb8b04124
+fdc6e3d581
+fdfce7e6fc
+fe0f76d41b
+fe24b0677d
+fe3c02699d
+fe58b48235
+fe6a5596b8
+fe6c244f63
+fe7afec086
+fe985d510a
+fe9db35d15
+fea8ffcd36
+feb1080388
+fed208bfca
+feda5ad1c2
+feec95b386
+ff15a5eff6
+ff204daf4b
+ff25f55852
+ff2ada194f
+ff2ce142e8
+ff49d36d20
+ff5a1ec4f3
+ff66152b25
+ff692fdc56
+ff773b1a1e
+ff97129478
+ffb904207d
+ffc43fc345
+fffe5f8df6
diff --git a/Make-A-Protagonist/experts/blip_inference.py b/Make-A-Protagonist/experts/blip_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5f0114a9e79e4ffc9c1e8e940cffee12861006a
--- /dev/null
+++ b/Make-A-Protagonist/experts/blip_inference.py
@@ -0,0 +1,75 @@
+from PIL import Image
+from transformers import Blip2Processor, Blip2ForConditionalGeneration
+
+import torch
+import os
+from glob import glob
+import argparse
+from glob import glob
+
+from BLIP2.blip_video_model import Blip2ForVideoConditionalGeneration as Blip2ForConditionalGeneration
+
+from termcolor import colored, cprint
+
+parser = argparse.ArgumentParser()
+parser.add_argument("-d", "--data_root", type=str, required=True)
+parser.add_argument("-fn" , "--frame_num", type=int, default=8)
+parser.add_argument("-fps" , "--frame_rate", type=int, default=1)
+args = parser.parse_args()
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# Salesforce/blip2-flan-t5-xxl
+# Salesforce/blip2-opt-6.7b
+blip2_version = "Salesforce/blip2-flan-t5-xl"
+# blip2_version = "Salesforce/blip2-opt-6.7b"
+
+weight_dtype = torch.bfloat16 if "flan" in blip2_version else torch.float16
+# weight_dtype = torch.float16
+
+processor = Blip2Processor.from_pretrained(blip2_version)
+model = Blip2ForConditionalGeneration.from_pretrained(
+ blip2_version, torch_dtype=weight_dtype
+)
+model.to(device)
+
+
+if not os.path.isdir(args.data_root):
+ image_list = [args.data_root]
+else:
+ # ipdb.set_trace()
+ all_image_list = sorted(glob(os.path.join(args.data_root, "*.jpg"))) + sorted(glob(os.path.join(args.data_root, "*.png")))
+ image_list = [all_image_list[f] for f in range(0, args.frame_num*args.frame_rate, args.frame_rate)]
+ assert len(image_list) == args.frame_num
+
+
+images = []
+for image_path in image_list:
+ image = Image.open(image_path).convert("RGB")
+ images.append(image)
+
+def blip2_call(prompt=None, max_new_tokens=20):
+ inputs = processor(images, text=prompt, return_tensors="pt").to(device, weight_dtype)
+ generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
+ if prompt is not None:
+ cprint(prompt, "red")
+ else:
+ cprint("No prompt", "red")
+
+ print(generated_text)
+
+
+## prompt captioning
+prompt = "this is a video of"
+
+print("Captioning")
+blip2_call(prompt, 20)
+
+
+prompt = "Question: what is the protagonist in this video? Answer: "
+
+blip2_call(prompt, 10)
+
+
diff --git a/Make-A-Protagonist/experts/controlnet_signals.py b/Make-A-Protagonist/experts/controlnet_signals.py
new file mode 100644
index 0000000000000000000000000000000000000000..90afe20063183216224ede4c087fd8fc13f8b695
--- /dev/null
+++ b/Make-A-Protagonist/experts/controlnet_signals.py
@@ -0,0 +1,86 @@
+import sys
+sys.path.insert(0, './')
+import os
+import cv2
+import torch
+from glob import glob
+
+import argparse
+from tqdm import tqdm
+from PIL import Image
+
+from controlnet_aux import MidasDetector, OpenposeDetector
+
+torch.set_grad_enabled(False)
+
+estimators = {
+ 'depth': MidasDetector,
+ 'openpose': OpenposeDetector,
+ 'openposefull': OpenposeDetector,
+}
+
+def get_base_argument_parser() -> argparse.ArgumentParser:
+ """get the base argument parser for inference scripts"""
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-d',
+ '--data',
+ type=str,
+ help='dir for images: data/dir/images',
+ default=None,
+ required=True,
+ )
+
+ parser.add_argument(
+ '-c',
+ '--which_cond',
+ type=str,
+ required=True,
+ help='which condition modality you want to test',
+ )
+
+ return parser
+
+
+
+def main():
+ parser = get_base_argument_parser()
+ opt = parser.parse_args()
+
+ which_cond = opt.which_cond
+
+
+ outdir = opt.data.replace("images", which_cond) ## path of save
+ os.makedirs(outdir, exist_ok=True)
+
+ opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ image_paths = sorted(glob(os.path.join(opt.data, "*.jpg"))) + sorted(glob(os.path.join(opt.data, "*.png")))
+ print("Processing video : {}, video length {}".format(opt.data, len(image_paths)))
+
+
+ # prepare models
+ cond_model = None
+ cond_model = estimators[which_cond].from_pretrained("lllyasviel/Annotators")#.
+
+ # inference
+ for test_idx, cond_path in enumerate(tqdm(image_paths)):
+ image = Image.open(cond_path).convert('RGB')
+ fname = os.path.basename(cond_path).split('.')[0] # *.jpg
+ width, height = image.size
+ if which_cond == 'depth':
+ new_w = width // 64 * 64
+ new_h = height // 64 * 64
+ image = image.resize((new_w, new_h))
+
+ if which_cond == 'openposefull':
+ cond = cond_model(image, hand_and_face=True)
+ else:
+ cond = cond_model(image)
+
+ cond.resize((width, height))
+ cond.save(os.path.join(outdir, f'{fname}.png'))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Make-A-Protagonist/experts/grounded_sam_inference.py b/Make-A-Protagonist/experts/grounded_sam_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dfb31126a3cc39f9a0c80d7c3d161ff2cefa9ac
--- /dev/null
+++ b/Make-A-Protagonist/experts/grounded_sam_inference.py
@@ -0,0 +1,256 @@
+import sys
+import os
+
+project_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(project_dir)
+import argparse
+import os
+import copy
+
+import numpy as np
+import json
+import torch
+from PIL import Image, ImageDraw, ImageFont
+
+# Grounding DINO
+import GroundedSAM.GroundingDINO.groundingdino.datasets.transforms as T
+from GroundedSAM.GroundingDINO.groundingdino.models import build_model
+from GroundedSAM.GroundingDINO.groundingdino.util import box_ops
+from GroundedSAM.GroundingDINO.groundingdino.util.slconfig import SLConfig
+from GroundedSAM.GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
+
+# segment anything
+from GroundedSAM.segment_anything.segment_anything import build_sam, SamPredictor
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+from glob import glob
+import ipdb
+import imageio
+from tqdm import tqdm
+
+
+'''
+processing multiple images with grounded sam
+only one text one time
+'''
+
+def load_image(image_path):
+ # load image
+ image_pil = Image.open(image_path).convert("RGB") # load image
+
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image, _ = transform(image_pil, None) # 3, h, w
+ return image_pil, image
+
+
+def load_model(model_config_path, model_checkpoint_path, device):
+ args = SLConfig.fromfile(model_config_path)
+ args.device = device
+ model = build_model(args)
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ print(load_res)
+ _ = model.eval()
+ return model
+
+
+def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
+ caption = caption.lower()
+ caption = caption.strip()
+ if not caption.endswith("."):
+ caption = caption + "."
+ model = model.to(device)
+ image = image.to(device)
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
+ logits.shape[0]
+
+ # filter output
+ logits_filt = logits.clone()
+ boxes_filt = boxes.clone()
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
+ logits_filt.shape[0]
+
+ # get phrase
+ tokenlizer = model.tokenizer
+ tokenized = tokenlizer(caption)
+ # build pred
+ pred_phrases = []
+ for logit, box in zip(logits_filt, boxes_filt):
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
+ if with_logits:
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
+ else:
+ pred_phrases.append(pred_phrase)
+
+ return boxes_filt, pred_phrases, logits_filt
+
+def show_mask(mask, ax, random_color=False):
+ if random_color:
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
+ else:
+ color = np.array([30/255, 144/255, 255/255, 0.6])
+ h, w = mask.shape[-2:]
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+ ax.imshow(mask_image)
+
+
+def show_box(box, ax, label):
+ x0, y0 = box[0], box[1]
+ w, h = box[2] - box[0], box[3] - box[1]
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
+ ax.text(x0, y0, label)
+
+
+def save_mask_data(output_dir, mask_list, box_list, label_list):
+ value = 0 # 0 for background
+
+ mask_img = torch.zeros(mask_list.shape[-2:])
+ for idx, mask in enumerate(mask_list):
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
+ plt.figure(figsize=(10, 10))
+ plt.imshow(mask_img.numpy())
+ plt.axis('off')
+ plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
+
+ json_data = [{
+ 'value': value,
+ 'label': 'background'
+ }]
+ for label, box in zip(label_list, box_list):
+ value += 1
+ name, logit = label.split('(')
+ logit = logit[:-1] # the last is ')'
+ json_data.append({
+ 'value': value,
+ 'label': name,
+ 'logit': float(logit),
+ 'box': box.numpy().tolist(),
+ })
+ with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
+ json.dump(json_data, f)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
+ parser.add_argument("-d", "--data", type=str, required=True, help="path to image file")
+ parser.add_argument("-t", "--text_prompt", type=str, required=True, help="text prompt")
+ parser.add_argument(
+ "--output_dir", "-o", type=str, default="outputs", required=False, help="output directory"
+ )
+
+ parser.add_argument("--config", type=str,
+ default="experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py",
+ help="path to config file")
+ parser.add_argument(
+ "--grounded_checkpoint", type=str, default="checkpoints/groundingdino_swinb_cogcoor.pth", help="path to checkpoint file"
+ )
+ parser.add_argument(
+ "--sam_checkpoint", type=str, default="checkpoints/sam_vit_h_4b8939.pth", help="path to checkpoint file"
+ )
+
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
+
+ parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
+
+ parser.add_argument("--masked_out", action='store_true', help="save the masked image")
+ args = parser.parse_args()
+
+ # cfg
+ config_file = args.config # change the path of the model config file
+ grounded_checkpoint = args.grounded_checkpoint # change the path of the model
+ sam_checkpoint = args.sam_checkpoint
+ # image_path = args.data
+ text_prompt = args.text_prompt
+ output_dir = os.path.dirname(os.path.dirname(args.data))
+ box_threshold = args.box_threshold
+ text_threshold = args.text_threshold
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # make dir
+ text_prompt_dir = "-".join(text_prompt.split(" "))
+
+ # text_prompt_dir
+ os.makedirs(output_dir, exist_ok=True)
+ # os.makedirs(os.path.join(output_dir, "raw"), exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "{}.viz".format(text_prompt_dir)), exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "{}.mask".format(text_prompt_dir)), exist_ok=True)
+
+ # load model
+ model = load_model(config_file, grounded_checkpoint, device=device)
+ # initialize SAM
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
+
+ if os.path.isdir(args.data):
+ images = sorted(glob(os.path.join(args.data, "*.jpg"))) + sorted(glob(os.path.join(args.data, "*.png")))
+ else:
+ images = [args.data]
+
+ for image_path in tqdm(images):
+ fname = os.path.basename(image_path).split('.')[0]
+ # load image
+ image_pil, image = load_image(image_path)
+
+ # run grounding dino model
+ boxes_filt, pred_phrases, logits_filt = get_grounding_output(
+ model, image, text_prompt, box_threshold, text_threshold, device=device
+ )
+
+ image = cv2.imread(image_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ predictor.set_image(image)
+
+ size = image_pil.size
+ H, W = size[1], size[0]
+ for i in range(boxes_filt.size(0)):
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
+ boxes_filt[i][2:] += boxes_filt[i][:2]
+
+ boxes_filt = boxes_filt.cpu()
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
+
+ masks, _, _ = predictor.predict_torch(
+ point_coords = None,
+ point_labels = None,
+ boxes = transformed_boxes.to(device),
+ multimask_output = False,
+ )
+
+ # draw output image
+ plt.figure(figsize=(10, 10))
+ plt.imshow(image)
+ for mask in masks:
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
+ for box, label in zip(boxes_filt, pred_phrases):
+ show_box(box.numpy(), plt.gca(), label)
+
+ plt.axis('off')
+ plt.savefig(
+ os.path.join(output_dir, "{}.viz".format(text_prompt_dir), fname + ".jpg"),
+ bbox_inches="tight", dpi=300, pad_inches=0.0
+ )
+
+ # ipdb.set_trace()
+ max_logit_index = logits_filt.max(-1)[0].argmax().item()
+ _mask = masks[max_logit_index,0].cpu().numpy().astype(np.uint8) * 255
+ imageio.imwrite(os.path.join(output_dir, "{}.mask".format(text_prompt_dir), fname + ".png"), _mask)
+
+ if args.masked_out:
+ masked_image = np.asarray(image_pil).astype(np.float32) * _mask[:,:,None].astype(np.float32) / 255
+ imageio.imwrite(os.path.join(output_dir, "masked_" + fname + ".png"), masked_image.astype(np.uint8))
+ # save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
+
diff --git a/Make-A-Protagonist/experts/grounded_sam_mask_out.py b/Make-A-Protagonist/experts/grounded_sam_mask_out.py
new file mode 100644
index 0000000000000000000000000000000000000000..22e67fa89478cdec6640782cbbe9e0d39e572080
--- /dev/null
+++ b/Make-A-Protagonist/experts/grounded_sam_mask_out.py
@@ -0,0 +1,208 @@
+import sys
+import os
+
+project_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(project_dir)
+import argparse
+import os
+import copy
+
+import numpy as np
+import json
+import torch
+from PIL import Image, ImageDraw, ImageFont
+
+# Grounding DINO
+import GroundedSAM.GroundingDINO.groundingdino.datasets.transforms as T
+from GroundedSAM.GroundingDINO.groundingdino.models import build_model
+from GroundedSAM.GroundingDINO.groundingdino.util import box_ops
+from GroundedSAM.GroundingDINO.groundingdino.util.slconfig import SLConfig
+from GroundedSAM.GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
+
+# segment anything
+from GroundedSAM.segment_anything.segment_anything import build_sam, SamPredictor
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+from glob import glob
+import ipdb
+import imageio
+from tqdm import tqdm
+
+
+'''
+processing multiple images with grounded sam
+only one text one time
+'''
+
+def load_image(image_path):
+ # load image
+ image_pil = Image.open(image_path).convert("RGB") # load image
+
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image, _ = transform(image_pil, None) # 3, h, w
+ return image_pil, image
+
+def load_image_pil(image_pil):
+
+ transform = T.Compose(
+ [
+ T.RandomResize([800], max_size=1333),
+ T.ToTensor(),
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+ )
+ image, _ = transform(image_pil, None) # 3, h, w
+ return image_pil, image
+
+
+def load_model(model_config_path, model_checkpoint_path, device):
+ args = SLConfig.fromfile(model_config_path)
+ args.device = device
+ model = build_model(args)
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
+ print(load_res)
+ _ = model.eval()
+ return model
+
+
+def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
+ caption = caption.lower()
+ caption = caption.strip()
+ if not caption.endswith("."):
+ caption = caption + "."
+ model = model.to(device)
+ image = image.to(device)
+ with torch.no_grad():
+ outputs = model(image[None], captions=[caption])
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
+ logits.shape[0]
+
+ # filter output
+ logits_filt = logits.clone()
+ boxes_filt = boxes.clone()
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
+ logits_filt.shape[0]
+
+ # get phrase
+ tokenlizer = model.tokenizer
+ tokenized = tokenlizer(caption)
+ # build pred
+ pred_phrases = []
+ for logit, box in zip(logits_filt, boxes_filt):
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
+ if with_logits:
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
+ else:
+ pred_phrases.append(pred_phrase)
+
+ return boxes_filt, pred_phrases, logits_filt
+
+def show_mask(mask, ax, random_color=False):
+ if random_color:
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
+ else:
+ color = np.array([30/255, 144/255, 255/255, 0.6])
+ h, w = mask.shape[-2:]
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
+ ax.imshow(mask_image)
+
+
+def show_box(box, ax, label):
+ x0, y0 = box[0], box[1]
+ w, h = box[2] - box[0], box[3] - box[1]
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
+ ax.text(x0, y0, label)
+
+
+def save_mask_data(output_dir, mask_list, box_list, label_list):
+ value = 0 # 0 for background
+
+ mask_img = torch.zeros(mask_list.shape[-2:])
+ for idx, mask in enumerate(mask_list):
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
+ plt.figure(figsize=(10, 10))
+ plt.imshow(mask_img.numpy())
+ plt.axis('off')
+ plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
+
+ json_data = [{
+ 'value': value,
+ 'label': 'background'
+ }]
+ for label, box in zip(label_list, box_list):
+ value += 1
+ name, logit = label.split('(')
+ logit = logit[:-1] # the last is ')'
+ json_data.append({
+ 'value': value,
+ 'label': name,
+ 'logit': float(logit),
+ 'box': box.numpy().tolist(),
+ })
+ with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
+ json.dump(json_data, f)
+
+
+def mask_out_reference_image(image, text_prompt):
+
+ # cfg
+ config_file = "Make-A-Protagonist/experts/GroundedSAM/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py"
+ grounded_checkpoint = 'checkpoints/groundingdino_swinb_cogcoor.pth'
+ sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
+
+ box_threshold = 0.3
+ text_threshold = 0.25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # load model
+ model = load_model(config_file, grounded_checkpoint, device=device)
+ # initialize SAM
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
+
+ image_pil, image = load_image_pil(image)
+
+ # run grounding dino model
+ boxes_filt, pred_phrases, logits_filt = get_grounding_output(
+ model, image, text_prompt, box_threshold, text_threshold, device=device
+ )
+ # ipdb.set_trace()
+ image = np.asarray(image_pil).astype(np.uint8)
+ predictor.set_image(image)
+
+ size = image_pil.size
+ H, W = size[1], size[0]
+ for i in range(boxes_filt.size(0)):
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
+ boxes_filt[i][2:] += boxes_filt[i][:2]
+
+ boxes_filt = boxes_filt.cpu()
+ transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
+
+ masks, _, _ = predictor.predict_torch(
+ point_coords = None,
+ point_labels = None,
+ boxes = transformed_boxes.to(device),
+ multimask_output = False,
+ )
+
+
+ # ipdb.set_trace()
+ max_logit_index = logits_filt.max(-1)[0].argmax().item()
+ _mask = masks[max_logit_index,0].cpu().numpy().astype(np.uint8) * 255
+ masked_image = np.asarray(image_pil).astype(np.float32) * _mask[:,:,None].astype(np.float32) / 255
+
+ return Image.fromarray(masked_image.astype(np.uint8))
+
+
diff --git a/Make-A-Protagonist/experts/xmem_inference.py b/Make-A-Protagonist/experts/xmem_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3313aefc265e008125030ab2b17ceacbdcbcaac2
--- /dev/null
+++ b/Make-A-Protagonist/experts/xmem_inference.py
@@ -0,0 +1,196 @@
+import sys
+import os
+
+project_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(project_dir)
+
+from argparse import ArgumentParser
+import shutil
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import numpy as np
+from PIL import Image
+
+from XMem.inference.data.test_datasets import CustomDataset
+from XMem.inference.data.mask_mapper import MaskMapper
+from XMem.model.network import XMem
+from XMem.inference.inference_core import InferenceCore
+
+from progressbar import progressbar
+
+
+
+"""
+Arguments loading
+"""
+parser = ArgumentParser()
+parser.add_argument('--model', default='checkpoints/XMem.pth')
+
+# Data options
+parser.add_argument('-d', '--data', default='data', required=True, help='the dir name to the images data/dir/images')
+parser.add_argument('-v', '--video', required=False, help='video name')
+parser.add_argument('--mask_dir', required=True, help='the dir name to the mask e.g., man.mask')
+
+parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='G')
+parser.add_argument('--split', help='val/test', default='val')
+parser.add_argument('--output', default=None)
+parser.add_argument('--save_all', action='store_true',
+ help='Save all frames. Useful only in YouTubeVOS/long-time video', )
+
+parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
+
+# Long-term memory options
+parser.add_argument('--disable_long_term', action='store_true')
+parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
+parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
+parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
+ type=int, default=10000)
+parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
+
+parser.add_argument('--top_k', type=int, default=30)
+parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
+parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
+
+# Multi-scale options
+parser.add_argument('--save_scores', action='store_true')
+parser.add_argument('--flip', action='store_true')
+parser.add_argument('--size', default=480, type=int,
+ help='Resize the shorter side to this size. -1 to use original resolution. ')
+
+args = parser.parse_args()
+config = vars(args)
+config['enable_long_term'] = not config['disable_long_term']
+
+if args.output is None:
+ args.output = args.data.replace('images', args.mask_dir)
+ print(f'Output path not provided. By default saving to the mask dir')
+
+os.makedirs(args.output, exist_ok=True)
+
+"""
+Data preparation
+"""
+
+out_path = args.output
+
+if args.dataset == 'G':
+ meta_dataset = CustomDataset(args.data, mask_dir=args.mask_dir, size=args.size)
+ if not args.save_all:
+ args.save_all = True
+ print('save_all is forced to be true in generic evaluation mode.')
+else:
+ raise NotImplementedError
+
+torch.autograd.set_grad_enabled(False)
+
+# Set up loader
+meta_loader = meta_dataset.get_datasets()
+
+# Load our checkpoint
+network = XMem(config, args.model).cuda().eval()
+if args.model is not None:
+ model_weights = torch.load(args.model)
+ network.load_weights(model_weights, init_as_zero_if_needed=True)
+else:
+ print('No model loaded.')
+
+total_process_time = 0
+total_frames = 0
+
+# Start eval
+for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
+
+ loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
+ vid_name = vid_reader.vid_name
+ vid_length = len(loader)
+ # no need to count usage for LT if the video is not that long anyway
+ config['enable_long_term_count_usage'] = (
+ config['enable_long_term'] and
+ (vid_length
+ / (config['max_mid_term_frames']-config['min_mid_term_frames'])
+ * config['num_prototypes'])
+ >= config['max_long_term_elements']
+ )
+
+ mapper = MaskMapper()
+ processor = InferenceCore(network, config=config)
+ first_mask_loaded = False
+
+ for ti, data in enumerate(loader):
+ with torch.cuda.amp.autocast(enabled=not args.benchmark):
+ rgb = data['rgb'].cuda()[0]
+ msk = data.get('mask')
+ info = data['info']
+ frame = info['frame'][0]
+ shape = info['shape']
+ need_resize = info['need_resize'][0]
+
+ """
+ For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
+ Seems to be very similar in testing as my previous timing method
+ with two cuda sync + time.time() in STCN though
+ """
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+
+ if not first_mask_loaded:
+ if msk is not None:
+ first_mask_loaded = True
+ else:
+ # no point to do anything without a mask
+ continue
+
+ if args.flip:
+ rgb = torch.flip(rgb, dims=[-1])
+ msk = torch.flip(msk, dims=[-1]) if msk is not None else None
+
+ # Map possibly non-continuous labels to continuous ones
+ if msk is not None:
+ msk, labels = mapper.convert_mask(msk[0].numpy())
+ msk = torch.Tensor(msk).cuda()
+ if need_resize:
+ msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
+ processor.set_all_labels(list(mapper.remappings.values()))
+ else:
+ labels = None
+
+ # Run the model on this frame
+ prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
+
+ # Upsample to original size if needed
+ if need_resize:
+ prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
+
+ end.record()
+ torch.cuda.synchronize()
+ total_process_time += (start.elapsed_time(end)/1000)
+ total_frames += 1
+
+ if args.flip:
+ prob = torch.flip(prob, dims=[-1])
+
+ # Probability mask -> index mask
+ out_mask = torch.argmax(prob, dim=0)
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
+
+ # Save the mask
+ if args.save_all or info['save'][0]:
+ this_out_path = out_path
+ os.makedirs(this_out_path, exist_ok=True)
+ out_mask = mapper.remap_index_mask(out_mask)
+ out_img = Image.fromarray(out_mask)
+ if vid_reader.get_palette() is not None:
+ out_img.putpalette(vid_reader.get_palette())
+ out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
+
+
+
+
+print(f'Total processing time: {total_process_time}')
+print(f'Total processed frames: {total_frames}')
+print(f'FPS: {total_frames / total_process_time}')
+print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
+
diff --git a/Make-A-Protagonist/makeaprotagonist/args_util.py b/Make-A-Protagonist/makeaprotagonist/args_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5458e109309cba56a9f29b6f751af3f20627718f
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/args_util.py
@@ -0,0 +1,113 @@
+
+from argparse import Action, ArgumentParser
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ if val == 'None':
+ return None
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count('(') == string.count(')')) and (
+ string.count('[') == string.count(']')), \
+ f'Imbalanced brackets exist in {string}'
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
+ and (pre.count('[') == pre.count(']'))):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip('\'\"').replace(' ', '')
+ is_tuple = False
+ if val.startswith('(') and val.endswith(')'):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith('[') and val.endswith(']'):
+ val = val[1:-1]
+ elif ',' not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1:]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
+
+
+def key_missing_assert(key, config):
+ if not key in config:
+ raise "Error ===> {} is not in dict".format(key)
+
+def config_merge_dict(opts, config):
+ for k, v in opts.items():
+ d = config
+ key_list = k.split('.')
+ for subkey in key_list[:-1]:
+ key_missing_assert(subkey, d)
+ d = d[subkey]
+ subkey = key_list[-1]
+ key_missing_assert(subkey, d)
+ d[subkey] = v
diff --git a/Make-A-Protagonist/makeaprotagonist/dataset/dataset.py b/Make-A-Protagonist/makeaprotagonist/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb453c9018e9d31e1771b937a615ee0ba684a62a
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/dataset/dataset.py
@@ -0,0 +1,160 @@
+
+import torch
+from torch.utils.data import Dataset
+import torch.nn.functional as F
+
+from einops import rearrange
+import os
+import os.path as osp
+from glob import glob
+import imageio
+import cv2
+import numpy as np
+import random
+import ipdb
+
+class MakeAProtagonistDataset(Dataset):
+ def __init__(
+ self,
+ video_dir: str,
+ prompt: str,
+ condition: list[str] = 'openpose', ## type of condition used
+ video_suffix: str = '.jpg',
+ condition_suffix: str = '.png',
+ width: int = 512,
+ height: int = 512,
+ n_sample_frames: int = 8,
+ sample_start_idx: int = 0,
+ sample_frame_rate: int = 1,
+ random_sample: bool = False,
+ mask_dir: str = None,
+ **kwargs,
+ ):
+ self.video_dir = video_dir ## path to the video dir
+ self.video_path = osp.join(self.video_dir, 'images')
+
+ self.condition = condition
+ if isinstance(condition, str):
+ condition = [condition]
+ self.condition_path = {_condition: osp.join(self.video_dir, _condition) for _condition in condition}
+ self.video_suffix = video_suffix
+ self.condition_suffix = condition_suffix
+ self.random_sample = random_sample
+ self.mask_dir = mask_dir
+ if mask_dir:
+ self.mask_dir = osp.join(self.video_dir, mask_dir)
+
+ ## get frame path
+ frame_list_path = osp.join(self.video_dir, 'frame_list.txt')
+ if not osp.isfile(frame_list_path):
+ all_frames = sorted(glob(osp.join(self.video_path, '*')))
+ self.frame_list = []
+ with open(frame_list_path, 'w') as f:
+ for _frame_path in all_frames:
+ _frame_name = osp.basename(_frame_path).split('.')[0]
+ self.frame_list.append(_frame_name)
+ f.write(_frame_name + '\n')
+
+ else:
+ with open(frame_list_path, 'r') as f:
+ self.frame_list = f.read().splitlines()
+
+ self.video_length = len(self.frame_list)
+
+ self.prompt = prompt
+ self.prompt_ids = None
+
+ self.width = width
+ self.height = height
+ self.n_sample_frames = n_sample_frames
+ self.sample_start_idx = sample_start_idx
+ self.sample_frame_rate = sample_frame_rate
+ self.img_embeddings = []
+
+ print('Training on Video {} \t totally {} frames'.format(self.video_dir.split('/')[-1], self.video_length))
+
+ @torch.no_grad()
+ def preprocess_img_embedding(self, feature_extractor, image_encoder):
+ for f_name in self.frame_list:
+ image = imageio.imread(osp.join(self.video_path, f_name + self.video_suffix))
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values
+ image_embeds = image_encoder(image).image_embeds
+ self.img_embeddings.append(image_embeds[0]) # 1,768 --> 768
+
+
+ def __len__(self):
+ return 1
+
+ def __getitem__(self, index):
+ # load and sample video frames
+ video_indices = list(range(self.sample_start_idx, self.video_length, self.sample_frame_rate))
+ video = []
+ conditions = {_condition: [] for _condition in self.condition}
+
+ mask = []
+ if self.random_sample:
+ start_index = random.randint(0,len(video_indices) - self.n_sample_frames) ## [a,b] include both
+ else:
+ start_index = 0
+ sample_index = video_indices[start_index:start_index+self.n_sample_frames]
+ # ipdb.set_trace()
+ for _f_idx in sample_index:
+ _frame = imageio.imread(osp.join(self.video_path, self.frame_list[_f_idx] + self.video_suffix))
+ if self.mask_dir:
+ _mask = imageio.imread(osp.join(self.mask_dir, self.frame_list[_f_idx] + '.png')).astype(np.float32) ## H,W 0 and 255
+ _mask /= 255 # 0 and 1
+ else:
+ _mask = np.ones(_frame.shape[:2])
+ video.append(_frame)
+ mask.append(_mask)
+
+ for _control_type, _control_path in self.condition_path.items():
+ _condition = imageio.imread(osp.join(_control_path, self.frame_list[_f_idx] + self.condition_suffix)) ##
+ conditions[_control_type].append(_condition)
+
+ ref_idx = random.choice(sample_index) # idx random sample one ref image index from the select video clip
+
+ video = torch.from_numpy(np.stack(video, axis=0)).float() # f,h,w,c
+
+ video = rearrange(video, "f h w c -> f c h w")
+ video = F.interpolate(video, size=(self.height, self.width), mode='bilinear')
+
+ # ipdb.set_trace()
+ conditions_transform = {}
+ for _control_type, condition in conditions.items():
+ condition = torch.from_numpy(np.stack(condition, axis=0)).float() # f,h,w,c
+ condition = rearrange(condition, "f h w c -> f c h w")
+ condition = F.interpolate(condition, size=(self.height, self.width), mode='bilinear')
+ conditions_transform[_control_type] = condition / 255
+
+ mask = torch.from_numpy(np.stack(mask, axis=0)).float() # f,h,w
+ mask = rearrange(mask[:,:,:,None], "f h w c -> f c h w")
+ mask = F.interpolate(mask, size=(self.height, self.width), mode='nearest')
+
+ ref_img = imageio.imread(osp.join(self.video_path, self.frame_list[ref_idx] + self.video_suffix)) # read ref image
+ ref_img = torch.from_numpy(ref_img).float() # h,w,c convert to tensor
+ ref_img = ref_img.permute(2,0,1).unsqueeze(0).repeat(self.n_sample_frames,1,1,1) ## h,w,c -> c,h,w -> 1,c,h,w -> f,c,h,w
+ ref_img = F.interpolate(ref_img, size=(self.height, self.width), mode='bilinear')
+
+ ref_condition = torch.zeros_like(ref_img)
+ # ipdb.set_trace()
+ example = {
+ "pixel_values": (video / 127.5 - 1.0),
+ "conditions": conditions_transform,
+ # "prompt_ids": self.prompt_ids,
+ "ref_img": (ref_img / 127.5 - 1.0),
+ "ref_condition": ref_condition / 255,
+ "masks": mask,
+ "sample_indices": torch.LongTensor(sample_index),
+
+ }
+
+ ref_imbed = None
+ if len(self.img_embeddings):
+ ref_imbed = self.img_embeddings[ref_idx]
+ example["ref_imbed"] = ref_imbed
+
+
+ return example
+
+
diff --git a/Make-A-Protagonist/makeaprotagonist/models/attention.py b/Make-A-Protagonist/makeaprotagonist/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..47e9250d85982c7839907205353e21e37c94bc7a
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/models/attention.py
@@ -0,0 +1,560 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+
+from dataclasses import dataclass
+from typing import Optional, Callable
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+
+
+
+from einops import rearrange, repeat
+import ipdb
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+
+ # SC-Attn
+ self.attn1 = SparseCausalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # Temp-Attn
+ self.attn_temp = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+
+ causal_mask = None
+ ### causal attn mask
+ ## NOTE the following two lines are causal mask for temporal attn
+ causal_mask = torch.tril( torch.ones((hidden_states.size(1),hidden_states.size(1)), dtype=hidden_states.dtype, device=hidden_states.device) ) # f,f
+ causal_mask = (1.0 - causal_mask[None]) * -10000.0 # 1,f,f
+
+ hidden_states = self.attn_temp(norm_hidden_states, attention_mask=causal_mask) + hidden_states
+
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer. in diffuser 0.11 version, for Tune-A-Video original implementation
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+class SparseCausalAttention(CrossAttention):
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query) ## the name of this func is changed to `heads_to_batch_dim` in diffuser 0.15
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c")
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c")
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
diff --git a/Make-A-Protagonist/makeaprotagonist/models/resnet.py b/Make-A-Protagonist/makeaprotagonist/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9fbfc98fd73072107c131ad553731ee064031df
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/models/resnet.py
@@ -0,0 +1,270 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+import ipdb
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+class TemporalConv(nn.Conv1d):
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None) -> None:
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
+
+ # nn.init.dirac_(self.weight.data) # initialized to be identity
+ nn.init.zeros_(self.weight.data) # initialized to zeros
+ nn.init.zeros_(self.bias.data)
+
+ def forward(self, x):
+ # ipdb.set_trace()
+ _, c_dim, f_dim, h_dim, w_dim = x.size()
+
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
+ x = super().forward(x)
+ x = rearrange(x, "(b h w) c f -> b c f h w", h=h_dim, w=w_dim)
+
+ return x
+
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ temporal_conv=False,
+
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.temp_conv1, self.temp_conv2 = None, None
+
+ if temporal_conv:
+ self.temp_conv1 = TemporalConv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.temp_conv2 = TemporalConv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, input_tensor, temb, temb_aux, masks):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ # ipdb.set_trace()
+ if self.temp_conv1 is not None:
+ hidden_states = hidden_states + self.temp_conv1(hidden_states)
+
+ if temb is not None:
+ ## temb F,C
+
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] # B,C,1,1,1
+ video_length = hidden_states.size(2)
+ if temb.size(0) == hidden_states.size(0)*video_length:
+ temb = temb.reshape(hidden_states.size(0), video_length, temb.size(1), 1, 1) # b,f,c,1,1
+ temb = temb.permute(0,2,1,3,4) # b,c,f,1,1
+
+ ## NOTE the shape of temb, hidden state and masks here # it is b,c,f,h,w
+ if masks is not None:
+ # ipdb.set_trace()
+ masks = torch.nn.functional.interpolate(masks, size=hidden_states.size()[-3:], mode="nearest")
+
+
+ if temb_aux is not None: ## this keeps the same no matter how many masks
+ ## NOTE masks should also be downsampled
+
+ temb_aux = self.time_emb_proj(self.nonlinearity(temb_aux))[:, :, None, None, None] # B,C,1,1,1
+
+ if temb_aux.size(0) == hidden_states.size(0)*video_length:
+ temb_aux = temb_aux.reshape(hidden_states.size(0), video_length, temb_aux.size(1), 1, 1) # b,f,c,1,1
+ temb_aux = temb_aux.permute(0,2,1,3,4)
+
+
+ temb = temb * masks + (1-masks) * temb_aux
+
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.temp_conv2 is not None:
+
+ hidden_states = hidden_states + self.temp_conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/Make-A-Protagonist/makeaprotagonist/models/unet.py b/Make-A-Protagonist/makeaprotagonist/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..95fb6bef9019953e34891b3f46f9968c92d19748
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/models/unet.py
@@ -0,0 +1,660 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union, Callable
+
+import os
+import json
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps, GaussianFourierProjection
+from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+from .resnet import InflatedConv3d
+
+import ipdb
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+
+ time_embedding_type: str = "positional",
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+
+ temporal_conv=False,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.in_channels = in_channels
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ temporal_conv=temporal_conv
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+ self.conv_act = nn.SiLU()
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = InflatedConv3d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ try:
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+ except:
+ ipdb.set_trace()
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import UNet2DConditionModel
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+ >>> model = UNet2DConditionModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
+ ... )
+ >>> model = model.to("cuda")
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ adapter_features: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ ## mask args
+ class_labels_aux: Optional[torch.Tensor] = None,
+ masks: Optional[torch.Tensor] = None,
+
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+ if adapter_features is None:
+ adapter_features = [None] * len(self.down_blocks)
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ t_emb = self.time_embedding(t_emb)
+
+ # ipdb.set_trace()
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ ## NOTE use torch.tensor(0) indicate not use image embedding
+ if class_labels.dim() == 0:
+ class_emb = 0
+ else:
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) # F,C
+ # class_emb = 0 ## NOTE this is for debugging, trying not use image embedding
+
+ # ipdb.set_trace()
+ if self.config.class_embeddings_concat: ## false
+ emb = torch.cat([t_emb, class_emb], dim=-1)
+ else:
+ if t_emb.size(0) == class_emb.size(0):
+ emb = t_emb + class_emb
+ elif t_emb.size(0) * sample.size(2) == class_emb.size(0):
+ # e_emb 2,C / class emb: 2F,C
+ class_emb = class_emb.reshape(t_emb.size(0), sample.size(2), class_emb.size(-1))
+ emb = t_emb[:, None] + class_emb
+ emb = emb.reshape(-1, emb.size(-1))
+ else:
+ emb = t_emb.repeat(2,1) + class_emb
+
+ ## NOTE aux embedding
+ emb_aux = None
+ if self.class_embedding is not None and class_labels_aux is not None:
+ # ipdb.set_trace()
+ if class_labels_aux is not None and masks is None:
+ raise ValueError("masks should be provided when class_labels_aux is given")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels_aux = self.time_proj(class_labels_aux)
+
+ class_emb_aux = self.class_embedding(class_labels_aux).to(dtype=self.dtype)
+
+ if self.config.class_embeddings_concat: ## false
+ emb_aux = torch.cat([t_emb, class_emb_aux], dim=-1)
+ else:
+ # emb_aux = t_emb + class_emb_aux
+ if t_emb.size(0) == class_emb_aux.size(0):
+ emb_aux = t_emb + class_emb_aux
+ else:
+ # e_emb 2,C / class emb: 2F,C
+ class_emb_aux = class_emb_aux.reshape(t_emb.size(0), sample.size(2), class_emb_aux.size(-1))
+ emb_aux = t_emb[:, None] + class_emb_aux
+ emb_aux = emb_aux.reshape(-1, emb_aux.size(-1))
+
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for down_id, downsample_block in enumerate(self.down_blocks):
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+
+ # ipdb.set_trace()
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ temb_aux=emb_aux,
+ masks=masks,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ adapter_feature=adapter_features[down_id]
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, temb_aux=emb_aux, masks=masks,adapter_feature=adapter_features[down_id])
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None: ## seems to fit controlnet
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, temb_aux=emb_aux, masks=masks, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ temb_aux=emb_aux,
+ masks=masks,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, temb_aux=emb_aux, masks=masks, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_temporal_conv=False):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ config["temporal_conv"] = use_temporal_conv
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+ for k, v in model.state_dict().items():
+ if '_temp.' in k:
+ state_dict.update({k: v})
+ if 'temp_conv' in k:
+ state_dict.update({k: v})
+
+ model.load_state_dict(state_dict)
+
+ return model
\ No newline at end of file
diff --git a/Make-A-Protagonist/makeaprotagonist/models/unet_blocks.py b/Make-A-Protagonist/makeaprotagonist/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccaa939698c7cb101e5e59a857024e995a54cd79
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/models/unet_blocks.py
@@ -0,0 +1,625 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+
+import torch
+from torch import nn
+
+from .attention import Transformer3DModel
+from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+import ipdb
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ temporal_conv=False,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ temporal_conv=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_conv=temporal_conv
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ temporal_conv=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, temb_aux=None, masks=None):
+ hidden_states = self.resnets[0](hidden_states, temb, temb_aux=temb_aux, masks=masks)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb, temb_aux=temb_aux, masks=masks)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ temporal_conv=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, adapter_feature=None, temb_aux=None, masks=None):
+ output_states = ()
+
+ # ipdb.set_trace()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, temb_aux, masks)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ # ipdb.set_trace()
+ if adapter_feature is not None:
+ hidden_states = hidden_states + adapter_feature
+ else:
+ hidden_states = resnet(hidden_states, temb, temb_aux=temb_aux, masks=masks)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ if adapter_feature is not None:
+ hidden_states = hidden_states + adapter_feature
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ temporal_conv=False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, adapter_feature=None, temb_aux=None, masks=None):
+ output_states = ()
+ # ipdb.set_trace()
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, temb_aux, masks)
+ # ipdb.set_trace()
+ if adapter_feature is not None:
+ hidden_states = hidden_states + adapter_feature
+
+ else:
+
+ hidden_states = resnet(hidden_states, temb, temb_aux, masks)
+ ## TODO check if it go through this
+ # ipdb.set_trace()
+ if adapter_feature is not None:
+ hidden_states = hidden_states + adapter_feature
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ temporal_conv=False
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ temb_aux=None,
+ masks=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, temb_aux, masks)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb, temb_aux, masks)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temporal_conv=False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ temporal_conv=temporal_conv
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, temb_aux=None, masks=None,):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, temb_aux, masks)
+ else:
+ hidden_states = resnet(hidden_states, temb, temb_aux, masks)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/Make-A-Protagonist/makeaprotagonist/pipelines/pipeline_stable_unclip_controlavideo.py b/Make-A-Protagonist/makeaprotagonist/pipelines/pipeline_stable_unclip_controlavideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c3e3ca996e227c5a7f7ad149250b47281abddc
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/pipelines/pipeline_stable_unclip_controlavideo.py
@@ -0,0 +1,1531 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+'''
+NOTE This is with prior
+When combined with an unCLIP prior, it can also be used for full text to image generation.
+
+NOTE this pipeline introduce controlnet into it
+'''
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+
+import PIL
+import torch
+import torch.nn as nn
+import numpy as np
+from dataclasses import dataclass
+
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+from einops import rearrange
+
+from diffusers.utils.import_utils import is_accelerate_available
+
+from diffusers.loaders import TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, PriorTransformer
+from diffusers.models.controlnet import ControlNetOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import get_timestep_embedding
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring, PIL_INTERPOLATION
+from diffusers.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+
+from ..models.unet import UNet3DConditionModel
+from diffusers.utils import deprecate, logging, BaseOutput
+import ipdb
+import warnings
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import requests
+ >>> import torch
+ >>> from PIL import Image
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableUnCLIPImg2ImgPipeline
+
+ >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
+ ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
+ ... ) # TODO update model path
+ >>> pipe = pipe.to("cuda")
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ >>> response = requests.get(url)
+ >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> init_image = init_image.resize((768, 512))
+
+ >>> prompt = "A fantasy landscape, trending on artstation"
+
+ >>> images = pipe(prompt, init_image).images
+ >>> images[0].save("fantasy_landscape.png")
+ ```
+"""
+
+
+@dataclass
+class TuneAVideoPipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+
+
+
+class MultiControlNetModel(ModelMixin):
+ r"""
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
+
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
+ compatible with `ControlNetModel`.
+
+ Args:
+ controlnets (`List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ down_samples, mid_sample = controlnet(
+ sample,
+ timestep,
+ encoder_hidden_states,
+ image,
+ scale,
+ class_labels,
+ timestep_cond,
+ attention_mask,
+ cross_attention_kwargs,
+ return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
+ return down_block_res_samples, mid_block_res_sample
+
+
+
+class MakeAProtagonistStableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ """
+ Pipeline for text-guided image to image generation using stable unCLIP.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ feature_extractor ([`CLIPImageProcessor`]):
+ Feature extractor for image pre-processing before being encoded.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ CLIP vision model for encoding images.
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
+ embeddings after the noise has been applied.
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`KarrasDiffusionSchedulers`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ """
+
+ def __init__(
+ self,
+ # prior components
+ prior_tokenizer: CLIPTokenizer,
+ prior_text_encoder: CLIPTextModelWithProjection,
+ prior: PriorTransformer,
+ prior_scheduler: KarrasDiffusionSchedulers,
+ # image encoding components
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ # image noising components
+ image_normalizer: StableUnCLIPImageNormalizer,
+ image_noising_scheduler: KarrasDiffusionSchedulers,
+ # regular denoising components
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ unet: UNet3DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ # vae
+ vae: AutoencoderKL,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_encoder,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ vae=vae,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
+ models = [
+ self.image_encoder,
+ self.prior_text_encoder,
+ self.text_encoder,
+ self.unet,
+ self.vae,
+ self.controlnet,
+ ]
+ for cpu_offloaded_model in models:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.image_encoder, self.unet, self.vae, self.controlnet]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder
+ def _encode_prior_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.prior_tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.prior_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.prior_tokenizer.batch_decode(
+ untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length]
+
+ prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = prior_text_encoder_output.text_embeds
+ prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.prior_tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.prior_tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder(
+ uncond_input.input_ids.to(device)
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
+ uncond_prior_text_encoder_hidden_states = (
+ negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
+ )
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prior_text_encoder_hidden_states = torch.cat(
+ [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
+ )
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, prior_text_encoder_hidden_states, text_mask
+
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def _encode_image(
+ self,
+ image,
+ device,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ noise_level,
+ generator,
+ image_embeds,
+ return_image_embeds=False
+ ):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if isinstance(image, PIL.Image.Image):
+ # the image embedding should repeated so it matches the total batch size of the prompt
+ repeat_by = batch_size
+ else:
+ # assume the image input is already properly batched and just needs to be repeated so
+ # it matches the num_videos_per_prompt.
+ #
+ # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
+ # `image_embeds`. If those happen to be common use cases, let's think harder about
+ # what the expected dimensions of inputs should be and how we handle the encoding.
+ repeat_by = num_videos_per_prompt
+
+ if image_embeds is None:
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+
+ if return_image_embeds:
+ return image_embeds
+
+ image_embeds = self.noise_image_embeddings(
+ image_embeds=image_embeds,
+ noise_level=noise_level,
+ generator=generator,
+ )
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ image_embeds = image_embeds.unsqueeze(1)
+ bs_embed, seq_len, _ = image_embeds.shape
+ image_embeds = image_embeds.repeat(1, repeat_by, 1)
+ image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
+ image_embeds = image_embeds.squeeze(1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeds)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ # def decode_latents(self, latents):
+ # latents = 1 / self.vae.config.scaling_factor * latents
+ # image = self.vae.decode(latents).sample
+ # image = (image / 2 + 0.5).clamp(0, 1)
+ # # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ # image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ # return image
+
+ def decode_latents(self, latents):
+ latents = latents.to(self.vae.dtype)
+ video_length = latents.shape[2]
+ latents = 1 / self.vae.config.scaling_factor * latents
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ video = self.vae.decode(latents).sample
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler
+ def prepare_prior_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the prior_scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image, ## this is for image embedding
+ control_image, ## this is for controlnet # the shape should be B,F,C,H,W
+ height,
+ width,
+ callback_steps,
+ noise_level,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two."
+ )
+
+ if prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined."
+ )
+
+ if prompt is not None and negative_prompt is not None:
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
+ raise ValueError(
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
+ )
+
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Please make sure to define only one of the two."
+ )
+
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+
+ if image is not None:
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ if isinstance(self.controlnet, ControlNetModel):
+ self.check_image(control_image, prompt, prompt_embeds)
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if not isinstance(control_image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in control_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(control_image) != len(self.controlnet.nets):
+ raise ValueError(
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
+ )
+
+ for image_ in control_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if isinstance(self.controlnet, ControlNetModel):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ elif image_is_tensor:
+ image_batch_size = image.shape[0]
+ elif image_is_pil_list:
+ image_batch_size = len(image)
+ elif image_is_tensor_list:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+
+ def prepare_image(
+ self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
+ ):
+ '''
+ image here should be batch wise video, B,F,C,H,W
+ '''
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ image = torch.cat([image] * 2)
+
+ return image
+
+
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents_shape(self, shape, dtype, device, generator, latents, scheduler):
+ # ipdb.set_trace()
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ ## B,4,F,H,W
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
+ def noise_image_embeddings(
+ self,
+ image_embeds: torch.Tensor,
+ noise_level: int,
+ noise: Optional[torch.FloatTensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ):
+ """
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
+ `noise_level` increases the variance in the final un-noised images.
+
+ The noise is applied in two ways
+ 1. A noise schedule is applied directly to the embeddings
+ 2. A vector of sinusoidal time embeddings are appended to the output.
+
+ In both cases, the amount of noise is controlled by the same `noise_level`.
+
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
+ """
+ # ipdb.set_trace()
+
+ if noise is None:
+ noise = randn_tensor(
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
+ )
+
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
+
+ self.image_normalizer.to(image_embeds.device)
+ image_embeds = self.image_normalizer.scale(image_embeds)
+
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
+
+ image_embeds = self.image_normalizer.unscale(image_embeds)
+
+ noise_level = get_timestep_embedding(
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
+ )
+
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
+ # but we might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ noise_level = noise_level.to(image_embeds.dtype)
+
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
+
+ return image_embeds
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ prompt: Union[str, List[str]] = None,
+ control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
+ video_length: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50, # 20 in image unclip pipline
+ guidance_scale: float = 7.5, # 10 in image unclip pipline
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ noise_level: int = 0,
+ image_embeds: Optional[torch.FloatTensor] = None,
+ adapter_features: Optional[torch.Tensor] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ controlnet_image_embeds_type: str = "empty", ## "image" for using image embedding for control net
+ ## text guidance
+ text_guidance_scale: float = 0, # NOTE CFG on no image embedding sample
+ # prior args
+ prior_num_inference_steps: int = 25,
+ prior_guidance_scale: float = 4.0,
+ prior_latents: Optional[torch.FloatTensor] = None,
+ interpolate_embed_weight: float = 1.0,
+ return_prior_embed: bool = False, ## return prior embedding for DDIM inv
+ prior_denoised_embeds: Optional[torch.FloatTensor] = None, ## the embedding used for the background, after denoising
+
+ ## mask args
+ masks: Optional[torch.FloatTensor] = None,
+ inverse_mask: bool = False, ## inverse mask of image embedding and text
+ start_step: int = -1, ## start to use mask
+ end_step: int = 1000, ## end to use mask
+ mask_mode: str = 'all',
+ mask_latent_fuse_mode: str = 'all',
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
+ used or prompt is initialized to `""`.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
+ process.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 10.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ noise_level (`int`, *optional*, defaults to `0`):
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
+ image_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
+ `latents`.
+
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps in the prior denoising process. More denoising steps usually lead to a
+ higher quality image at the expense of slower inference.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion
+ Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of
+ [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ prior_latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ embedding generation in the prior denoising process. Can be used to tweak the same generation with
+ different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied
+ random `generator`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if prompt is None and prompt_embeds is None:
+ prompt = len(image) * [""] if isinstance(image, list) else ""
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ assert not isinstance(prompt, list)
+ ## NOTE only support one prompt here
+
+ else:
+ if control_image.dim() == 5:
+ prompt_len = len(prompt) if isinstance(prompt, list) else 1
+ control_image = control_image.repeat(prompt_len,1,1,1,1) # B,F,3,H,W
+
+ if len(masks.unique()) == 1 and masks.unique()[0] == 1: ## if all ones, just ignore
+ masks = None
+
+ if masks is not None:
+ if not interpolate_embed_weight: ## is 0
+ warnings.warn( "Using mask should use image embedding combined with prior embedding. Now only prior embedding is used, the results should be the same with no mask")
+ # assert interpolate_embed_weight, "Using mask should use image embedding combined with prior embedding"
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ image=image,
+ control_image=control_image,
+ height=height,
+ width=width,
+ callback_steps=callback_steps,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ image_embeds=image_embeds,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ )
+
+ # ipdb.set_trace()
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ batch_size = batch_size * num_videos_per_prompt
+
+ device = self._execution_device
+
+ ## NOTE using prior denoised latents from the source embedding for partially editing
+
+ if prior_denoised_embeds is None:
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ prior_do_classifier_free_guidance = prior_guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=prior_do_classifier_free_guidance,
+ )
+ ## prior_prompt_embeds: text embeds
+ ## prior_text_encoder_hidden_states: last hidden state
+
+ # 4. Prepare prior timesteps
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ # 5. Prepare prior latent variables
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents_shape(
+ (batch_size, embedding_dim),
+ prior_prompt_embeds.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+ # ipdb.set_trace()
+
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta)
+
+
+ # 7. Prior denoising loop
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents
+ latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t)
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prior_prompt_embeds,
+ encoder_hidden_states=prior_text_encoder_hidden_states,
+ attention_mask=prior_text_mask,
+ ).predicted_image_embedding
+
+ if prior_do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ **prior_extra_step_kwargs,
+ ).prev_sample
+
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, prior_latents)
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ if return_prior_embed:
+ return prior_latents
+ # done prior
+ else:
+ prior_latents = prior_denoised_embeds
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
+
+ # ipdb.set_trace()
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ prompt_embeds_text, image_embeds_text = None, None
+ if text_guidance_scale:
+ prompt_embeds_text = prompt_embeds[prompt_embeds.size(0)//2:] ## with text
+
+ # 4. Encoder input image
+ noise_level = torch.tensor([noise_level], device=device)
+
+ if interpolate_embed_weight:
+ # assert image is not None, "interpolate image embedding with prior embedding requires the image"
+ image_embeds_given = self._encode_image(
+ image=image,
+ device=device,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ noise_level=noise_level,
+ generator=generator,
+ image_embeds=image_embeds,
+ return_image_embeds=True,
+ )
+ image_embeds = interpolate_embed_weight * image_embeds_given + (1-interpolate_embed_weight) * prior_latents
+ else:
+ image_embeds = prior_latents
+
+ image_embeds = self._encode_image(
+ image=image,
+ device=device,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ noise_level=noise_level,
+ generator=generator,
+ image_embeds=image_embeds,
+ ) # 2B,C
+
+ aux_latents = None
+ if masks is not None:
+ ## NOTE encode prior latent
+ aux_latents = self._encode_image(
+ image=None,
+ device=device,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ noise_level=noise_level,
+ generator=generator,
+ image_embeds=prior_latents,
+ ) # 2B,C
+
+ #
+ if text_guidance_scale:
+ image_embeds_text = image_embeds[:image_embeds.size(0)//2] # no image
+
+ # 5. Prepare control image
+ ## control image shape B,F,C,H,W
+
+ # assert control_image.dim() == 5
+ if isinstance(self.controlnet, ControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt,
+ num_images_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_videos_per_prompt,
+ num_images_per_prompt=num_videos_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ images.append(image_)
+
+ control_image = images
+ else:
+ assert False
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size=batch_size,
+ num_channels_latents=num_channels_latents,
+ video_length=video_length,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+
+ ## control_image: B,F,C,H,W
+ if isinstance(control_image, list):
+ control_image = [rearrange(c_image, "b f c h w -> (b f) c h w").to(device=self.controlnet.device, dtype=self.controlnet.dtype) for c_image in control_image]
+ else:
+ control_image = rearrange(control_image, "b f c h w -> (b f) c h w").to(device=self.controlnet.device, dtype=self.controlnet.dtype) ##
+ # aux_latents = torch.zeros_like(image_embeds)
+ if masks is not None:
+ # ipdb.set_trace()
+ masks = rearrange(masks, "b f c h w -> b c f h w").to(device=self.unet.device, dtype=self.unet.dtype)
+ masks = torch.nn.functional.interpolate(masks, size=latents.size()[-3:], mode="nearest")
+ # ipdb.set_trace()
+ if inverse_mask:
+ masks = 1 - masks
+ image_embeds, aux_latents = aux_latents, image_embeds
+
+ # ipdb.set_trace()
+ mask_mode_cfg = mask_mode ## mask_mode: emb / latent / all
+ mask_mode = mask_mode_cfg
+ mask_latent_fuse_mode = mask_latent_fuse_mode ## inverse or all
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ ## t is 1000 divided in to 50 steps
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ ## 2B,C,F,H,W
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # ipdb.set_trace()
+ # controlnet(s) inference
+ down_block_res_samples, mid_block_res_sample = None, None
+
+ latent_model_input_control = rearrange(latent_model_input, "b c f h w -> (b f) c h w").to(dtype=self.controlnet.dtype) ##
+
+ if controlnet_image_embeds_type == "image":
+ controlnet_image_embeds = image_embeds.repeat(video_length, 1)
+ else:
+ # controlnet_image_embeds = torch.zeros_like(image_embeds).repeat(video_length, 1)
+ # NOTE this is a support for frame wise image embedding
+ controlnet_image_embeds = torch.zeros_like(image_embeds[:latent_model_input.size(0)]).repeat(video_length, 1)
+
+ controlnet_image_embeds = controlnet_image_embeds.to(self.controlnet.dtype)
+ # ipdb.set_trace()
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ latent_model_input_control,
+ t,
+ class_labels=controlnet_image_embeds,
+ encoder_hidden_states=prompt_embeds.repeat(video_length, 1, 1),
+ controlnet_cond=control_image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+ down_block_res_samples = [rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) for sample in down_block_res_samples]
+ mid_block_res_sample = rearrange(mid_block_res_sample, "(b f) c h w -> b c f h w", f=video_length)
+
+ # ipdb.set_trace()
+ if i >= start_step and i < end_step:
+ _aux_latents = aux_latents
+ _masks = masks
+ ## NOTE this can use emb in some steps and use mask latent in other steps
+ mask_mode = mask_mode_cfg
+ else:
+ _aux_latents, _masks = None, None
+ mask_mode = "emb"
+ # predict the noise residual
+
+ ## NOTE mask mode list, the first is for image content, using latent mask, the second is for text, no latent mask / inverse mask
+ value = torch.zeros_like(latents)
+ count = torch.zeros_like(latents)
+
+ # ipdb.set_trace()
+ if mask_mode == "latent":
+ cls_labels = [image_embeds, _aux_latents]
+ cls_labels_aux = [None, None]
+ cls_masks = [None, None]
+
+ if mask_latent_fuse_mode == "all":
+ latent_masks = [masks, torch.ones_like(masks)] # this is used for combining latents
+ else:
+ latent_masks = [masks, 1-masks] # this is used for combining latents
+
+ elif mask_mode == "all":
+ cls_labels = [image_embeds, image_embeds]
+ cls_labels_aux = [None, _aux_latents]
+ cls_masks = [None, _masks]
+
+ if mask_latent_fuse_mode == "all":
+ latent_masks = [_masks, torch.ones_like(_masks)] # this is used for combining latents
+ else:
+ latent_masks = [_masks, 1-_masks] # this is used for combining latents
+
+ else: ## this is the original version
+ cls_labels = [image_embeds]
+ cls_labels_aux = [_aux_latents]
+ cls_masks = [_masks]
+ latent_masks = [ torch.ones_like(latents)]
+
+ for _cls_labels, _cls_labels_aux, _cls_masks, _latent_masks in zip(cls_labels, cls_labels_aux, cls_masks, latent_masks):
+
+ latent_view = latents
+ latent_model_input = torch.cat([latent_view] * 2) if do_classifier_free_guidance else latent_view
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # (batch_size, 4, F, H, W)
+ # ipdb.set_trace()
+ ## if use all training embedding, _cls_labels: 2F,C -> reshape 2,F,C will get the correct result,
+ ## NOTE the batch size 0 is the unconditional
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ class_labels=_cls_labels,
+ class_labels_aux=_cls_labels_aux,
+ masks=_cls_masks,
+ adapter_features=adapter_features,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view, **extra_step_kwargs).prev_sample
+
+ # ipdb.set_trace()
+ ## _latent_masks 1,1,F,H,W
+ value += latents_view_denoised * _latent_masks
+ count += _latent_masks
+
+ assert (count > 0).all()
+ latents = torch.where(count > 0, value / count, value)
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ video = self.decode_latents(latents)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ # Convert to tensor
+ if output_type == "tensor":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return TuneAVideoPipelineOutput(videos=video)
diff --git a/Make-A-Protagonist/makeaprotagonist/util.py b/Make-A-Protagonist/makeaprotagonist/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1adda28145acb26cf1c92267d50c646f01703fd2
--- /dev/null
+++ b/Make-A-Protagonist/makeaprotagonist/util.py
@@ -0,0 +1,189 @@
+import os
+import imageio
+import numpy as np
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+import torchvision
+
+from tqdm import tqdm
+from einops import rearrange
+import ipdb
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, fps=fps)
+
+
+# DDIM Inversion
+@torch.no_grad()
+def init_prompt(prompt, pipeline):
+ uncond_input = pipeline.tokenizer(
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
+ return_tensors="pt"
+ )
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
+ text_input = pipeline.tokenizer(
+ [prompt],
+ padding="max_length",
+ max_length=pipeline.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ return context
+
+
+def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
+ timestep, next_timestep = min(
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+
+def get_noise_pred_single(latents, t, context, unet, image_embeds=None):
+ noise_pred = unet(latents, t, encoder_hidden_states=context, class_labels=image_embeds)["sample"]
+ return noise_pred
+
+
+@torch.no_grad()
+def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
+ context = init_prompt(prompt, pipeline)
+ uncond_embeddings, cond_embeddings = context.chunk(2)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ for i in tqdm(range(num_inv_steps)):
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
+ all_latent.append(latent)
+ return all_latent
+
+
+@torch.no_grad()
+def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
+ return ddim_latents
+
+
+# DDIM Inversion
+@torch.no_grad()
+def init_image_embed(image_embeds, pipeline, noise_level, generator):
+ # ipdb.set_trace()
+ dtype = next(pipeline.image_encoder.parameters()).dtype
+ device = pipeline.image_encoder.device
+
+ noise_level = torch.tensor([noise_level], device=device)
+
+ image_embeds = pipeline.noise_image_embeddings(
+ image_embeds=image_embeds,
+ noise_level=noise_level,
+ generator=generator,
+ ) # 1,1024
+
+ return image_embeds
+
+def next_step_velocity(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
+ timestep, next_timestep = min(
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_next = 1 - alpha_prod_t_next
+
+ next_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ next_pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * next_pred_epsilon
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+
+ return next_sample
+
+@torch.no_grad()
+def ddim_loop_unclip(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, image_embed, noise_level, generator):
+ context = init_prompt(prompt, pipeline)
+ uncond_embeddings, cond_embeddings = context.chunk(2)
+ assert image_embed is not None
+
+ if not image_embed.dim() == 0:
+ image_embeddings = init_image_embed(image_embed, pipeline, noise_level, generator)
+ else:
+ image_embeddings = image_embed
+
+
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ for i in tqdm(range(num_inv_steps)):
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet, image_embeddings)
+ latent = next_step_velocity(noise_pred, t, latent, ddim_scheduler)
+ all_latent.append(latent)
+ return all_latent
+
+
+@torch.no_grad()
+def ddim_inversion_unclip(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", image_embed=None, noise_level=0, seed=0):
+ '''
+ generator should be fixed here for consistent latent estimation
+ '''
+ generator = torch.Generator(device=video_latent.device)
+ generator.manual_seed(seed)
+
+ ddim_latents = ddim_loop_unclip(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, image_embed, noise_level, generator)
+ return ddim_latents
+
+def next_step_sample(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ num_inference_steps: int, ddim_scheduler):
+ timestep, next_timestep = min(
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = model_output
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+def get_noise_pred_single_prior(latents, t, prior, prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask):
+ noise_pred = prior(latents, timestep=t, proj_embedding=prior_prompt_embeds,
+ encoder_hidden_states=prior_text_encoder_hidden_states,
+ attention_mask=prior_text_mask,).predicted_image_embedding
+ return noise_pred
+
+@torch.no_grad()
+def ddim_loop_prior(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
+ prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = pipeline._encode_prior_prompt(prompt=prompt, device=pipeline.prior.device, num_images_per_prompt=1, do_classifier_free_guidance=False)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ # ipdb.set_trace()
+ for i in tqdm(range(num_inv_steps)):
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+ noise_pred = get_noise_pred_single_prior(latent, t, pipeline.prior, prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask)
+ latent = next_step_sample(noise_pred, t, num_inv_steps, ddim_scheduler)
+ all_latent.append(latent)
+ return all_latent
+
+@torch.no_grad()
+def ddim_inversion_prior(pipeline, ddim_scheduler, latent, num_inv_steps, prompt=""):
+ # ipdb.set_trace()
+ ddim_latents = ddim_loop_prior(pipeline, ddim_scheduler, latent, num_inv_steps, prompt)
+ return ddim_latents
diff --git a/Make-A-Protagonist/requirements.txt b/Make-A-Protagonist/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0446d32b6db970d14ccc587c1809715e8bb91ab3
--- /dev/null
+++ b/Make-A-Protagonist/requirements.txt
@@ -0,0 +1,23 @@
+## CUDA 11.6 torch 1.13.1
+# torch==1.13.1+cu116
+# torchvision==0.14.1+cu116
+# torchaudio==0.13.1
+# xformers==0.0.17dev466 ## 0.0.16 not work
+
+## install diffusers
+pip install git+https://github.com/huggingface/diffusers.git
+transformers>=4.25.1
+bitsandbytes==0.35.4
+accelerate
+tensorboard
+modelcards
+omegaconf
+einops
+imageio
+ftfy
+opencv-python
+timm
+wandb
+ipdb
+matplotlib
+triton
\ No newline at end of file
diff --git a/Make-A-Protagonist/train.py b/Make-A-Protagonist/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d8d8229af05e728b811c5e3f85f0ae81b657011
--- /dev/null
+++ b/Make-A-Protagonist/train.py
@@ -0,0 +1,519 @@
+import argparse
+import datetime
+import logging
+import inspect
+import math
+import os
+from typing import Dict, Optional, Tuple
+from omegaconf import OmegaConf
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import numpy as np
+from PIL import Image
+
+import diffusers
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, PNDMScheduler, ControlNetModel, PriorTransformer, UnCLIPScheduler
+from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
+
+from makeaprotagonist.models.unet import UNet3DConditionModel
+from makeaprotagonist.dataset.dataset import MakeAProtagonistDataset
+from makeaprotagonist.util import save_videos_grid, ddim_inversion_unclip
+from makeaprotagonist.pipelines.pipeline_stable_unclip_controlavideo import MakeAProtagonistStableUnCLIPPipeline, MultiControlNetModel
+
+from einops import rearrange
+from makeaprotagonist.args_util import DictAction, config_merge_dict
+import ipdb
+import random
+from glob import glob
+import sys
+
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.15.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def main(
+ pretrained_model_path: str,
+ controlnet_pretrained_model_path: str,
+ output_dir: str,
+ train_data: Dict,
+ validation_data: Dict,
+ validation_steps: int = 100,
+ trainable_modules: Tuple[str] = (
+ "attn1.to_q",
+ "attn2.to_q",
+ "attn_temp",
+ ),
+ trainable_params: Tuple[str] = (),
+ train_batch_size: int = 1,
+ max_train_steps: int = 500,
+ learning_rate: float = 3e-5,
+ scale_lr: bool = False,
+ lr_scheduler: str = "constant",
+ lr_warmup_steps: int = 0,
+ adam_beta1: float = 0.9,
+ adam_beta2: float = 0.999,
+ adam_weight_decay: float = 1e-2,
+ adam_epsilon: float = 1e-08,
+ max_grad_norm: float = 1.0,
+ gradient_accumulation_steps: int = 1,
+ gradient_checkpointing: bool = True,
+ checkpointing_steps: int = 500,
+ resume_from_checkpoint: Optional[str] = None,
+ mixed_precision: Optional[str] = "fp16",
+ use_8bit_adam: bool = False,
+ enable_xformers_memory_efficient_attention: bool = True,
+ seed: Optional[int] = None,
+ adapter_config=None, # the config for adapter
+ use_temporal_conv=False, ## use temporal conv in resblocks
+):
+ *_, config = inspect.getargvalues(inspect.currentframe())
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ mixed_precision=mixed_precision,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if seed is not None:
+ set_seed(seed)
+
+ # Handle the output folder creation
+ if accelerator.is_main_process:
+ # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ # output_dir = os.path.join(output_dir, now)
+ os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
+ os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
+ prior_model_id = "kakaobrain/karlo-v1-alpha"
+ data_type = torch.float16
+ prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)
+
+ prior_text_model_id = "openai/clip-vit-large-patch14"
+ prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)
+ prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)
+ prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler")
+ prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+
+ # image encoding components
+ feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
+ # image noising components
+ image_normalizer = StableUnCLIPImageNormalizer.from_pretrained(pretrained_model_path, subfolder="image_normalizer")
+ image_noising_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="image_noising_scheduler")
+ # regular denoising components
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_temporal_conv=use_temporal_conv)
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+
+ # vae
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+
+ ## controlnet
+ assert not isinstance(controlnet_pretrained_model_path, str)
+ controlnet = MultiControlNetModel( [ControlNetModel.from_pretrained(_control_model_path) for _control_model_path in controlnet_pretrained_model_path] )
+
+ # Freeze vae and text_encoder and adapter
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ ## freeze image embed
+ image_encoder.requires_grad_(False)
+
+ unet.requires_grad_(False)
+ ## freeze controlnet
+ controlnet.requires_grad_(False)
+
+ ## freeze prior
+ prior.requires_grad_(False)
+ prior_text_model.requires_grad_(False)
+
+ for name, module in unet.named_modules():
+ if name.endswith(tuple(trainable_modules)):
+ for params in module.parameters():
+ params.requires_grad = True
+
+ if len(trainable_params):
+ for name, params in unet.named_parameters():
+ if name.endswith(tuple(trainable_params)):
+ params.requires_grad = True
+
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if scale_lr:
+ learning_rate = (
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=learning_rate,
+ betas=(adam_beta1, adam_beta2),
+ weight_decay=adam_weight_decay,
+ eps=adam_epsilon,
+ )
+
+ # Get the training dataset
+ train_dataset = MakeAProtagonistDataset(**train_data)
+
+ # Preprocessing the dataset
+ train_dataset.prompt_ids = tokenizer(
+ train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids[0]
+
+ train_dataset.preprocess_img_embedding(feature_extractor, image_encoder)
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=train_batch_size
+ )
+
+ # Get the validation pipeline
+ # validation_pipeline = TuneAVideoPipeline(
+ # vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
+ # scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ # )
+ prior_val_scheduler = DDIMScheduler.from_config(prior_scheduler.config) if validation_data.get("prior_val_scheduler", "") == "DDIM" else prior_scheduler
+
+ validation_pipeline = MakeAProtagonistStableUnCLIPPipeline(
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_val_scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ )
+
+
+ validation_pipeline.enable_vae_slicing()
+ ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
+ ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
+
+ # Scheduler
+ lr_scheduler = get_scheduler(
+ lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
+ )
+
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move models to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ ## note controlnet use the unet dtype
+ controlnet.to(accelerator.device, dtype=weight_dtype)
+ ## prior
+ prior.to(accelerator.device, dtype=weight_dtype)
+ prior_text_model.to(accelerator.device, dtype=weight_dtype)
+
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2video-fine-tune")
+
+ # Train!
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if resume_from_checkpoint:
+ if resume_from_checkpoint != "latest":
+ path = os.path.basename(resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1]
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = global_step % num_update_steps_per_epoch
+
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ if not "noise_level" in validation_data:
+ validation_data.noise_level = train_data.noise_level
+
+ image_embed_drop = train_data.get("image_embed_drop", 0)
+
+ for epoch in range(first_epoch, num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert videos to latent space
+ pixel_values = batch["pixel_values"].to(weight_dtype)
+ video_length = pixel_values.shape[1]
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+ latents = vae.encode(pixel_values).latent_dist.sample()
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each video
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
+
+ #
+ # ipdb.set_trace()
+ ref_imbed = batch["ref_imbed"].to(accelerator.device, dtype=weight_dtype) # 1,1,768
+ ##
+ if train_data.noise_level >= 1000:
+ train_noise = random.randint(0, 999)
+ else:
+ train_noise = train_data.noise_level
+
+ image_embeds = validation_pipeline.noise_image_embeddings(
+ image_embeds=ref_imbed,
+ noise_level=train_noise,
+ generator=None,
+ )
+
+ if random.random() < image_embed_drop:
+ image_embeds = torch.zeros_like(image_embeds)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.prediction_type == "v_prediction": ## use this for unclip model
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, adapter_features=adapter_features).sample
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, class_labels=image_embeds).sample
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
+ train_loss += avg_loss.item() / gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % validation_steps == 0:
+ if accelerator.is_main_process:
+
+ # ControlNet
+ conditions = [_condition.to(weight_dtype) for _, _condition in batch["conditions"].items()] # b f c h w
+ masks = batch["masks"].to(weight_dtype) # b,f,1,h,w
+ if not validation_data.get("use_masks", False):
+ masks = torch.ones_like(masks)
+
+ ddim_inv_latent = None
+ if validation_data.use_inv_latent: #
+ emb_dim = train_dataset.img_embeddings[0].size(0)
+ key_frame_embed = torch.zeros((1, emb_dim)).to(device=latents.device, dtype=latents.dtype) ## this is dim 0
+ ddim_inv_latent = ddim_inversion_unclip(
+ validation_pipeline, ddim_inv_scheduler, video_latent=latents,
+ num_inv_steps=validation_data.num_inv_steps, prompt="", image_embed=key_frame_embed, noise_level=validation_data.noise_level, seed=seed)[-1].to(weight_dtype)
+
+ set_noise = validation_data.pop("noise_level")
+ v_noise = set_noise
+
+ if not validation_data.get("interpolate_embed_weight", False):
+ validation_data.interpolate_embed_weight = 1.0
+
+
+ samples = []
+
+ generator = torch.Generator(device=accelerator.device)
+ generator.manual_seed(seed)
+
+ for idx, prompt in enumerate(validation_data.prompts):
+
+ _ref_image = Image.open(validation_data.ref_images[idx])
+ image_embed = None
+ ## prior latents
+ prior_embeds = None
+ prior_denoised_embeds = None
+ if validation_data.get("source_background", False):
+ ## using source background and changing the protagonist
+ prior_denoised_embeds = train_dataset.img_embeddings[0][None].to(device=latents.device, dtype=latents.dtype) # 1, 768 for UnCLIP-small
+
+ if validation_data.get("source_protagonist", False):
+ # using source protagonist and changing the background
+ sample_indices = batch["sample_indices"][0]
+ image_embed = [train_dataset.img_embeddings[idx] for idx in sample_indices]
+ image_embed = torch.stack(image_embed, dim=0).to(device=latents.device, dtype=latents.dtype) # F, 768 for UnCLIP-small # F,C
+ _ref_image = None
+
+ sample = validation_pipeline(image=_ref_image, prompt=prompt, control_image=conditions, generator=generator, latents=ddim_inv_latent, image_embeds=image_embed, noise_level=v_noise, masks=masks, prior_latents=prior_embeds, prior_denoised_embeds=prior_denoised_embeds, **validation_data).videos
+
+ save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}-seed{seed}/{idx}-{prompt}.gif")
+ samples.append(sample)
+
+ #
+ samples = [sample.float() for sample in samples]
+ samples = torch.concat(samples)
+ save_path = f"{output_dir}/samples/sample-{global_step}-s{validation_data.start_step}-e{validation_data.end_step}-seed{seed}.gif" # noise level and noise level for inv
+ save_videos_grid(samples, save_path, n_rows=len(samples))
+ logger.info(f"Saved samples to {save_path}")
+
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='Override some settings in the '
+ 'used config, the key-value pair in xxx=yyy format will be merged '
+ 'into config file. If the value to be overwritten is a list, it '
+ 'should be like key="[a,b]" or key=a,b It also allows nested '
+ 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
+ 'marks are necessary and that no white space is allowed.')
+
+ args = parser.parse_args()
+
+ ## read from cmd line
+ # ipdb.set_trace()
+ # Load the YAML configuration file
+ config = OmegaConf.load(args.config)
+ # Merge the command-line arguments with the configuration file
+ if args.options is not None:
+ # config = OmegaConf.merge(config, args.options)
+ config_merge_dict(args.options, config)
+
+ main(**config)
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f75d99f4adbe5b9e83aa8df14a2282bd2873562d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,10 @@
+---
+title: Make-A-Protagonist Inference
+colorFrom: red
+colorTo: purple
+sdk: docker
+pinned: false
+license: apache-2.0
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100755
index 0000000000000000000000000000000000000000..ef91cfd15041835c3dbcb0e201c134d42c287a19
--- /dev/null
+++ b/app.py
@@ -0,0 +1,340 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import os
+import sys
+import warnings
+
+os.system("python -m pip install -e Make-A-Protagonist/experts/GroundedSAM/segment_anything")
+os.system("python -m pip install -e Make-A-Protagonist/experts/GroundedSAM/GroundingDINO")
+# os.system("pip install --upgrade diffusers[torch]")
+warnings.filterwarnings("ignore")
+
+import gradio as gr
+
+from inference import InferencePipeline
+
+
+class InferenceUtil:
+ def __init__(self, hf_token: str | None):
+ self.hf_token = hf_token
+
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
+ ## TODO the modelcard is in the readme of huggingface repo, should know how to write it
+ try:
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
+ except Exception:
+ return '', ''
+ # return ''
+ base_model = getattr(card.data, 'base_model', '')
+ protagonist = getattr(card.data, 'protagonist', '')
+ training_prompt = getattr(card.data, 'training_prompt', '')
+ return protagonist, training_prompt
+ # return training_prompt
+
+
+# TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
+HF_TOKEN = os.getenv('HF_TOKEN')
+# print("HF Token ===> ", HF_TOKEN)
+pipe = InferencePipeline(HF_TOKEN)
+app = InferenceUtil(HF_TOKEN)
+
+with gr.Blocks(css='style.css') as demo:
+ # gr.Markdown(TITLE)
+
+ gr.HTML(
+ """
+
+
+ Make-A-Protagonist:
+
+ Generic Video Editing with An Ensemble of Experts
+
+
+
+
+ 1 National University of Singapore
+ 2 Huawei Noah's Ark Lab
+
+
+
+
+ TL;DR: The first framework for generic video editing with both visual and textual clues.
+
+
+ """)
+
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ model_id = gr.Dropdown(
+ label='Model ID',
+ choices=[
+ 'Make-A-Protagonist/ikun',
+ 'Make-A-Protagonist/huaqiang',
+ 'Make-A-Protagonist/yanzi',
+ 'Make-A-Protagonist/car-turn',
+ ],
+ value='Make-A-Protagonist/ikun')
+
+ with gr.Row():
+ base_model_used_for_training = gr.Textbox(
+ label='Protagonist', interactive=False, value='man')
+ prompt_used_for_training = gr.Textbox(
+ label='Training prompt', interactive=False, value='A man is playing basketball')
+ with gr.Box():
+ ref_image = gr.Image(label='Reference Image', type='pil', visible=True).style(height="auto")
+ ref_pro_prompt = gr.Textbox(label='Reference Image Protagonist Prompt',
+ max_lines=1,
+ placeholder='Example: "man"')
+
+ prompt = gr.Textbox(label='Prompt',
+ max_lines=1,
+ placeholder='Example: "A panda is surfing"')
+ video_length = gr.Slider(label='Video length',
+ minimum=4,
+ maximum=8,
+ step=1,
+ value=8)
+ fps = gr.Slider(label='FPS',
+ minimum=1,
+ maximum=8,
+ step=1,
+ value=4)
+ seed = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=100000,
+ step=1,
+ value=0)
+
+ with gr.Accordion('ControlNet Parameters', open=True):
+ control_pose = gr.Slider(label='Pose',
+ minimum=0,
+ maximum=1,
+ step=0.1,
+ value=.5)
+ control_depth = gr.Slider(label='Depth',
+ minimum=0,
+ maximum=1,
+ step=0.1,
+ value=.5)
+
+ with gr.Accordion('Editing Function', open=True):
+ with gr.Row():
+ source_pro = gr.Slider(label='Source Protagonist',
+ minimum=0,
+ maximum=1,
+ step=1,
+ value=0)
+ source_bg = gr.Slider(label='Source Background',
+ minimum=0,
+ maximum=1,
+ step=1,
+ value=0)
+
+ with gr.Accordion('Other Parameters', open=False):
+ num_steps = gr.Slider(label='Number of Steps',
+ minimum=0,
+ maximum=100,
+ step=1,
+ value=50)
+ guidance_scale = gr.Slider(label='CFG Scale',
+ minimum=0,
+ maximum=50,
+ step=0.1,
+ value=12.5)
+
+ noise_level = gr.Slider(label='Noise Level',
+ minimum=0,
+ maximum=999,
+ step=1,
+ value=0)
+
+
+ run_button = gr.Button('Generate')
+
+ gr.Markdown('''
+ - It takes a few minutes to download model first.
+ - It takes one minute to load model and conduct DDIM inverse
+ ''')
+ with gr.Column():
+ result = gr.Video(label='Result')
+ with gr.Row():
+ examples = [
+ [
+ 'Make-A-Protagonist/ikun',
+ 'A man is playing basketball on the beach, anime style.',
+ 8,
+ 4,
+ 33,
+ 50,
+ 12.5,
+ 'data/ikun/reference_images/zhongli.jpg',
+ 'man',
+ 0,
+ 0.5,
+ 0.5,
+ 0,
+ 0
+ ],
+
+ [
+ 'Make-A-Protagonist/huaqiang',
+ 'Elon Musk walking down the street.',
+ 8,
+ 4,
+ 33,
+ 50,
+ 12.5,
+ 'data/huaqiang/reference_images/musk.jpg',
+ 'man',
+ 0,
+ 0.5,
+ 0.5,
+ 0,
+ 1,
+ ],
+
+ [
+ 'Make-A-Protagonist/yanzi',
+ 'A panda walking down the snowy street.',
+ 8,
+ 4,
+ 33,
+ 50,
+ 12.5,
+ 'data/yanzi/reference_images/panda.jpeg',
+ 'panda',
+ 0,
+ 0.5,
+ 0.5,
+ 0,
+ 0
+ ],
+
+ [
+ 'Make-A-Protagonist/car-turn',
+ 'A car moving in the desert.',
+ 8,
+ 4,
+ 33,
+ 50,
+ 12.5,
+ 'data/car-turn/reference_images/audi.jpeg',
+ 'car',
+ 0,
+ 0.0,
+ 1.0,
+ 0,
+ 0
+ ],
+
+ [
+ 'Make-A-Protagonist/car-turn',
+ 'A Suzuki Jimny driving down a mountain road in the rain.',
+ 8,
+ 4,
+ 33,
+ 50,
+ 12.5,
+ 'data/car-turn/images/0000.jpg',
+ 'car',
+ 0,
+ 0.0,
+ 1.0,
+ 1,
+ 0
+ ],
+
+ ]
+ gr.Examples(examples=examples,
+ inputs=[
+ model_id,
+ prompt,
+ video_length,
+ fps,
+ seed,
+ num_steps,
+ guidance_scale,
+ ref_image,
+ ref_pro_prompt,
+ noise_level,
+ control_pose,
+ control_depth,
+ source_pro,
+ source_bg,
+ ],
+ outputs=result,
+ fn=pipe.run,
+ cache_examples=os.getenv('SYSTEM') == 'spaces')
+
+ model_id.change(fn=app.load_model_info,
+ inputs=model_id,
+ outputs=[
+ base_model_used_for_training,
+ prompt_used_for_training,
+ ])
+
+
+
+ inputs = [
+ model_id,
+ prompt,
+ video_length,
+ fps,
+ seed,
+ num_steps,
+ guidance_scale,
+ ref_image,
+ ref_pro_prompt,
+ noise_level,
+ control_pose,
+ control_depth,
+ source_pro,
+ source_bg,
+ ]
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
+
+demo.queue().launch(share=True)
diff --git a/data/bird-forest/images/0001.jpg b/data/bird-forest/images/0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..886b0362a820852a9a68fb92bf509b3246e3f455
Binary files /dev/null and b/data/bird-forest/images/0001.jpg differ
diff --git a/data/bird-forest/images/0002.jpg b/data/bird-forest/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4926cce82532120e4898f13d4aeecbda092a7a5d
Binary files /dev/null and b/data/bird-forest/images/0002.jpg differ
diff --git a/data/bird-forest/images/0003.jpg b/data/bird-forest/images/0003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d8e7555da79e9a4c0fd999c83119619cf3af667d
Binary files /dev/null and b/data/bird-forest/images/0003.jpg differ
diff --git a/data/bird-forest/images/0004.jpg b/data/bird-forest/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c37f3e036c22b8a41bbc800c38a0d527fd4936a6
Binary files /dev/null and b/data/bird-forest/images/0004.jpg differ
diff --git a/data/bird-forest/images/0005.jpg b/data/bird-forest/images/0005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bf0041ccb7b4a3841d4782b20d8186442753d557
Binary files /dev/null and b/data/bird-forest/images/0005.jpg differ
diff --git a/data/bird-forest/images/0006.jpg b/data/bird-forest/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..41ddee3a8504301018be2dd7bbad539e0d3ce417
Binary files /dev/null and b/data/bird-forest/images/0006.jpg differ
diff --git a/data/bird-forest/images/0007.jpg b/data/bird-forest/images/0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8fe2e7c5f2b370b4836aa78f489ab0e286cd6c13
Binary files /dev/null and b/data/bird-forest/images/0007.jpg differ
diff --git a/data/bird-forest/images/0008.jpg b/data/bird-forest/images/0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bddf1a543934ccf8d1466d49b2db8efaa3f786b3
Binary files /dev/null and b/data/bird-forest/images/0008.jpg differ
diff --git a/data/bird-forest/reference_images/eagle.jpeg b/data/bird-forest/reference_images/eagle.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..3a92621b4a4a224aa86e715f63273853f196e895
Binary files /dev/null and b/data/bird-forest/reference_images/eagle.jpeg differ
diff --git a/data/car-turn/depth/0000.png b/data/car-turn/depth/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..5011c15c46784eb3434f44b9cd9c071f4004f67c
Binary files /dev/null and b/data/car-turn/depth/0000.png differ
diff --git a/data/car-turn/depth/0006.png b/data/car-turn/depth/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b932c74a970062f166bcc6b8be99144d848d3ed
Binary files /dev/null and b/data/car-turn/depth/0006.png differ
diff --git a/data/car-turn/depth/0012.png b/data/car-turn/depth/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..43a8af417e6e177d2c5d0c8597b7ed5b6d079db4
Binary files /dev/null and b/data/car-turn/depth/0012.png differ
diff --git a/data/car-turn/depth/0018.png b/data/car-turn/depth/0018.png
new file mode 100644
index 0000000000000000000000000000000000000000..d79e2c34f639f1dd8f2052359149977ae4b8cc04
Binary files /dev/null and b/data/car-turn/depth/0018.png differ
diff --git a/data/car-turn/depth/0024.png b/data/car-turn/depth/0024.png
new file mode 100644
index 0000000000000000000000000000000000000000..c1085380491e6ba5e632c80225463644267755e4
Binary files /dev/null and b/data/car-turn/depth/0024.png differ
diff --git a/data/car-turn/depth/0030.png b/data/car-turn/depth/0030.png
new file mode 100644
index 0000000000000000000000000000000000000000..877db048e13c73aedef43f99773ab542d43a3049
Binary files /dev/null and b/data/car-turn/depth/0030.png differ
diff --git a/data/car-turn/depth/0036.png b/data/car-turn/depth/0036.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9728e3fb0e65eb70e1f17a01257458386865ad9
Binary files /dev/null and b/data/car-turn/depth/0036.png differ
diff --git a/data/car-turn/depth/0042.png b/data/car-turn/depth/0042.png
new file mode 100644
index 0000000000000000000000000000000000000000..43238dac40f4e462c4f3c2db9c274376eaf89378
Binary files /dev/null and b/data/car-turn/depth/0042.png differ
diff --git a/data/car-turn/frame_list.txt b/data/car-turn/frame_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..00139c4c2c3e374e5759ba0bbb839730d678c2d4
--- /dev/null
+++ b/data/car-turn/frame_list.txt
@@ -0,0 +1,8 @@
+0000
+0006
+0012
+0018
+0024
+0030
+0036
+0042
diff --git a/data/car-turn/images/0000.jpg b/data/car-turn/images/0000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cd5f0ffaa6faca4d00edd7535abbeaeefe9dea30
Binary files /dev/null and b/data/car-turn/images/0000.jpg differ
diff --git a/data/car-turn/images/0006.jpg b/data/car-turn/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..93b2d0eddecc8a265b25f55e71093e0d63d6a299
Binary files /dev/null and b/data/car-turn/images/0006.jpg differ
diff --git a/data/car-turn/images/0012.jpg b/data/car-turn/images/0012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..243d084909c37aca1cf072fdfc85175ca33c40c1
Binary files /dev/null and b/data/car-turn/images/0012.jpg differ
diff --git a/data/car-turn/images/0018.jpg b/data/car-turn/images/0018.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a57b7cffd8ecc0da27682fdb5394721677b966cc
Binary files /dev/null and b/data/car-turn/images/0018.jpg differ
diff --git a/data/car-turn/images/0024.jpg b/data/car-turn/images/0024.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1bec239adccb4e7c17884a0b5ae45791342dffcc
Binary files /dev/null and b/data/car-turn/images/0024.jpg differ
diff --git a/data/car-turn/images/0030.jpg b/data/car-turn/images/0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..eea477eac836fca93e793e101a4efa5176375383
Binary files /dev/null and b/data/car-turn/images/0030.jpg differ
diff --git a/data/car-turn/images/0036.jpg b/data/car-turn/images/0036.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6fd7af9e891f034920113d4eb00468bbef8e7c50
Binary files /dev/null and b/data/car-turn/images/0036.jpg differ
diff --git a/data/car-turn/images/0042.jpg b/data/car-turn/images/0042.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..24ed69ab5903e77e855132c677cb7bdf795fcb92
Binary files /dev/null and b/data/car-turn/images/0042.jpg differ
diff --git a/data/car-turn/openposefull/0000.png b/data/car-turn/openposefull/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0000.png differ
diff --git a/data/car-turn/openposefull/0006.png b/data/car-turn/openposefull/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0006.png differ
diff --git a/data/car-turn/openposefull/0012.png b/data/car-turn/openposefull/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0012.png differ
diff --git a/data/car-turn/openposefull/0018.png b/data/car-turn/openposefull/0018.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0018.png differ
diff --git a/data/car-turn/openposefull/0024.png b/data/car-turn/openposefull/0024.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0024.png differ
diff --git a/data/car-turn/openposefull/0030.png b/data/car-turn/openposefull/0030.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0030.png differ
diff --git a/data/car-turn/openposefull/0036.png b/data/car-turn/openposefull/0036.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0036.png differ
diff --git a/data/car-turn/openposefull/0042.png b/data/car-turn/openposefull/0042.png
new file mode 100644
index 0000000000000000000000000000000000000000..5220d0f9dbfa747cdb356761a73beb1b1f1c96bb
Binary files /dev/null and b/data/car-turn/openposefull/0042.png differ
diff --git a/data/car-turn/reference_images/audi.jpeg b/data/car-turn/reference_images/audi.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..f3c8f6143315ac71e65d424171b4a413ce85b60f
Binary files /dev/null and b/data/car-turn/reference_images/audi.jpeg differ
diff --git a/data/car-turn/suzuki-jimny.mask/0000.png b/data/car-turn/suzuki-jimny.mask/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..d6cd48a3766de703079376234c980ec19de8a643
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0000.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0006.png b/data/car-turn/suzuki-jimny.mask/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..f05e11a7b4088cf059012388625e2f26ff0366f2
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0006.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0012.png b/data/car-turn/suzuki-jimny.mask/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..715901b2baee46aff35960e50df4c28e368d8e11
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0012.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0018.png b/data/car-turn/suzuki-jimny.mask/0018.png
new file mode 100644
index 0000000000000000000000000000000000000000..18c595ffc39963243ec50f1144524245fe1b4c4a
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0018.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0024.png b/data/car-turn/suzuki-jimny.mask/0024.png
new file mode 100644
index 0000000000000000000000000000000000000000..060d01ba9bd143e3fe3fe165d0e475d7265fc992
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0024.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0030.png b/data/car-turn/suzuki-jimny.mask/0030.png
new file mode 100644
index 0000000000000000000000000000000000000000..07356c991585086c3d64d1ced34b87b36394be09
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0030.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0036.png b/data/car-turn/suzuki-jimny.mask/0036.png
new file mode 100644
index 0000000000000000000000000000000000000000..0bee513c723e62649a8e5b7c97b6241ee440db3a
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0036.png differ
diff --git a/data/car-turn/suzuki-jimny.mask/0042.png b/data/car-turn/suzuki-jimny.mask/0042.png
new file mode 100644
index 0000000000000000000000000000000000000000..932db26efcf1ae967f33608ad76a5e3fcfc6ea88
Binary files /dev/null and b/data/car-turn/suzuki-jimny.mask/0042.png differ
diff --git a/data/drift-turn/images/0000.jpg b/data/drift-turn/images/0000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..daac67ff6e56f708938bacc4cad2cf2971af1be4
Binary files /dev/null and b/data/drift-turn/images/0000.jpg differ
diff --git a/data/drift-turn/images/0001.jpg b/data/drift-turn/images/0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9faf9085269756b073defe7cd3d5db3797ffc67a
Binary files /dev/null and b/data/drift-turn/images/0001.jpg differ
diff --git a/data/drift-turn/images/0002.jpg b/data/drift-turn/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9685d46ade95b0f43b71a96aaaa3513cb1de7d33
Binary files /dev/null and b/data/drift-turn/images/0002.jpg differ
diff --git a/data/drift-turn/images/0003.jpg b/data/drift-turn/images/0003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8a35937bf74f63f5045f34ab3056c5c6c373da6b
Binary files /dev/null and b/data/drift-turn/images/0003.jpg differ
diff --git a/data/drift-turn/images/0004.jpg b/data/drift-turn/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..56403a88dbea43099e0945600d5f641ee0859e50
Binary files /dev/null and b/data/drift-turn/images/0004.jpg differ
diff --git a/data/drift-turn/images/0005.jpg b/data/drift-turn/images/0005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1b46dadfd5097d0b993cea5d05f0d0a9f148b15f
Binary files /dev/null and b/data/drift-turn/images/0005.jpg differ
diff --git a/data/drift-turn/images/0006.jpg b/data/drift-turn/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c3a3d888fb80091f41fc458ab59f951007b3629d
Binary files /dev/null and b/data/drift-turn/images/0006.jpg differ
diff --git a/data/drift-turn/images/0007.jpg b/data/drift-turn/images/0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a185eea36e52e403b48208b350590976e0d2608
Binary files /dev/null and b/data/drift-turn/images/0007.jpg differ
diff --git a/data/drift-turn/reference_images/aston.jpeg b/data/drift-turn/reference_images/aston.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..47ecb1e39175df2530a05eb9156541afdbf122bd
Binary files /dev/null and b/data/drift-turn/reference_images/aston.jpeg differ
diff --git a/data/huaqiang/depth/0000.png b/data/huaqiang/depth/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..cd4169a84a0421152d9e4cc54f6ff2e8e7efa6ba
Binary files /dev/null and b/data/huaqiang/depth/0000.png differ
diff --git a/data/huaqiang/depth/0002.png b/data/huaqiang/depth/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..f178e56d32551288997c3a04f968617628ac536e
Binary files /dev/null and b/data/huaqiang/depth/0002.png differ
diff --git a/data/huaqiang/depth/0004.png b/data/huaqiang/depth/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc9431d56c26656cbf1ae39c758aefd3c006fa4d
Binary files /dev/null and b/data/huaqiang/depth/0004.png differ
diff --git a/data/huaqiang/depth/0006.png b/data/huaqiang/depth/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..e666e53e3fdf54f72a579480e4fb8ba5b2a03905
Binary files /dev/null and b/data/huaqiang/depth/0006.png differ
diff --git a/data/huaqiang/depth/0008.png b/data/huaqiang/depth/0008.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e9c0d10cad347bdef2ccb6cfa8996708a12b697
Binary files /dev/null and b/data/huaqiang/depth/0008.png differ
diff --git a/data/huaqiang/depth/0010.png b/data/huaqiang/depth/0010.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab5cdc90d52b3299860a8169caef031f1b0c3765
Binary files /dev/null and b/data/huaqiang/depth/0010.png differ
diff --git a/data/huaqiang/depth/0012.png b/data/huaqiang/depth/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..c601b084b0ce5aa06a303ff8561825ac653ddfd8
Binary files /dev/null and b/data/huaqiang/depth/0012.png differ
diff --git a/data/huaqiang/depth/0014.png b/data/huaqiang/depth/0014.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ede808278ea1f4b9109f7b03b8c11bc97c45421
Binary files /dev/null and b/data/huaqiang/depth/0014.png differ
diff --git a/data/huaqiang/frame_list.txt b/data/huaqiang/frame_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..351eeb6d0995330821733ba935f4b2d24e2771e0
--- /dev/null
+++ b/data/huaqiang/frame_list.txt
@@ -0,0 +1,8 @@
+0000
+0002
+0004
+0006
+0008
+0010
+0012
+0014
diff --git a/data/huaqiang/images/0000.jpg b/data/huaqiang/images/0000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4f60946bdf75f4c4ccf45f7b0b41c7e6ddc5b701
Binary files /dev/null and b/data/huaqiang/images/0000.jpg differ
diff --git a/data/huaqiang/images/0002.jpg b/data/huaqiang/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a72c5db174b29e783f4b0b5fc2a811310a79190
Binary files /dev/null and b/data/huaqiang/images/0002.jpg differ
diff --git a/data/huaqiang/images/0004.jpg b/data/huaqiang/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4b2f7e4494c1fffeb180167d0de1672c9414bdf2
Binary files /dev/null and b/data/huaqiang/images/0004.jpg differ
diff --git a/data/huaqiang/images/0006.jpg b/data/huaqiang/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6442a5a77f7faec5d2a03b85e10068639cd7d8d6
Binary files /dev/null and b/data/huaqiang/images/0006.jpg differ
diff --git a/data/huaqiang/images/0008.jpg b/data/huaqiang/images/0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cddcd4f7cd70207df5220e31b30191fbd8be6773
Binary files /dev/null and b/data/huaqiang/images/0008.jpg differ
diff --git a/data/huaqiang/images/0010.jpg b/data/huaqiang/images/0010.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b30e3276f14ac87900fdcb63d9f5d51c3f24e3f5
Binary files /dev/null and b/data/huaqiang/images/0010.jpg differ
diff --git a/data/huaqiang/images/0012.jpg b/data/huaqiang/images/0012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3c4e04aeaa5819b90173f707a69b41ccfe1bf848
Binary files /dev/null and b/data/huaqiang/images/0012.jpg differ
diff --git a/data/huaqiang/images/0014.jpg b/data/huaqiang/images/0014.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e36bdda2ba0d065226b333035d5583401f03ae67
Binary files /dev/null and b/data/huaqiang/images/0014.jpg differ
diff --git a/data/huaqiang/man.mask/0000.png b/data/huaqiang/man.mask/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..86d285c08b072e9c3bfe3e5f3d58fa1f8d087d97
Binary files /dev/null and b/data/huaqiang/man.mask/0000.png differ
diff --git a/data/huaqiang/man.mask/0002.png b/data/huaqiang/man.mask/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..383c47a1eb4f4a7443d99d8bc9b8ebcc352dce15
Binary files /dev/null and b/data/huaqiang/man.mask/0002.png differ
diff --git a/data/huaqiang/man.mask/0004.png b/data/huaqiang/man.mask/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..9889f13ca2cd91b39348c34cbf33b53fc74159be
Binary files /dev/null and b/data/huaqiang/man.mask/0004.png differ
diff --git a/data/huaqiang/man.mask/0006.png b/data/huaqiang/man.mask/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..95beb655acda6f24e5d5fbf074aa46da1201162a
Binary files /dev/null and b/data/huaqiang/man.mask/0006.png differ
diff --git a/data/huaqiang/man.mask/0008.png b/data/huaqiang/man.mask/0008.png
new file mode 100644
index 0000000000000000000000000000000000000000..f5797b734aae4cce520ad180b651a1ff25a8ed35
Binary files /dev/null and b/data/huaqiang/man.mask/0008.png differ
diff --git a/data/huaqiang/man.mask/0010.png b/data/huaqiang/man.mask/0010.png
new file mode 100644
index 0000000000000000000000000000000000000000..114ee627a37aefebd04b68b0e8e4172dc9c6bcf9
Binary files /dev/null and b/data/huaqiang/man.mask/0010.png differ
diff --git a/data/huaqiang/man.mask/0012.png b/data/huaqiang/man.mask/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..ed400a225c12c212455ae2e9d8f7025d8010904e
Binary files /dev/null and b/data/huaqiang/man.mask/0012.png differ
diff --git a/data/huaqiang/man.mask/0014.png b/data/huaqiang/man.mask/0014.png
new file mode 100644
index 0000000000000000000000000000000000000000..1aa48ee05fd0f0acfba4b9c3d0cc2253470085f3
Binary files /dev/null and b/data/huaqiang/man.mask/0014.png differ
diff --git a/data/huaqiang/man.mask/musk.png b/data/huaqiang/man.mask/musk.png
new file mode 100644
index 0000000000000000000000000000000000000000..73be61e68d7e164e8cac1b49c810a852925f10ba
Binary files /dev/null and b/data/huaqiang/man.mask/musk.png differ
diff --git a/data/huaqiang/openposefull/0000.png b/data/huaqiang/openposefull/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..17f924d13bac9f5bbd26dc9eff3d4a0db4c1d14a
Binary files /dev/null and b/data/huaqiang/openposefull/0000.png differ
diff --git a/data/huaqiang/openposefull/0002.png b/data/huaqiang/openposefull/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..835d873fcb9fa2f9f5084e080e4d9b9e74c59245
Binary files /dev/null and b/data/huaqiang/openposefull/0002.png differ
diff --git a/data/huaqiang/openposefull/0004.png b/data/huaqiang/openposefull/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..21d9bcffee4d09f71572cc01e0b6e492e2fcc5b1
Binary files /dev/null and b/data/huaqiang/openposefull/0004.png differ
diff --git a/data/huaqiang/openposefull/0006.png b/data/huaqiang/openposefull/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..979f319e0bb235dd3548e6845d5f230e8bdb192d
Binary files /dev/null and b/data/huaqiang/openposefull/0006.png differ
diff --git a/data/huaqiang/openposefull/0008.png b/data/huaqiang/openposefull/0008.png
new file mode 100644
index 0000000000000000000000000000000000000000..f6fdb2ea3d53fc7659cae6fd7e434894b142171e
Binary files /dev/null and b/data/huaqiang/openposefull/0008.png differ
diff --git a/data/huaqiang/openposefull/0010.png b/data/huaqiang/openposefull/0010.png
new file mode 100644
index 0000000000000000000000000000000000000000..eed3bfacc1ce00ceb36de53908e364fb7d6e6950
Binary files /dev/null and b/data/huaqiang/openposefull/0010.png differ
diff --git a/data/huaqiang/openposefull/0012.png b/data/huaqiang/openposefull/0012.png
new file mode 100644
index 0000000000000000000000000000000000000000..acc0dc99717574916bbed67f6c458609dd4d9d84
Binary files /dev/null and b/data/huaqiang/openposefull/0012.png differ
diff --git a/data/huaqiang/openposefull/0014.png b/data/huaqiang/openposefull/0014.png
new file mode 100644
index 0000000000000000000000000000000000000000..77d83f3fa8a7c375a95b0a77e2ab430cf4e1ea57
Binary files /dev/null and b/data/huaqiang/openposefull/0014.png differ
diff --git a/data/huaqiang/reference_images/musk.jpg b/data/huaqiang/reference_images/musk.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfdb917e26ce809a4225807b6d4f028da660ea
Binary files /dev/null and b/data/huaqiang/reference_images/musk.jpg differ
diff --git a/data/ikun/depth/0000.png b/data/ikun/depth/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f44231edea5b53c62147b55b135eb792fc87ddc
Binary files /dev/null and b/data/ikun/depth/0000.png differ
diff --git a/data/ikun/depth/0001.png b/data/ikun/depth/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e0adb41a72cb889e0c56b2356b7c9d04c827319
Binary files /dev/null and b/data/ikun/depth/0001.png differ
diff --git a/data/ikun/depth/0002.png b/data/ikun/depth/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..7930a31b568aaaf88de3ebd27ac7fe23dde6fd86
Binary files /dev/null and b/data/ikun/depth/0002.png differ
diff --git a/data/ikun/depth/0003.png b/data/ikun/depth/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..f58f498698544eefd72cf83355a29ac0e7150c7b
Binary files /dev/null and b/data/ikun/depth/0003.png differ
diff --git a/data/ikun/depth/0004.png b/data/ikun/depth/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..6920472bd7cd52ceeed38dae61c7d8212ee08693
Binary files /dev/null and b/data/ikun/depth/0004.png differ
diff --git a/data/ikun/depth/0005.png b/data/ikun/depth/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..e667935dc6ee00f6f346cf894343cb7e6a9f9b48
Binary files /dev/null and b/data/ikun/depth/0005.png differ
diff --git a/data/ikun/depth/0006.png b/data/ikun/depth/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..ed6c6422b0eb70082d66cafd90961b2cb04d887b
Binary files /dev/null and b/data/ikun/depth/0006.png differ
diff --git a/data/ikun/depth/0007.png b/data/ikun/depth/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..75775f24faa6fd9fd1723a417a66cc62799f95fb
Binary files /dev/null and b/data/ikun/depth/0007.png differ
diff --git a/data/ikun/frame_list.txt b/data/ikun/frame_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f35e95921ec451add29cb3a6e2192d85be3c66fb
--- /dev/null
+++ b/data/ikun/frame_list.txt
@@ -0,0 +1,8 @@
+0000
+0001
+0002
+0003
+0004
+0005
+0006
+0007
diff --git a/data/ikun/images/0000.jpg b/data/ikun/images/0000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..655d17c2fc30dfc333b8e962a071a01b40ecd9d7
Binary files /dev/null and b/data/ikun/images/0000.jpg differ
diff --git a/data/ikun/images/0001.jpg b/data/ikun/images/0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac9053ada81aedfb6547c95731c15c69b60244db
Binary files /dev/null and b/data/ikun/images/0001.jpg differ
diff --git a/data/ikun/images/0002.jpg b/data/ikun/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b16e02ab179be5d8f8bed6f9a5ff865faea1795b
Binary files /dev/null and b/data/ikun/images/0002.jpg differ
diff --git a/data/ikun/images/0003.jpg b/data/ikun/images/0003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ace48be8bf7937840356531fe4c2c7b5633dee72
Binary files /dev/null and b/data/ikun/images/0003.jpg differ
diff --git a/data/ikun/images/0004.jpg b/data/ikun/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f7cff748b6242441de91fcb9d08706a75e589276
Binary files /dev/null and b/data/ikun/images/0004.jpg differ
diff --git a/data/ikun/images/0005.jpg b/data/ikun/images/0005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a8ea3b9ac19109595e99225066093c89ad02370
Binary files /dev/null and b/data/ikun/images/0005.jpg differ
diff --git a/data/ikun/images/0006.jpg b/data/ikun/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..89272f479f0eea52890178f53b45dd5504b5c349
Binary files /dev/null and b/data/ikun/images/0006.jpg differ
diff --git a/data/ikun/images/0007.jpg b/data/ikun/images/0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..899af807f278f1abe0e6b36582dcf252194dba6f
Binary files /dev/null and b/data/ikun/images/0007.jpg differ
diff --git a/data/ikun/man.mask/0000.png b/data/ikun/man.mask/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..dee21da989aea51cc783966194c44218824cc1db
Binary files /dev/null and b/data/ikun/man.mask/0000.png differ
diff --git a/data/ikun/man.mask/0001.png b/data/ikun/man.mask/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..06bd550a6898426d958e9046e90702db942e7f07
Binary files /dev/null and b/data/ikun/man.mask/0001.png differ
diff --git a/data/ikun/man.mask/0002.png b/data/ikun/man.mask/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e802311191fb3f22bdad624d756c428303f537e
Binary files /dev/null and b/data/ikun/man.mask/0002.png differ
diff --git a/data/ikun/man.mask/0003.png b/data/ikun/man.mask/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e6b51cef0a92c3981b076be344e1e833cb26012
Binary files /dev/null and b/data/ikun/man.mask/0003.png differ
diff --git a/data/ikun/man.mask/0004.png b/data/ikun/man.mask/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a6012d943b7a9cc8f1f13186b9f89bd5aa258c0
Binary files /dev/null and b/data/ikun/man.mask/0004.png differ
diff --git a/data/ikun/man.mask/0005.png b/data/ikun/man.mask/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d1533a42a7ae3999a9115f4a9d9c27ca93be788
Binary files /dev/null and b/data/ikun/man.mask/0005.png differ
diff --git a/data/ikun/man.mask/0006.png b/data/ikun/man.mask/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..49e81bf711a4c834e8f42ba700e4123d372a24fb
Binary files /dev/null and b/data/ikun/man.mask/0006.png differ
diff --git a/data/ikun/man.mask/0007.png b/data/ikun/man.mask/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c90ad3bc860829facb6d3cf331ed9ee73fef778
Binary files /dev/null and b/data/ikun/man.mask/0007.png differ
diff --git a/data/ikun/man.mask/wt.png b/data/ikun/man.mask/wt.png
new file mode 100644
index 0000000000000000000000000000000000000000..17075a1499622d33ea26835a4f5f8d8a9df12b3c
Binary files /dev/null and b/data/ikun/man.mask/wt.png differ
diff --git a/data/ikun/man.mask/zhongli.png b/data/ikun/man.mask/zhongli.png
new file mode 100644
index 0000000000000000000000000000000000000000..70bd01a0a257d0fb374f9673cabf68a4035b0c23
Binary files /dev/null and b/data/ikun/man.mask/zhongli.png differ
diff --git a/data/ikun/man.mask/zhongli2.png b/data/ikun/man.mask/zhongli2.png
new file mode 100644
index 0000000000000000000000000000000000000000..78328e3e400eb1bfaea71a081955141e194b94e8
Binary files /dev/null and b/data/ikun/man.mask/zhongli2.png differ
diff --git a/data/ikun/openposefull/0000.png b/data/ikun/openposefull/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..1fb4e6f8319be7efec08c6ec6c07a99bf3c90912
Binary files /dev/null and b/data/ikun/openposefull/0000.png differ
diff --git a/data/ikun/openposefull/0001.png b/data/ikun/openposefull/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..f4bced5dcac9928b9a8e2cf2dceed403389f67d9
Binary files /dev/null and b/data/ikun/openposefull/0001.png differ
diff --git a/data/ikun/openposefull/0002.png b/data/ikun/openposefull/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b23f4d88e6ff3e6e25d88eb289fcd4a51d8bb5a
Binary files /dev/null and b/data/ikun/openposefull/0002.png differ
diff --git a/data/ikun/openposefull/0003.png b/data/ikun/openposefull/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..478ac6fee30b89816ae7d75dfb725067c93835d6
Binary files /dev/null and b/data/ikun/openposefull/0003.png differ
diff --git a/data/ikun/openposefull/0004.png b/data/ikun/openposefull/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..2d2a315ac949e7bc437215fd9591c36fb631d36a
Binary files /dev/null and b/data/ikun/openposefull/0004.png differ
diff --git a/data/ikun/openposefull/0005.png b/data/ikun/openposefull/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..4646a21c9480c3a771558758588b89c3b830eebe
Binary files /dev/null and b/data/ikun/openposefull/0005.png differ
diff --git a/data/ikun/openposefull/0006.png b/data/ikun/openposefull/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..fbff9218e8200903fd681c45054f0d9cdbbaf2fd
Binary files /dev/null and b/data/ikun/openposefull/0006.png differ
diff --git a/data/ikun/openposefull/0007.png b/data/ikun/openposefull/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..9929f05fe051f9d038daf5d32657881d429437c1
Binary files /dev/null and b/data/ikun/openposefull/0007.png differ
diff --git a/data/ikun/reference_images/wt.jpg b/data/ikun/reference_images/wt.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5ae22881173ce2a3d62536c0f0b1dfe75e59347f
--- /dev/null
+++ b/data/ikun/reference_images/wt.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b57c85a5b37590e80fff6d0db5627fffcd725f38411832ad78000970d96039c7
+size 2494721
diff --git a/data/ikun/reference_images/zhongli.jpg b/data/ikun/reference_images/zhongli.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bf878118027d228c981bb02259308031fe665521
Binary files /dev/null and b/data/ikun/reference_images/zhongli.jpg differ
diff --git a/data/motorbike/images/0001.jpg b/data/motorbike/images/0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1aca4f48a72016c0b0c3f2b9274e8528736ad275
Binary files /dev/null and b/data/motorbike/images/0001.jpg differ
diff --git a/data/motorbike/images/0002.jpg b/data/motorbike/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..18cf00175c8cc465318c385e49e5b3f5ce62888f
Binary files /dev/null and b/data/motorbike/images/0002.jpg differ
diff --git a/data/motorbike/images/0003.jpg b/data/motorbike/images/0003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87f6dd0fecb3e61e5aab2a02048c8382c5935bec
Binary files /dev/null and b/data/motorbike/images/0003.jpg differ
diff --git a/data/motorbike/images/0004.jpg b/data/motorbike/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9f39ef5dd9ef06808f50e0fe352656e3df653a9d
Binary files /dev/null and b/data/motorbike/images/0004.jpg differ
diff --git a/data/motorbike/images/0005.jpg b/data/motorbike/images/0005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..981aedbb308ff5ba94acaecfcc67dd6e47e6bc09
Binary files /dev/null and b/data/motorbike/images/0005.jpg differ
diff --git a/data/motorbike/images/0006.jpg b/data/motorbike/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..308d413cc96d52eb9959c2bb7064786d6a86eb47
Binary files /dev/null and b/data/motorbike/images/0006.jpg differ
diff --git a/data/motorbike/images/0007.jpg b/data/motorbike/images/0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19b4e3cfd59336eb21ec2b8898bd58b2a52506ca
Binary files /dev/null and b/data/motorbike/images/0007.jpg differ
diff --git a/data/motorbike/images/0008.jpg b/data/motorbike/images/0008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9178766c75312f1a0020e6b82d178a92d13c6ec9
Binary files /dev/null and b/data/motorbike/images/0008.jpg differ
diff --git a/data/motorbike/reference_images/pink-motor.png b/data/motorbike/reference_images/pink-motor.png
new file mode 100644
index 0000000000000000000000000000000000000000..7faa6415c537585a3347ca5417e3bfbebd0e151f
--- /dev/null
+++ b/data/motorbike/reference_images/pink-motor.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84d0bcfbfe17f7c2e1199583d1a03b79854e575e12ba3ec7e36e84744c339bca
+size 1313508
diff --git a/data/motorbike/reference_images/wanye.jpeg b/data/motorbike/reference_images/wanye.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..ad2d0029eee4297bf928aeb42418db5269128137
Binary files /dev/null and b/data/motorbike/reference_images/wanye.jpeg differ
diff --git a/data/yanzi/depth/0000.png b/data/yanzi/depth/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..61169e65f04c5a449b56a3c93ea0d91fa902c0fd
Binary files /dev/null and b/data/yanzi/depth/0000.png differ
diff --git a/data/yanzi/depth/0001.png b/data/yanzi/depth/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..cde58c6a951bec18a697d1834e9306a1166480f3
Binary files /dev/null and b/data/yanzi/depth/0001.png differ
diff --git a/data/yanzi/depth/0002.png b/data/yanzi/depth/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..65af5cc7a670ac7e3bc58b013214f821a2ce4860
Binary files /dev/null and b/data/yanzi/depth/0002.png differ
diff --git a/data/yanzi/depth/0003.png b/data/yanzi/depth/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..628ad9d7f2a8876c662c61ef23fe60e107722c19
Binary files /dev/null and b/data/yanzi/depth/0003.png differ
diff --git a/data/yanzi/depth/0004.png b/data/yanzi/depth/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..5d9728ff20210d1bf4dccd1a1fa0d4ed25b83aeb
Binary files /dev/null and b/data/yanzi/depth/0004.png differ
diff --git a/data/yanzi/depth/0005.png b/data/yanzi/depth/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..700b1f43f2e022b34f5c062008c76b6b99ee3887
Binary files /dev/null and b/data/yanzi/depth/0005.png differ
diff --git a/data/yanzi/depth/0006.png b/data/yanzi/depth/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..f97e78a0fdd3a37cd331134f1e303311eaec5da8
Binary files /dev/null and b/data/yanzi/depth/0006.png differ
diff --git a/data/yanzi/depth/0007.png b/data/yanzi/depth/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..a308de838ad81688b98d49cf6c1dcb7fded624a4
Binary files /dev/null and b/data/yanzi/depth/0007.png differ
diff --git a/data/yanzi/frame_list.txt b/data/yanzi/frame_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f35e95921ec451add29cb3a6e2192d85be3c66fb
--- /dev/null
+++ b/data/yanzi/frame_list.txt
@@ -0,0 +1,8 @@
+0000
+0001
+0002
+0003
+0004
+0005
+0006
+0007
diff --git a/data/yanzi/images/0000.jpg b/data/yanzi/images/0000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..acd4c20dab6e593aae82997204f498331268efc9
Binary files /dev/null and b/data/yanzi/images/0000.jpg differ
diff --git a/data/yanzi/images/0001.jpg b/data/yanzi/images/0001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0a7de986f425d77473eee5358adac43d73ea5109
Binary files /dev/null and b/data/yanzi/images/0001.jpg differ
diff --git a/data/yanzi/images/0002.jpg b/data/yanzi/images/0002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2a813c0beadee9d5ca1f0eb839b9771279d9dec0
Binary files /dev/null and b/data/yanzi/images/0002.jpg differ
diff --git a/data/yanzi/images/0003.jpg b/data/yanzi/images/0003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0d06988c68af9f13c968c09c3c767d0aa69b7048
Binary files /dev/null and b/data/yanzi/images/0003.jpg differ
diff --git a/data/yanzi/images/0004.jpg b/data/yanzi/images/0004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3da3e7d94b4f790edaaaf7d3b9169d18b317c5d0
Binary files /dev/null and b/data/yanzi/images/0004.jpg differ
diff --git a/data/yanzi/images/0005.jpg b/data/yanzi/images/0005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fff79a05318e8daacf8382cd61f45bbc3b000eae
Binary files /dev/null and b/data/yanzi/images/0005.jpg differ
diff --git a/data/yanzi/images/0006.jpg b/data/yanzi/images/0006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f6ddad051158fc4bfe8e79b962260f859148356
Binary files /dev/null and b/data/yanzi/images/0006.jpg differ
diff --git a/data/yanzi/images/0007.jpg b/data/yanzi/images/0007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4244c263b93aad38a3bd07a370503cd9541e850f
Binary files /dev/null and b/data/yanzi/images/0007.jpg differ
diff --git a/data/yanzi/man.mask/0000.png b/data/yanzi/man.mask/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..c221f6c845e4c06c632857a44723237b2f61a64b
Binary files /dev/null and b/data/yanzi/man.mask/0000.png differ
diff --git a/data/yanzi/man.mask/0001.png b/data/yanzi/man.mask/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..25fc370eb0c26da4535ce8fd082bd4275ff9322c
Binary files /dev/null and b/data/yanzi/man.mask/0001.png differ
diff --git a/data/yanzi/man.mask/0002.png b/data/yanzi/man.mask/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7d4807404ae5badeee27710b05da16bfab43685
Binary files /dev/null and b/data/yanzi/man.mask/0002.png differ
diff --git a/data/yanzi/man.mask/0003.png b/data/yanzi/man.mask/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..35883055a5c9b7971bd216349a52a01c4f1e5736
Binary files /dev/null and b/data/yanzi/man.mask/0003.png differ
diff --git a/data/yanzi/man.mask/0004.png b/data/yanzi/man.mask/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ad2c8a5fe5eee4047120a4f9470d66c10bc5b98
Binary files /dev/null and b/data/yanzi/man.mask/0004.png differ
diff --git a/data/yanzi/man.mask/0005.png b/data/yanzi/man.mask/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b6da03ab71e57e2c5fcfa29372b3d4bf7a5f1d3
Binary files /dev/null and b/data/yanzi/man.mask/0005.png differ
diff --git a/data/yanzi/man.mask/0006.png b/data/yanzi/man.mask/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..7c71ef6329ff97731076de419fd13baf895e7159
Binary files /dev/null and b/data/yanzi/man.mask/0006.png differ
diff --git a/data/yanzi/man.mask/0007.png b/data/yanzi/man.mask/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..c83437943495cb2d6f1a9ba66c22760e206109fa
Binary files /dev/null and b/data/yanzi/man.mask/0007.png differ
diff --git a/data/yanzi/openposefull/0000.png b/data/yanzi/openposefull/0000.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d92e605514a5fde5280e50d80215c99191bf0ed
Binary files /dev/null and b/data/yanzi/openposefull/0000.png differ
diff --git a/data/yanzi/openposefull/0001.png b/data/yanzi/openposefull/0001.png
new file mode 100644
index 0000000000000000000000000000000000000000..815e5aad1161293163f2d6202e4f5d3ad41ac415
Binary files /dev/null and b/data/yanzi/openposefull/0001.png differ
diff --git a/data/yanzi/openposefull/0002.png b/data/yanzi/openposefull/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..6bca62153c4f324cb361c04460200cc67487246b
Binary files /dev/null and b/data/yanzi/openposefull/0002.png differ
diff --git a/data/yanzi/openposefull/0003.png b/data/yanzi/openposefull/0003.png
new file mode 100644
index 0000000000000000000000000000000000000000..94cd75a48a399beb91402935b742d4ecbeb984ca
Binary files /dev/null and b/data/yanzi/openposefull/0003.png differ
diff --git a/data/yanzi/openposefull/0004.png b/data/yanzi/openposefull/0004.png
new file mode 100644
index 0000000000000000000000000000000000000000..de4471383cba030d8fb51abc9e94163df683074d
Binary files /dev/null and b/data/yanzi/openposefull/0004.png differ
diff --git a/data/yanzi/openposefull/0005.png b/data/yanzi/openposefull/0005.png
new file mode 100644
index 0000000000000000000000000000000000000000..47ef851977682d37a369e025d93f24633d997ddc
Binary files /dev/null and b/data/yanzi/openposefull/0005.png differ
diff --git a/data/yanzi/openposefull/0006.png b/data/yanzi/openposefull/0006.png
new file mode 100644
index 0000000000000000000000000000000000000000..c4db582bf96759e05c2fd121022a0fdbdcfe7f4c
Binary files /dev/null and b/data/yanzi/openposefull/0006.png differ
diff --git a/data/yanzi/openposefull/0007.png b/data/yanzi/openposefull/0007.png
new file mode 100644
index 0000000000000000000000000000000000000000..4df6779d1aa5ae49dbf0666f0c472c65ced81dc2
Binary files /dev/null and b/data/yanzi/openposefull/0007.png differ
diff --git a/data/yanzi/reference_images/panda.jpeg b/data/yanzi/reference_images/panda.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..9c4fb1fb7c7047a00d29277d977b5c4cb6cadd39
Binary files /dev/null and b/data/yanzi/reference_images/panda.jpeg differ
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..92e1474e7db577fdbbd76a6b3670acf6fa2cf3ef
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,283 @@
+from __future__ import annotations
+
+import gc
+import pathlib
+import sys
+import tempfile
+import os
+import gradio as gr
+import imageio
+import PIL.Image
+import torch
+from diffusers.utils.import_utils import is_xformers_available
+from einops import rearrange
+from huggingface_hub import ModelCard
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
+from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, PNDMScheduler, ControlNetModel, PriorTransformer, UnCLIPScheduler
+from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+from omegaconf import OmegaConf
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+
+sys.path.append('Make-A-Protagonist')
+
+from makeaprotagonist.models.unet import UNet3DConditionModel
+from makeaprotagonist.pipelines.pipeline_stable_unclip_controlavideo import MakeAProtagonistStableUnCLIPPipeline, MultiControlNetModel
+from makeaprotagonist.dataset.dataset import MakeAProtagonistDataset
+from makeaprotagonist.util import save_videos_grid, ddim_inversion_unclip, ddim_inversion_prior
+from experts.grounded_sam_mask_out import mask_out_reference_image
+
+
+import ipdb
+
+class InferencePipeline:
+ def __init__(self, hf_token: str | None = None):
+ self.hf_token = hf_token
+ self.pipe = None
+ self.device = torch.device(
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
+ self.model_id = None
+
+ self.conditions = None
+ self.masks = None
+ self.ddim_inv_latent = None
+ self.train_dataset, self.sample_indices = None, None
+
+ def clear(self) -> None:
+ self.model_id = None
+ del self.pipe
+ self.pipe = None
+ self.conditions = None
+ self.masks = None
+ self.ddim_inv_latent = None
+ self.train_dataset, self.sample_indices = None, None
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ @staticmethod
+ def check_if_model_is_local(model_id: str) -> bool:
+ return pathlib.Path(model_id).exists()
+
+ @staticmethod
+ def get_model_card(model_id: str,
+ hf_token: str | None = None) -> ModelCard:
+ if InferencePipeline.check_if_model_is_local(model_id):
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
+ else:
+ card_path = model_id
+ return ModelCard.load(card_path, token=hf_token)
+
+ @staticmethod
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
+ card = InferencePipeline.get_model_card(model_id, hf_token)
+ return card.data.base_model
+
+ @torch.no_grad()
+ def load_pipe(self, model_id: str, n_steps, seed) -> None:
+ if model_id == self.model_id:
+ return self.conditions, self.masks, self.ddim_inv_latent, self.train_dataset, self.sample_indices
+
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
+
+ pretrained_model_path = 'stabilityai/stable-diffusion-2-1-unclip-small'
+ # image encoding components
+ feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
+ # image noising components
+ image_normalizer = StableUnCLIPImageNormalizer.from_pretrained(pretrained_model_path, subfolder="image_normalizer", torch_dtype=torch.float16,)
+ image_noising_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="image_noising_scheduler")
+ # regular denoising components
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16,)
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", torch_dtype=torch.float16,)
+ self.ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
+ self.ddim_inv_scheduler.set_timesteps(n_steps)
+
+ prior_model_id = "kakaobrain/karlo-v1-alpha"
+ data_type = torch.float16
+ prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)
+
+ prior_text_model_id = "openai/clip-vit-large-patch14"
+ prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)
+ prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)
+ prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler")
+ prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+
+ controlnet_model_id = ['controlnet-2-1-unclip-small-openposefull', 'controlnet-2-1-unclip-small-depth']
+ controlnet = MultiControlNetModel( [ControlNetModel.from_pretrained('Make-A-Protagonist/controlnet-2-1-unclip-small', subfolder=subfolder_id, torch_dtype=torch.float16) for subfolder_id in controlnet_model_id] )
+
+ unet = UNet3DConditionModel.from_pretrained(
+ model_id,
+ subfolder='unet',
+ torch_dtype=torch.float16,
+ use_auth_token=self.hf_token)
+
+ # Freeze vae and text_encoder and adapter
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ ## freeze image embed
+ image_encoder.requires_grad_(False)
+
+ unet.requires_grad_(False)
+ ## freeze controlnet
+ controlnet.requires_grad_(False)
+
+ ## freeze prior
+ prior.requires_grad_(False)
+ prior_text_model.requires_grad_(False)
+
+ config_file = os.path.join('Make-A-Protagonist/configs', model_id.split('/')[-1] + '.yaml')
+ self.cfg = OmegaConf.load(config_file)
+
+ # def source_parsing(self, n_steps):
+ # ipdb.set_trace()
+ train_dataset = MakeAProtagonistDataset(**self.cfg)
+ train_dataset.preprocess_img_embedding(feature_extractor, image_encoder)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=1, num_workers=0,
+ )
+ image_encoder.to(dtype=data_type)
+ pipe = MakeAProtagonistStableUnCLIPPipeline(
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ image_normalizer=image_normalizer,
+ image_noising_scheduler=image_noising_scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ )
+
+ pipe = pipe.to(self.device)
+
+ if is_xformers_available():
+ pipe.unet.enable_xformers_memory_efficient_attention()
+ pipe.controlnet.enable_xformers_memory_efficient_attention()
+ self.pipe = pipe
+ self.model_id = model_id # type: ignore
+ self.vae = vae
+ # self.feature_extractor = feature_extractor
+ # self.image_encoder = image_encoder
+ ## ddim inverse for source video
+
+ batch = next(iter(train_dataloader))
+ weight_dtype = torch.float16
+ pixel_values = batch["pixel_values"].to(weight_dtype).to(self.device)
+ video_length = pixel_values.shape[1]
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
+ latents = self.vae.encode(pixel_values).latent_dist.sample()
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
+ latents = latents * self.vae.config.scaling_factor
+ # ControlNet
+ # ipdb.set_trace()
+ conditions = [_condition.to(weight_dtype).to(self.device) for _, _condition in batch["conditions"].items()] # b f c h w
+ masks = batch["masks"].to(weight_dtype).to(self.device) # b,f,1,h,w
+ emb_dim = train_dataset.img_embeddings[0].size(0)
+ key_frame_embed = torch.zeros((1, emb_dim)).to(device=latents.device, dtype=latents.dtype) ## this is dim 0
+ # ipdb.set_trace()
+ ddim_inv_latent = ddim_inversion_unclip(
+ self.pipe, self.ddim_inv_scheduler, video_latent=latents,
+ num_inv_steps=n_steps, prompt="", image_embed=key_frame_embed, noise_level=0, seed=seed)[-1].to(weight_dtype)
+ self.conditions = conditions
+ self.masks = masks
+ self.ddim_inv_latent = ddim_inv_latent
+ self.train_dataset = train_dataset
+ self.sample_indices = batch["sample_indices"][0]
+ return conditions, masks, ddim_inv_latent, train_dataset, batch["sample_indices"][0]
+
+ def run(
+ self,
+ model_id: str,
+ prompt: str,
+ video_length: int,
+ fps: int,
+ seed: int,
+ n_steps: int,
+ guidance_scale: float,
+ ref_image: PIL.Image.Image,
+ ref_pro_prompt: str,
+ noise_level: int,
+ control_pose: float,
+ control_depth: float,
+ source_pro: int = 0, # 0 or 1
+ source_bg: int = 0,
+ ) -> PIL.Image.Image:
+
+ if not torch.cuda.is_available():
+ raise gr.Error('CUDA is not available.')
+
+ torch.cuda.empty_cache()
+
+ conditions, masks, ddim_inv_latent, _, _ = self.load_pipe(model_id, n_steps, seed)
+ ## conditions [1,F,3,H,W]
+ ## masks [1,F,1,H,W]
+ ## ddim_inv_latent [1,4,F,H,W]
+ ## NOTE this is to deal with video length
+ conditions = [_condition[:,:video_length] for _condition in conditions]
+ masks = masks[:, :video_length]
+ ddim_inv_latent = ddim_inv_latent[:,:,:video_length]
+
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ ## TODO mask out reference image
+ # ipdb.set_trace()
+ ref_image = mask_out_reference_image(ref_image, ref_pro_prompt)
+ controlnet_conditioning_scale = [control_pose, control_depth]
+
+ prior_denoised_embeds = None
+ image_embed = None
+ if source_bg:
+ ## using source background and changing the protagonist
+ prior_denoised_embeds = self.train_dataset.img_embeddings[0][None].to(device=ddim_inv_latent.device, dtype=ddim_inv_latent.dtype) # 1, 768 for UnCLIP-small
+
+ if source_pro:
+ # using source protagonist and changing the background
+ sample_indices = self.sample_indices
+ image_embed = [self.train_dataset.img_embeddings[idx] for idx in sample_indices]
+ image_embed = torch.stack(image_embed, dim=0).to(device=ddim_inv_latent.device, dtype=ddim_inv_latent.dtype) # F, 768 for UnCLIP-small # F,C
+ ref_image = None
+
+ # ipdb.set_trace()
+ out = self.pipe(
+ image=ref_image,
+ prompt=prompt,
+ control_image=conditions,
+ video_length=video_length,
+ width=768,
+ height=768,
+ num_inference_steps=n_steps,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ ## ddim inversion
+ latents=ddim_inv_latent,
+ ## ref image embeds
+ noise_level=noise_level,
+ ## controlnet
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ ## mask
+ masks=masks,
+ mask_mode='all',
+ mask_latent_fuse_mode = 'all',
+ ## edit bg and pro
+ prior_latents=None,
+ image_embeds=image_embed, # keep pro
+ prior_denoised_embeds=prior_denoised_embeds # keep bg
+ )
+
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
+ frames = (frames * 255).to(torch.uint8).numpy()
+
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
+ writer = imageio.get_writer(out_file.name, fps=fps)
+ for frame in frames:
+ writer.append_data(frame)
+ writer.close()
+
+ return out_file.name
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eca72a38c224dd251be52cef8ac9e1adb8186d95
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+accelerate==0.18.0
+bitsandbytes==0.35.4
+decord==0.6.0
+diffusers[torch]==0.15.0
+einops==0.6.0
+ftfy==6.1.1
+gradio==3.18.0
+huggingface-hub==0.12.0
+imageio==2.25.0
+imageio-ffmpeg==0.4.8
+omegaconf==2.3.0
+Pillow==9.4.0
+python-slugify==7.0.0
+tensorboard==2.11.2
+torch==1.13.1
+torchvision==0.14.1
+transformers==4.27.4
+triton==2.0.0.post1
+xformers==0.0.17
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..c4739b4ea5fc35e774a049e3dacc443f7f0eac19
--- /dev/null
+++ b/style.css
@@ -0,0 +1,3 @@
+h1 {
+ text-align: center;
+}