Spaces:
Build error
Build error
File size: 6,833 Bytes
b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 8962d34 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 ffb7b59 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 b1f218b 24a2388 02634b8 b1f218b 02634b8 24a2388 02634b8 b1f218b 02634b8 24a2388 b1f218b 02634b8 24a2388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from ram import get_transform, inference_ram, inference_tag2text
from ram.models import ram, tag2text_caption
ram_checkpoint = "./ram_swin_large_14m.pth"
tag2text_checkpoint = "./tag2text_swin_14m.pth"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
print(f"Start processing, image size {raw_image.size}")
image = transform(raw_image).unsqueeze(0).to(device)
if tagging_model_type == "RAM":
res = inference_ram(image, tagging_model)
tags = res[0].strip(' ').replace(' ', ' ')
tags_chinese = res[1].strip(' ').replace(' ', ' ')
print("Tags: ", tags)
print("标签: ", tags_chinese)
return tags, tags_chinese
else:
res = inference_tag2text(image, tagging_model, specified_tags)
tags = res[0].strip(' ').replace(' ', ' ')
caption = res[2]
print(f"Tags: {tags}")
print(f"Caption: {caption}")
return tags, caption
def inference_with_ram(img):
return inference(img, None, "RAM", ram_model, transform)
def inference_with_t2t(img, input_tags):
return inference(img, input_tags, "Tag2Text", tag2text_model, transform)
if __name__ == "__main__":
import gradio as gr
# get transform and load models
transform = get_transform(image_size=image_size)
ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
tag2text_model = tag2text_caption(
pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
# build GUI
def build_gui():
description = """
<center><strong><font size='10'>Recognize Anything Model</font></strong></center>
<br>
<p>Welcome to the <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything Model</a> / <a href='https://tag2text.github.io/Tag2Text' target='_blank'>Tag2Text Model</a> demo!</p>
<li>
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
</li>
<li>
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
</li>
<p><b>More over:</b> Combine with <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-SAM</a>, you can get <b>boxes and masks</b>! Please run <a href='https://github.com/xinyu1205/recognize-anything/blob/main/gui_demo.ipynb' target='_blank'>this notebook</a> to try out!</p>
<p>Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!</p>
""" # noqa
article = """
<p style='text-align: center'>
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
<a href='https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
</p>
""" # noqa
with gr.Blocks(title="Recognize Anything Model") as demo:
###############
# components
###############
gr.HTML(description)
with gr.Tab(label="Recognize Anything Model"):
with gr.Row():
with gr.Column():
ram_in_img = gr.Image(type="pil")
with gr.Row():
ram_btn_run = gr.Button(value="Run")
try:
ram_btn_clear = gr.ClearButton()
except AttributeError: # old gradio does not have ClearButton, not big problem
ram_btn_clear = None
with gr.Column():
ram_out_tag = gr.Textbox(label="Tags")
ram_out_biaoqian = gr.Textbox(label="标签")
gr.Examples(
examples=[
["images/demo1.jpg"],
["images/demo2.jpg"],
["images/demo4.jpg"],
],
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian],
cache_examples=True
)
with gr.Tab(label="Tag2Text Model"):
with gr.Row():
with gr.Column():
t2t_in_img = gr.Image(type="pil")
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
with gr.Row():
t2t_btn_run = gr.Button(value="Run")
try:
t2t_btn_clear = gr.ClearButton()
except AttributeError: # old gradio does not have ClearButton, not big problem
t2t_btn_clear = None
with gr.Column():
t2t_out_tag = gr.Textbox(label="Tags")
t2t_out_cap = gr.Textbox(label="Caption")
gr.Examples(
examples=[
["images/demo4.jpg", ""],
["images/demo4.jpg", "power line"],
["images/demo4.jpg", "track, train"],
],
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap],
cache_examples=True
)
gr.HTML(article)
###############
# events
###############
# run inference
ram_btn_run.click(
fn=inference_with_ram,
inputs=[ram_in_img],
outputs=[ram_out_tag, ram_out_biaoqian]
)
t2t_btn_run.click(
fn=inference_with_t2t,
inputs=[t2t_in_img, t2t_in_tag],
outputs=[t2t_out_tag, t2t_out_cap]
)
# clear
if ram_btn_clear is not None:
ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
if t2t_btn_clear is not None:
t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])
return demo
build_gui().launch(enable_queue=True)
|