Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
# import required libraries | |
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration | |
import gradio as gr | |
# In[2]: | |
# pipeline function with default values | |
def query(image, user_question): | |
""" | |
image: single image or batch of images; | |
question: user prompt question; | |
""" | |
# select model from hugging face | |
model_name = "google/deplot" | |
# set preprocessor for current model | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
# load pre-trained model | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
# process the inputs for prediction | |
inputs = processor(images=image, text=user_question, return_tensors="pt") | |
# save the results | |
predictions = model.generate(**inputs, max_new_tokens=512) | |
# save output | |
result = processor.decode(predictions[0], skip_special_tokens=True) | |
# process the results for output table | |
outs = [x.strip() for x in result.split("<0x0A>")] | |
# create an empty list | |
nested = list() | |
# loop for splitting the data | |
for data in outs: | |
if "|" in data: | |
nested.append([x.strip() for x in data.split("|")]) | |
else: | |
nested.append(data) | |
# return the converted output | |
return nested | |
# In[ ]: | |
# Interface framework to customize the io page | |
ui = gr.Interface(title="Chart Q/A", | |
fn=query, | |
inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")], | |
outputs="list", | |
examples=[["./samples/sample1.png", "Generate underlying data table of the figure"], | |
["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]], | |
# ["./samples/sample3.webp", "What are the 2020 net sales?"]], | |
cache_examples=True, | |
allow_flagging='never') | |
ui.queue(api_open=True) | |
ui.launch(inline=False, share=False, debug=True) | |