Spaces:
Running
on
Zero
Running
on
Zero
from transformers import ( | |
NougatProcessor, | |
VisionEncoderDecoderModel, | |
TextIteratorStreamer, | |
) | |
import gradio as gr | |
import torch | |
from pathlib import Path | |
from pdf2image import convert_from_path | |
import spaces | |
from threading import Thread | |
from gradio_pdf import PDF | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
models_supported = { | |
"arabic-small-nougat": [ | |
NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat"), | |
VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat"), | |
], | |
"arabic-base-nougat": [ | |
NougatProcessor.from_pretrained("MohamedRashad/arabic-base-nougat"), | |
VisionEncoderDecoderModel.from_pretrained( | |
"MohamedRashad/arabic-base-nougat", | |
torch_dtype=torch.bfloat16, | |
attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"}, | |
), | |
], | |
"arabic-large-nougat": [ | |
NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat"), | |
VisionEncoderDecoderModel.from_pretrained( | |
"MohamedRashad/arabic-large-nougat", | |
torch_dtype=torch.bfloat16, | |
attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"}, | |
), | |
], | |
} | |
def extract_text_from_image(image, model_name): | |
print(f"Extracting text from image using model: {model_name}") | |
processor, model = models_supported[model_name] | |
context_length = model.decoder.config.max_position_embeddings | |
torch_dtype = model.dtype | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
pixel_values = ( | |
processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device) | |
) | |
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) | |
# Start generation in a separate thread | |
generation_kwargs = { | |
"pixel_values": pixel_values, | |
"min_length": 1, | |
"max_new_tokens": context_length, | |
"repetition_penalty": 1.5, | |
"streamer": streamer, | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield tokens as they become available | |
output = "" | |
for token in streamer: | |
output += token | |
yield output | |
thread.join() | |
def extract_text_from_pdf(pdf_path, model_name): | |
processor, model = models_supported[model_name] | |
context_length = model.decoder.config.max_position_embeddings | |
torch_dtype = model.dtype | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) | |
print(f"Extracting text from PDF: {pdf_path}") | |
images = convert_from_path(pdf_path) | |
pdf_output = "" | |
for image in images: | |
pixel_values = ( | |
processor(image, return_tensors="pt") | |
.pixel_values.to(torch_dtype) | |
.to(device) | |
) | |
# Start generation in a separate thread | |
generation_kwargs = { | |
"pixel_values": pixel_values, | |
"min_length": 1, | |
"max_new_tokens": context_length, | |
"repetition_penalty": 1.5, | |
"streamer": streamer, | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield tokens as they become available | |
for token in streamer: | |
pdf_output += token | |
yield pdf_output | |
thread.join() | |
pdf_output += "\n\n" | |
yield pdf_output | |
model_description = """This is the official demo for the Arabic Nougat models. It is an end-to-end Markdown Extraction model that extracts text from images or PDFs and write them in Markdown. | |
There are three models available: | |
- [arabic-small-nougat](https://huggingface.co/MohamedRashad/arabic-small-nougat): A small model that is faster but less accurate (a finetune from [facebook/nougat-small](https://huggingface.co/facebook/nougat-small)). | |
- [arabic-base-nougat](https://huggingface.co/MohamedRashad/arabic-base-nougat): A base model that is more accurate but slower (a finetune from [facebook/nougat-base](https://huggingface.co/facebook/nougat-base)). | |
- [arabic-large-nougat](https://huggingface.co/MohamedRashad/arabic-large-nougat): The largest of the three (Made from scratch using [riotu-lab/Aranizer-PBE-86k](https://huggingface.co/riotu-lab/Aranizer-PBE-86k) tokenizer and a larger transformer decoder model). | |
**Disclaimer**: These models hallucinate text and are not perfect. They are trained on a mix of synthetic and real data and may not work well on all types of images. | |
""" | |
example_images = list(Path(__file__).parent.glob("*.jpeg")) | |
example_pdfs = [str(p) for p in Path(__file__).parent.glob("*.pdf")] | |
with gr.Blocks(title="Arabic Nougat") as demo: | |
gr.HTML( | |
"<h1 style='text-align: center'>Arabic End-to-End Structured OCR for textbooks</h1>" | |
) | |
gr.Markdown(model_description) | |
with gr.Tab("Extract Text from Image"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
model_dropdown = gr.Dropdown( | |
label="Model", choices=list(models_supported.keys()), value=None | |
) | |
image_submit_button = gr.Button(value="Submit", variant="primary") | |
output = gr.Markdown(label="Output Markdown", rtl=True) | |
image_submit_button.click( | |
extract_text_from_image, | |
inputs=[input_image, model_dropdown], | |
outputs=output, | |
) | |
gr.Examples( | |
example_images, | |
[input_image], | |
output, | |
extract_text_from_image, | |
cache_examples=False, | |
) | |
with gr.Tab("Extract Text from PDF"): | |
with gr.Row(): | |
with gr.Column(): | |
input_pdf = PDF(label="Input PDF") | |
model_dropdown = gr.Dropdown( | |
label="Model", choices=list(models_supported.keys()), value=None | |
) | |
pdf_submit_button = gr.Button(value="Submit", variant="primary") | |
output = gr.Markdown(label="Output Markdown", rtl=True) | |
pdf_submit_button.click( | |
extract_text_from_pdf, inputs=[input_pdf, model_dropdown], outputs=output | |
) | |
gr.Examples( | |
example_pdfs, | |
[input_pdf], | |
output, | |
extract_text_from_pdf, | |
cache_examples=False, | |
) | |
demo.queue().launch(share=False) | |