import numpy as np import torch from torch import nn import streamlit as st import os from PIL import Image from io import BytesIO from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids def run_prediction(sample, model, processor, mode): skip_tokens = get_table_token_ids(processor) no_repeat_ngram_size = 10 if mode == "OCR": prompt = "" else: prompt = "" print("prompt:", prompt) print("no_repeat_ngram_size:", no_repeat_ngram_size) pixel_values = processor(np.array( sample, np.float32, ), return_tensors="pt").pixel_values with torch.no_grad(): outputs = model.generate( pixel_values.to(device), decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device), logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)], do_sample=True, top_p=0.92, #.92, top_k=5, no_repeat_ngram_size=0, num_beams=3, output_attentions=False, output_hidden_states=False, ) # process output prediction = processor.batch_decode(outputs)[0] print(prediction) return prediction logo = Image.open("./rsz_unstructured_logo.png") st.image(logo) st.markdown(''' ### Chipper Chipper is an OCR-free Document Understanding Transformer. It was pre-trained with over 1M documents from public sources and fine-tuned on a large range of documents. At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines. Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community. ''') image_upload = None photo = None with st.sidebar: # file upload uploaded_file = st.file_uploader("Upload a document") if uploaded_file is not None: # To read file as bytes: image_bytes_data = uploaded_file.getvalue() image_upload = Image.open(BytesIO(image_bytes_data)) mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1) if image_upload: image = image_upload else: image = Image.open(f"./document.png") st.image(image, caption='Your target document') with st.spinner(f'Processing the document ...'): pre_trained_model = "unstructuredio/chipper-fast-fine-tuning" processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN']) device = "cuda" if torch.cuda.is_available() else "cpu" if 'model' in st.session_state: model = st.session_state['model'] else: model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN']) from huggingface_hub import hf_hub_download lm_head_file = hf_hub_download( repo_id=pre_trained_model, filename="lm_head.pth", token=os.environ['HF_TOKEN'] ) rank = 128 model.decoder.lm_head = nn.Sequential( nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False), nn.Linear(rank, rank, bias=False), nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True), ) model.decoder.lm_head.load_state_dict(torch.load(lm_head_file)) model.eval() model.to(device) st.session_state['model'] = model st.info(f'Parsing document') parsed_info = run_prediction(image.convert("RGB"), model, processor, mode) st.text(f'\nDocument:') st.text_area('Output text', value=parsed_info, height=800)