Spaces:
Paused
Paused
cathyxl
commited on
Commit
·
f239efc
1
Parent(s):
8fb958b
added
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .gitignore +66 -0
- DATA.md +124 -0
- README.md +376 -5
- app.py +18 -0
- assert/data.png +3 -0
- assert/logo.png +3 -0
- assert/module.png +3 -0
- assert/performance.png +3 -0
- assert/teaser.jpg +3 -0
- assert/zeroshot.png +3 -0
- dataset/__init__.py +158 -0
- dataset/base_dataset.py +108 -0
- dataset/it_dataset.py +206 -0
- dataset/utils.py +41 -0
- dataset/video_utils.py +214 -0
- docs/PoolLLaVA_Report.pdf +3 -0
- example/1917.mp4 +3 -0
- example/bear.jpg +3 -0
- example/cooking.mp4 +3 -0
- example/dog.png +3 -0
- example/jesse_dance.mp4 +3 -0
- example/working.mp4 +3 -0
- example/yoga.mp4 +3 -0
- models/__init__.py +0 -0
- models/pllava/__init__.py +55 -0
- models/pllava/configuration_pllava.py +149 -0
- models/pllava/convert_pllava_weights_to_hf.py +1 -0
- models/pllava/modeling_pllava.py +626 -0
- models/pllava/processing_pllava.py +292 -0
- python_scripts/hf.py +80 -0
- requirements.no_torch.txt +244 -0
- requirements.torch.txt +4 -0
- requirements.txt +246 -0
- scripts/accel_config_deepspeed_zero2.yaml +21 -0
- scripts/accel_config_deepspeed_zero3_offload.yaml +22 -0
- scripts/accel_config_deepspeed_zero3_offload_multinode.yaml +25 -0
- scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml +25 -0
- scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml +25 -0
- scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml +23 -0
- scripts/accel_config_multigpu.yaml +16 -0
- scripts/accel_config_multinode.yaml +18 -0
- scripts/accel_config_singlegpu.yaml +16 -0
- scripts/demo.sh +32 -0
- scripts/eval.sh +104 -0
- scripts/eval_yiprompt.sh +53 -0
- scripts/gallery.sh +11 -0
- scripts/train_pllava.sh +34 -0
- scripts/train_pllava_13b.sh +50 -0
- scripts/train_pllava_34b.sh +50 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mov filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# local #
|
2 |
+
tmp*/
|
3 |
+
cache/*
|
4 |
+
*/cache*/
|
5 |
+
tmp*.py
|
6 |
+
tmp*
|
7 |
+
*pickle
|
8 |
+
data/
|
9 |
+
|
10 |
+
# Zip Files/Packages #
|
11 |
+
*.7z
|
12 |
+
*.dmg
|
13 |
+
*.gz
|
14 |
+
*.iso
|
15 |
+
*.jar
|
16 |
+
*.rar
|
17 |
+
*.tar
|
18 |
+
*.zip
|
19 |
+
|
20 |
+
# Logs and databases #
|
21 |
+
*.log
|
22 |
+
*.sql
|
23 |
+
*.sqlite
|
24 |
+
.ipynb_checkpoints/
|
25 |
+
*.swp
|
26 |
+
*.vscode/
|
27 |
+
*.idea/
|
28 |
+
*.pyc
|
29 |
+
__pycache__
|
30 |
+
slurm*out
|
31 |
+
|
32 |
+
# OS files #
|
33 |
+
.DS_Store
|
34 |
+
.DS_Store?
|
35 |
+
._*
|
36 |
+
.Spotlight-V100
|
37 |
+
.Trashes
|
38 |
+
ehthumbs.db
|
39 |
+
Thumbs.db
|
40 |
+
|
41 |
+
|
42 |
+
.vim-arsync
|
43 |
+
scratch.norg
|
44 |
+
sync_to_red.sh
|
45 |
+
|
46 |
+
anno/
|
47 |
+
wandb/
|
48 |
+
logs/
|
49 |
+
accelerate_config/
|
50 |
+
*.pth
|
51 |
+
hf_*
|
52 |
+
|
53 |
+
# local folders
|
54 |
+
MODELS
|
55 |
+
DATAS
|
56 |
+
SAVED
|
57 |
+
EXPERIMENTS
|
58 |
+
REMOTE_HF
|
59 |
+
TEST
|
60 |
+
|
61 |
+
test_results
|
62 |
+
test_training
|
63 |
+
test_hdfs.py
|
64 |
+
magic_video_outputs/llava*
|
65 |
+
magic_video_outputs
|
66 |
+
pllava_video_outputs/
|
DATA.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data
|
2 |
+
## Instruction Training Data
|
3 |
+
<!-- > *originated from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)* -->
|
4 |
+
|
5 |
+
|
6 |
+
For training, we leveraged the video instruction tuning data from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2).
|
7 |
+
|
8 |
+
#### 1. Download json annotation files from huggingface.
|
9 |
+
[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VideoChat2%20IT-blue)](https://huggingface.co/datasets/OpenGVLab/VideoChat2-IT)
|
10 |
+
|
11 |
+
<!-- > ![images](./assert/data.png) -->
|
12 |
+
|
13 |
+
#### 2. Download the raw videos from the following links.
|
14 |
+
The video directories can be found in tasks/train/instruction_data.py. You can also change them to your own saved paths.
|
15 |
+
|
16 |
+
- [VideoChat](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data): Based on [InternVid](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid), download the processed version directly [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/videochat2_conversation_videos.zip)
|
17 |
+
- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data)
|
18 |
+
- [Kinetics-710](https://github.com/OpenGVLab/UniFormerV2/blob/main/DATASET.md), download Kinetics 400/600/700 [here](https://openxlab.org.cn/datasets?keywords=kinetics).
|
19 |
+
- [SthSthV2](https://developer.qualcomm.com/software/ai-datasets/something-something): Option candidates were generated from [UMT](https://github.com/OpenGVLab/unmasked_teacher) top-20 predictions.
|
20 |
+
- [NExTQA](https://github.com/doc-doc/NExT-QA)
|
21 |
+
- [CLEVRER](https://clevrer.csail.mit.edu/)
|
22 |
+
- [WebVid](https://maxbain.com/webvid-dataset/)
|
23 |
+
- [YouCook2](https://youcook2.eecs.umich.edu/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/youcook_split_videos.zip).
|
24 |
+
- [TextVR](https://github.com/callsys/textvr)
|
25 |
+
- [TGIF](https://github.com/YunseokJANG/tgif-qa)
|
26 |
+
- [EgoQA](https://ego4d-data.org/), download the processed version [here](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/videochat2/data/egoqa_split_videos.zip).
|
27 |
+
|
28 |
+
#### 3. We also provide our processed json annotation files here.
|
29 |
+
|
30 |
+
[![Dataset meta](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-magic%5Fjsons-blue)](https://huggingface.co/datasets/cathyxl/magic_jsons)
|
31 |
+
|
32 |
+
|
33 |
+
<!-- We leveraged the training data from [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). We only used the video part for video instruct tuning. -->
|
34 |
+
|
35 |
+
## Evaluation Data & Others
|
36 |
+
Follow this section to obtain the evaluation open resources.
|
37 |
+
|
38 |
+
### VCGBench
|
39 |
+
|
40 |
+
We refer to the VideoChatGPT video question answering evaluation as VCGBench in this repo. We followed the original [repo](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main) to prepare the evaluation data.
|
41 |
+
|
42 |
+
### MVBench
|
43 |
+
We follow the original [Videochat2 repo](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2) in setting up the MVBench Evaluation. You can also find helpful resources at their [huggingface repo](https://huggingface.co/datasets/OpenGVLab/MVBench)
|
44 |
+
|
45 |
+
|
46 |
+
### Videoqabench
|
47 |
+
We refer to all other video question answering benchmarks as videoqabench in this repo. They are mainly prepared folloing the original repos. Each listed:
|
48 |
+
1. [MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) & [MSRVTT](https://github.com/xudejing/video-question-answering)
|
49 |
+
|
50 |
+
3. [Activity Net](https://github.com/MILVLG/activitynet-qa/tree/master)
|
51 |
+
4. [TGIF](https://github.com/raingo/TGIF-Release/tree/master)
|
52 |
+
|
53 |
+
Also other fantastic repo intergrating these benchmarks are helpful in the process of setting up the evaluation data:
|
54 |
+
- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main)
|
55 |
+
- [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava)
|
56 |
+
- [IG-VLM](https://github.com/imagegridworth/IG-VLM/tree/main)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
### Recaptioning
|
61 |
+
#### Inter4k
|
62 |
+
|
63 |
+
This is a dataset with 1000 samples of high resolution videos. We prepare the data folloing the instructions from their [official website](https://alexandrosstergiou.github.io/datasets/Inter4K/index.html)
|
64 |
+
|
65 |
+
#### Extending Reacptioning
|
66 |
+
The recaptioning part is designed to be extendable.
|
67 |
+
|
68 |
+
inference script [tasks/eval/recaption/pllava_recaption.py](tasks/eval/recaption/pllava_recaption.py) would use a dataset class [RecaptionDataset](tasks/eval/recaption/__init__.py#L197). The detailed information is kept in the data_list_info attribute as:
|
69 |
+
```
|
70 |
+
data_list_info = OrderedDict({
|
71 |
+
# "Panda70M": OrderedDict(
|
72 |
+
# json_relpath="Panda70M/annotations.json",
|
73 |
+
# prefix="DATAS/Recaption/Panda70M/videos",
|
74 |
+
# data_type="video",
|
75 |
+
# bound=False,
|
76 |
+
# key_rename_map={
|
77 |
+
# # 'caption': 'hint',
|
78 |
+
# },
|
79 |
+
# name_key='video_name',
|
80 |
+
# postfix=('mp4', 'mkv', 'webm'),
|
81 |
+
# recaption_type=RecaptionSample,
|
82 |
+
# ), # don't has start & end
|
83 |
+
"Inter4K": OrderedDict(
|
84 |
+
json_relpath="Inter4K/annotations.json",
|
85 |
+
prefix="DATAS/Recaption/Inter4K/60fps/UHD",
|
86 |
+
data_type="video",
|
87 |
+
bound=False,
|
88 |
+
key_rename_map={
|
89 |
+
# 'caption': 'hint',
|
90 |
+
},
|
91 |
+
name_key='video_name',
|
92 |
+
postfix=('mp4', 'mkv', 'webm'),
|
93 |
+
recaption_type=CaptionSample,
|
94 |
+
), # don't has start & end
|
95 |
+
})
|
96 |
+
```
|
97 |
+
It contains the path to a annotation json file where there is a list and each item of the list is a sample waiting for captioning. For example, the Inter4K/annotations.json is like:
|
98 |
+
```json
|
99 |
+
[
|
100 |
+
{
|
101 |
+
"video_name": "973"
|
102 |
+
},
|
103 |
+
...
|
104 |
+
]
|
105 |
+
```
|
106 |
+
and the directory DATAS/Recaption/Inter4K/60fps/UHD would look like:
|
107 |
+
```
|
108 |
+
$ ls DATAS/Recaption/Inter4K/60fps/UHD
|
109 |
+
1.mp4 134.mp4 170.mp4 ....
|
110 |
+
```
|
111 |
+
|
112 |
+
Naively, only the video is needed when captioning directly, therefore the annotation file only needs to contain the names of each video under the "prefix" directory.
|
113 |
+
|
114 |
+
Extending a dataset for captioning would consist of the folloing steps:
|
115 |
+
1. have all the videos downloaded
|
116 |
+
2. construct a annotation.json file with sepecific format.
|
117 |
+
3. configure the recaption dataset [here](tasks/eval/recaption/__init__.py#L197), where you would need to determine:
|
118 |
+
- json_relpath: the annotation relative path
|
119 |
+
- prefix: root directory for videos
|
120 |
+
- postfix: a list containing all the file extensions for these videos
|
121 |
+
|
122 |
+
The other options are experimental, so stick with the default setting as in Inter4k. The recommended length of video is around 5-20 seconds.
|
123 |
+
|
124 |
+
p.s. "bound" is to make sure the video pass to the model doesn't have scene transition or so. This part wasn't tested, so set the bound to false and make sure the original videos files are single clip of a video. But always feel free to discover and contribute to PLLaVA!
|
README.md
CHANGED
@@ -1,12 +1,383 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Plava 7b Demo
|
3 |
+
emoji: 👁
|
4 |
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.27.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
<div align="center">
|
13 |
+
|
14 |
+
<h2><a href="https://pllava.github.io/">PLLaVA: Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning</a></h2>
|
15 |
+
|
16 |
+
[Lin Xu](https://scholar.google.com/citations?user=_Gu69coAAAAJ), [Yilin Zhao](https://ermu2001.github.io/me.io/), [Daquan Zhou](https://scholar.google.com/citations?user=DdCAbWwAAAAJ), [Zhijie Lin](https://scholar.google.com/citations?user=xXMj6_EAAAAJ), [See-Kiong Ng](https://scholar.google.com/citations?user=_wsommYAAAAJ), [Jiashi Feng](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en)
|
17 |
+
|
18 |
+
</div>
|
19 |
+
|
20 |
+
<!-- [![Paper](https://img.shields.io/badge/cs.CV-2311.17005-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2311.17005) -->
|
21 |
+
|
22 |
+
**Project Page: [PLLaVA](https://pllava.github.io/)**
|
23 |
+
|
24 |
+
[![arXiv](https://img.shields.io/badge/arXiv-2404.16994-b31b1b.svg)](https://arxiv.org/abs/2404.16994)
|
25 |
+
[![YouTube Video](https://img.shields.io/badge/YouTube-Video-red)](https://www.youtube.com/watch?v=nAEje8tu18U)
|
26 |
+
[![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b)
|
27 |
+
|
28 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-activitynet)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-activitynet?p=pllava-parameter-free-llava-extension-from-1)
|
29 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msrvtt-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msrvtt-qa?p=pllava-parameter-free-llava-extension-from-1)
|
30 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-msvd-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-msvd-qa?p=pllava-parameter-free-llava-extension-from-1)
|
31 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=pllava-parameter-free-llava-extension-from-1)
|
32 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/zeroshot-video-question-answer-on-tgif-qa)](https://paperswithcode.com/sota/zeroshot-video-question-answer-on-tgif-qa?p=pllava-parameter-free-llava-extension-from-1)
|
33 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-4)](https://paperswithcode.com/sota/video-based-generative-performance-4?p=pllava-parameter-free-llava-extension-from-1)
|
34 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-3)](https://paperswithcode.com/sota/video-based-generative-performance-3?p=pllava-parameter-free-llava-extension-from-1)
|
35 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=pllava-parameter-free-llava-extension-from-1)
|
36 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-2)](https://paperswithcode.com/sota/video-based-generative-performance-2?p=pllava-parameter-free-llava-extension-from-1)
|
37 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-1)](https://paperswithcode.com/sota/video-based-generative-performance-1?p=pllava-parameter-free-llava-extension-from-1)
|
38 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pllava-parameter-free-llava-extension-from-1/video-based-generative-performance-5)](https://paperswithcode.com/sota/video-based-generative-performance-5?p=pllava-parameter-free-llava-extension-from-1)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
![]()
|
47 |
+
<div align="center">
|
48 |
+
<a href="https://pllava.github.io">
|
49 |
+
<img src="assert/logo.png">
|
50 |
+
</a>
|
51 |
+
</div>
|
52 |
+
|
53 |
+
<div align="center">
|
54 |
+
<video src="https://github.com/magic-research/PLLaVA/assets/55656210/a6619702-12d3-489d-bfcc-0ef7105544b2" width="100%">
|
55 |
+
</div>
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
## Overview
|
63 |
+
|
64 |
+
Welcome to PLLAVA!
|
65 |
+
|
66 |
+
The primary purpose of this repository is to support research and the development of prototype models. It is designed to facilitate ease of experimentation and enable a clear overview of results. Please note that this repo is currently undergoing development and reconstruction.
|
67 |
+
|
68 |
+
It's important to mention that we have not optimized the response speed of the application or the frontend logic. Our goal is to maintain simplicity, clarity, and ease of development, making it accessible for both researchers and students. If you have suggestions or want to enhance the application's performance, please feel free to contact us or contribute to the project.
|
69 |
+
|
70 |
+
|
71 |
+
We've briefly introduce our work in section [PLLAVA](#%EF%B8%8F-pllava). For more details, feel free to read our paper. Check out section [Usage](#hammer-usage) to start using this repo. If you felt our works interesting, please star us, your support is all we want. If you find our work helpful, feel free to [cite](#page_facing_up-citation) us directly.
|
72 |
+
|
73 |
+
## :fire: Updates
|
74 |
+
|
75 |
+
- **2024/4/24**: Release:
|
76 |
+
- We are releasing our code/models/datasets.
|
77 |
+
|
78 |
+
## 🏖️ PLLAVA
|
79 |
+
<div align="center">
|
80 |
+
<a href="https://www.youtube.com/embed/nAEje8tu18U?si=GXxjgP93j77FzDbw">
|
81 |
+
<img src="assert/teaser.jpg">
|
82 |
+
</a>
|
83 |
+
</div>
|
84 |
+
|
85 |
+
|
86 |
+
### Abstract
|
87 |
+
|
88 |
+
Vision-language pre-training (VLP) has significantly elevated performance across a range of vision-language applications. Yet, the pre-training process for video-related tasks demands an exceptionally high degree of computational and data resources. This paper investigates a straightforward, highly efficient, and resource-light approach to adapting an existing image-language pre-training model for video data. Our preliminary experiments reveal that directly fine-tuning pre-trained image-language models with multiple frames on video datasets leads to performance saturation or even a drop in caption-related tasks. Besides, it is also vulnerable to prompts and tends to provide short descriptions. We conducted a deep analysis and observed that the performance saturation and the vulnerability might be related to the dominant patches that exist in some single video patches. We then propose a simple pooling strategy to smooth the feature distribution along the temporal dimension and thus reduce the dominant impacts from some extreme tokens. The new model is termed Pooling LLaVA, or PLLaVA in short. With the proposed pooling strategy, we achieve new state-of-the-art performance on all evaluated datasets. Notably, on the recent popular Video ChatGPT benchmark, PLLaVA achieves a score of 3.48 out of 5 on average of five evaluated dimensions, which is the new state-of-the-art score on the leaderboard and is 0.31 higher than the previous SOTA results from GPT4V (IG-VLM). On the latest multi-choice benchmark MVBench, PLLaVA achieves 58.1% accuracy on average across 20 sub-tasks, which is the new state-of-the-art result and is 14.5% higher than GPT4V (IG-VLM).
|
89 |
+
|
90 |
+
<div align="center"><img src="assert/module.png"></div>
|
91 |
+
|
92 |
+
|
93 |
+
### SEARCHING FOR OPTIMAL POOLING STRATEGY
|
94 |
+
There are two dimensions for the pooling strategy: the spatial dimension and the temporal dimension. We empirically found that reducing the spatial dimension with a larger temporal dimension could lead to better model performance, compared to reducing the temporal dimension directly.
|
95 |
+
|
96 |
+
<div align="center"><img src="assert/zeroshot.png"></div>
|
97 |
+
|
98 |
+
|
99 |
+
### STATE-OF-THE-ART PERFORMANCE
|
100 |
+
We compare the performance of PLLAVA with recent popular methods over both question-answer and captioning datasets. The results are shown below.
|
101 |
+
|
102 |
+
<div align="center"><img src="assert/performance.png"></div>
|
103 |
+
|
104 |
+
## :hammer: Usage
|
105 |
+
|
106 |
+
This section provides guidance on how to run, train, and evaluate our models.
|
107 |
+
|
108 |
+
### Install
|
109 |
+
First, you will need to set up the environment and download some pre-trained weights.
|
110 |
+
|
111 |
+
This repo is built up using [transformers](https://github.com/huggingface/transformers) for model construction along with [accelerate](https://github.com/huggingface/accelerate) for distributed training. Follow the instructions to install the needed environment.
|
112 |
+
|
113 |
+
1. Above all, the following environment set up is for python 3.10. If you choose to use conda for environment setup, we recommend creating the virtual environment with:
|
114 |
+
```bash
|
115 |
+
conda create -n pllava python=3.10
|
116 |
+
```
|
117 |
+
|
118 |
+
1. Firstly, install [pytorch](https://pytorch.org/) from the official website. The code runs on torch 2.2.1, cu118 or cu122. Select the version that suits your drive version.
|
119 |
+
|
120 |
+
```
|
121 |
+
torch 2.2.1+cu118
|
122 |
+
torchaudio 2.2.1+cu118
|
123 |
+
torchvision 0.17.1+cu118
|
124 |
+
```
|
125 |
+
|
126 |
+
If your driver version is higher than cu121, you could probably try installing with the following scripts:
|
127 |
+
```bash
|
128 |
+
pip install -r requirements.txt
|
129 |
+
```
|
130 |
+
|
131 |
+
Otherwise, you would need to install a torch for your server first, then install the other packages:
|
132 |
+
```bash
|
133 |
+
pip install -r requirements.torch.txt # decide your own requirements, (this is for cu11), or install torch directly following the official website.
|
134 |
+
pip install -r requirements.no_torch.txt # install the following
|
135 |
+
```
|
136 |
+
|
137 |
+
1. Prepare the model.
|
138 |
+
We prefer to have huggingface models explicitly downloaded to a MODELS directory. However, if you are familiar with huggingface-hub usage, feel free to organize the model yourself.
|
139 |
+
```
|
140 |
+
python python_scripts/hf.py
|
141 |
+
```
|
142 |
+
|
143 |
+
Here are some detailed information of the obtained models:
|
144 |
+
|
145 |
+
|
146 |
+
| Model | Link | Initialized From |
|
147 |
+
| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------- |
|
148 |
+
| pllava-7b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-7b) | [llava-hf/llava-v1.6-vicuna-7b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) |
|
149 |
+
| pllava-13b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-13b) | [llava-hf/llava-v1.6-vicuna-13b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) |
|
150 |
+
| pllava-34b | [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm-dark.svg)](https://huggingface.co/ermu2001/pllava-34b) | [llava-hf/llava-v1.6-34b-hf · Hugging Face](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) |
|
151 |
+
|
152 |
+
The model directory should look like this, where you would only need the corresponding model's weights and directory.
|
153 |
+
|
154 |
+
```
|
155 |
+
$ tree MODELS
|
156 |
+
MODELS
|
157 |
+
|-- pllava-13b
|
158 |
+
| |-- added_tokens.json
|
159 |
+
| |-- config.json
|
160 |
+
| |-- generation_config.json
|
161 |
+
| |-- model-00001-of-00006.safetensors
|
162 |
+
| |-- model-00002-of-00006.safetensors
|
163 |
+
| |-- model-00003-of-00006.safetensors
|
164 |
+
| |-- model-00004-of-00006.safetensors
|
165 |
+
| |-- model-00005-of-00006.safetensors
|
166 |
+
| |-- model-00006-of-00006.safetensors
|
167 |
+
| |-- model.safetensors.index.json
|
168 |
+
| |-- preprocessor_config.json
|
169 |
+
| |-- processor_config.json
|
170 |
+
| |-- special_tokens_map.json
|
171 |
+
| |-- tokenizer.json
|
172 |
+
| |-- tokenizer.model
|
173 |
+
| `-- tokenizer_config.json
|
174 |
+
|-- pllava-34b
|
175 |
+
| |-- added_tokens.json
|
176 |
+
| |-- config.json
|
177 |
+
| |-- generation_config.json
|
178 |
+
| |-- model-00001-of-00015.safetensors
|
179 |
+
| |-- model-00002-of-00015.safetensors
|
180 |
+
| |-- model-00003-of-00015.safetensors
|
181 |
+
| |-- model-00004-of-00015.safetensors
|
182 |
+
| |-- model-00005-of-00015.safetensors
|
183 |
+
| |-- model-00006-of-00015.safetensors
|
184 |
+
| |-- model-00007-of-00015.safetensors
|
185 |
+
| |-- model-00008-of-00015.safetensors
|
186 |
+
| |-- model-00009-of-00015.safetensors
|
187 |
+
| |-- model-00010-of-00015.safetensors
|
188 |
+
| |-- model-00011-of-00015.safetensors
|
189 |
+
| |-- model-00012-of-00015.safetensors
|
190 |
+
| |-- model-00013-of-00015.safetensors
|
191 |
+
| |-- model-00014-of-00015.safetensors
|
192 |
+
| |-- model-00015-of-00015.safetensors
|
193 |
+
| |-- model.safetensors-deprecated
|
194 |
+
| |-- model.safetensors.index.json
|
195 |
+
| |-- preprocessor_config.json
|
196 |
+
| |-- processor_config.json
|
197 |
+
| |-- special_tokens_map.json
|
198 |
+
| |-- tokenizer.json
|
199 |
+
| |-- tokenizer.model
|
200 |
+
| `-- tokenizer_config.json
|
201 |
+
|-- pllava-7b
|
202 |
+
|-- added_tokens.json
|
203 |
+
|-- config.json
|
204 |
+
|-- generation_config.json
|
205 |
+
|-- model-00001-of-00003.safetensors
|
206 |
+
|-- model-00002-of-00003.safetensors
|
207 |
+
|-- model-00003-of-00003.safetensors
|
208 |
+
|-- model.safetensors.index.json
|
209 |
+
|-- preprocessor_config.json
|
210 |
+
|-- processor_config.json
|
211 |
+
|-- special_tokens_map.json
|
212 |
+
|-- tokenizer.json
|
213 |
+
|-- tokenizer.model
|
214 |
+
`-- tokenizer_config.json
|
215 |
+
```
|
216 |
+
|
217 |
+
With the above steps, you should be able to proceed on with the following usages.
|
218 |
+
|
219 |
+
### Run Application
|
220 |
+
|
221 |
+
To run our models, make sure you have downloaded a model pretrained weights from the huggingface spaces. Then, run the following scripts with the corresponding path input. Since we are only training with lora and the projector, the model to be run are determined with:
|
222 |
+
|
223 |
+
- **model_dir**: model directory, one with config.json as compatible with transformers. This refers to the base model's directory, for example "llava-hf/llava-v1.6-vicuna-7b-hf"/"ermu2001/pllava-7b"/"MODELS/pllava-7b". (default to: MODELS/plave-7b)
|
224 |
+
- **weights_dir**: your weights directory. could be the same as model_dir, but if you have a weights directory for the lora weights, you should set this weights_dir to that directory to load the lora weights. This directory should be local. Also, it would need to contain a config.json file within. (default to: ${model_dir}).
|
225 |
+
|
226 |
+
```bash
|
227 |
+
model_dir="model directory"
|
228 |
+
weights_dir="weights directory"
|
229 |
+
bash scripts/demo.sh ${model_dir} ${weights_dir}
|
230 |
+
```
|
231 |
+
|
232 |
+
Now check out the application demo and try play with PLLAVA!
|
233 |
+
|
234 |
+
### Train
|
235 |
+
|
236 |
+
Follow the following steps to reproduce our results or train your own variant:
|
237 |
+
|
238 |
+
#### 1. Data Preparation
|
239 |
+
|
240 |
+
To train our model from a starting Image-aligned Vision LLM, you would need to download the data first. Our data set up is mainly based on the original Videochat2's training data. Check out [Instruction Data](./DATA.md) to prepare the instruction training data. Ideally, setting up a root data directory and alter the code [here](./tasks/train/instruction_data.py#L6) would accomodate the data for training most smoothly.
|
241 |
+
|
242 |
+
#### 2. Start Training
|
243 |
+
|
244 |
+
Now you're only a few step away from starting the training. Follow the instructions:
|
245 |
+
|
246 |
+
##### Setup Accelerator
|
247 |
+
|
248 |
+
Customize a accelerate training config. For example, a simple config using multiple gpus with no distribution strategy (only torch DDP) would look like:
|
249 |
+
|
250 |
+
```yaml
|
251 |
+
compute_environment: LOCAL_MACHINE
|
252 |
+
debug: false
|
253 |
+
distributed_type: MULTI_GPU
|
254 |
+
downcast_bf16: 'no'
|
255 |
+
gpu_ids: all
|
256 |
+
machine_rank: 0
|
257 |
+
main_training_function: main
|
258 |
+
mixed_precision: bf16
|
259 |
+
num_machines: 1
|
260 |
+
num_processes: 8
|
261 |
+
rdzv_backend: static
|
262 |
+
same_network: true
|
263 |
+
tpu_env: []
|
264 |
+
tpu_use_cluster: false
|
265 |
+
tpu_use_sudo: false
|
266 |
+
use_cpu: false
|
267 |
+
```
|
268 |
+
|
269 |
+
Check out out the [Accelerate](https://huggingface.co/docs/accelerate/index) documents for more details.
|
270 |
+
|
271 |
+
##### Overwatch the training configuration
|
272 |
+
|
273 |
+
Next, you should go over a basic training configuration of the training process in [here](tasks/train/config_pllava_nframe.py). Then passing this file as the first arg to the [training script](tasks/train/train_pllava_nframe_accel.py) would utilize every arguments in the file. You can customize some of the hyper parameters for your own training process by passing them in the format of "key" "value" pair in the following arguments. A example training scripts could be find [here](scripts/train_pllava.sh).
|
274 |
+
|
275 |
+
We recommand customize a [configuration](tasks/train/config_pllava_nframe.py) to set up a customized training!
|
276 |
+
|
277 |
+
With the above steps, you would be able to start the training process. The output would be well organized in the output directory, each a qualified model directory to pass in to demo as weights_dir, since we are only saveing the lora weights and projector weights to avoide redundancy.
|
278 |
+
|
279 |
+
### Evaluation
|
280 |
+
|
281 |
+
This section mainly introduce how to reproduce the evaluation or evaluate your own model.
|
282 |
+
|
283 |
+
#### Set up Evaluation Data
|
284 |
+
|
285 |
+
Make sure you set up the "DATAS" directory as in [DATA.md](DATA.md), then you would be able to run the inference with fortune! The evaluation data directory of DATAS would look like:
|
286 |
+
|
287 |
+
```
|
288 |
+
DATAS/:
|
289 |
+
DATAS/VideoQA:
|
290 |
+
DATAS/VideoQA/TGIF_QA:
|
291 |
+
test_a.json
|
292 |
+
test_q.json
|
293 |
+
DATAS/VideoQA/TGIF_QA/videos:
|
294 |
+
tumblr_m4387mGrlc1r6m5e8o1_250.gif
|
295 |
+
...
|
296 |
+
DATAS/VideoQA/TGIF_QA/videos_mp4:
|
297 |
+
tumblr_m4387mGrlc1r6m5e8o1_250.mp4
|
298 |
+
...
|
299 |
+
DATAS/VideoQA/TGIF_QA/video_gif:
|
300 |
+
tumblr_m4387mGrlc1r6m5e8o1_250.gif
|
301 |
+
...
|
302 |
+
DATAS/VideoQA/MSVD_Zero_Shot_QA:
|
303 |
+
test_a.json
|
304 |
+
test_q.json
|
305 |
+
DATAS/VideoQA/MSVD_Zero_Shot_QA/videos:
|
306 |
+
-4wsuPCjDBc_5_15.avi
|
307 |
+
DATAS/VideoQA/MSVD_Zero_Shot_QA/msvd_qa:
|
308 |
+
DATAS/VideoQA/ActivityNet:
|
309 |
+
test_a.json
|
310 |
+
test_q.json
|
311 |
+
DATAS/VideoQA/ActivityNet/all_test:
|
312 |
+
v_--tFD65KaK4.mp4
|
313 |
+
...
|
314 |
+
DATAS/VideoQA/MSRVTT_Zero_Shot_QA:
|
315 |
+
test_a.json
|
316 |
+
test_q.json
|
317 |
+
DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos:
|
318 |
+
DATAS/VideoQA/MSRVTT_Zero_Shot_QA/videos/all:
|
319 |
+
video0.mp4
|
320 |
+
...
|
321 |
+
|
322 |
+
DATAS/MVBench:
|
323 |
+
...
|
324 |
+
|
325 |
+
DATAS/Recaption/Inter4K:
|
326 |
+
annotations.json
|
327 |
+
DATAS/Recaption/Inter4K/60fps:
|
328 |
+
DATAS/Recaption/Inter4K/60fps/UHD:
|
329 |
+
1.mp4
|
330 |
+
...
|
331 |
+
|
332 |
+
```
|
333 |
+
|
334 |
+
#### Start Evaluate
|
335 |
+
|
336 |
+
Once you have construted the evaluation data, you can start the evaluation as in [here](scripts/eval.sh). This script is for evaluating 7B/13B models. As pllava-34b model uses a slightly different prompting, it is evaluated with this [script](scripts/eval_yiprompt.sh).
|
337 |
+
|
338 |
+
```
|
339 |
+
bash scripts/eval.sh
|
340 |
+
```
|
341 |
+
|
342 |
+
Same as running the demo, you would need to determine the model_dir and weights_dir to evaluate the model. Feel free to comment out some commands and produce partial evaluation.
|
343 |
+
|
344 |
+
#### Overwatch the Results
|
345 |
+
|
346 |
+
The evaluation results would be shown to you with our results gallery demo:
|
347 |
+
|
348 |
+
```bash
|
349 |
+
bash scripts/gallery.sh
|
350 |
+
```
|
351 |
+
|
352 |
+
Feel free to use the compare version to compare differnt models' results or use the single gallery version to check out one model's results. They are basically the same. Check out this [script](scripts/gallery.sh) for more details
|
353 |
+
|
354 |
+
#### For Captioning and Recaptioning
|
355 |
+
Follow instructions at [DATA.md](DATA.md#extending-reacptioning) and you can extend the recaptioning data with a few steps.
|
356 |
+
|
357 |
+
Feel free to point out high quality dataset of videos, we would proceed on doing captioning on those datasets.
|
358 |
+
|
359 |
+
|
360 |
+
# :page_facing_up: Citation
|
361 |
+
|
362 |
+
If you find this project useful in your research, please consider cite:
|
363 |
+
|
364 |
+
```BibTeX
|
365 |
+
@misc{xu2024pllava,
|
366 |
+
title={PLLaVA : Parameter-free LLaVA Extension from Images to Videos for Video Dense Captioning},
|
367 |
+
author={Lin Xu and Yilin Zhao and Daquan Zhou and Zhijie Lin and See Kiong Ng and Jiashi Feng},
|
368 |
+
year={2024},
|
369 |
+
eprint={2404.16994},
|
370 |
+
archivePrefix={arXiv},
|
371 |
+
primaryClass={cs.CV}
|
372 |
+
}
|
373 |
+
```
|
374 |
+
|
375 |
+
# :dizzy: Acknowledgement
|
376 |
+
|
377 |
+
This code base is mainly built upon [Videochat2](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2). SALUTE.
|
378 |
+
|
379 |
+
We would also like to recognize and commend the following open source projects, thank you for your great contribution to the open source community:
|
380 |
+
|
381 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA): Fantastic Open Source Image LLM Model.
|
382 |
+
- [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main): Great Evaluation Benchmarking Framework.
|
383 |
+
- [VideoLlava](https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main/videollava):Video LLM repo with helpful resources.
|
app.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
snapshot_download(
|
4 |
+
'ermu2001/pllava-7b',
|
5 |
+
local_dir='MODELS/pllava-7b',
|
6 |
+
repo_type='model',
|
7 |
+
local_dir_use_symlinks=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
sys.argv.extend([
|
11 |
+
"--pretrained_model_name_or_path", "MODELS/pllava-7b",
|
12 |
+
"--num_frames", "16",
|
13 |
+
"--use_lora",
|
14 |
+
"--weight_dir", "MODELS/pllava-7b",
|
15 |
+
"--lora_alpha", "4",
|
16 |
+
"--conv_mode", "plain",
|
17 |
+
])
|
18 |
+
import tasks.eval.demo.pllava_demo
|
assert/data.png
ADDED
Git LFS Details
|
assert/logo.png
ADDED
Git LFS Details
|
assert/module.png
ADDED
Git LFS Details
|
assert/performance.png
ADDED
Git LFS Details
|
assert/teaser.jpg
ADDED
Git LFS Details
|
assert/zeroshot.png
ADDED
Git LFS Details
|
dataset/__init__.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.transforms import InterpolationMode
|
5 |
+
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset
|
6 |
+
|
7 |
+
|
8 |
+
def get_media_type(dataset_config):
|
9 |
+
if len(dataset_config) == 3 and dataset_config[2] == "video":
|
10 |
+
return "video"
|
11 |
+
elif dataset_config[-1] == "only_video":
|
12 |
+
return "only_video"
|
13 |
+
else:
|
14 |
+
return "image"
|
15 |
+
|
16 |
+
|
17 |
+
def create_dataset(dataset_type, config):
|
18 |
+
if "clip" in config.model.get("vit_model", 'vit'):
|
19 |
+
mean = (0.485, 0.456, 0.406)
|
20 |
+
std = (0.229, 0.224, 0.225)
|
21 |
+
else:
|
22 |
+
vision_enc_name = config.model.vision_encoder.name
|
23 |
+
if "swin" in vision_enc_name or "vit" in vision_enc_name:
|
24 |
+
mean = (0.485, 0.456, 0.406)
|
25 |
+
std = (0.229, 0.224, 0.225)
|
26 |
+
elif "beit" in vision_enc_name:
|
27 |
+
mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning
|
28 |
+
std = (0.5, 0.5, 0.5)
|
29 |
+
elif "clip" in vision_enc_name:
|
30 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
31 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
32 |
+
else:
|
33 |
+
raise ValueError
|
34 |
+
|
35 |
+
normalize = transforms.Normalize(mean, std)
|
36 |
+
|
37 |
+
# loaded images and videos are torch.Tensor of torch.uint8 format,
|
38 |
+
# ordered as (T, 1 or 3, H, W) where T=1 for image
|
39 |
+
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
|
40 |
+
|
41 |
+
if config.inputs.video_input.random_aug:
|
42 |
+
aug_transform = transforms.RandAugment()
|
43 |
+
else:
|
44 |
+
aug_transform = transforms.Lambda(lambda x: x)
|
45 |
+
|
46 |
+
train_transform = transforms.Compose(
|
47 |
+
[
|
48 |
+
aug_transform,
|
49 |
+
transforms.RandomResizedCrop(
|
50 |
+
config.inputs.image_res,
|
51 |
+
scale=(0.5, 1.0),
|
52 |
+
interpolation=InterpolationMode.BICUBIC,
|
53 |
+
),
|
54 |
+
transforms.RandomHorizontalFlip(),
|
55 |
+
type_transform,
|
56 |
+
normalize,
|
57 |
+
]
|
58 |
+
)
|
59 |
+
test_transform = transforms.Compose(
|
60 |
+
[
|
61 |
+
transforms.Resize(
|
62 |
+
(config.inputs.image_res, config.inputs.image_res),
|
63 |
+
interpolation=InterpolationMode.BICUBIC,
|
64 |
+
),
|
65 |
+
type_transform,
|
66 |
+
normalize,
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
video_reader_type = config.inputs.video_input.get("video_reader_type", "decord")
|
71 |
+
video_only_dataset_kwargs_train = dict(
|
72 |
+
video_reader_type=video_reader_type,
|
73 |
+
sample_type=config.inputs.video_input.sample_type,
|
74 |
+
num_frames=config.inputs.video_input.num_frames,
|
75 |
+
num_tries=3, # false tolerance
|
76 |
+
)
|
77 |
+
|
78 |
+
if dataset_type == "pt_train":
|
79 |
+
raise ValueError("NOT PRETRAINING YET")
|
80 |
+
elif dataset_type in ["it_train"]:
|
81 |
+
# convert to list of lists
|
82 |
+
train_files = (
|
83 |
+
[config.train_file] if isinstance(config.train_file[0], str) else config.train_file
|
84 |
+
)
|
85 |
+
train_media_types = sorted(list({get_media_type(e) for e in train_files}))
|
86 |
+
|
87 |
+
train_datasets = []
|
88 |
+
for m in train_media_types:
|
89 |
+
dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset
|
90 |
+
# dataset of the same media_type will be mixed in a single Dataset object
|
91 |
+
_train_files = [e for e in train_files if get_media_type(e) == m]
|
92 |
+
|
93 |
+
datasets = []
|
94 |
+
for train_file in _train_files:
|
95 |
+
dataset_kwargs = dict(
|
96 |
+
ann_file=train_file,
|
97 |
+
transform=train_transform,
|
98 |
+
mm_alone=config.preprocess.get("mm_alone", True),
|
99 |
+
add_second_msg=config.preprocess.get("add_second_msg", True),
|
100 |
+
skip_short_sample=config.preprocess.get("skip_short_sample", False),
|
101 |
+
clip_transform=config.preprocess.get("clip_transform", False),
|
102 |
+
random_shuffle=config.preprocess.get("random_shuffle", True),
|
103 |
+
system=config.preprocess.get("system", ""),
|
104 |
+
role=config.preprocess.get('roles', ("Human", "Assistant")),
|
105 |
+
end_signal=config.preprocess.get('end_signal', "###"),
|
106 |
+
begin_signal=config.preprocess.get('begin_signal', ""),
|
107 |
+
)
|
108 |
+
if m == "video":
|
109 |
+
video_only_dataset_kwargs_train.update({
|
110 |
+
"start_token": config.model.get("start_token", "<Video>"),
|
111 |
+
"end_token": config.model.get("end_token", "</Video>"),
|
112 |
+
})
|
113 |
+
dataset_kwargs.update(video_only_dataset_kwargs_train)
|
114 |
+
if "tgif" in train_file[1]:
|
115 |
+
video_only_dataset_kwargs_train.update({
|
116 |
+
"video_reader_type": "gif"
|
117 |
+
})
|
118 |
+
dataset_kwargs.update(video_only_dataset_kwargs_train)
|
119 |
+
elif "webvid" in train_file[1]:
|
120 |
+
video_only_dataset_kwargs_train.update({
|
121 |
+
"video_reader_type": "hdfs"
|
122 |
+
})
|
123 |
+
else:
|
124 |
+
video_only_dataset_kwargs_train.update({
|
125 |
+
"video_reader_type": "decord"
|
126 |
+
})
|
127 |
+
dataset_kwargs.update(video_only_dataset_kwargs_train)
|
128 |
+
datasets.append(dataset_cls(**dataset_kwargs))
|
129 |
+
dataset = ConcatDataset(datasets)
|
130 |
+
train_datasets.append(dataset)
|
131 |
+
return train_datasets
|
132 |
+
|
133 |
+
|
134 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
135 |
+
loaders = []
|
136 |
+
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
137 |
+
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
138 |
+
):
|
139 |
+
if is_train:
|
140 |
+
shuffle = sampler is None
|
141 |
+
drop_last = True
|
142 |
+
else:
|
143 |
+
shuffle = False
|
144 |
+
drop_last = False
|
145 |
+
loader = DataLoader(
|
146 |
+
dataset,
|
147 |
+
batch_size=bs,
|
148 |
+
num_workers=n_worker,
|
149 |
+
pin_memory=False,
|
150 |
+
sampler=sampler,
|
151 |
+
shuffle=shuffle,
|
152 |
+
collate_fn=collate_fn,
|
153 |
+
drop_last=drop_last,
|
154 |
+
persistent_workers=True if n_worker > 0 else False,
|
155 |
+
)
|
156 |
+
loaders.append(loader)
|
157 |
+
return loaders
|
158 |
+
|
dataset/base_dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import time
|
7 |
+
from dataset.utils import load_image_from_path
|
8 |
+
|
9 |
+
try:
|
10 |
+
from petrel_client.client import Client
|
11 |
+
has_client = True
|
12 |
+
except ImportError:
|
13 |
+
has_client = False
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class ImageVideoBaseDataset(Dataset):
|
19 |
+
"""Base class that implements the image and video loading methods"""
|
20 |
+
|
21 |
+
media_type = "video"
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
assert self.media_type in ["image", "video", "only_video"]
|
25 |
+
self.data_root = None
|
26 |
+
self.anno_list = (
|
27 |
+
None # list(dict), each dict contains {"image": str, # image or video path}
|
28 |
+
)
|
29 |
+
self.transform = None
|
30 |
+
self.video_reader = None
|
31 |
+
self.num_tries = None
|
32 |
+
|
33 |
+
self.client = None
|
34 |
+
if has_client:
|
35 |
+
self.client = Client('~/petreloss.conf')
|
36 |
+
|
37 |
+
def __getitem__(self, index):
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
def get_anno(self, index):
|
44 |
+
"""obtain the annotation for one media (video or image)
|
45 |
+
|
46 |
+
Args:
|
47 |
+
index (int): The media index.
|
48 |
+
|
49 |
+
Returns: dict.
|
50 |
+
- "image": the filename, video also use "image".
|
51 |
+
- "caption": The caption for this file.
|
52 |
+
|
53 |
+
"""
|
54 |
+
anno = self.anno_list[index]
|
55 |
+
if self.data_root is not None:
|
56 |
+
anno["image"] = os.path.join(self.data_root, anno["image"])
|
57 |
+
return anno
|
58 |
+
|
59 |
+
def load_and_transform_media_data(self, index, data_path):
|
60 |
+
if self.media_type == "image":
|
61 |
+
return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform)
|
62 |
+
else:
|
63 |
+
return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform)
|
64 |
+
|
65 |
+
def load_and_transform_media_data_image(self, index, data_path, clip_transform=False):
|
66 |
+
image = load_image_from_path(data_path, client=self.client)
|
67 |
+
if not clip_transform:
|
68 |
+
image = self.transform(image)
|
69 |
+
return image, index
|
70 |
+
|
71 |
+
def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False):
|
72 |
+
for _ in range(self.num_tries):
|
73 |
+
try:
|
74 |
+
max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
|
75 |
+
if "webvid" in data_path:
|
76 |
+
hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train"
|
77 |
+
video_name = os.path.basename(data_path)
|
78 |
+
video_id, extension = os.path.splitext(video_name)
|
79 |
+
ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id])
|
80 |
+
frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type,
|
81 |
+
max_num_frames=max_num_frames, client=self.client, clip=clip)
|
82 |
+
else:
|
83 |
+
frames, frame_indices, fps = self.video_reader(
|
84 |
+
data_path, self.num_frames, self.sample_type,
|
85 |
+
max_num_frames=max_num_frames, client=self.client, clip=clip
|
86 |
+
)
|
87 |
+
except Exception as e:
|
88 |
+
logger.warning(
|
89 |
+
f"Caught exception {e} when loading video {data_path}, "
|
90 |
+
f"randomly sample a new video as replacement"
|
91 |
+
)
|
92 |
+
index = random.randint(0, len(self) - 1)
|
93 |
+
ann = self.get_anno(index)
|
94 |
+
data_path = ann["image"]
|
95 |
+
continue
|
96 |
+
# shared aug for video frames
|
97 |
+
if not clip_transform:
|
98 |
+
frames = self.transform(frames)
|
99 |
+
if return_fps:
|
100 |
+
sec = [str(round(f / fps, 1)) for f in frame_indices]
|
101 |
+
return frames, index, sec
|
102 |
+
else:
|
103 |
+
return frames, index
|
104 |
+
else:
|
105 |
+
raise RuntimeError(
|
106 |
+
f"Failed to fetch video after {self.num_tries} tries. "
|
107 |
+
f"This might indicate that you have many corrupted videos."
|
108 |
+
)
|
dataset/it_dataset.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import sqlite3
|
5 |
+
import random
|
6 |
+
from os.path import basename
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import datetime
|
10 |
+
|
11 |
+
from dataset.base_dataset import ImageVideoBaseDataset
|
12 |
+
from dataset.video_utils import VIDEO_READER_FUNCS
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
IMAGE_TOKEN="<image>"
|
16 |
+
|
17 |
+
class ITImgTrainDataset(ImageVideoBaseDataset):
|
18 |
+
media_type = "image"
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self, ann_file, transform,
|
22 |
+
system="", role=("Human", "Assistant"),
|
23 |
+
mm_alone=True,
|
24 |
+
add_second_msg=True,
|
25 |
+
start_token="<Image>", end_token="</Image>",
|
26 |
+
random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle
|
27 |
+
begin_signal=None,
|
28 |
+
end_signal=None,
|
29 |
+
clip_transform=False,
|
30 |
+
skip_short_sample=False,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.mm_alone = mm_alone
|
34 |
+
self.clip_transform = clip_transform
|
35 |
+
if len(ann_file) == 3 and ann_file[2] == "video":
|
36 |
+
self.media_type = "video"
|
37 |
+
else:
|
38 |
+
self.media_type = "image"
|
39 |
+
self.label_file, self.data_root = ann_file[:2]
|
40 |
+
|
41 |
+
logger.info('Load json file')
|
42 |
+
with open(self.label_file, 'r') as f:
|
43 |
+
self.anno = json.load(f)
|
44 |
+
self.num_examples = len(self.anno)
|
45 |
+
self.transform = transform
|
46 |
+
annos = []
|
47 |
+
for ann in self.anno:
|
48 |
+
filename = ann['video'] if 'video' in ann else ann['image']
|
49 |
+
if self.media_type =='video' and "webvid" in self.data_root:
|
50 |
+
video_id, extension = os.path.splitext(os.path.basename(filename))
|
51 |
+
if video_id not in self.keys_indexfile:
|
52 |
+
pass
|
53 |
+
else:
|
54 |
+
annos.append(ann)
|
55 |
+
else:
|
56 |
+
|
57 |
+
if filename is None or filename=="None":
|
58 |
+
pass
|
59 |
+
else:
|
60 |
+
if os.path.exists(os.path.join(self.data_root, filename)):
|
61 |
+
annos.append(ann)
|
62 |
+
else:
|
63 |
+
...
|
64 |
+
self.anno = annos
|
65 |
+
self.num_examples = len(self.anno)
|
66 |
+
|
67 |
+
|
68 |
+
# prompt parameters
|
69 |
+
if system:
|
70 |
+
assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token."
|
71 |
+
# currently not support add start_token and end_token in the system, since the msg should be added properly
|
72 |
+
self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal
|
73 |
+
self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal
|
74 |
+
self.start_token = start_token
|
75 |
+
self.end_token = end_token
|
76 |
+
self.system = system
|
77 |
+
self.role = role
|
78 |
+
self.random_shuffle = random_shuffle
|
79 |
+
# instruction location and number
|
80 |
+
logger.info(f"Random shuffle: {self.random_shuffle}")
|
81 |
+
|
82 |
+
def get_anno(self, index):
|
83 |
+
filename = self.anno[index][self.media_type]
|
84 |
+
qa = self.anno[index]["QA"]
|
85 |
+
|
86 |
+
if "start" in self.anno[index] and "end" in self.anno[index]:
|
87 |
+
anno = {
|
88 |
+
"image": os.path.join(self.data_root, filename), "qa": qa,
|
89 |
+
"start": self.anno[index]["start"], "end": self.anno[index]["end"],
|
90 |
+
}
|
91 |
+
else:
|
92 |
+
anno = {"image": os.path.join(self.data_root, filename), "qa": qa}
|
93 |
+
return anno
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return self.num_examples
|
97 |
+
|
98 |
+
def process_qa(self, qa, msg=""):
|
99 |
+
cur_instruction = ""
|
100 |
+
# randomly shuffle qa for conversation
|
101 |
+
if self.random_shuffle and len(qa) > 1:
|
102 |
+
random.shuffle(qa)
|
103 |
+
if "i" in qa[0].keys() and qa[0]["i"] != "":
|
104 |
+
cur_instruction = qa[0]["i"] + self.end_signal[0]
|
105 |
+
|
106 |
+
conversation = self.system
|
107 |
+
# add instruction as system message
|
108 |
+
if cur_instruction:
|
109 |
+
conversation += cur_instruction
|
110 |
+
|
111 |
+
# rstrip() for the extra " " in msg
|
112 |
+
if self.mm_alone:
|
113 |
+
conversation += (
|
114 |
+
self.begin_signal[0] + self.role[0] +
|
115 |
+
self.start_token + self.end_token + msg.rstrip() + self.end_signal[0]
|
116 |
+
)
|
117 |
+
|
118 |
+
for i, sentence in enumerate(qa):
|
119 |
+
q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"]
|
120 |
+
a = sentence["a"]
|
121 |
+
if q != "":
|
122 |
+
conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1])
|
123 |
+
else:
|
124 |
+
# no question, often in caption dataset
|
125 |
+
pass
|
126 |
+
conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1])
|
127 |
+
|
128 |
+
|
129 |
+
if cur_instruction:
|
130 |
+
cur_instruction += qa[0]["q"]
|
131 |
+
return conversation, cur_instruction.strip()
|
132 |
+
|
133 |
+
def __getitem__(self, index):
|
134 |
+
try:
|
135 |
+
ann = self.get_anno(index)
|
136 |
+
image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform)
|
137 |
+
conversation, instruction = self.process_qa(ann["qa"])
|
138 |
+
return image, conversation, instruction, index
|
139 |
+
except Exception as e:
|
140 |
+
logger.warning(f"Caught exception {e} when loading image {ann['image']}")
|
141 |
+
index = np.random.randint(0, len(self))
|
142 |
+
return self.__getitem__(index)
|
143 |
+
|
144 |
+
|
145 |
+
class ITVidTrainDataset(ITImgTrainDataset):
|
146 |
+
media_type = "video"
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self, ann_file, transform,
|
150 |
+
num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3,
|
151 |
+
mm_alone=True,
|
152 |
+
system="", role=("Human", "Assistant"),
|
153 |
+
start_token="<Video>", end_token="</Video>",
|
154 |
+
add_second_msg=True,
|
155 |
+
random_shuffle=True,
|
156 |
+
begin_signal=None,
|
157 |
+
end_signal=None,
|
158 |
+
clip_transform=False,
|
159 |
+
skip_short_sample=False,
|
160 |
+
|
161 |
+
):
|
162 |
+
# "id index file for webvid"
|
163 |
+
if "webvid" in ann_file[1]:
|
164 |
+
with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f:
|
165 |
+
self.keys_indexfile = json.load(f) # the correponding index file for each webvid id
|
166 |
+
|
167 |
+
super().__init__(
|
168 |
+
ann_file, transform,
|
169 |
+
system=system, role=role,
|
170 |
+
mm_alone=mm_alone,
|
171 |
+
start_token=start_token, end_token=end_token,
|
172 |
+
random_shuffle=random_shuffle,
|
173 |
+
begin_signal=begin_signal,
|
174 |
+
end_signal=end_signal,
|
175 |
+
clip_transform=clip_transform,
|
176 |
+
skip_short_sample=skip_short_sample,
|
177 |
+
)
|
178 |
+
self.num_frames = num_frames
|
179 |
+
self.video_reader_type = video_reader_type
|
180 |
+
self.video_reader = VIDEO_READER_FUNCS[video_reader_type]
|
181 |
+
self.sample_type = sample_type
|
182 |
+
self.num_tries = num_tries
|
183 |
+
self.add_second_msg = add_second_msg
|
184 |
+
|
185 |
+
logger.info(f"Use {video_reader_type} for data in {ann_file}")
|
186 |
+
if add_second_msg:
|
187 |
+
logger.info(f"Add second message: The video contains X frames sampled at T seconds.")
|
188 |
+
|
189 |
+
def __getitem__(self, index):
|
190 |
+
try:
|
191 |
+
ann = self.get_anno(index)
|
192 |
+
|
193 |
+
msg = ""
|
194 |
+
clip = None
|
195 |
+
if "start" in ann and "end" in ann:
|
196 |
+
clip = [ann["start"], ann["end"]]
|
197 |
+
video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform)
|
198 |
+
if self.add_second_msg:
|
199 |
+
# " " should be added in the start and end
|
200 |
+
msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. "
|
201 |
+
conversation, instruction = self.process_qa(ann["qa"], msg)
|
202 |
+
return video, conversation, instruction, index
|
203 |
+
except Exception as e:
|
204 |
+
logger.warning(f"Caught exception {e} when loading video {ann['image']}")
|
205 |
+
index = np.random.randint(0, len(self))
|
206 |
+
return self.__getitem__(index)
|
dataset/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.distributed import is_main_process, get_rank, get_world_size
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import numpy as np
|
6 |
+
from os.path import join
|
7 |
+
from tqdm import trange
|
8 |
+
from PIL import Image
|
9 |
+
from PIL import ImageFile
|
10 |
+
from torchvision.transforms import PILToTensor
|
11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
12 |
+
Image.MAX_IMAGE_PIXELS = None
|
13 |
+
|
14 |
+
|
15 |
+
def load_image_from_path(image_path, client):
|
16 |
+
if image_path.startswith('s3') or image_path.startswith('p2'):
|
17 |
+
value = client.Get(image_path)
|
18 |
+
img_bytes = np.frombuffer(value, dtype=np.uint8)
|
19 |
+
buff = io.BytesIO(img_bytes)
|
20 |
+
image = Image.open(buff).convert('RGB')
|
21 |
+
else:
|
22 |
+
image = Image.open(image_path).convert('RGB') # PIL Image
|
23 |
+
image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8
|
24 |
+
return image
|
25 |
+
|
26 |
+
def pre_text(text, max_l=None, pre_text=True):
|
27 |
+
if pre_text:
|
28 |
+
text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower())
|
29 |
+
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
30 |
+
|
31 |
+
text = re.sub(r"\s{2,}", ' ', text)
|
32 |
+
text = text.rstrip('\n').strip(' ')
|
33 |
+
|
34 |
+
if max_l: # truncate
|
35 |
+
words = text.split(' ')
|
36 |
+
if len(words) > max_l:
|
37 |
+
text = ' '.join(words[:max_l])
|
38 |
+
else:
|
39 |
+
pass
|
40 |
+
return text
|
41 |
+
|
dataset/video_utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
|
3 |
+
"""
|
4 |
+
import random
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import av
|
8 |
+
import cv2
|
9 |
+
import decord
|
10 |
+
import imageio
|
11 |
+
from decord import VideoReader
|
12 |
+
|
13 |
+
# from dataloader import KVReader
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import math
|
17 |
+
# import tensorflow as tf
|
18 |
+
decord.bridge.set_bridge("torch")
|
19 |
+
|
20 |
+
import logging
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
|
24 |
+
"""
|
25 |
+
Converts a present time with the given time base and start_pts offset to seconds.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
time_in_seconds (float): The corresponding time in seconds.
|
29 |
+
|
30 |
+
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
|
31 |
+
"""
|
32 |
+
if pts == math.inf:
|
33 |
+
return math.inf
|
34 |
+
|
35 |
+
return int(pts - start_pts) * time_base
|
36 |
+
|
37 |
+
|
38 |
+
def get_pyav_video_duration(video_reader):
|
39 |
+
video_stream = video_reader.streams.video[0]
|
40 |
+
video_duration = pts_to_secs(
|
41 |
+
video_stream.duration,
|
42 |
+
video_stream.time_base,
|
43 |
+
video_stream.start_time
|
44 |
+
)
|
45 |
+
return float(video_duration)
|
46 |
+
|
47 |
+
|
48 |
+
def get_frame_indices_by_fps():
|
49 |
+
pass
|
50 |
+
|
51 |
+
|
52 |
+
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
53 |
+
if sample in ["rand", "middle"]: # uniform sampling
|
54 |
+
acc_samples = min(num_frames, vlen)
|
55 |
+
# split the video into `acc_samples` intervals, and sample from each interval.
|
56 |
+
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
57 |
+
ranges = []
|
58 |
+
for idx, interv in enumerate(intervals[:-1]):
|
59 |
+
ranges.append((interv, intervals[idx + 1] - 1))
|
60 |
+
if sample == 'rand':
|
61 |
+
try:
|
62 |
+
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
63 |
+
except:
|
64 |
+
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
65 |
+
frame_indices.sort()
|
66 |
+
frame_indices = list(frame_indices)
|
67 |
+
elif fix_start is not None:
|
68 |
+
frame_indices = [x[0] + fix_start for x in ranges]
|
69 |
+
elif sample == 'middle':
|
70 |
+
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
if len(frame_indices) < num_frames: # padded with last frame
|
75 |
+
padded_frame_indices = [frame_indices[-1]] * num_frames
|
76 |
+
padded_frame_indices[:len(frame_indices)] = frame_indices
|
77 |
+
frame_indices = padded_frame_indices
|
78 |
+
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
79 |
+
output_fps = float(sample[3:])
|
80 |
+
duration = float(vlen) / input_fps
|
81 |
+
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
82 |
+
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
83 |
+
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
84 |
+
frame_indices = [e for e in frame_indices if e < vlen]
|
85 |
+
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
86 |
+
frame_indices = frame_indices[:max_num_frames]
|
87 |
+
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
|
88 |
+
else:
|
89 |
+
raise ValueError
|
90 |
+
return frame_indices
|
91 |
+
|
92 |
+
|
93 |
+
def read_frames_av(
|
94 |
+
video_path, num_frames, sample='rand', fix_start=None,
|
95 |
+
max_num_frames=-1, client=None, clip=None,
|
96 |
+
):
|
97 |
+
reader = av.open(video_path)
|
98 |
+
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
|
99 |
+
vlen = len(frames)
|
100 |
+
duration = get_pyav_video_duration(reader)
|
101 |
+
fps = vlen / float(duration)
|
102 |
+
frame_indices = get_frame_indices(
|
103 |
+
num_frames, vlen, sample=sample, fix_start=fix_start,
|
104 |
+
input_fps=fps, max_num_frames=max_num_frames
|
105 |
+
)
|
106 |
+
frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
|
107 |
+
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
|
108 |
+
return frames, frame_indices, fps
|
109 |
+
|
110 |
+
|
111 |
+
def read_frames_gif(
|
112 |
+
video_path, num_frames, sample='rand', fix_start=None,
|
113 |
+
max_num_frames=-1, client=None, clip=None,
|
114 |
+
):
|
115 |
+
if video_path.startswith('s3') or video_path.startswith('p2'):
|
116 |
+
video_bytes = client.get(video_path)
|
117 |
+
gif = imageio.get_reader(io.BytesIO(video_bytes))
|
118 |
+
else:
|
119 |
+
gif = imageio.get_reader(video_path)
|
120 |
+
vlen = len(gif)
|
121 |
+
frame_indices = get_frame_indices(
|
122 |
+
num_frames, vlen, sample=sample, fix_start=fix_start,
|
123 |
+
max_num_frames=max_num_frames
|
124 |
+
)
|
125 |
+
frames = []
|
126 |
+
for index, frame in enumerate(gif):
|
127 |
+
# for index in frame_idxs:
|
128 |
+
if index in frame_indices:
|
129 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
130 |
+
frame = torch.from_numpy(frame).byte()
|
131 |
+
# # (H x W x C) to (C x H x W)
|
132 |
+
frame = frame.permute(2, 0, 1)
|
133 |
+
frames.append(frame)
|
134 |
+
frames = torch.stack(frames) # .float() / 255
|
135 |
+
|
136 |
+
return frames, frame_indices, 25. # for tgif
|
137 |
+
|
138 |
+
|
139 |
+
def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None,
|
140 |
+
max_num_frames=-1, client=None, clip=None):
|
141 |
+
_context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)}
|
142 |
+
_sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)}
|
143 |
+
num_parallel_reader = 1
|
144 |
+
filename, extension = os.path.splitext(ind_file)
|
145 |
+
reader = KVReader(filename, num_parallel_reader)
|
146 |
+
key = vid
|
147 |
+
values = reader.read_many([key])
|
148 |
+
item = values[0]
|
149 |
+
contexts, sequences = tf.io.parse_single_sequence_example(
|
150 |
+
serialized=item,
|
151 |
+
context_features=_context_features,
|
152 |
+
sequence_features=_sequence_features)
|
153 |
+
|
154 |
+
# text = contexts['title'].numpy().decode("utf-8")
|
155 |
+
rawframes = sequences['data']
|
156 |
+
vlen = len(rawframes)
|
157 |
+
sample="rand"
|
158 |
+
|
159 |
+
frame_indices = get_frame_indices(num_frames, vlen, sample=sample,
|
160 |
+
fix_start=fix_start,
|
161 |
+
max_num_frames=max_num_frames)
|
162 |
+
def read_image(raw_data):
|
163 |
+
return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy()
|
164 |
+
|
165 |
+
frames = []
|
166 |
+
for index, frame in enumerate(rawframes):
|
167 |
+
if index in frame_indices:
|
168 |
+
frame = read_image(frame)
|
169 |
+
frame = torch.as_tensor(frame)
|
170 |
+
frames.append(frame)
|
171 |
+
|
172 |
+
frames = torch.stack(frames)
|
173 |
+
# print("in hdfs========>",frames[0])
|
174 |
+
frames = frames.permute(0, 3, 1, 2)
|
175 |
+
return frames, frame_indices, 25 # don't know the fps for index
|
176 |
+
|
177 |
+
|
178 |
+
def read_frames_decord(
|
179 |
+
video_path, num_frames, sample='rand', fix_start=None,
|
180 |
+
max_num_frames=-1, client=None, clip=None
|
181 |
+
):
|
182 |
+
if video_path.startswith('s3') or video_path.startswith('p2'):
|
183 |
+
video_bytes = client.get(video_path)
|
184 |
+
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
|
185 |
+
else:
|
186 |
+
video_reader = VideoReader(video_path, num_threads=1)
|
187 |
+
vlen = len(video_reader)
|
188 |
+
fps = video_reader.get_avg_fps()
|
189 |
+
duration = vlen / float(fps)
|
190 |
+
|
191 |
+
if clip:
|
192 |
+
start, end = clip
|
193 |
+
duration = end - start
|
194 |
+
vlen = int(duration * fps)
|
195 |
+
start_index = int(start * fps)
|
196 |
+
|
197 |
+
frame_indices = get_frame_indices(
|
198 |
+
num_frames, vlen, sample=sample, fix_start=fix_start,
|
199 |
+
input_fps=fps, max_num_frames=max_num_frames
|
200 |
+
)
|
201 |
+
if clip:
|
202 |
+
frame_indices = [f + start_index for f in frame_indices]
|
203 |
+
|
204 |
+
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
|
205 |
+
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
|
206 |
+
return frames, frame_indices, float(fps)
|
207 |
+
|
208 |
+
|
209 |
+
VIDEO_READER_FUNCS = {
|
210 |
+
'av': read_frames_av,
|
211 |
+
'decord': read_frames_decord,
|
212 |
+
'gif': read_frames_gif,
|
213 |
+
'hdfs': read_frames_hdfs,
|
214 |
+
}
|
docs/PoolLLaVA_Report.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b9f175bd915cdc6f9791a95149992fde1f48ebfffa6c8bff9e6365b7186c57d
|
3 |
+
size 3850702
|
example/1917.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99f5f2a10985964ddc0555a8fa12b9d41f130b49ad62879a9e150d91834e93d5
|
3 |
+
size 1535936
|
example/bear.jpg
ADDED
Git LFS Details
|
example/cooking.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a1395530cc13c0441ae99ce66477f533f6009ebdb913064aec91e38eaf3b8e9
|
3 |
+
size 876622
|
example/dog.png
ADDED
Git LFS Details
|
example/jesse_dance.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1fc41c6ebae0692726ea56b33ba711f21186fd4203ac54cd43a5cd898be4350
|
3 |
+
size 1221420
|
example/working.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09372cdb6b0ea272868b4469d5067674670a948962f1236196e8f23e1f7ce764
|
3 |
+
size 4718899
|
example/yoga.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74b65d9bec7f83e487b7f923076c01d476dd2ef7ed83928a696ab6f88c7751b7
|
3 |
+
size 776184
|
models/__init__.py
ADDED
File without changes
|
models/pllava/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import TYPE_CHECKING
|
15 |
+
|
16 |
+
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
17 |
+
|
18 |
+
|
19 |
+
_import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]}
|
20 |
+
|
21 |
+
try:
|
22 |
+
if not is_torch_available():
|
23 |
+
raise OptionalDependencyNotAvailable()
|
24 |
+
except OptionalDependencyNotAvailable:
|
25 |
+
pass
|
26 |
+
else:
|
27 |
+
_import_structure["modeling_pllava"] = [
|
28 |
+
"PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
29 |
+
"PllavaForConditionalGeneration",
|
30 |
+
"PllavaPreTrainedModel",
|
31 |
+
]
|
32 |
+
_import_structure["processing_pllava"] = ["PllavaProcessor"]
|
33 |
+
|
34 |
+
|
35 |
+
if TYPE_CHECKING:
|
36 |
+
from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig
|
37 |
+
|
38 |
+
try:
|
39 |
+
if not is_torch_available():
|
40 |
+
raise OptionalDependencyNotAvailable()
|
41 |
+
except OptionalDependencyNotAvailable:
|
42 |
+
pass
|
43 |
+
else:
|
44 |
+
from .modeling_pllava import (
|
45 |
+
PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
46 |
+
PllavaForConditionalGeneration,
|
47 |
+
PllavaPreTrainedModel,
|
48 |
+
)
|
49 |
+
from .processing_pllava import PllavaProcessor
|
50 |
+
|
51 |
+
|
52 |
+
else:
|
53 |
+
import sys
|
54 |
+
|
55 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
models/pllava/configuration_pllava.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
""" Llava model configuration"""
|
15 |
+
|
16 |
+
from transformers.configuration_utils import PretrainedConfig
|
17 |
+
from transformers.utils import logging
|
18 |
+
from transformers.models.auto import CONFIG_MAPPING
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
24 |
+
"llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class PllavaConfig(PretrainedConfig):
|
29 |
+
r"""
|
30 |
+
This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
|
31 |
+
Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
32 |
+
with the defaults will yield a similar configuration to that of the Llava-9B.
|
33 |
+
|
34 |
+
e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
|
35 |
+
|
36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
37 |
+
documentation from [`PretrainedConfig`] for more information.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
vision_config (`LlavaVisionConfig`, *optional*):
|
41 |
+
Custom vision config or dict
|
42 |
+
text_config (`Union[AutoConfig, dict]`, *optional*):
|
43 |
+
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
|
44 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
45 |
+
The ignore index for the loss function.
|
46 |
+
image_token_index (`int`, *optional*, defaults to 32000):
|
47 |
+
The image token index to encode the image prompt.
|
48 |
+
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
49 |
+
The activation function used by the multimodal projector.
|
50 |
+
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
51 |
+
The feature selection strategy used to select the vision feature from the CLIP backbone.
|
52 |
+
vision_feature_layer (`int`, *optional*, defaults to -2):
|
53 |
+
The index of the layer to select the vision feature.
|
54 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
55 |
+
Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
|
56 |
+
`inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
|
57 |
+
|
58 |
+
Example:
|
59 |
+
|
60 |
+
```python
|
61 |
+
>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
|
62 |
+
|
63 |
+
>>> # Initializing a CLIP-vision config
|
64 |
+
>>> vision_config = CLIPVisionConfig()
|
65 |
+
|
66 |
+
>>> # Initializing a Llama config
|
67 |
+
>>> text_config = LlamaConfig()
|
68 |
+
|
69 |
+
>>> # Initializing a Llava llava-1.5-7b style configuration
|
70 |
+
>>> configuration = LlavaConfig(vision_config, text_config)
|
71 |
+
|
72 |
+
>>> # Initializing a model from the llava-1.5-7b style configuration
|
73 |
+
>>> model = LlavaForConditionalGeneration(configuration)
|
74 |
+
|
75 |
+
>>> # Accessing the model configuration
|
76 |
+
>>> configuration = model.config
|
77 |
+
```"""
|
78 |
+
|
79 |
+
model_type = "llava"
|
80 |
+
is_composition = False
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
vision_config=None,
|
85 |
+
text_config=None,
|
86 |
+
ignore_index=-100,
|
87 |
+
image_token_index=32000,
|
88 |
+
projector_hidden_act="gelu",
|
89 |
+
vision_feature_select_strategy="default",
|
90 |
+
vision_feature_layer=-2,
|
91 |
+
vocab_size=32000,
|
92 |
+
pooling_method='avg',
|
93 |
+
pooling_shape=(8, 16, 16),
|
94 |
+
frame_shape=(24, 24), # llava 1.5 pretrained frame shape
|
95 |
+
num_frames=1, # llava 1.5 pretrained frame shape
|
96 |
+
use_pooling=True,
|
97 |
+
gradient_checkpointing=False,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
self.ignore_index = ignore_index
|
101 |
+
self.image_token_index = image_token_index
|
102 |
+
self.projector_hidden_act = projector_hidden_act
|
103 |
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
104 |
+
self.vision_feature_layer = vision_feature_layer
|
105 |
+
self.vocab_size = vocab_size
|
106 |
+
self.use_pooling = use_pooling
|
107 |
+
self.gradient_checkpointing = gradient_checkpointing
|
108 |
+
|
109 |
+
self.vision_config = vision_config
|
110 |
+
|
111 |
+
self.pooling_method = pooling_method # should be in 'max', 'avg'
|
112 |
+
self.pooling_shape = pooling_shape #
|
113 |
+
self.frame_shape = frame_shape #
|
114 |
+
self.num_frames = num_frames
|
115 |
+
if isinstance(self.vision_config, dict):
|
116 |
+
vision_config["model_type"] = (
|
117 |
+
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
118 |
+
)
|
119 |
+
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
120 |
+
elif vision_config is None:
|
121 |
+
self.vision_config = CONFIG_MAPPING["clip_vision_model"](
|
122 |
+
intermediate_size=4096,
|
123 |
+
hidden_size=1024,
|
124 |
+
patch_size=14,
|
125 |
+
image_size=336,
|
126 |
+
num_hidden_layers=24,
|
127 |
+
num_attention_heads=16,
|
128 |
+
vocab_size=32000,
|
129 |
+
projection_dim=768,
|
130 |
+
)
|
131 |
+
self.vocab_size = self.vocab_size
|
132 |
+
|
133 |
+
self.text_config = text_config
|
134 |
+
|
135 |
+
if isinstance(self.text_config, dict):
|
136 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
137 |
+
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
138 |
+
self.vocab_size = self.text_config.vocab_size
|
139 |
+
self.text_config.gradient_checkpointing = self.gradient_checkpointing
|
140 |
+
|
141 |
+
elif text_config is None:
|
142 |
+
tmp_config = {"_attn_implementation":"flash_attention_2",
|
143 |
+
"gradient_checkpointing": self.gradient_checkpointing}
|
144 |
+
self.text_config = CONFIG_MAPPING["llama"](**tmp_config)
|
145 |
+
self.text_config.gradient_checkpointing = self.gradient_checkpointing
|
146 |
+
# self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code
|
147 |
+
|
148 |
+
|
149 |
+
super().__init__(**kwargs)
|
models/pllava/convert_pllava_weights_to_hf.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Not yet
|
models/pllava/modeling_pllava.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch Llava model."""
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
import math
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
from torch import nn
|
23 |
+
import os
|
24 |
+
from transformers import PreTrainedModel
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.cache_utils import Cache
|
27 |
+
from transformers.modeling_outputs import ModelOutput
|
28 |
+
from transformers.utils import (
|
29 |
+
add_start_docstrings,
|
30 |
+
add_start_docstrings_to_model_forward,
|
31 |
+
logging,
|
32 |
+
replace_return_docstrings,
|
33 |
+
)
|
34 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
35 |
+
import einops
|
36 |
+
|
37 |
+
from .configuration_pllava import PllavaConfig
|
38 |
+
import pickle
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
_CONFIG_FOR_DOC = "LlavaConfig"
|
43 |
+
|
44 |
+
PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
+
"",
|
46 |
+
"",
|
47 |
+
"",
|
48 |
+
# See all Llava models at https://huggingface.co/models?filter=llava
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
|
54 |
+
class PllavaCausalLMOutputWithPast(ModelOutput):
|
55 |
+
"""
|
56 |
+
Base class for Llava causal language model (or autoregressive) outputs.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
60 |
+
Language modeling loss (for next-token prediction).
|
61 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
62 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
63 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
64 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
65 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
66 |
+
|
67 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
68 |
+
`past_key_values` input) to speed up sequential decoding.
|
69 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
70 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
71 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
72 |
+
|
73 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
74 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
75 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
76 |
+
sequence_length)`.
|
77 |
+
|
78 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
79 |
+
heads.
|
80 |
+
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
81 |
+
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
82 |
+
sequence_length, hidden_size)`.
|
83 |
+
|
84 |
+
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
85 |
+
"""
|
86 |
+
|
87 |
+
loss: Optional[torch.FloatTensor] = None
|
88 |
+
logits: torch.FloatTensor = None
|
89 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
90 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
91 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
92 |
+
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
93 |
+
|
94 |
+
class PllavaMultiModalProjector(nn.Module):
|
95 |
+
supported_highres = ['pad_crop_four', 'slide', ]
|
96 |
+
def __init__(self, config: PllavaConfig):
|
97 |
+
super().__init__()
|
98 |
+
self.use_pooling = config.use_pooling
|
99 |
+
self.frame_shape=config.frame_shape
|
100 |
+
self.num_frames = config.num_frames
|
101 |
+
self.pooling_shape = config.pooling_shape
|
102 |
+
|
103 |
+
self.pooling = nn.AdaptiveAvgPool3d(config.pooling_shape)
|
104 |
+
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
105 |
+
self.act = ACT2FN[config.projector_hidden_act]
|
106 |
+
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
107 |
+
|
108 |
+
def convert_Fembeddings2video(self, input, num_videos, frame_shape):
|
109 |
+
input = einops.rearrange(input,
|
110 |
+
'(num_videos num_frames) (h w) embed_dims -> num_videos embed_dims num_frames h w',
|
111 |
+
num_videos=num_videos, h=frame_shape[0])
|
112 |
+
return input
|
113 |
+
|
114 |
+
def convert_video2Fembeddings(self, input):
|
115 |
+
input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> (num_videos num_frames) (h w) embed_dims ', )
|
116 |
+
return input
|
117 |
+
|
118 |
+
def convert_video2MMembeddings(self, input):
|
119 |
+
input = einops.rearrange(input, 'num_videos embed_dims num_frames h w -> num_videos (num_frames h w) embed_dims ', )
|
120 |
+
return input
|
121 |
+
|
122 |
+
def forward(self, image_features, media_type, batch_size=None, num_videos=None):
|
123 |
+
frame_shape = self.frame_shape
|
124 |
+
num_frames = self.num_frames
|
125 |
+
assert media_type in ( 'video', 'image'), f'only image or video, but got media_type {media_type}'
|
126 |
+
hidden_states = image_features
|
127 |
+
|
128 |
+
if media_type == 'image':
|
129 |
+
hidden_states = hidden_states.repeat(num_frames, 1, 1)
|
130 |
+
|
131 |
+
total_frames, spatial_seqlen, embed_dims = hidden_states.shape
|
132 |
+
#TODO: temporal code, should ensure num_frames == total frames in data loading later
|
133 |
+
if total_frames < num_frames and self.use_pooling: #
|
134 |
+
multiplier = int(num_frames/total_frames)+1
|
135 |
+
hidden_states= hidden_states.repeat_interleave(multiplier, dim=0)[:num_frames]
|
136 |
+
total_frames, spatial_seqlen, embed_dims = hidden_states.shape
|
137 |
+
|
138 |
+
assert total_frames % num_frames == 0
|
139 |
+
assert frame_shape[0] * frame_shape[1] == spatial_seqlen
|
140 |
+
hidden_states = self.linear_1(hidden_states)
|
141 |
+
hidden_states = self.act(hidden_states)
|
142 |
+
hidden_states = self.linear_2(hidden_states)
|
143 |
+
hidden_states_videos = self.convert_Fembeddings2video(hidden_states, num_videos * batch_size, frame_shape)
|
144 |
+
hidden_states_videos = self.pooling(hidden_states_videos)
|
145 |
+
hidden_states = einops.rearrange(hidden_states_videos, 'batch_size_num_videos embed_dims num_frames h w -> batch_size_num_videos num_frames (h w) embed_dims', )
|
146 |
+
hidden_states = einops.rearrange(hidden_states, 'batch_size_num_videos num_frames hw embed_dims -> batch_size_num_videos (num_frames hw) embed_dims ')
|
147 |
+
return hidden_states
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
PLLAVA_START_DOCSTRING = r"""
|
152 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
153 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
154 |
+
etc.)
|
155 |
+
|
156 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
157 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
158 |
+
and behavior.
|
159 |
+
|
160 |
+
Parameters:
|
161 |
+
config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
|
162 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
163 |
+
load the weights associated with the model, only the configuration. Check out the
|
164 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
165 |
+
"""
|
166 |
+
|
167 |
+
|
168 |
+
@add_start_docstrings(
|
169 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
170 |
+
PLLAVA_START_DOCSTRING,
|
171 |
+
)
|
172 |
+
class PllavaPreTrainedModel(PreTrainedModel):
|
173 |
+
config_class = PllavaConfig
|
174 |
+
base_model_prefix = "model"
|
175 |
+
supports_gradient_checkpointing = True
|
176 |
+
_no_split_modules = ["LlavaVisionAttention"]
|
177 |
+
_skip_keys_device_placement = "past_key_values"
|
178 |
+
_supports_flash_attn_2 = True
|
179 |
+
|
180 |
+
def _init_weights(self, module):
|
181 |
+
# important: this ported version of Llava isn't meant for training from scratch - only
|
182 |
+
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
183 |
+
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
|
184 |
+
std = (
|
185 |
+
self.config.initializer_range
|
186 |
+
if hasattr(self.config, "initializer_range")
|
187 |
+
else self.config.text_config.initializer_range
|
188 |
+
)
|
189 |
+
|
190 |
+
if hasattr(module, "class_embedding"):
|
191 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
192 |
+
|
193 |
+
# if isinstance(module, (nn.Linear, nn.Conv2d)):
|
194 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
195 |
+
# if module.bias is not None:
|
196 |
+
# module.bias.data.zero_()
|
197 |
+
|
198 |
+
elif isinstance(module, nn.Embedding):
|
199 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
200 |
+
if module.padding_idx is not None:
|
201 |
+
module.weight.data[module.padding_idx].zero_()
|
202 |
+
|
203 |
+
elif isinstance(module, PllavaMultiModalProjector):
|
204 |
+
# module.register_embed.data.normal_(mean=0.0, std=std)
|
205 |
+
if self.config.register:
|
206 |
+
module.register_embed.data.zero_()
|
207 |
+
|
208 |
+
@property
|
209 |
+
def _supports_sdpa(self):
|
210 |
+
"""
|
211 |
+
Retrieve language_model's attribute to check whether the model supports
|
212 |
+
SDPA or not.
|
213 |
+
"""
|
214 |
+
return self.language_model._supports_sdpa
|
215 |
+
|
216 |
+
|
217 |
+
PLLAVA_INPUTS_DOCSTRING = r"""
|
218 |
+
Args:
|
219 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
220 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
221 |
+
it.
|
222 |
+
|
223 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
224 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
225 |
+
|
226 |
+
[What are input IDs?](../glossary#input-ids)
|
227 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
228 |
+
The tensors corresponding to the input images. Pixel values can be obtained using
|
229 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
230 |
+
[`CLIPImageProcessor`] for processing images).
|
231 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
232 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
233 |
+
|
234 |
+
- 1 for tokens that are **not masked**,
|
235 |
+
- 0 for tokens that are **masked**.
|
236 |
+
|
237 |
+
[What are attention masks?](../glossary#attention-mask)
|
238 |
+
|
239 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
240 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
241 |
+
|
242 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
243 |
+
`past_key_values`).
|
244 |
+
|
245 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
246 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
247 |
+
information on the default strategy.
|
248 |
+
|
249 |
+
- 1 indicates the head is **not masked**,
|
250 |
+
- 0 indicates the head is **masked**.
|
251 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
252 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
253 |
+
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
254 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
255 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
256 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
257 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
258 |
+
|
259 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
260 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
261 |
+
|
262 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
263 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
264 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
265 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
266 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
267 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
268 |
+
model's internal embedding lookup matrix.
|
269 |
+
use_cache (`bool`, *optional*):
|
270 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
271 |
+
`past_key_values`).
|
272 |
+
output_attentions (`bool`, *optional*):
|
273 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
274 |
+
tensors for more detail.
|
275 |
+
output_hidden_states (`bool`, *optional*):
|
276 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
277 |
+
more detail.
|
278 |
+
return_dict (`bool`, *optional*):
|
279 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
280 |
+
"""
|
281 |
+
|
282 |
+
|
283 |
+
@add_start_docstrings(
|
284 |
+
"""The LLAVA model which consists of a vision backbone and a language model.""",
|
285 |
+
PLLAVA_START_DOCSTRING,
|
286 |
+
)
|
287 |
+
class PllavaForConditionalGeneration(PllavaPreTrainedModel):
|
288 |
+
def __init__(self, config: PllavaConfig):
|
289 |
+
super().__init__(config)
|
290 |
+
self.config = config
|
291 |
+
self.vision_tower = AutoModel.from_config(config.vision_config)
|
292 |
+
self.multi_modal_projector = PllavaMultiModalProjector(config)
|
293 |
+
self.vocab_size = config.vocab_size
|
294 |
+
# self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="flash_attention_2")
|
295 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation="eager")
|
296 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.text_config.pad_token_id
|
297 |
+
assert self.pad_token_id is not None, 'provide the model with pad_token_id, this would be used to arranging new embedings'
|
298 |
+
self.post_init()
|
299 |
+
|
300 |
+
def get_input_embeddings(self):
|
301 |
+
return self.language_model.get_input_embeddings()
|
302 |
+
|
303 |
+
def set_input_embeddings(self, value):
|
304 |
+
self.language_model.set_input_embeddings(value)
|
305 |
+
|
306 |
+
def get_output_embeddings(self):
|
307 |
+
return self.language_model.get_output_embeddings()
|
308 |
+
|
309 |
+
def set_output_embeddings(self, new_embeddings):
|
310 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
311 |
+
|
312 |
+
def set_decoder(self, decoder):
|
313 |
+
self.language_model.set_decoder(decoder)
|
314 |
+
|
315 |
+
def get_decoder(self):
|
316 |
+
return self.language_model.get_decoder()
|
317 |
+
|
318 |
+
def tie_weights(self):
|
319 |
+
return self.language_model.tie_weights()
|
320 |
+
|
321 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
322 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
323 |
+
# update vocab size
|
324 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
325 |
+
self.config.vocab_size = model_embeds.num_embeddings
|
326 |
+
self.vocab_size = model_embeds.num_embeddings
|
327 |
+
return model_embeds
|
328 |
+
|
329 |
+
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
330 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
331 |
+
batch_size, sequence_length = input_ids.shape
|
332 |
+
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
333 |
+
# 1. Create a mask to know where special image tokens are
|
334 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
335 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
336 |
+
# Compute the maximum embed dimension
|
337 |
+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
338 |
+
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
339 |
+
|
340 |
+
# 2. Compute the positions where text should be written
|
341 |
+
# Calculate new positions for text tokens in merged image-text sequence.
|
342 |
+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
343 |
+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
344 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
345 |
+
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
346 |
+
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
347 |
+
if left_padding:
|
348 |
+
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
349 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
350 |
+
|
351 |
+
# 3. Create the full embedding, already padded to the maximum position
|
352 |
+
final_embedding = torch.zeros(
|
353 |
+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
354 |
+
)
|
355 |
+
final_attention_mask = torch.zeros(
|
356 |
+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
357 |
+
)
|
358 |
+
if labels is not None:
|
359 |
+
final_labels = torch.full(
|
360 |
+
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
361 |
+
)
|
362 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
363 |
+
# set the corresponding tensors into their correct target device.
|
364 |
+
target_device = inputs_embeds.device
|
365 |
+
batch_indices, non_image_indices, text_to_overwrite = (
|
366 |
+
batch_indices.to(target_device),
|
367 |
+
non_image_indices.to(target_device),
|
368 |
+
text_to_overwrite.to(target_device),
|
369 |
+
)
|
370 |
+
attention_mask = attention_mask.to(target_device)
|
371 |
+
|
372 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
373 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
374 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
375 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
376 |
+
if labels is not None:
|
377 |
+
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
378 |
+
|
379 |
+
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
380 |
+
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
381 |
+
image_to_overwrite &= image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)
|
382 |
+
|
383 |
+
# # somthing really weird here.
|
384 |
+
# temp1 = (image_to_overwrite.cumsum(-1) > nb_image_pad[:, None].to(target_device)) & image_to_overwrite
|
385 |
+
# # this is for right padding
|
386 |
+
# temp2 = (image_to_overwrite.cumsum(-1) <= num_special_image_tokens.max() * num_image_patches - nb_image_pad[:, None]) & image_to_overwrite
|
387 |
+
|
388 |
+
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
389 |
+
raise ValueError(
|
390 |
+
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
391 |
+
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
392 |
+
)
|
393 |
+
|
394 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
395 |
+
final_attention_mask |= image_to_overwrite
|
396 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
397 |
+
|
398 |
+
if labels is None:
|
399 |
+
final_labels = None
|
400 |
+
|
401 |
+
return final_embedding, final_attention_mask, final_labels, position_ids
|
402 |
+
|
403 |
+
@add_start_docstrings_to_model_forward(PLLAVA_INPUTS_DOCSTRING)
|
404 |
+
@replace_return_docstrings(output_type=PllavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
405 |
+
def forward(
|
406 |
+
self,
|
407 |
+
input_ids: torch.LongTensor = None,
|
408 |
+
pixel_values: torch.FloatTensor = None,
|
409 |
+
attention_mask: Optional[torch.Tensor] = None,
|
410 |
+
media_type: str = None,
|
411 |
+
position_ids: Optional[torch.LongTensor] = None,
|
412 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
413 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
414 |
+
vision_feature_layer: Optional[int] = None,
|
415 |
+
vision_feature_select_strategy: Optional[str] = None,
|
416 |
+
labels: Optional[torch.LongTensor] = None,
|
417 |
+
use_cache: Optional[bool] = None,
|
418 |
+
output_attentions: Optional[bool] = None,
|
419 |
+
output_hidden_states: Optional[bool] = None,
|
420 |
+
return_dict: Optional[bool] = None,
|
421 |
+
) -> Union[Tuple, PllavaCausalLMOutputWithPast]:
|
422 |
+
r"""
|
423 |
+
Args:
|
424 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
425 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
426 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
427 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
|
431 |
+
Example:
|
432 |
+
|
433 |
+
```python
|
434 |
+
>>> from PIL import Image
|
435 |
+
>>> import requests
|
436 |
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
437 |
+
|
438 |
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
439 |
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
440 |
+
|
441 |
+
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
|
442 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
443 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
444 |
+
|
445 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
446 |
+
|
447 |
+
>>> # Generate
|
448 |
+
>>> generate_ids = model.generate(**inputs, max_length=30)
|
449 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
450 |
+
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
|
451 |
+
```"""
|
452 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
453 |
+
output_hidden_states = (
|
454 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
455 |
+
)
|
456 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
457 |
+
vision_feature_layer = (
|
458 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
459 |
+
)
|
460 |
+
vision_feature_select_strategy = (
|
461 |
+
vision_feature_select_strategy
|
462 |
+
if vision_feature_select_strategy is not None
|
463 |
+
else self.config.vision_feature_select_strategy
|
464 |
+
)
|
465 |
+
|
466 |
+
if inputs_embeds is None:
|
467 |
+
# 1. Extra the input embeddings
|
468 |
+
no_img_input_ids = torch.where(input_ids!=self.config.image_token_index, input_ids, self.pad_token_id) # some model used up all the embeddings
|
469 |
+
inputs_embeds = self.get_input_embeddings()(no_img_input_ids)
|
470 |
+
batch_size = inputs_embeds.shape[0]
|
471 |
+
# 2. Merge text and images
|
472 |
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
473 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
474 |
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
475 |
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer] # ( b, img_seqlen, embed_dim)
|
476 |
+
if vision_feature_select_strategy == "default":
|
477 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
478 |
+
elif vision_feature_select_strategy == "full":
|
479 |
+
raise ValueError("not implemented")
|
480 |
+
selected_image_feature = selected_image_feature
|
481 |
+
else:
|
482 |
+
raise ValueError(
|
483 |
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
484 |
+
)
|
485 |
+
|
486 |
+
image_features = self.multi_modal_projector(selected_image_feature,
|
487 |
+
media_type,
|
488 |
+
batch_size=batch_size,
|
489 |
+
num_videos=pixel_values.shape[0]//self.config.num_frames//batch_size,)
|
490 |
+
|
491 |
+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
492 |
+
image_features, inputs_embeds, input_ids, attention_mask, labels
|
493 |
+
)
|
494 |
+
if labels is None:
|
495 |
+
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
496 |
+
else:
|
497 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
498 |
+
# generation with cache
|
499 |
+
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
500 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
501 |
+
# that are set to 0
|
502 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
503 |
+
|
504 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
505 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
506 |
+
|
507 |
+
# Get the target length
|
508 |
+
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
509 |
+
|
510 |
+
extended_attention_mask = torch.ones(
|
511 |
+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
512 |
+
dtype=attention_mask.dtype,
|
513 |
+
device=attention_mask.device,
|
514 |
+
)
|
515 |
+
|
516 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
517 |
+
# if one uses Llava + Fused modules where the cache on the
|
518 |
+
# first iteration is already big enough, or if one passes custom cache
|
519 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
520 |
+
new_batch_index = batch_index[valid_indices]
|
521 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
522 |
+
|
523 |
+
# Zero-out the places where we don't need to attend
|
524 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
525 |
+
|
526 |
+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
527 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
528 |
+
|
529 |
+
outputs = self.language_model(
|
530 |
+
attention_mask=attention_mask,
|
531 |
+
position_ids=position_ids,
|
532 |
+
past_key_values=past_key_values,
|
533 |
+
inputs_embeds=inputs_embeds,
|
534 |
+
use_cache=use_cache,
|
535 |
+
output_attentions=output_attentions,
|
536 |
+
output_hidden_states=output_hidden_states,
|
537 |
+
return_dict=return_dict,
|
538 |
+
)
|
539 |
+
|
540 |
+
logits = outputs[0]
|
541 |
+
|
542 |
+
loss = None
|
543 |
+
if labels is not None:
|
544 |
+
# Shift so that tokens < n predict n
|
545 |
+
if attention_mask is not None:
|
546 |
+
shift_attention_mask = attention_mask[..., 1:]
|
547 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
548 |
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
549 |
+
else:
|
550 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
551 |
+
shift_labels = labels[..., 1:].contiguous()
|
552 |
+
# Flatten the tokens
|
553 |
+
loss_fct = nn.CrossEntropyLoss()
|
554 |
+
loss = loss_fct(
|
555 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
556 |
+
)
|
557 |
+
|
558 |
+
if not return_dict:
|
559 |
+
output = (logits,) + outputs[1:]
|
560 |
+
return (loss,) + output if loss is not None else output
|
561 |
+
|
562 |
+
return PllavaCausalLMOutputWithPast(
|
563 |
+
loss=loss,
|
564 |
+
logits=logits,
|
565 |
+
past_key_values=outputs.past_key_values,
|
566 |
+
hidden_states=outputs.hidden_states,
|
567 |
+
attentions=outputs.attentions,
|
568 |
+
)
|
569 |
+
|
570 |
+
def prepare_inputs_for_generation(
|
571 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
572 |
+
):
|
573 |
+
if past_key_values is not None:
|
574 |
+
if isinstance(past_key_values, Cache):
|
575 |
+
cache_length = past_key_values.get_seq_length()
|
576 |
+
past_length = past_key_values.seen_tokens
|
577 |
+
else:
|
578 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
579 |
+
|
580 |
+
# Keep only the unprocessed tokens:
|
581 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
582 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
583 |
+
# input)
|
584 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
585 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
586 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
587 |
+
# input_ids based on the past_length.
|
588 |
+
elif past_length < input_ids.shape[1]:
|
589 |
+
input_ids = input_ids[:, past_length:]
|
590 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
591 |
+
elif self.config.image_token_index in input_ids:
|
592 |
+
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
593 |
+
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
594 |
+
# older attention values, as their corresponding values are not part of the input.
|
595 |
+
if cache_length < past_length and attention_mask is not None:
|
596 |
+
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
597 |
+
|
598 |
+
position_ids = kwargs.get("position_ids", None)
|
599 |
+
if attention_mask is not None and position_ids is None:
|
600 |
+
# create position_ids on the fly for batch generation
|
601 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
602 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
603 |
+
if past_key_values:
|
604 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
605 |
+
|
606 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
607 |
+
if inputs_embeds is not None and past_key_values is None:
|
608 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
609 |
+
else:
|
610 |
+
model_inputs = {"input_ids": input_ids}
|
611 |
+
media_type = kwargs.get('media_type', None)
|
612 |
+
|
613 |
+
model_inputs.update(
|
614 |
+
{
|
615 |
+
"position_ids": position_ids,
|
616 |
+
"past_key_values": past_key_values,
|
617 |
+
"use_cache": kwargs.get("use_cache"),
|
618 |
+
"attention_mask": attention_mask,
|
619 |
+
"pixel_values": pixel_values,
|
620 |
+
"media_type": media_type,
|
621 |
+
}
|
622 |
+
)
|
623 |
+
return model_inputs
|
624 |
+
|
625 |
+
def _reorder_cache(self, *args, **kwargs):
|
626 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
models/pllava/processing_pllava.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for Llava.
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
import itertools
|
21 |
+
from typing import List, Optional, Union
|
22 |
+
import PIL.Image
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from transformers import AutoTokenizer
|
26 |
+
from transformers.feature_extraction_utils import BatchFeature
|
27 |
+
from transformers.image_utils import (
|
28 |
+
ImageInput,
|
29 |
+
make_list_of_images,
|
30 |
+
valid_images,
|
31 |
+
infer_channel_dimension_format,
|
32 |
+
to_numpy_array,
|
33 |
+
get_image_size,
|
34 |
+
ChannelDimension,
|
35 |
+
)
|
36 |
+
from transformers.image_processing_utils import get_size_dict
|
37 |
+
from transformers.image_utils import PILImageResampling
|
38 |
+
from transformers.processing_utils import ProcessorMixin
|
39 |
+
from transformers.image_transforms import resize, pad, PaddingMode, to_channel_dimension_format, get_resize_output_image_size
|
40 |
+
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
41 |
+
from transformers.utils import TensorType
|
42 |
+
|
43 |
+
|
44 |
+
class PllavaProcessor(ProcessorMixin):
|
45 |
+
r"""
|
46 |
+
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
|
47 |
+
|
48 |
+
[`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
49 |
+
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image_processor ([`CLIPImageProcessor`], *optional*):
|
53 |
+
The image processor is a required input.
|
54 |
+
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
55 |
+
The tokenizer is a required input.
|
56 |
+
"""
|
57 |
+
|
58 |
+
attributes = ["image_processor", "tokenizer"]
|
59 |
+
image_processor_class = "CLIPImageProcessor"
|
60 |
+
tokenizer_class = "AutoTokenizer"
|
61 |
+
|
62 |
+
def __init__(self, image_processor=None, tokenizer=None,
|
63 |
+
shortest_edge=336,
|
64 |
+
longest_edge=762,
|
65 |
+
center_pad=False):
|
66 |
+
self.shortest_edge = shortest_edge
|
67 |
+
self.longest_edge = longest_edge
|
68 |
+
self.center_pad = center_pad
|
69 |
+
super().__init__(image_processor, tokenizer)
|
70 |
+
|
71 |
+
def resize_crop_longshort(self, videos: list[list[np.ndarray]], input_data_format):
|
72 |
+
video_spatial_sizes = [get_image_size(images[0], input_data_format) for images in videos]
|
73 |
+
long_short_rates = [max(size) / min(size) for size in video_spatial_sizes]
|
74 |
+
min_long_short_rate = min(long_short_rates)
|
75 |
+
min_long_short_video_idx = long_short_rates.index(min_long_short_rate)
|
76 |
+
|
77 |
+
clip_resolution = self.image_processor.size['shortest_edge']
|
78 |
+
out_video_spatial_size = video_spatial_sizes[min_long_short_video_idx]
|
79 |
+
out_videos_short_edge = max(min(size) for size in video_spatial_sizes)
|
80 |
+
resize_longest_edge = max(max(size) for size in video_spatial_sizes)
|
81 |
+
resize_longest_edge = min(640, resize_longest_edge)
|
82 |
+
out_videos_short_edge = min(out_videos_short_edge, int(resize_longest_edge / min_long_short_rate))
|
83 |
+
out_videos_short_edge = max(out_videos_short_edge, clip_resolution)
|
84 |
+
|
85 |
+
|
86 |
+
if out_video_spatial_size[0] > out_video_spatial_size[1]: # h > w:
|
87 |
+
out_video_spatial_size = (int(out_videos_short_edge * min_long_short_rate), out_videos_short_edge )
|
88 |
+
else:
|
89 |
+
out_video_spatial_size = ( out_videos_short_edge, int(out_videos_short_edge * min_long_short_rate) )
|
90 |
+
videos = [
|
91 |
+
[self.resize(frame, input_data_format=input_data_format, shortest_edge=out_videos_short_edge, longest_edge=9999) for frame in frames]
|
92 |
+
for frames in videos
|
93 |
+
]
|
94 |
+
out_videos = []
|
95 |
+
for frames in videos:
|
96 |
+
out_frames = []
|
97 |
+
video_spatial_size = get_image_size(frames[0], input_data_format)
|
98 |
+
assert min(video_spatial_size) == out_videos_short_edge
|
99 |
+
overhead = (max(video_spatial_size) - max(out_video_spatial_size)) // 2
|
100 |
+
slice_start, slice_end = overhead // 2, overhead // 2 + max(out_video_spatial_size)
|
101 |
+
hslice, wslice = (slice(slice_start, slice_end), slice(None, None)) if video_spatial_size[0] > video_spatial_size[1] \
|
102 |
+
else (slice(None, None), slice(slice_start, slice_end)) # h > w
|
103 |
+
for frame in frames:
|
104 |
+
if input_data_format == ChannelDimension.FIRST:
|
105 |
+
out_frames.append(frame[..., hslice, wslice])
|
106 |
+
elif input_data_format == ChannelDimension.LAST:
|
107 |
+
out_frames.append(frame[..., hslice, wslice, :])
|
108 |
+
out_videos.append(out_frames)
|
109 |
+
|
110 |
+
return out_videos
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def _compute_num_blocks_and_overlaps(input_shape, resolution):
|
114 |
+
input_shape = np.array(input_shape)
|
115 |
+
resolution = np.array(resolution)
|
116 |
+
assert input_shape.max() >= resolution
|
117 |
+
num_blocks = np.ceil(input_shape / resolution).astype(np.int32).tolist()
|
118 |
+
overlaps = [0 if size % resolution==0
|
119 |
+
else int(np.floor((resolution - size % resolution) / (num_block - 1))) for num_block, size in zip(num_blocks, input_shape)]
|
120 |
+
return num_blocks, overlaps
|
121 |
+
|
122 |
+
def resize(
|
123 |
+
self,
|
124 |
+
image: np.ndarray,
|
125 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC, # type: ignore
|
126 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
127 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
128 |
+
shortest_edge: int = None,
|
129 |
+
longest_edge: int = None,
|
130 |
+
**kwargs,
|
131 |
+
) -> np.ndarray:
|
132 |
+
"""
|
133 |
+
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
134 |
+
resized to keep the input aspect ratio.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
image (`np.ndarray`):
|
138 |
+
Image to resize.
|
139 |
+
size (`Dict[str, int]`):
|
140 |
+
Size of the output image.
|
141 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
142 |
+
Resampling filter to use when resiizing the image.
|
143 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
144 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
145 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
146 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
147 |
+
"""
|
148 |
+
shortest_edge = getattr(self, 'shortest_edge', None) if shortest_edge is None else shortest_edge
|
149 |
+
longest_edge = getattr(self, 'longest_edge', None) if longest_edge is None else longest_edge
|
150 |
+
default_to_square = False
|
151 |
+
output_size = get_resize_output_image_size(
|
152 |
+
image,
|
153 |
+
size=shortest_edge,
|
154 |
+
default_to_square=default_to_square,
|
155 |
+
max_size=longest_edge,
|
156 |
+
input_data_format=input_data_format,
|
157 |
+
)
|
158 |
+
clip_resolution = self.image_processor.size['shortest_edge']
|
159 |
+
if min(output_size) < clip_resolution:
|
160 |
+
output_size = get_resize_output_image_size(
|
161 |
+
image,
|
162 |
+
size=shortest_edge,
|
163 |
+
default_to_square=default_to_square,
|
164 |
+
input_data_format=input_data_format,
|
165 |
+
)
|
166 |
+
return resize(
|
167 |
+
image,
|
168 |
+
size=output_size,
|
169 |
+
resample=resample,
|
170 |
+
data_format=data_format,
|
171 |
+
input_data_format=input_data_format,
|
172 |
+
**kwargs,
|
173 |
+
)
|
174 |
+
|
175 |
+
def __call__(
|
176 |
+
self,
|
177 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
178 |
+
images: ImageInput = None,
|
179 |
+
center_pad = None,
|
180 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
181 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
182 |
+
max_length=None,
|
183 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
184 |
+
) -> BatchFeature:
|
185 |
+
"""
|
186 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
187 |
+
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
188 |
+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
189 |
+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
190 |
+
of the above two methods for more information.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
194 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
195 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
196 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
197 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
198 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
199 |
+
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
200 |
+
number of channels, H and W are image height and width.
|
201 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
202 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
203 |
+
index) among:
|
204 |
+
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
205 |
+
sequence if provided).
|
206 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
207 |
+
acceptable input length for the model if that argument is not provided.
|
208 |
+
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
209 |
+
lengths).
|
210 |
+
max_length (`int`, *optional*):
|
211 |
+
Maximum length of the returned list and optionally padding length (see above).
|
212 |
+
truncation (`bool`, *optional*):
|
213 |
+
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
214 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
215 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
216 |
+
|
217 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
218 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
219 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
220 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
224 |
+
|
225 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
226 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
227 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
228 |
+
`None`).
|
229 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
230 |
+
"""
|
231 |
+
data=dict()
|
232 |
+
if images is not None:
|
233 |
+
if isinstance(images, list) and isinstance(images[0], PIL.Image.Image):
|
234 |
+
videos = [images] # one video
|
235 |
+
else:
|
236 |
+
videos = images
|
237 |
+
|
238 |
+
pixel_values_list = []
|
239 |
+
videos = [[to_numpy_array(image) for image in make_list_of_images(images)] for images in videos]
|
240 |
+
# images = [self.resize(image, ) if min(get_image_size(image, input_data_format)) < clip_resolution else image for image in images]
|
241 |
+
input_data_format = infer_channel_dimension_format(videos[0][0])
|
242 |
+
videos = self.resize_crop_longshort(videos, input_data_format)
|
243 |
+
|
244 |
+
for images in videos:
|
245 |
+
if not valid_images(images):
|
246 |
+
raise ValueError(
|
247 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
248 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
249 |
+
)
|
250 |
+
|
251 |
+
center_pad = center_pad if center_pad is not None else self.center_pad
|
252 |
+
if center_pad:
|
253 |
+
images = [self.pad_to_square(image, 0, input_data_format, input_data_format) for image in images]
|
254 |
+
|
255 |
+
pixel_values = self.image_processor(images, return_tensors='np')["pixel_values"]
|
256 |
+
pixel_values_list.append(pixel_values)
|
257 |
+
|
258 |
+
pixel_values = np.concatenate(pixel_values_list)
|
259 |
+
data.update(pixel_values=pixel_values)
|
260 |
+
|
261 |
+
else:
|
262 |
+
data.update(pixel_values = None)
|
263 |
+
|
264 |
+
if text is not None:
|
265 |
+
text_inputs = self.tokenizer(
|
266 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
267 |
+
)
|
268 |
+
data.update(**text_inputs)
|
269 |
+
return BatchFeature(data, tensor_type=return_tensors)
|
270 |
+
|
271 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
272 |
+
def batch_decode(self, *args, **kwargs):
|
273 |
+
"""
|
274 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
275 |
+
refer to the docstring of this method for more information.
|
276 |
+
"""
|
277 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
278 |
+
|
279 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
280 |
+
def decode(self, *args, **kwargs):
|
281 |
+
"""
|
282 |
+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
283 |
+
the docstring of this method for more information.
|
284 |
+
"""
|
285 |
+
return self.tokenizer.decode(*args, **kwargs)
|
286 |
+
|
287 |
+
@property
|
288 |
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
289 |
+
def model_input_names(self):
|
290 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
291 |
+
image_processor_input_names = self.image_processor.model_input_names
|
292 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
python_scripts/hf.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import multiprocessing
|
5 |
+
import functools
|
6 |
+
import huggingface_hub
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
+
|
9 |
+
|
10 |
+
def upload(repo_id, local_dir, path_in_repo, repo_type, token):
|
11 |
+
huggingface_hub.upload_folder(
|
12 |
+
repo_id=repo_id,
|
13 |
+
folder_path=local_dir,
|
14 |
+
path_in_repo=path_in_repo,
|
15 |
+
token=token,
|
16 |
+
repo_type=repo_type
|
17 |
+
)
|
18 |
+
|
19 |
+
def download(repo_id, local_dir, repo_type, token, filter_re=None):
|
20 |
+
files = huggingface_hub.list_repo_files(repo_id, repo_type=repo_type, token=token)
|
21 |
+
if filter_re is not None:
|
22 |
+
files = [file for file in files if re.search(filter_re, file) is not None]
|
23 |
+
pool = multiprocessing.Pool(8)
|
24 |
+
download_func = functools.partial(
|
25 |
+
huggingface_hub.hf_hub_download,
|
26 |
+
repo_id,
|
27 |
+
repo_type=repo_type,
|
28 |
+
local_dir=local_dir,
|
29 |
+
local_dir_use_symlinks=True,
|
30 |
+
token=token
|
31 |
+
)
|
32 |
+
pool.map(download_func, files)
|
33 |
+
print(f'downloaded files {files}')
|
34 |
+
|
35 |
+
|
36 |
+
def upload_file(repo_id, file_path, repo_type, token):
|
37 |
+
huggingface_hub.upload_file(
|
38 |
+
repo_id=repo_id,
|
39 |
+
path_or_fileobj=file_path,
|
40 |
+
path_in_repo=file_path,
|
41 |
+
token=token,
|
42 |
+
repo_type=repo_type,
|
43 |
+
)
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
read_token = '...'
|
47 |
+
write_token = '...'
|
48 |
+
repo_id = '...'
|
49 |
+
local_dir = '...'
|
50 |
+
repo_type = '...'
|
51 |
+
|
52 |
+
|
53 |
+
# #############
|
54 |
+
# # Examples on most simple hf usage
|
55 |
+
# # downlaod
|
56 |
+
# filters = []
|
57 |
+
# for filter_re in filters:
|
58 |
+
# download(repo_id,
|
59 |
+
# local_dir,
|
60 |
+
# repo_type,
|
61 |
+
# filter_re)
|
62 |
+
|
63 |
+
# # upload
|
64 |
+
# upload(repo_id, local_dir, local_dir, repo_type, write_token)
|
65 |
+
# #############
|
66 |
+
|
67 |
+
# download models
|
68 |
+
repo_ids = [
|
69 |
+
'ermu2001/pllava-7b',
|
70 |
+
'ermu2001/pllava-13b',
|
71 |
+
]
|
72 |
+
for repo_id in repo_ids:
|
73 |
+
local_dir = repo_id.replace('ermu2001', 'MODELS')
|
74 |
+
snapshot_download(
|
75 |
+
repo_id,
|
76 |
+
local_dir=local_dir,
|
77 |
+
repo_type='model',
|
78 |
+
local_dir_use_symlinks=True,
|
79 |
+
token=read_token,
|
80 |
+
)
|
requirements.no_torch.txt
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.26.1
|
3 |
+
addict==2.4.0
|
4 |
+
aiofiles==23.2.1
|
5 |
+
aliyun-python-sdk-core==2.15.0
|
6 |
+
aliyun-python-sdk-kms==2.16.2
|
7 |
+
altair==5.2.0
|
8 |
+
annotated-types==0.6.0
|
9 |
+
antlr4-python3-runtime==4.9.3
|
10 |
+
anyio==4.3.0
|
11 |
+
anykeystore==0.2
|
12 |
+
apex==0.9.10.dev0
|
13 |
+
appdirs==1.4.4
|
14 |
+
argcomplete==3.2.3
|
15 |
+
attrs==23.2.0
|
16 |
+
av==10.0.0
|
17 |
+
beautifulsoup4==4.12.3
|
18 |
+
blessed==1.20.0
|
19 |
+
blessings==1.7
|
20 |
+
boto3==1.34.63
|
21 |
+
botocore==1.34.63
|
22 |
+
Brotli==1.1.0
|
23 |
+
cachetools==5.3.3
|
24 |
+
certifi==2024.2.2
|
25 |
+
cffi==1.16.0
|
26 |
+
charset-normalizer==3.3.2
|
27 |
+
click==8.1.7
|
28 |
+
colorama==0.4.6
|
29 |
+
contourpy==1.2.0
|
30 |
+
crcmod==1.7
|
31 |
+
cryptacular==1.6.2
|
32 |
+
cryptography==42.0.5
|
33 |
+
cycler==0.12.1
|
34 |
+
dacite==1.7.0
|
35 |
+
decorator==4.4.2
|
36 |
+
decord==0.6.0
|
37 |
+
deepspeed==0.14.0
|
38 |
+
defusedxml==0.7.1
|
39 |
+
Deprecated==1.2.14
|
40 |
+
dill==0.3.8
|
41 |
+
distro==1.9.0
|
42 |
+
dnspython==2.6.1
|
43 |
+
docker-pycreds==0.4.0
|
44 |
+
einops==0.6.1
|
45 |
+
exceptiongroup==1.2.0
|
46 |
+
fastapi==0.110.0
|
47 |
+
ffmpeg==1.4
|
48 |
+
ffmpy==0.3.2
|
49 |
+
fiftyone==0.23.6
|
50 |
+
fiftyone-brain==0.16.1
|
51 |
+
fiftyone_db==1.1.2
|
52 |
+
filelock==3.9.0
|
53 |
+
flash-attn==2.5.6
|
54 |
+
fonttools==4.49.0
|
55 |
+
fsspec==2024.2.0
|
56 |
+
ftfy==6.1.3
|
57 |
+
future==1.0.0
|
58 |
+
fvcore==0.1.5.post20221221
|
59 |
+
gdown==5.1.0
|
60 |
+
gitdb==4.0.11
|
61 |
+
GitPython==3.1.42
|
62 |
+
glob2==0.7
|
63 |
+
google-auth==2.28.2
|
64 |
+
google-auth-oauthlib==1.2.0
|
65 |
+
gpustat==1.1.1
|
66 |
+
gradio==4.21.0
|
67 |
+
gradio_client==0.12.0
|
68 |
+
graphql-core==3.2.3
|
69 |
+
greenlet==3.0.3
|
70 |
+
grpcio==1.62.1
|
71 |
+
h11==0.14.0
|
72 |
+
h2==4.1.0
|
73 |
+
hjson==3.1.0
|
74 |
+
hpack==4.0.0
|
75 |
+
httpcore==1.0.4
|
76 |
+
httpx==0.27.0
|
77 |
+
huggingface-hub==0.21.4
|
78 |
+
humanize==4.9.0
|
79 |
+
hupper==1.12.1
|
80 |
+
Hypercorn==0.16.0
|
81 |
+
hyperframe==6.0.1
|
82 |
+
idna==3.6
|
83 |
+
idscheck==2.3.0
|
84 |
+
imageio==2.27.0
|
85 |
+
imageio-ffmpeg==0.4.9
|
86 |
+
importlib_metadata==7.0.2
|
87 |
+
importlib_resources==6.3.0
|
88 |
+
inflate64==1.0.0
|
89 |
+
iopath==0.1.10
|
90 |
+
Jinja2==3.1.2
|
91 |
+
jmespath==0.10.0
|
92 |
+
joblib==1.3.2
|
93 |
+
jsonlines==4.0.0
|
94 |
+
jsonschema==4.21.1
|
95 |
+
jsonschema-specifications==2023.12.1
|
96 |
+
kaleido==0.2.1
|
97 |
+
kiwisolver==1.4.5
|
98 |
+
lazy_loader==0.3
|
99 |
+
Markdown==3.6
|
100 |
+
markdown-it-py==3.0.0
|
101 |
+
MarkupSafe==2.1.3
|
102 |
+
matplotlib==3.8.3
|
103 |
+
mdurl==0.1.2
|
104 |
+
mmcv-full==1.7.2
|
105 |
+
model-index==0.1.11
|
106 |
+
mongoengine==0.24.2
|
107 |
+
motor==3.3.2
|
108 |
+
moviepy==1.0.3
|
109 |
+
mpmath==1.3.0
|
110 |
+
multivolumefile==0.2.3
|
111 |
+
networkx==3.2.1
|
112 |
+
ninja==1.11.1.1
|
113 |
+
numpy
|
114 |
+
oauthlib==3.2.2
|
115 |
+
omegaconf==2.3.0
|
116 |
+
openai==1.14.0
|
117 |
+
opencv-python==4.9.0.80
|
118 |
+
opencv-python-headless==4.9.0.80
|
119 |
+
opendatalab==0.0.10
|
120 |
+
openmim==0.3.9
|
121 |
+
openxlab==0.0.36
|
122 |
+
ordered-set==4.1.0
|
123 |
+
orjson==3.9.15
|
124 |
+
oss2==2.17.0
|
125 |
+
packaging==24.0
|
126 |
+
pandas==1.5.3
|
127 |
+
PasteDeploy==3.1.0
|
128 |
+
pathtools==0.1.2
|
129 |
+
pbkdf2==1.3
|
130 |
+
peft==0.10.0
|
131 |
+
pillow==10.2.0
|
132 |
+
plaster==1.1.2
|
133 |
+
plaster-pastedeploy==1.0.1
|
134 |
+
platformdirs==4.2.0
|
135 |
+
plotly==5.20.0
|
136 |
+
portalocker==2.8.2
|
137 |
+
pprintpp==0.4.0
|
138 |
+
priority==2.0.0
|
139 |
+
proglog==0.1.10
|
140 |
+
protobuf==4.23.4
|
141 |
+
psutil==5.9.4
|
142 |
+
py-cpuinfo==9.0.0
|
143 |
+
py7zr==0.21.0
|
144 |
+
pyasn1==0.5.1
|
145 |
+
pyasn1-modules==0.3.0
|
146 |
+
pybcj==1.0.2
|
147 |
+
pycparser==2.21
|
148 |
+
pycryptodome==3.20.0
|
149 |
+
pycryptodomex==3.20.0
|
150 |
+
pydantic==2.6.4
|
151 |
+
pydantic_core==2.16.3
|
152 |
+
pydub==0.25.1
|
153 |
+
Pygments==2.17.2
|
154 |
+
pymongo==4.6.2
|
155 |
+
pynvml==11.5.0
|
156 |
+
pyparsing==3.1.2
|
157 |
+
pyppmd==1.1.0
|
158 |
+
pyramid==2.0.2
|
159 |
+
pyramid-mailer==0.15.1
|
160 |
+
PySocks==1.7.1
|
161 |
+
python-dateutil==2.9.0.post0
|
162 |
+
python-multipart==0.0.9
|
163 |
+
python3-openid==3.2.0
|
164 |
+
pytz==2023.4
|
165 |
+
PyYAML==6.0
|
166 |
+
pyzstd==0.15.9
|
167 |
+
rarfile==4.1
|
168 |
+
referencing==0.33.0
|
169 |
+
regex==2023.12.25
|
170 |
+
repoze.sendmail==4.4.1
|
171 |
+
requests==2.28.2
|
172 |
+
requests-oauthlib==1.4.0
|
173 |
+
retrying==1.3.4
|
174 |
+
rich==13.4.2
|
175 |
+
rpds-py==0.18.0
|
176 |
+
rsa==4.9
|
177 |
+
ruff==0.3.2
|
178 |
+
s3transfer==0.10.1
|
179 |
+
safetensors==0.4.2
|
180 |
+
scikit-image==0.22.0
|
181 |
+
scikit-learn==1.4.1.post1
|
182 |
+
scipy==1.10.1
|
183 |
+
semantic-version==2.10.0
|
184 |
+
sentencepiece==0.2.0
|
185 |
+
sentry-sdk==1.42.0
|
186 |
+
setproctitle==1.3.3
|
187 |
+
shellingham==1.5.4
|
188 |
+
six==1.16.0
|
189 |
+
smmap==5.0.1
|
190 |
+
sniffio==1.3.1
|
191 |
+
sortedcontainers==2.4.0
|
192 |
+
soupsieve==2.5
|
193 |
+
SQLAlchemy==2.0.28
|
194 |
+
sse-starlette==0.10.3
|
195 |
+
sseclient-py==1.8.0
|
196 |
+
starlette==0.36.3
|
197 |
+
strawberry-graphql==0.138.1
|
198 |
+
sympy==1.12
|
199 |
+
tabulate==0.9.0
|
200 |
+
taskgroup==0.0.0a4
|
201 |
+
tenacity==8.2.3
|
202 |
+
tensorboard==2.15.1
|
203 |
+
tensorboard-data-server==0.7.2
|
204 |
+
tensorboardX==2.6.2.2
|
205 |
+
termcolor==2.3.0
|
206 |
+
texttable==1.7.0
|
207 |
+
threadpoolctl==3.3.0
|
208 |
+
tifffile==2024.2.12
|
209 |
+
timm==0.6.12
|
210 |
+
tokenizers==0.15.2
|
211 |
+
tomli==2.0.1
|
212 |
+
tomlkit==0.12.0
|
213 |
+
toolz==0.12.1
|
214 |
+
tqdm==4.65.2
|
215 |
+
transaction==4.0
|
216 |
+
transformers==4.37.1
|
217 |
+
translationstring==1.4
|
218 |
+
triton==2.2.0
|
219 |
+
typer==0.9.0
|
220 |
+
typing_extensions==4.8.0
|
221 |
+
tzdata==2024.1
|
222 |
+
tzlocal==5.2
|
223 |
+
universal-analytics-python3==1.1.1
|
224 |
+
urllib3==1.26.18
|
225 |
+
uvicorn==0.28.0
|
226 |
+
velruse==1.1.1
|
227 |
+
venusian==3.1.0
|
228 |
+
voxel51-eta==0.12.6
|
229 |
+
wandb==0.14.0
|
230 |
+
wcwidth==0.2.13
|
231 |
+
WebOb==1.8.7
|
232 |
+
websockets==11.0.3
|
233 |
+
Werkzeug==3.0.1
|
234 |
+
wrapt==1.16.0
|
235 |
+
wsproto==1.2.0
|
236 |
+
WTForms==3.1.2
|
237 |
+
wtforms-recaptcha==0.3.2
|
238 |
+
xmltodict==0.13.0
|
239 |
+
yacs==0.1.8
|
240 |
+
yapf==0.40.2
|
241 |
+
zipp==3.18.1
|
242 |
+
zope.deprecation==5.0
|
243 |
+
zope.interface==6.2
|
244 |
+
zope.sqlalchemy==3.1
|
requirements.torch.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--index-url https://download.pytorch.org/whl/cu118
|
2 |
+
torch==2.2.1
|
3 |
+
torchaudio==2.2.1
|
4 |
+
torchvision==0.17.1
|
requirements.txt
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.26.1
|
3 |
+
addict==2.4.0
|
4 |
+
aiofiles==23.2.1
|
5 |
+
aliyun-python-sdk-core==2.15.0
|
6 |
+
aliyun-python-sdk-kms==2.16.2
|
7 |
+
altair==5.2.0
|
8 |
+
annotated-types==0.6.0
|
9 |
+
antlr4-python3-runtime==4.9.3
|
10 |
+
anyio==4.3.0
|
11 |
+
anykeystore==0.2
|
12 |
+
apex==0.9.10.dev0
|
13 |
+
appdirs==1.4.4
|
14 |
+
argcomplete==3.2.3
|
15 |
+
attrs==23.2.0
|
16 |
+
av==10.0.0
|
17 |
+
beautifulsoup4==4.12.3
|
18 |
+
blessed==1.20.0
|
19 |
+
blessings==1.7
|
20 |
+
boto3==1.34.63
|
21 |
+
botocore==1.34.63
|
22 |
+
Brotli==1.1.0
|
23 |
+
cachetools==5.3.3
|
24 |
+
certifi==2024.2.2
|
25 |
+
cffi==1.16.0
|
26 |
+
charset-normalizer==3.3.2
|
27 |
+
click==8.1.7
|
28 |
+
colorama==0.4.6
|
29 |
+
contourpy==1.2.0
|
30 |
+
crcmod==1.7
|
31 |
+
cryptacular==1.6.2
|
32 |
+
cryptography==42.0.5
|
33 |
+
cycler==0.12.1
|
34 |
+
dacite==1.7.0
|
35 |
+
decorator==4.4.2
|
36 |
+
decord==0.6.0
|
37 |
+
deepspeed==0.14.0
|
38 |
+
defusedxml==0.7.1
|
39 |
+
Deprecated==1.2.14
|
40 |
+
dill==0.3.8
|
41 |
+
distro==1.9.0
|
42 |
+
dnspython==2.6.1
|
43 |
+
docker-pycreds==0.4.0
|
44 |
+
einops==0.6.1
|
45 |
+
exceptiongroup==1.2.0
|
46 |
+
fastapi==0.110.0
|
47 |
+
ffmpeg==1.4
|
48 |
+
ffmpy==0.3.2
|
49 |
+
fiftyone==0.23.6
|
50 |
+
fiftyone-brain==0.16.1
|
51 |
+
fiftyone_db==1.1.2
|
52 |
+
filelock==3.9.0
|
53 |
+
fonttools==4.49.0
|
54 |
+
fsspec==2024.2.0
|
55 |
+
ftfy==6.1.3
|
56 |
+
future==1.0.0
|
57 |
+
fvcore==0.1.5.post20221221
|
58 |
+
gdown==5.1.0
|
59 |
+
gitdb==4.0.11
|
60 |
+
GitPython==3.1.42
|
61 |
+
glob2==0.7
|
62 |
+
google-auth==2.28.2
|
63 |
+
google-auth-oauthlib==1.2.0
|
64 |
+
gpustat==1.1.1
|
65 |
+
gradio==4.21.0
|
66 |
+
gradio_client==0.12.0
|
67 |
+
graphql-core==3.2.3
|
68 |
+
greenlet==3.0.3
|
69 |
+
grpcio==1.62.1
|
70 |
+
h11==0.14.0
|
71 |
+
h2==4.1.0
|
72 |
+
hjson==3.1.0
|
73 |
+
hpack==4.0.0
|
74 |
+
httpcore==1.0.4
|
75 |
+
httpx==0.27.0
|
76 |
+
huggingface-hub==0.21.4
|
77 |
+
humanize==4.9.0
|
78 |
+
hupper==1.12.1
|
79 |
+
Hypercorn==0.16.0
|
80 |
+
hyperframe==6.0.1
|
81 |
+
idna==3.6
|
82 |
+
idscheck==2.3.0
|
83 |
+
imageio==2.27.0
|
84 |
+
imageio-ffmpeg==0.4.9
|
85 |
+
importlib_metadata==7.0.2
|
86 |
+
importlib_resources==6.3.0
|
87 |
+
inflate64==1.0.0
|
88 |
+
iopath==0.1.10
|
89 |
+
Jinja2==3.1.2
|
90 |
+
jmespath==0.10.0
|
91 |
+
joblib==1.3.2
|
92 |
+
jsonlines==4.0.0
|
93 |
+
jsonschema==4.21.1
|
94 |
+
jsonschema-specifications==2023.12.1
|
95 |
+
kaleido==0.2.1
|
96 |
+
kiwisolver==1.4.5
|
97 |
+
lazy_loader==0.3
|
98 |
+
Markdown==3.6
|
99 |
+
markdown-it-py==3.0.0
|
100 |
+
MarkupSafe==2.1.3
|
101 |
+
matplotlib==3.8.3
|
102 |
+
mdurl==0.1.2
|
103 |
+
mmcv-full==1.7.2
|
104 |
+
model-index==0.1.11
|
105 |
+
mongoengine==0.24.2
|
106 |
+
motor==3.3.2
|
107 |
+
moviepy==1.0.3
|
108 |
+
mpmath==1.3.0
|
109 |
+
multivolumefile==0.2.3
|
110 |
+
networkx==3.2.1
|
111 |
+
ninja==1.11.1.1
|
112 |
+
numpy==1.23.5
|
113 |
+
oauthlib==3.2.2
|
114 |
+
omegaconf==2.3.0
|
115 |
+
openai==1.14.0
|
116 |
+
opencv-python==4.9.0.80
|
117 |
+
opencv-python-headless==4.9.0.80
|
118 |
+
opendatalab==0.0.10
|
119 |
+
openmim==0.3.9
|
120 |
+
openxlab==0.0.36
|
121 |
+
ordered-set==4.1.0
|
122 |
+
orjson==3.9.15
|
123 |
+
oss2==2.17.0
|
124 |
+
packaging==24.0
|
125 |
+
pandas==1.5.3
|
126 |
+
PasteDeploy==3.1.0
|
127 |
+
pathtools==0.1.2
|
128 |
+
pbkdf2==1.3
|
129 |
+
peft==0.10.0
|
130 |
+
pillow==10.2.0
|
131 |
+
plaster==1.1.2
|
132 |
+
plaster-pastedeploy==1.0.1
|
133 |
+
platformdirs==4.2.0
|
134 |
+
plotly==5.20.0
|
135 |
+
portalocker==2.8.2
|
136 |
+
pprintpp==0.4.0
|
137 |
+
priority==2.0.0
|
138 |
+
proglog==0.1.10
|
139 |
+
protobuf==4.23.4
|
140 |
+
psutil==5.9.4
|
141 |
+
py-cpuinfo==9.0.0
|
142 |
+
py7zr==0.21.0
|
143 |
+
pyasn1==0.5.1
|
144 |
+
pyasn1-modules==0.3.0
|
145 |
+
pybcj==1.0.2
|
146 |
+
pycparser==2.21
|
147 |
+
pycryptodome==3.20.0
|
148 |
+
pycryptodomex==3.20.0
|
149 |
+
pydantic==2.6.4
|
150 |
+
pydantic_core==2.16.3
|
151 |
+
pydub==0.25.1
|
152 |
+
Pygments==2.17.2
|
153 |
+
pymongo==4.6.2
|
154 |
+
pynvml==11.5.0
|
155 |
+
pyparsing==3.1.2
|
156 |
+
pyppmd==1.1.0
|
157 |
+
pyramid==2.0.2
|
158 |
+
pyramid-mailer==0.15.1
|
159 |
+
PySocks==1.7.1
|
160 |
+
python-dateutil==2.9.0.post0
|
161 |
+
python-multipart==0.0.9
|
162 |
+
python3-openid==3.2.0
|
163 |
+
pytz==2023.4
|
164 |
+
PyYAML==6.0
|
165 |
+
pyzstd==0.15.9
|
166 |
+
rarfile==4.1
|
167 |
+
referencing==0.33.0
|
168 |
+
regex==2023.12.25
|
169 |
+
repoze.sendmail==4.4.1
|
170 |
+
requests==2.28.2
|
171 |
+
requests-oauthlib==1.4.0
|
172 |
+
retrying==1.3.4
|
173 |
+
rich==13.4.2
|
174 |
+
rpds-py==0.18.0
|
175 |
+
rsa==4.9
|
176 |
+
ruff==0.3.2
|
177 |
+
s3transfer==0.10.1
|
178 |
+
safetensors==0.4.2
|
179 |
+
scikit-image==0.22.0
|
180 |
+
scikit-learn==1.4.1.post1
|
181 |
+
scipy==1.10.1
|
182 |
+
semantic-version==2.10.0
|
183 |
+
sentencepiece==0.2.0
|
184 |
+
sentry-sdk==1.42.0
|
185 |
+
setproctitle==1.3.3
|
186 |
+
shellingham==1.5.4
|
187 |
+
six==1.16.0
|
188 |
+
smmap==5.0.1
|
189 |
+
sniffio==1.3.1
|
190 |
+
sortedcontainers==2.4.0
|
191 |
+
soupsieve==2.5
|
192 |
+
SQLAlchemy==2.0.28
|
193 |
+
sse-starlette==0.10.3
|
194 |
+
sseclient-py==1.8.0
|
195 |
+
starlette==0.36.3
|
196 |
+
strawberry-graphql==0.138.1
|
197 |
+
sympy==1.12
|
198 |
+
tabulate==0.9.0
|
199 |
+
taskgroup==0.0.0a4
|
200 |
+
tenacity==8.2.3
|
201 |
+
tensorboard==2.15.1
|
202 |
+
tensorboard-data-server==0.7.2
|
203 |
+
tensorboardX==2.6.2.2
|
204 |
+
termcolor==2.3.0
|
205 |
+
texttable==1.7.0
|
206 |
+
threadpoolctl==3.3.0
|
207 |
+
tifffile==2024.2.12
|
208 |
+
timm==0.6.12
|
209 |
+
tokenizers==0.15.2
|
210 |
+
tomli==2.0.1
|
211 |
+
tomlkit==0.12.0
|
212 |
+
toolz==0.12.1
|
213 |
+
torch==2.2.1
|
214 |
+
torchaudio==2.2.1
|
215 |
+
torchvision==0.17.1
|
216 |
+
tqdm==4.65.2
|
217 |
+
transaction==4.0
|
218 |
+
transformers
|
219 |
+
translationstring==1.4
|
220 |
+
triton==2.2.0
|
221 |
+
typer==0.9.0
|
222 |
+
typing_extensions==4.8.0
|
223 |
+
tzdata==2024.1
|
224 |
+
tzlocal==5.2
|
225 |
+
universal-analytics-python3==1.1.1
|
226 |
+
urllib3==1.26.18
|
227 |
+
uvicorn==0.28.0
|
228 |
+
velruse==1.1.1
|
229 |
+
venusian==3.1.0
|
230 |
+
voxel51-eta==0.12.6
|
231 |
+
wandb==0.14.0
|
232 |
+
wcwidth==0.2.13
|
233 |
+
WebOb==1.8.7
|
234 |
+
websockets==11.0.3
|
235 |
+
Werkzeug==3.0.1
|
236 |
+
wrapt==1.16.0
|
237 |
+
wsproto==1.2.0
|
238 |
+
WTForms==3.1.2
|
239 |
+
wtforms-recaptcha==0.3.2
|
240 |
+
xmltodict==0.13.0
|
241 |
+
yacs==0.1.8
|
242 |
+
yapf==0.40.2
|
243 |
+
zipp==3.18.1
|
244 |
+
zope.deprecation==5.0
|
245 |
+
zope.interface==6.2
|
246 |
+
zope.sqlalchemy==3.1
|
scripts/accel_config_deepspeed_zero2.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 8
|
5 |
+
offload_optimizer_device: none
|
6 |
+
offload_param_device: none
|
7 |
+
zero3_init_flag: false
|
8 |
+
zero_stage: 2
|
9 |
+
distributed_type: DEEPSPEED
|
10 |
+
downcast_bf16: 'no'
|
11 |
+
machine_rank: 0
|
12 |
+
main_training_function: main
|
13 |
+
mixed_precision: bf16
|
14 |
+
num_machines: 1
|
15 |
+
num_processes: 4
|
16 |
+
rdzv_backend: static
|
17 |
+
same_network: true
|
18 |
+
tpu_env: []
|
19 |
+
tpu_use_cluster: false
|
20 |
+
tpu_use_sudo: false
|
21 |
+
use_cpu: false
|
scripts/accel_config_deepspeed_zero3_offload.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 2
|
5 |
+
offload_optimizer_device: cpu
|
6 |
+
offload_param_device: cpu
|
7 |
+
zero3_init_flag: true
|
8 |
+
zero3_save_16bit_model: true
|
9 |
+
zero_stage: 3
|
10 |
+
distributed_type: DEEPSPEED
|
11 |
+
downcast_bf16: 'no'
|
12 |
+
machine_rank: 0
|
13 |
+
main_training_function: main
|
14 |
+
mixed_precision: bf16
|
15 |
+
num_machines: 1
|
16 |
+
num_processes: 8
|
17 |
+
rdzv_backend: static
|
18 |
+
same_network: true
|
19 |
+
tpu_env: []
|
20 |
+
tpu_use_cluster: false
|
21 |
+
tpu_use_sudo: false
|
22 |
+
use_cpu: false
|
scripts/accel_config_deepspeed_zero3_offload_multinode.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
deepspeed_multinode_launcher: standard
|
5 |
+
gradient_accumulation_steps: 2
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: true
|
9 |
+
zero3_save_16bit_model: true
|
10 |
+
zero_stage: 3
|
11 |
+
distributed_type: DEEPSPEED
|
12 |
+
downcast_bf16: 'no'
|
13 |
+
machine_rank: 0
|
14 |
+
main_process_ip: fdbd:dc61:18:8::20
|
15 |
+
main_process_port: 6876
|
16 |
+
main_training_function: main
|
17 |
+
mixed_precision: bf16
|
18 |
+
num_machines: 2
|
19 |
+
num_processes: 16
|
20 |
+
rdzv_backend: static
|
21 |
+
same_network: true
|
22 |
+
tpu_env: []
|
23 |
+
tpu_use_cluster: false
|
24 |
+
tpu_use_sudo: false
|
25 |
+
use_cpu: false
|
scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
deepspeed_multinode_launcher: standard
|
5 |
+
gradient_accumulation_steps: 2
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: true
|
9 |
+
zero3_save_16bit_model: true
|
10 |
+
zero_stage: 3
|
11 |
+
distributed_type: DEEPSPEED
|
12 |
+
downcast_bf16: 'no'
|
13 |
+
machine_rank: 0
|
14 |
+
main_process_ip: fdbd:dc61:18:8::20
|
15 |
+
main_process_port: 6876
|
16 |
+
main_training_function: main
|
17 |
+
mixed_precision: bf16
|
18 |
+
num_machines: 2
|
19 |
+
num_processes: 16
|
20 |
+
rdzv_backend: static
|
21 |
+
same_network: true
|
22 |
+
tpu_env: []
|
23 |
+
tpu_use_cluster: false
|
24 |
+
tpu_use_sudo: false
|
25 |
+
use_cpu: false
|
scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
deepspeed_multinode_launcher: standard
|
5 |
+
gradient_accumulation_steps: 2
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: true
|
9 |
+
zero3_save_16bit_model: true
|
10 |
+
zero_stage: 3
|
11 |
+
distributed_type: DEEPSPEED
|
12 |
+
downcast_bf16: 'no'
|
13 |
+
machine_rank: 1
|
14 |
+
main_process_ip: fdbd:dc61:18:8::20
|
15 |
+
main_process_port: 6876
|
16 |
+
main_training_function: main
|
17 |
+
mixed_precision: bf16
|
18 |
+
num_machines: 2
|
19 |
+
num_processes: 16
|
20 |
+
rdzv_backend: static
|
21 |
+
same_network: true
|
22 |
+
tpu_env: []
|
23 |
+
tpu_use_cluster: false
|
24 |
+
tpu_use_sudo: false
|
25 |
+
use_cpu: false
|
scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 16
|
5 |
+
gradient_clipping: 1.0
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: true
|
9 |
+
zero3_save_16bit_model: true
|
10 |
+
zero_stage: 3
|
11 |
+
distributed_type: DEEPSPEED
|
12 |
+
downcast_bf16: 'no'
|
13 |
+
machine_rank: 0
|
14 |
+
main_training_function: main
|
15 |
+
mixed_precision: bf16
|
16 |
+
num_machines: 1
|
17 |
+
num_processes: 1
|
18 |
+
rdzv_backend: static
|
19 |
+
same_network: true
|
20 |
+
tpu_env: []
|
21 |
+
tpu_use_cluster: false
|
22 |
+
tpu_use_sudo: false
|
23 |
+
use_cpu: false
|
scripts/accel_config_multigpu.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
gpu_ids: 2,3,4,5
|
6 |
+
machine_rank: 0
|
7 |
+
main_training_function: main
|
8 |
+
mixed_precision: bf16
|
9 |
+
num_machines: 1
|
10 |
+
num_processes: 4
|
11 |
+
rdzv_backend: static
|
12 |
+
same_network: true
|
13 |
+
tpu_env: []
|
14 |
+
tpu_use_cluster: false
|
15 |
+
tpu_use_sudo: false
|
16 |
+
use_cpu: false
|
scripts/accel_config_multinode.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
gpu_ids: all
|
6 |
+
machine_rank: 1
|
7 |
+
main_process_ip: 10.193.16.150
|
8 |
+
main_process_port: 6784
|
9 |
+
main_training_function: main
|
10 |
+
mixed_precision: bf16
|
11 |
+
num_machines: 2
|
12 |
+
num_processes: 16
|
13 |
+
rdzv_backend: static
|
14 |
+
same_network: true
|
15 |
+
tpu_env: []
|
16 |
+
tpu_use_cluster: false
|
17 |
+
tpu_use_sudo: false
|
18 |
+
use_cpu: false
|
scripts/accel_config_singlegpu.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
gpu_ids: '0'
|
6 |
+
machine_rank: 0
|
7 |
+
main_training_function: main
|
8 |
+
mixed_precision: bf16
|
9 |
+
num_machines: 1
|
10 |
+
num_processes: 1
|
11 |
+
rdzv_backend: static
|
12 |
+
same_network: true
|
13 |
+
tpu_env: []
|
14 |
+
tpu_use_cluster: false
|
15 |
+
tpu_use_sudo: false
|
16 |
+
use_cpu: false
|
scripts/demo.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_dir=${1:-"MODELS/pllava-7b"}
|
2 |
+
weight_dir=${2:-"${model_dir}"}
|
3 |
+
num_frames=16
|
4 |
+
lora_alpha=4
|
5 |
+
|
6 |
+
echo Running DEMO from model_dir: ${model_dir}
|
7 |
+
echo Running DEMO from weights_dir: ${weight_dir}
|
8 |
+
echo Running DEMO On Devices: ${CUDA_VISIBLE_DEVICES}
|
9 |
+
|
10 |
+
|
11 |
+
# # 34B Need to Use dispatch for this large.
|
12 |
+
# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} python -m tasks.eval.demo.pllava_demo \
|
13 |
+
# --pretrained_model_name_or_path ${model_dir} \
|
14 |
+
# --num_frames ${num_frames} \
|
15 |
+
# --use_lora \
|
16 |
+
# --weight_dir ${weight_dir} \
|
17 |
+
# --lora_alpha ${lora_alpha} \
|
18 |
+
# --conv_mode eval_vcg_llava_next \
|
19 |
+
# --use_multi_gpus \
|
20 |
+
|
21 |
+
|
22 |
+
# 7B and 13B, There are problem if Model was split around A100 40G... Probably because some unkown bug in accelerate dispatch
|
23 |
+
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1"} python -m tasks.eval.demo.pllava_demo \
|
24 |
+
--pretrained_model_name_or_path ${model_dir} \
|
25 |
+
--num_frames ${num_frames} \
|
26 |
+
--use_lora \
|
27 |
+
--weight_dir ${weight_dir} \
|
28 |
+
--lora_alpha ${lora_alpha} \
|
29 |
+
--conv_mode plain \
|
30 |
+
--use_multi_gpus
|
31 |
+
|
32 |
+
|
scripts/eval.sh
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# export CUDA_VISIBLE_DEVICES=2,6,7
|
2 |
+
export OPENAI_API_KEY=...
|
3 |
+
num_frames=16
|
4 |
+
test_ratio=1
|
5 |
+
|
6 |
+
# 13b, uses offload thus saving the full model
|
7 |
+
model_dir=MODELS/pllava-13b
|
8 |
+
weight_dir=MODELS/pllava-13b
|
9 |
+
SAVE_DIR=test_results/test_pllava_13b
|
10 |
+
lora_alpha=4
|
11 |
+
conv_mode=eval_vcgbench
|
12 |
+
python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
|
13 |
+
--pretrained_model_name_or_path ${model_dir} \
|
14 |
+
--save_path ${SAVE_DIR}/vcgbench \
|
15 |
+
--num_frames ${num_frames} \
|
16 |
+
--use_lora \
|
17 |
+
--lora_alpha ${lora_alpha} \
|
18 |
+
--weight_dir ${weight_dir} \
|
19 |
+
--pooling_shape 16-12-12 \
|
20 |
+
--test_ratio ${test_ratio} \
|
21 |
+
--conv_mode ${conv_mode}
|
22 |
+
|
23 |
+
conv_mode=eval_mvbench
|
24 |
+
python -m tasks.eval.mvbench.pllava_eval_mvbench \
|
25 |
+
--pretrained_model_name_or_path ${model_dir} \
|
26 |
+
--save_path ${SAVE_DIR}/mvbench \
|
27 |
+
--use_lora \
|
28 |
+
--lora_alpha ${lora_alpha} \
|
29 |
+
--num_frames ${num_frames} \
|
30 |
+
--weight_dir ${weight_dir} \
|
31 |
+
--pooling_shape 16-12-12 \
|
32 |
+
--conv_mode ${conv_mode}
|
33 |
+
|
34 |
+
onv_mode=eval_videoqabench
|
35 |
+
python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
|
36 |
+
--pretrained_model_name_or_path ${model_dir} \
|
37 |
+
--save_path ${SAVE_DIR}/videoqabench \
|
38 |
+
--num_frames ${num_frames} \
|
39 |
+
--use_lora \
|
40 |
+
--lora_alpha ${lora_alpha} \
|
41 |
+
--weight_dir ${weight_dir} \
|
42 |
+
--test_ratio ${test_ratio} \
|
43 |
+
--conv_mode ${conv_mode}
|
44 |
+
|
45 |
+
|
46 |
+
conv_mode=eval_recaption
|
47 |
+
python -m tasks.eval.recaption.pllava_recaption \
|
48 |
+
--pretrained_model_name_or_path ${model_dir} \
|
49 |
+
--save_path ${SAVE_DIR}/recaption \
|
50 |
+
--num_frames ${num_frames} \
|
51 |
+
--use_lora \
|
52 |
+
--weight_dir ${weight_dir} \
|
53 |
+
--lora_alpha ${lora_alpha} \
|
54 |
+
--test_ratio ${test_ratio} \
|
55 |
+
--conv_mode ${conv_mode}
|
56 |
+
|
57 |
+
|
58 |
+
model_dir=MODELS/pllava-7b
|
59 |
+
weight_dir=MODELS/pllava-7b
|
60 |
+
SAVE_DIR=test_results/test_pllava_7b
|
61 |
+
lora_alpha=4
|
62 |
+
|
63 |
+
conv_mode=eval_vcgbench
|
64 |
+
python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
|
65 |
+
--pretrained_model_name_or_path ${model_dir} \
|
66 |
+
--save_path ${SAVE_DIR}/vcgbench \
|
67 |
+
--num_frames ${num_frames} \
|
68 |
+
--use_lora \
|
69 |
+
--lora_alpha ${lora_alpha} \
|
70 |
+
--weight_dir ${weight_dir} \
|
71 |
+
--pooling_shape 16-12-12 \
|
72 |
+
--test_ratio ${test_ratio}
|
73 |
+
|
74 |
+
|
75 |
+
conv_mode=eval_mvbench
|
76 |
+
python -m tasks.eval.mvbench.pllava_eval_mvbench \
|
77 |
+
--pretrained_model_name_or_path ${model_dir} \
|
78 |
+
--save_path ${SAVE_DIR}/mvbench \
|
79 |
+
--use_lora \
|
80 |
+
--lora_alpha ${lora_alpha} \
|
81 |
+
--num_frames ${num_frames} \
|
82 |
+
--weight_dir ${weight_dir} \
|
83 |
+
--pooling_shape 16-12-12
|
84 |
+
|
85 |
+
|
86 |
+
onv_mode=eval_videoqabench
|
87 |
+
python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
|
88 |
+
--pretrained_model_name_or_path ${model_dir} \
|
89 |
+
--save_path ${SAVE_DIR}/videoqabench \
|
90 |
+
--num_frames ${num_frames} \
|
91 |
+
--use_lora \
|
92 |
+
--lora_alpha ${lora_alpha} \
|
93 |
+
--weight_dir ${weight_dir} \
|
94 |
+
--test_ratio ${test_ratio}
|
95 |
+
|
96 |
+
conv_mode=eval_recaption
|
97 |
+
python -m tasks.eval.recaption.pllava_recaption \
|
98 |
+
--pretrained_model_name_or_path ${model_dir} \
|
99 |
+
--save_path ${SAVE_DIR}/recaption \
|
100 |
+
--num_frames ${num_frames} \
|
101 |
+
--use_lora \
|
102 |
+
--lora_alpha ${lora_alpha} \
|
103 |
+
--weight_dir ${weight_dir} \
|
104 |
+
--test_ratio ${test_ratio}
|
scripts/eval_yiprompt.sh
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# export CUDA_VISIBLE_DEVICES=0,3,4,5,6,7
|
2 |
+
export OPENAI_API_KEY=...
|
3 |
+
num_frames=16
|
4 |
+
test_ratio=200
|
5 |
+
|
6 |
+
model_dir=MODELS/pllava-34b
|
7 |
+
weight_dir=MODELS/pllava-34b
|
8 |
+
SAVE_DIR=test_results/test_pllava_34b
|
9 |
+
lora_alpha=4
|
10 |
+
conv_mode=eval_vcg_llavanext
|
11 |
+
python -m tasks.eval.vcgbench.pllava_eval_vcgbench \
|
12 |
+
--pretrained_model_name_or_path ${model_dir} \
|
13 |
+
--save_path ${SAVE_DIR}/vcgbench \
|
14 |
+
--num_frames ${num_frames} \
|
15 |
+
--use_lora \
|
16 |
+
--lora_alpha ${lora_alpha} \
|
17 |
+
--weight_dir ${weight_dir} \
|
18 |
+
--pooling_shape 16-12-12 \
|
19 |
+
--test_ratio ${test_ratio} \
|
20 |
+
--conv_mode $conv_mode
|
21 |
+
|
22 |
+
conv_mode=eval_mvbench_llavanext
|
23 |
+
python -m tasks.eval.mvbench.pllava_eval_mvbench \
|
24 |
+
--pretrained_model_name_or_path ${model_dir} \
|
25 |
+
--save_path ${SAVE_DIR}/mvbench \
|
26 |
+
--use_lora \
|
27 |
+
--lora_alpha ${lora_alpha} \
|
28 |
+
--num_frames ${num_frames} \
|
29 |
+
--weight_dir ${weight_dir} \
|
30 |
+
--pooling_shape 16-12-12 \
|
31 |
+
--conv_mode $conv_mode
|
32 |
+
|
33 |
+
conv_mode=eval_videoqa_llavanext
|
34 |
+
python -m tasks.eval.videoqabench.pllava_eval_videoqabench \
|
35 |
+
--pretrained_model_name_or_path ${model_dir} \
|
36 |
+
--save_path ${SAVE_DIR}/videoqabench \
|
37 |
+
--num_frames ${num_frames} \
|
38 |
+
--use_lora \
|
39 |
+
--lora_alpha ${lora_alpha} \
|
40 |
+
--weight_dir ${weight_dir} \
|
41 |
+
--test_ratio ${test_ratio} \
|
42 |
+
--conv_mode ${conv_mode}
|
43 |
+
|
44 |
+
conv_mode=eval_recaption_llavanext
|
45 |
+
python -m tasks.eval.recaption.pllava_recaption \
|
46 |
+
--pretrained_model_name_or_path ${model_dir} \
|
47 |
+
--save_path ${SAVE_DIR}/recaption \
|
48 |
+
--num_frames ${num_frames} \
|
49 |
+
--use_lora \
|
50 |
+
--weight_dir ${weight_dir} \
|
51 |
+
--lora_alpha ${lora_alpha} \
|
52 |
+
--test_ratio ${test_ratio} \
|
53 |
+
--conv_mode $conv_mode
|
scripts/gallery.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export OPENAI_API_KEY=...
|
2 |
+
SAVE_DIR=${1:-"test_results"}
|
3 |
+
|
4 |
+
# # gallery view
|
5 |
+
# python -m tasks.eval.show_gallery \
|
6 |
+
# --root_dir ${SAVE_DIR}
|
7 |
+
|
8 |
+
# # compare view
|
9 |
+
python -m tasks.eval.demo.show_compare \
|
10 |
+
--root_dir ${SAVE_DIR}
|
11 |
+
|
scripts/train_pllava.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
2 |
+
which_python=$(which python)
|
3 |
+
echo "which python: ${which_python}"
|
4 |
+
export PYTHONPATH=${PYTHONPATH}:${which_python}
|
5 |
+
export PYTHONPATH=${PYTHONPATH}:.
|
6 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
7 |
+
|
8 |
+
OUTPUT_DIR=./pllava_video_outputs/test_train_7b_reconstruct
|
9 |
+
|
10 |
+
# # Naive Env
|
11 |
+
# rm -rf ${OUTPUT_DIR}
|
12 |
+
pooling_shape=(16,12,12)
|
13 |
+
accelerate launch --main_process_port 6876 --config_file scripts/accel_config_multigpu.yaml tasks/train/train_pllava_nframe_accel.py \
|
14 |
+
tasks/train/config_pllava_nframe.py \
|
15 |
+
output_dir ${OUTPUT_DIR} \
|
16 |
+
train_corpus videochat2_video \
|
17 |
+
save_steps 10000 \
|
18 |
+
num_workers 8 \
|
19 |
+
num_frames 16 \
|
20 |
+
model.pooling_method avg \
|
21 |
+
model.repo_id llava-hf/llava-v1.6-vicuna-7b-hf \
|
22 |
+
model.use_lora True \
|
23 |
+
model.pooling_shape $pooling_shape \
|
24 |
+
optimizer.lr 2e-5 \
|
25 |
+
scheduler.epochs 3 \
|
26 |
+
scheduler.warmup_ratio 0.2 \
|
27 |
+
scheduler.min_lr_multi 0.25 \
|
28 |
+
scheduler.is_videochat2_custom True \
|
29 |
+
preprocess.mm_alone False \
|
30 |
+
preprocess.random_shuffle False \
|
31 |
+
preprocess.add_second_msg False \
|
32 |
+
train_corpus videochat2_instruction_debug
|
33 |
+
|
34 |
+
|
scripts/train_pllava_13b.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
2 |
+
which_python=$(which python)
|
3 |
+
echo "which python: ${which_python}"
|
4 |
+
export PYTHONPATH=${PYTHONPATH}:${which_python}
|
5 |
+
export PYTHONPATH=${PYTHONPATH}:.
|
6 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
7 |
+
|
8 |
+
OUTPUT_DIR=./pllava_video_outputs/pllava_13b
|
9 |
+
|
10 |
+
|
11 |
+
pooling_shape=(16,12,12)
|
12 |
+
num_save_samples=80000
|
13 |
+
num_gpus=8
|
14 |
+
full_batch_size=128
|
15 |
+
batch_size=8
|
16 |
+
save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
|
17 |
+
ckpt_steps=$[$save_steps/10]
|
18 |
+
gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
|
19 |
+
echo $batch_size
|
20 |
+
echo $gradient_accumulation_steps
|
21 |
+
repo_id=llava-hf/llava-v1.6-vicuna-13b-hf
|
22 |
+
accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
|
23 |
+
tasks/train/config_pllava_nframe.py \
|
24 |
+
output_dir ${OUTPUT_DIR} \
|
25 |
+
train_corpus videochat2_instruction_debug \
|
26 |
+
save_steps $save_steps \
|
27 |
+
ckpt_steps $ckpt_steps \
|
28 |
+
num_workers 8 \
|
29 |
+
num_frames 16 \
|
30 |
+
gradient_accumulation_steps $gradient_accumulation_steps \
|
31 |
+
batch_size $batch_size \
|
32 |
+
deepspeed True \
|
33 |
+
model.pooling_method avg \
|
34 |
+
model.use_lora True \
|
35 |
+
model.use_pooling True \
|
36 |
+
model.repo_id $repo_id \
|
37 |
+
gradient_checkpointing True \
|
38 |
+
preprocess.center_pad False \
|
39 |
+
preprocess.clip_transform False \
|
40 |
+
optimizer.lr 2e-5 \
|
41 |
+
scheduler.epochs 3 \
|
42 |
+
scheduler.warmup_ratio 0.2 \
|
43 |
+
scheduler.min_lr_multi 0.25 \
|
44 |
+
model.pooling_shape $pooling_shape \
|
45 |
+
scheduler.is_videochat2_custom True \
|
46 |
+
preprocess.mm_alone False \
|
47 |
+
preprocess.random_shuffle False \
|
48 |
+
preprocess.add_second_msg False
|
49 |
+
|
50 |
+
|
scripts/train_pllava_34b.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
2 |
+
which_python=$(which python)
|
3 |
+
echo "which python: ${which_python}"
|
4 |
+
export PYTHONPATH=${PYTHONPATH}:${which_python}
|
5 |
+
export PYTHONPATH=${PYTHONPATH}:.
|
6 |
+
echo "PYTHONPATH: ${PYTHONPATH}"
|
7 |
+
|
8 |
+
machine_rank=${1:-"0"} # machine rank
|
9 |
+
|
10 |
+
OUTPUT_DIR=./pllava_video_outputs/pllava_34b_videchat2-video
|
11 |
+
|
12 |
+
pooling_shape=(16,12,12)
|
13 |
+
num_save_samples=80000
|
14 |
+
num_gpus=8
|
15 |
+
full_batch_size=128
|
16 |
+
batch_size=4
|
17 |
+
save_steps=$[$num_save_samples/($batch_size*$num_gpus)]
|
18 |
+
ckpt_steps=$[$save_steps/10]
|
19 |
+
gradient_accumulation_steps=$[$full_batch_size/($batch_size*$num_gpus)]
|
20 |
+
echo $batch_size
|
21 |
+
echo $gradient_accumulation_steps
|
22 |
+
repo_id=llava-hf/llava-v1.6-34b-hf
|
23 |
+
accelerate launch --main_process_port 6876 --config_file scripts/accel_config_deepspeed_zero3_offload.yaml tasks/train/train_pllava_nframe_accel.py \
|
24 |
+
tasks/train/config_pllava_nframe_yiprompt.py \
|
25 |
+
output_dir ${OUTPUT_DIR} \
|
26 |
+
train_corpus videochat2_instruction_debug \
|
27 |
+
save_steps $save_steps \
|
28 |
+
ckpt_steps $ckpt_steps \
|
29 |
+
num_workers 8 \
|
30 |
+
num_frames 16 \
|
31 |
+
deepspeed True \
|
32 |
+
gradient_accumulation_steps $gradient_accumulation_steps \
|
33 |
+
batch_size $batch_size \
|
34 |
+
model.pooling_method avg \
|
35 |
+
model.use_lora True \
|
36 |
+
model.use_pooling True \
|
37 |
+
model.repo_id $repo_id \
|
38 |
+
gradient_checkpointing True \
|
39 |
+
preprocess.center_pad False \
|
40 |
+
preprocess.clip_transform True \
|
41 |
+
optimizer.lr 2e-5 \
|
42 |
+
scheduler.epochs 3 \
|
43 |
+
scheduler.warmup_ratio 0.2 \
|
44 |
+
scheduler.min_lr_multi 0.25 \
|
45 |
+
model.pooling_shape $pooling_shape \
|
46 |
+
scheduler.is_videochat2_custom True \
|
47 |
+
preprocess.image_token_index 64002 \
|
48 |
+
preprocess.mm_alone False \
|
49 |
+
preprocess.random_shuffle False \
|
50 |
+
preprocess.add_second_msg False
|