test / app.py
th0mascat's picture
Update app.py
dc52995 verified
raw
history blame
3.28 kB
import cv2
import io
import numpy as np
from PIL import Image
import pytesseract
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder
from mltu.transformers import ImageResizer
from mltu.configs import BaseModelConfigs
from textblob import TextBlob
from happytransformer import HappyTextToText, TTSettings
from transformers import AutoTokenizer, T5ForConditionalGeneration
from pydantic import BaseModel
tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large")
configs = BaseModelConfigs.load("./configs.yaml")
#happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100)
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ImageToWordModel(OnnxInferenceModel):
def __init__(self, char_list, *args, **kwargs):
super().__init__(*args, **kwargs)
self.char_list = char_list
def predict(self, image: np.ndarray):
image = ImageResizer.resize_maintaining_aspect_ratio(
image, *self.input_shape[:2][::-1]
)
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
preds = self.model.run(None, {self.input_name: image_pred})[0]
text = ctc_decoder(preds, self.char_list)[0]
return text
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
extracted_text = ""
@app.post("/extract_handwritten_text/")
async def predict_text(image: UploadFile):
global extracted_text
# Read the uploaded image
img = await image.read()
nparr = np.frombuffer(img, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Make a prediction
extracted_text = model.predict(img)
#corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
return {"text": extracted_text}
@app.post("/extract_text/")
async def extract_text_from_image(image: UploadFile):
global extracted_text
# Check if the uploaded file is an image
if image.content_type.startswith("image/"):
# Read the image from the uploaded file
image_bytes = await image.read()
img = Image.open(io.BytesIO(image_bytes))
# Perform OCR on the image
extracted_text = pytesseract.image_to_string(img)
#corrected_text = happy_tt.generate_text(extracted_text, beam_settings)
return {"text": extracted_text}
else:
return {"error": "Invalid file format. Please upload an image."}
class ChatPrompt(BaseModel):
prompt: str
@app.post("/chat_prompt/")
async def chat_prompt(request: ChatPrompt):
global extracted_text
input_text = request.prompt + ": " + extracted_text
print(input_text)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = chatModel.generate(input_ids, max_length=256)
edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"edited_text": edited_text}