Update README.md
Browse files
README.md
CHANGED
@@ -77,12 +77,12 @@ torch.set_default_device('cpu') # or 'cuda'
|
|
77 |
|
78 |
# create model
|
79 |
model = AutoModelForCausalLM.from_pretrained(
|
80 |
-
'opencsg/opencsg-bunny-
|
81 |
torch_dtype=torch.float16,
|
82 |
device_map='auto',
|
83 |
trust_remote_code=True)
|
84 |
tokenizer = AutoTokenizer.from_pretrained(
|
85 |
-
'
|
86 |
trust_remote_code=True)
|
87 |
|
88 |
# text prompt
|
@@ -169,23 +169,53 @@ opensg-bnny-v0.1-3B是一个基于bunny-v1_0-3B的模型,该模型已在opensg
|
|
169 |
|
170 |
# 模型使用
|
171 |
|
|
|
|
|
172 |
```
|
|
|
|
|
173 |
import torch
|
|
|
174 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
text = tokenizer.batch_decode(outputs)[0]
|
188 |
-
print(text)
|
189 |
```
|
190 |
# 训练
|
191 |
|
|
|
77 |
|
78 |
# create model
|
79 |
model = AutoModelForCausalLM.from_pretrained(
|
80 |
+
'opencsg/opencsg-bunny-v0.1-3B',
|
81 |
torch_dtype=torch.float16,
|
82 |
device_map='auto',
|
83 |
trust_remote_code=True)
|
84 |
tokenizer = AutoTokenizer.from_pretrained(
|
85 |
+
'opencsg/opencsg-bunny-v0.1-3B',
|
86 |
trust_remote_code=True)
|
87 |
|
88 |
# text prompt
|
|
|
169 |
|
170 |
# 模型使用
|
171 |
|
172 |
+
```shell
|
173 |
+
pip install torch transformers accelerate pillow
|
174 |
```
|
175 |
+
|
176 |
+
```python
|
177 |
import torch
|
178 |
+
import transformers
|
179 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
180 |
+
from PIL import Image
|
181 |
+
import warnings
|
182 |
+
|
183 |
+
# disable some warnings
|
184 |
+
transformers.logging.set_verbosity_error()
|
185 |
+
transformers.logging.disable_progress_bar()
|
186 |
+
warnings.filterwarnings('ignore')
|
187 |
+
|
188 |
+
# set device
|
189 |
+
torch.set_default_device('cpu') # or 'cuda'
|
190 |
+
|
191 |
+
# create model
|
192 |
+
model = AutoModelForCausalLM.from_pretrained(
|
193 |
+
'opencsg/opencsg-bunny-v0.1-3B',
|
194 |
+
torch_dtype=torch.float16,
|
195 |
+
device_map='auto',
|
196 |
+
trust_remote_code=True)
|
197 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
198 |
+
'opencsg/opencsg-bunny-v0.1-3B',
|
199 |
+
trust_remote_code=True)
|
200 |
|
201 |
+
# text prompt
|
202 |
+
prompt = 'Why is the image funny?'
|
203 |
+
text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
|
204 |
+
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
|
205 |
+
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
|
206 |
|
207 |
+
# image, sample images can be found in images folder
|
208 |
+
image = Image.open('example_2.png')
|
209 |
+
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
|
210 |
|
211 |
+
# generate
|
212 |
+
output_ids = model.generate(
|
213 |
+
input_ids,
|
214 |
+
images=image_tensor,
|
215 |
+
max_new_tokens=100,
|
216 |
+
use_cache=True)[0]
|
217 |
|
218 |
+
print(tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip())
|
|
|
|
|
219 |
```
|
220 |
# 训练
|
221 |
|