truong-xuan-linh commited on
Commit
03484ca
·
1 Parent(s): a244fb6
.github/workflows/main.yml CHANGED
@@ -17,4 +17,4 @@ jobs:
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
- run: git push --force https://truong-xuan-linh:[email protected]/spaces/truong-xuan-linh/vietnamese-ocr master
 
17
  - name: Push to hub
18
  env:
19
  HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push --force https://truong-xuan-linh:[email protected]/spaces/truong-xuan-linh/vietnamese-ocr main
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.ipynb*
2
+ log
3
+ __pycache__
4
+ *test*
README.md CHANGED
@@ -1 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # vietnamese_ocr
 
1
+ ---
2
+ title: Vietnamese Ocr
3
+ emoji: 🌍
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.26.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
  # vietnamese_ocr
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+
4
+ #Trick to not init function multitime
5
+ if "ocr_detector" not in st.session_state:
6
+ print("INIT MODEL")
7
+ from src.setup import Setup
8
+ Setup().ocr_model_downloader()
9
+
10
+ from src.OCR import OCRDetector
11
+ st.session_state.ocr_detector = OCRDetector()
12
+ print("DONE INIT MODEL")
13
+
14
+ st.set_page_config(page_title="Vietnamese OCR", layout="wide", page_icon = "./storage/linhai.jpeg")
15
+ hide_menu_style = """
16
+ <style>
17
+ footer {visibility: hidden;}
18
+ </style>
19
+ """
20
+ st.markdown(hide_menu_style, unsafe_allow_html= True)
21
+
22
+ st.markdown(
23
+ """
24
+ <style>
25
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
26
+ width: 400px;
27
+ }
28
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
29
+ margin-left: -400px;
30
+ }
31
+
32
+ """,
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ st.markdown("<h2 style='text-align: center; color: grey;'>Input: Image </h2>", unsafe_allow_html=True)
37
+ st.markdown("<h2 style='text-align: center; color: grey;'>Output: The Vietnamese or English text in the image (if any).</h2>", unsafe_allow_html=True)
38
+ left_col, right_col = st.columns(2)
39
+
40
+ #LEFT COLUMN
41
+ upload_image = left_col.file_uploader("Choose an image file", type=["jpg", "jpeg", "png", "webp", ])
42
+
43
+ if left_col.button("OCR Detect"):
44
+ image, texts, boxes = st.session_state.ocr_detector.text_detector(upload_image, is_local=True)
45
+ left_col.write("**RESULTS:** ")
46
+ left_col.write(texts)
47
+
48
+ #RIGHT COLUMN
49
+ visualize_image = st.session_state.ocr_detector.visualize_ocr(image, texts, boxes)
50
+ right_col.write("**ORIGIN IMAGE:** ")
51
+ right_col.image(image)
52
+ right_col.write("**OCR IMAGE:** ")
53
+ right_col.image(visualize_image)
config/config.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ ocr_model: https://drive.google.com/uc?id=1-Cdr1MAztczfMxpkekn0wIiZsY8NAnJS
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #commom
2
+ opencv-python>=4.1.1
3
+ numpy<=1.20.0
4
+ torch>=1.8.0
5
+ torchvision
6
+ unidecode
7
+ Pillow==9.4.0
8
+ PyYAML>=5.3.1
9
+ gdown==4.4.0
10
+ paddlepaddle>=2.3.1
11
+ paddleocr>=2.5.0.3
12
+ vietocr>=0.3.8
13
+ streamlit==1.26.0
src/OCR.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddleocr import PaddleOCR
2
+ from vietocr.tool.config import Cfg
3
+ from vietocr.tool.predictor import Predictor
4
+
5
+ import cv2
6
+ import requests
7
+ import unidecode
8
+ import numpy as np
9
+ from PIL import Image, ImageFont, ImageDraw
10
+
11
+ class OCRDetector:
12
+ def __init__(self) -> None:
13
+ self.paddle_ocr = PaddleOCR(lang='en', use_angle_cls=False)
14
+ # config['weights'] = './weights/transformerocr.pth'
15
+ self.config = Cfg.load_config_from_name('vgg_transformer')
16
+ self.config['weights'] = "./storage/ocr_model.pth"
17
+ self.config['cnn']['pretrained']=False
18
+ self.config['device'] = "cpu"
19
+ self.config['predictor']['beamsearch']=False
20
+ self.viet_ocr = Predictor(self.config)
21
+
22
+ def find_box(self, image):
23
+ '''Xác định box dựa vào mô hình paddle_ocr'''
24
+ result = self.paddle_ocr.ocr(image, cls = False)
25
+ result = result[0]
26
+ # Extracting detected components
27
+ boxes = [res[0] for res in result]
28
+ texts = [{"text": res[1][0], "score": res[1][1]} for res in result]
29
+
30
+ # scores = [res[1][1] for res in result]
31
+ return boxes, texts
32
+
33
+ def vietnamese_text(self, boxes, image):
34
+ '''Xác định text dựa vào mô hình viet_ocr'''
35
+ texts = []
36
+ for box in boxes:
37
+ A = box[0]
38
+ B = box[1]
39
+ C = box[2]
40
+ D = box[3]
41
+ y1 = min(A[1], B[1])
42
+ y1 = int(max(0, y1 - max(0, 10 - abs(A[1] - B[1]))))
43
+ y2 = max(C[1], D[1])
44
+ y2 = int(y2 + max(0, 10 - abs(A[1] - B[1])))
45
+ x1 = int(max(0, min(A[0], D[0]) ))
46
+ x2 = int(max(B[0], C[0]) )
47
+ cut_image = image[y1:y2, x1:x2]
48
+ cut_image = Image.fromarray(np.uint8(cut_image))
49
+ text, score = self.viet_ocr.predict(cut_image, return_prob=True)
50
+ texts.append({"text": text,
51
+ "score": score})
52
+ return texts
53
+
54
+ #Merge
55
+ def text_detector(self, image_path, is_local=False):
56
+ if is_local:
57
+ image = Image.open(image_path).convert("RGB")
58
+ else:
59
+ image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
60
+ image = np.array(image)
61
+ boxes, paddle_texts = self.find_box(image)
62
+ if not boxes:
63
+ return image, None, None
64
+ viet_texts = self.vietnamese_text(boxes, image)
65
+ results_texts = []
66
+ for i, viet_txt in enumerate(viet_texts):
67
+ if viet_txt["text"] != unidecode.unidecode(viet_txt["text"]):
68
+ results_texts.append(viet_txt)
69
+ else:
70
+ results_texts.append(paddle_texts[i])
71
+ if results_texts != []:
72
+ return image, results_texts, boxes
73
+ else:
74
+ return image, None, None
75
+
76
+
77
+ def visualize_ocr(self, image, texts, boxes):
78
+ if not texts:
79
+ return image
80
+
81
+ img = image.copy()
82
+ for box, text in zip(boxes, texts):
83
+ (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
84
+
85
+ h = y3 - y1
86
+ scl = max(h//1000,1)
87
+ font = ImageFont.truetype("./storage/Roboto-Black.ttf", 15*scl)
88
+ img = cv2.rectangle(img, (int(x1), int(y1)), (int(x3), int(y3)), (0, 255, 0), 1)
89
+
90
+ img_pil = Image.fromarray(img)
91
+ draw = ImageDraw.Draw(img_pil)
92
+ draw.text((int(x1), int(y1-h-3)), text["text"], font = font, fill = (51, 51, 255))
93
+ img = np.array(img_pil)
94
+ # img = cv2.putText(img, text["text"], (int(x1), int(y1)-3), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,0,0), 1)
95
+ return img
src/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+
4
+ class Setup():
5
+ def __init__(self) -> None:
6
+ self.config = yaml.load(open("./config/config.yml"), yaml.loader.SafeLoader)
7
+ self.ocr_model = self.config["ocr_model"]
8
+
9
+ def ocr_model_downloader(self) -> None:
10
+ os.system("python -m pip install gdown --upgrade")
11
+ import gdown
12
+ if "ocr_model.pth" not in os.listdir(("./storage")):
13
+ gdown.download(self.ocr_model, "./storage/ocr_model.pth", quiet=False)
storage/.keep ADDED
File without changes
storage/Roboto-Black.ttf ADDED
Binary file (168 kB). View file
 
storage/linhai.jpeg ADDED