File size: 3,319 Bytes
7ff20b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a03eeb1
 
 
 
 
 
 
7ff20b3
 
3dfa0b7
7ff20b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dfa0b7
7ff20b3
3dfa0b7
7ff20b3
 
 
 
 
 
 
 
 
 
 
 
 
3dfa0b7
7ff20b3
3dfa0b7
7ff20b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import cv2
import io
import numpy as np
from PIL import Image

import pytesseract

from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware

from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder
from mltu.transformers import ImageResizer
from mltu.configs import BaseModelConfigs

from textblob import TextBlob
from happytransformer import HappyTextToText, TTSettings


from transformers import AutoTokenizer, T5ForConditionalGeneration
from pydantic import BaseModel

tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large", cache_dir="./cache")
chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large", cache_dir="./cache")

configs = BaseModelConfigs.load("./configs.yaml")

#happy_tt = HappyTextToText("T5", "vennify/t5-base-grammar-correction")

beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100)

app = FastAPI()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class ImageToWordModel(OnnxInferenceModel):
    def __init__(self, char_list, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.char_list = char_list

    def predict(self, image: np.ndarray):
        image = ImageResizer.resize_maintaining_aspect_ratio(
            image, *self.input_shape[:2][::-1]
        )

        image_pred = np.expand_dims(image, axis=0).astype(np.float32)

        preds = self.model.run(None, {self.input_name: image_pred})[0]

        text = ctc_decoder(preds, self.char_list)[0]

        return text


model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
extracted_text = ""

@app.post("/extract_handwritten_text/")
async def predict_text(image: UploadFile):
    global extracted_text
    # Read the uploaded image
    img = await image.read()
    nparr = np.frombuffer(img, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

    # Make a prediction
    extracted_text = model.predict(img)
    #corrected_text = happy_tt.generate_text(extracted_text, beam_settings)

    return {"text": extracted_text}


@app.post("/extract_text/")
async def extract_text_from_image(image: UploadFile):
    global extracted_text
    # Check if the uploaded file is an image
    if image.content_type.startswith("image/"):
        # Read the image from the uploaded file
        image_bytes = await image.read()
        img = Image.open(io.BytesIO(image_bytes))

        # Perform OCR on the image
        extracted_text = pytesseract.image_to_string(img)
        #corrected_text = happy_tt.generate_text(extracted_text, beam_settings)

        return {"text": extracted_text}
    else:
        return {"error": "Invalid file format. Please upload an image."}

class ChatPrompt(BaseModel):
    prompt: str

@app.post("/chat_prompt/")
async def chat_prompt(request: ChatPrompt):
    global extracted_text
    input_text = request.prompt + ": " + extracted_text
    print(input_text)
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    outputs = chatModel.generate(input_ids, max_length=256)
    edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return {"edited_text": edited_text}