|
--- |
|
pipeline_tag: text-generation |
|
--- |
|
# 8b version of our ChemVLM |
|
## Citation |
|
arxiv.org/abs/2408.07246 |
|
|
|
``` |
|
@misc{li2024chemvlmexploringpowermultimodal, |
|
title={ChemVLM: Exploring the Power of Multimodal Large Language Models in Chemistry Area}, |
|
author={Junxian Li and Di Zhang and Xunzhi Wang and Zeying Hao and Jingdi Lei and Qian Tan and Cai Zhou and Wei Liu and Yaotian Yang and Xinrui Xiong and Weiyun Wang and Zhe Chen and Wenhai Wang and Wei Li and Shufei Zhang and Mao Su and Wanli Ouyang and Yuqiang Li and Dongzhan Zhou}, |
|
year={2024}, |
|
eprint={2408.07246}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.LG}, |
|
url={https://arxiv.org/abs/2408.07246}, |
|
} |
|
``` |
|
|
|
Codebase and datasets can be found at https://github.com/AI4Chem/ChemVlm. |
|
|
|
### Performances of our 8b model on several tasks |
|
|
|
| Datasets | MMChemOCR | CMMU | MMCR-bench | Reaction type | |
|
| :----- | :----- | :----- |:----- |:----- | |
|
|metrics| tanimoto similarity\[email protected] | score(\%, GPT-4o helps judge) | score(\%, GPT-4o helps judge) | Accuracy(\%) | |
|
|scores of ChemVLM-8b| 81.75/57.69 | 52.7(SOTA) | 33.6 | 16.79 | |
|
|
|
|
|
Quick start as below(```transformers>=4.37.0 is needed```) |
|
Update: You may also need |
|
``` |
|
pip install sentencepiece |
|
pip install einops |
|
pip install timm |
|
pip install accelerate>=0.26.0 |
|
``` |
|
|
|
Code: |
|
```Python |
|
from transformers import AutoTokenizer, AutoModelforCasualLM |
|
import torch |
|
import torchvision.transforms as T |
|
import transformers |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
def build_transform(input_size): |
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
return transform |
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
best_ratio_diff = float('inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
if ratio_diff < best_ratio_diff: |
|
best_ratio_diff = ratio_diff |
|
best_ratio = ratio |
|
elif ratio_diff == best_ratio_diff: |
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
best_ratio = ratio |
|
return best_ratio |
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
|
orig_width, orig_height = image.size |
|
aspect_ratio = orig_width / orig_height |
|
|
|
# calculate the existing image aspect ratio |
|
target_ratios = set( |
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_num and i * j >= min_num) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
# find the closest aspect ratio to the target |
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
# calculate the target width and height |
|
target_width = image_size * target_aspect_ratio[0] |
|
target_height = image_size * target_aspect_ratio[1] |
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
# resize the image |
|
resized_img = image.resize((target_width, target_height)) |
|
processed_images = [] |
|
for i in range(blocks): |
|
box = ( |
|
(i % (target_width // image_size)) * image_size, |
|
(i // (target_width // image_size)) * image_size, |
|
((i % (target_width // image_size)) + 1) * image_size, |
|
((i // (target_width // image_size)) + 1) * image_size |
|
) |
|
# split the image |
|
split_img = resized_img.crop(box) |
|
processed_images.append(split_img) |
|
assert len(processed_images) == blocks |
|
if use_thumbnail and len(processed_images) != 1: |
|
thumbnail_img = image.resize((image_size, image_size)) |
|
processed_images.append(thumbnail_img) |
|
return processed_images |
|
|
|
|
|
def load_image(image_file, input_size=448, max_num=6): |
|
image = Image.open(image_file).convert('RGB') |
|
transform = build_transform(input_size=input_size) |
|
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
|
pixel_values = [transform(image) for image in images] |
|
pixel_values = torch.stack(pixel_values) |
|
return pixel_values |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('AI4Chem/ChemVLM-8B', trust_remote_code=True) |
|
|
|
query = "Please describe the molecule in the image." |
|
image_path = "your image path" |
|
pixel_values = load_image(image_path, max_num=6).to(torch.bfloat16).cuda() |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"AI4Chem/ChemVLM-8B", |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True |
|
).to(device).eval().cuda() |
|
|
|
gen_kwargs = {"max_length": 1000, "do_sample": True, "temperature": 0.7, "top_p": 0.9} |
|
|
|
response = model.chat(tokenizer, pixel_values, query, gen_kwargs) |
|
|
|
|
|
``` |
|
|
|
|