File size: 4,257 Bytes
e3b994d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import sys
import time
import torch
import numpy as np
import requests
import onnxruntime as ort
from PIL import Image
from io import BytesIO
from transformers import Qwen2VLConfig, AutoTokenizer
# Command line arguments
model_path = sys.argv[1]
onnx_path = sys.argv[2]
# Initialize model config and tokenizer
model_config = Qwen2VLConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Model configuration
max_length = 1024
num_attention_heads = model_config.num_attention_heads
num_key_value_heads = model_config.num_key_value_heads
head_dim = model_config.hidden_size // num_attention_heads
num_layers = model_config.num_hidden_layers
# Setup ONNX sessions
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Model paths and sessions
models = ['A', 'B', 'C', 'D', 'E']
model_paths = {m: os.path.join(onnx_path, f'QwenVL_{m}_q4f16.onnx') for m in models}
sessions = {m: ort.InferenceSession(path, sess_options=session_options) for m, path in model_paths.items()}
# Input/output names
inputs = {
'A': sessions['A'].get_inputs()[0].name,
'B': [sessions['B'].get_inputs()[i].name for i in range(2)],
'C': sessions['C'].get_inputs()[0].name,
'D': [inp.name for inp in sessions['D'].get_inputs()],
'E': [inp.name for inp in sessions['E'].get_inputs()]
}
outputs = {
'A': sessions['A'].get_outputs()[0].name,
'B': sessions['B'].get_outputs()[0].name,
'C': sessions['C'].get_outputs()[0].name,
'D': [out.name for out in sessions['D'].get_outputs()],
'E': [out.name for out in sessions['E'].get_outputs()]
}
# Process image
image_url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg'
image = Image.open(BytesIO(requests.get(image_url).content)).resize((960, 960)).convert('RGB')
image_array = np.expand_dims(np.transpose(np.array(image).astype(np.float32), (2, 0, 1)), axis=0) / 255.
# Prepare inputs
prompt = "Describe this image."
formatted_prompt = f"\n<|im_start|>user\n<|vision_start|><|vision_end|>{prompt}<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer(formatted_prompt, return_tensors='pt')['input_ids']
input_lengths = np.array([input_ids.shape[1]], dtype=np.int64)
tokens = np.zeros(max_length, dtype=np.int32)
tokens[:input_ids.shape[1]] = input_ids[0, :]
position = np.zeros(1, dtype=np.int64)
# Initialize caches
key_cache = np.zeros((num_layers, num_key_value_heads, max_length, head_dim), dtype=np.float16)
value_cache = key_cache.copy()
# Process initial inputs
hidden_states = sessions['B'].run(
[outputs['B']],
{inputs['B'][0]: tokens, inputs['B'][1]: input_lengths}
)[0]
batch_size = np.array(0, dtype=np.int32)
batch_size, = sessions['C'].run([outputs['C']], {inputs['C']: batch_size})
# Process image features
image_features = sessions['A'].run([outputs['A']], {inputs['A']: image_array})[0]
total_ids = 100 # 10 * 10 from original factors
input_lengths += total_ids
remaining_tokens = np.array(max_length - input_lengths[0] - total_ids, dtype=np.int32)
tokens_to_stop = np.array(input_lengths[0] - 5, dtype=np.int32)
hidden_states, batch_size = sessions['D'].run(
outputs['D'],
dict(zip(inputs['D'],
[hidden_states, image_features, input_lengths, tokens_to_stop, remaining_tokens]))
)
# Generate tokens
start_time = time.time()
for i in range(12): # MAX_ITERATIONS
token, key_cache, value_cache = sessions['E'].run(
outputs['E'],
dict(zip(inputs['E'],
[hidden_states, np.array([-65504. if i==0 else 0.], dtype=np.float16),
key_cache, value_cache, position, input_lengths, batch_size,
np.array([1-total_ids+10 if i==0 else position[0]+1], dtype=np.float16)]))
)
if token in [151643, 151645]: # End tokens
break
if i < 1:
position += input_lengths[0]
input_lengths[0] = 1
else:
position += 1
tokens[0] = token
hidden_states = sessions['B'].run(
[outputs['B']],
{inputs['B'][0]: tokens, inputs['B'][1]: input_lengths}
)[0]
print(tokenizer.decode(token), end='', flush=True)
print(f"\nTotal time: {time.time() - start_time:.2f}s")
|