import torch import onnx import onnxruntime as rt from torchvision import transforms as T from PIL import Image from tokenizer_base import Tokenizer import pathlib import os import gradio as gr from huggingface_hub import Repository repo = Repository( local_dir="secret_models", repo_type="model", clone_from="docparser/captcha", token=True ) repo.git_pull() cwd = pathlib.Path(__file__).parent.resolve() model_file = os.path.join(cwd,"secret_models","captcha.onnx") img_size = (32,128) charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" tokenizer_base = Tokenizer(charset) def get_transform(img_size): transforms = [] transforms.extend([ T.Resize(img_size, T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(0.5, 0.5) ]) return T.Compose(transforms) def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() def initialize_model(model_file): transform = get_transform(img_size) # Onnx model loading onnx_model = onnx.load(model_file) onnx.checker.check_model(onnx_model) ort_session = rt.InferenceSession(model_file) return transform,ort_session def get_text(img_org): # img_org = Image.open(image_path) # Preprocess. Model expects a batch of images with shape: (B, C, H, W) x = transform(img_org.convert('RGB')).unsqueeze(0) # compute ONNX Runtime output prediction ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} logits = ort_session.run(None, ort_inputs)[0] probs = torch.tensor(logits).softmax(-1) preds, probs = tokenizer_base.decode(probs) preds = preds[0] print(preds) return preds transform,ort_session = initialize_model(model_file=model_file) gr.Interface( get_text, inputs=gr.Image(type="pil"), outputs=gr.outputs.Textbox(), title="Text Captcha Reader", examples=["8000.png","11JW29.png","2a8486.jpg","2nbcx.png", "000679.png","000HU.png","00Uga.png.jpg","00bAQwhAZU.jpg", "00h57kYf.jpg","0EoHdtVb.png","0JS21.png","0p98z.png","10010.png"] ).launch() # if __name__ == "__main__": # image_path = "8000.png" # preds,probs = get_text(image_path) # print(preds[0])