diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..ed9abcb57f8e7935653adccce61bf1958981b803 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.pdiparams filter=lfs diff=lfs merge=lfs -text +*.pdmodel filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 21e6231b58d2001b5df9098f9a3a4a8c65b0ea1a..eb65153a0a20d9f35e2cd751d28d34ef058932e3 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ sdk: gradio sdk_version: 3.17.0 app_file: app.py pinned: false +duplicated_from: mertcobanov/deprem-ocr-2 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py index f98379b46abfede34f21d1a387fe754fd76ad8ea..6453371426464f37dae29da7bb8491c0da0f46af 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,4 @@ import gradio as gr -from easyocr import Reader -from PIL import Image -import io import json import csv import openai @@ -9,18 +6,64 @@ import ast import os from deta import Deta +import numpy as np +from ocr import utility +from ocr.detector import TextDetector +from ocr.recognizer import TextRecognizer + +# Global Detector and Recognizer +args = utility.parse_args() +text_recognizer = TextRecognizer(args) +text_detector = TextDetector(args) + +openai.api_key = os.getenv("API_KEY") + +args = utility.parse_args() +text_recognizer = TextRecognizer(args) +text_detector = TextDetector(args) + + +def apply_ocr(img): + # Detect text regions + dt_boxes, _ = text_detector(img) + + boxes = [] + for box in dt_boxes: + p1, p2, p3, p4 = box + x1 = min(p1[0], p2[0], p3[0], p4[0]) + y1 = min(p1[1], p2[1], p3[1], p4[1]) + x2 = max(p1[0], p2[0], p3[0], p4[0]) + y2 = max(p1[1], p2[1], p3[1], p4[1]) + boxes.append([x1, y1, x2, y2]) + + # Recognize text + img_list = [] + for i in range(len(boxes)): + x1, y1, x2, y2 = map(int, boxes[i]) + img_list.append(img.copy()[y1:y2, x1:x2]) + img_list.reverse() + + rec_res, _ = text_recognizer(img_list) + + # Postprocess + total_text = "" + for i in range(len(rec_res)): + total_text += rec_res[i][0] + " " + + total_text = total_text.strip() + return total_text -openai.api_key = os.getenv('API_KEY') -reader = Reader(["tr"]) def get_parsed_address(input_img): address_full_text = get_text(input_img) return openai_response(address_full_text) - + def get_text(input_img): - result = reader.readtext(input_img, detail=0) + input_img = np.array(input_img) + result = apply_ocr(input_img) + print(result) return " ".join(result) @@ -38,9 +81,10 @@ def get_json(mahalle, il, sokak, apartman): dump = json.dumps(adres, indent=4, ensure_ascii=False) return dump + def write_db(data_dict): # 2) initialize with a project key - deta_key = os.getenv('DETA_KEY') + deta_key = os.getenv("DETA_KEY") deta = Deta(deta_key) # 3) create and use as many DBs as you want! @@ -53,16 +97,17 @@ def text_dict(input): write_db(eval_result) return ( - str(eval_result['city']), - str(eval_result['distinct']), - str(eval_result['neighbourhood']), - str(eval_result['street']), - str(eval_result['address']), - str(eval_result['tel']), - str(eval_result['name_surname']), - str(eval_result['no']), + str(eval_result["city"]), + str(eval_result["distinct"]), + str(eval_result["neighbourhood"]), + str(eval_result["street"]), + str(eval_result["address"]), + str(eval_result["tel"]), + str(eval_result["name_surname"]), + str(eval_result["no"]), ) - + + def openai_response(ocr_input): prompt = f"""Tabular Data Extraction You are a highly intelligent and accurate tabular data extractor from plain text input and especially from emergency text that carries address information, your inputs can be text @@ -91,28 +136,31 @@ def openai_response(ocr_input): resp = eval(resp.replace("'{", "{").replace("}'", "}")) resp["input"] = ocr_input dict_keys = [ - 'city', - 'distinct', - 'neighbourhood', - 'street', - 'no', - 'tel', - 'name_surname', - 'address', - 'input', + "city", + "distinct", + "neighbourhood", + "street", + "no", + "tel", + "name_surname", + "address", + "input", ] for key in dict_keys: if key not in resp.keys(): - resp[key] = '' + resp[key] = "" return resp with gr.Blocks() as demo: gr.Markdown( - """ + """ # Enkaz Bildirme Uygulaması - """) - gr.Markdown("Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın.") + """ + ) + gr.Markdown( + "Bu uygulamada ekran görüntüsü sürükleyip bırakarak AFAD'a enkaz bildirimi yapabilirsiniz. Mesajı metin olarak da girebilirsiniz, tam adresi ayrıştırıp döndürür. API olarak kullanmak isterseniz sayfanın en altında use via api'ya tıklayın." + ) with gr.Row(): img_area = gr.Image(label="Ekran Görüntüsü yükleyin 👇") ocr_result = gr.Textbox(label="Metin yükleyin 👇 ") @@ -133,13 +181,23 @@ with gr.Blocks() as demo: with gr.Row(): no = gr.Textbox(label="Kapı No") + submit_button.click( + get_parsed_address, + inputs=img_area, + outputs=open_api_text, + api_name="upload_image", + ) - submit_button.click(get_parsed_address, inputs = img_area, outputs = open_api_text, api_name="upload_image") - - ocr_result.change(openai_response, ocr_result, open_api_text, api_name="upload-text") + ocr_result.change( + openai_response, ocr_result, open_api_text, api_name="upload-text" + ) - open_api_text.change(text_dict, open_api_text, [city, distinct, neighbourhood, street, address, tel, name_surname, no]) + open_api_text.change( + text_dict, + open_api_text, + [city, distinct, neighbourhood, street, address, tel, name_surname, no], + ) if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch() diff --git a/ocr/.gitignore b/ocr/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b6e47617de110dea7ca47e087ff1347cc2646eda --- /dev/null +++ b/ocr/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/ocr/README.md b/ocr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fed321cb18036f9bc580e3b2e20dc04b62fe1e15 --- /dev/null +++ b/ocr/README.md @@ -0,0 +1 @@ +# deprem-ocr \ No newline at end of file diff --git a/ocr/__init__.py b/ocr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams b/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams new file mode 100644 index 0000000000000000000000000000000000000000..269760eae0bfeb635f385c257020a920580f7d34 --- /dev/null +++ b/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e9518c6ab706fe87842a8de1c098f990e67f9212b67c9ef8bc4bca6dc17b91a +size 2377917 diff --git a/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info b/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info new file mode 100644 index 0000000000000000000000000000000000000000..622d87ba5e353458afb1250ceb916ab24e30ac8e Binary files /dev/null and b/ocr/ch_PP-OCRv3_det_infer/inference.pdiparams.info differ diff --git a/ocr/ch_PP-OCRv3_det_infer/inference.pdmodel b/ocr/ch_PP-OCRv3_det_infer/inference.pdmodel new file mode 100644 index 0000000000000000000000000000000000000000..f43665a3c2b0f86f3e948534063e903687a8ef63 --- /dev/null +++ b/ocr/ch_PP-OCRv3_det_infer/inference.pdmodel @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74b075e6cfbc8206dab2eee86a6a8bd015a7be612b2bf6d1a1ef878d31df84f7 +size 1413260 diff --git a/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams b/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams new file mode 100644 index 0000000000000000000000000000000000000000..b398d1f288cb7514cc069469e8e7df137747aaf4 --- /dev/null +++ b/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d99d4279f7c64471b8f0be426ee09a46c0f1ecb344406bf0bb9571f670e8d0c7 +size 10614098 diff --git a/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info b/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info new file mode 100644 index 0000000000000000000000000000000000000000..9e133bfd81a3c0ef1e39b7ca1acb933c5521eb22 Binary files /dev/null and b/ocr/ch_PP-OCRv3_rec_infer/inference.pdiparams.info differ diff --git a/ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel b/ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel new file mode 100644 index 0000000000000000000000000000000000000000..ba592fe80043fd68de46e183b44c27aab43e61dc --- /dev/null +++ b/ocr/ch_PP-OCRv3_rec_infer/inference.pdmodel @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9beb0b9520d34bde2a0f92581ed64db7e4d6c76abead8b859189ea72db9ee20 +size 1266415 diff --git a/ocr/detector.py b/ocr/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ad1578066b20de34ce7a11be6c6f456f2c26ce --- /dev/null +++ b/ocr/detector.py @@ -0,0 +1,248 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) + +os.environ["FLAGS_allocator_strategy"] = "auto_growth" + +import json +import sys +import time + +import cv2 +import numpy as np + +import utility +from postprocess import build_post_process +from ppocr.data import create_operators, transform + + +class TextDetector(object): + def __init__(self, args): + self.args = args + self.det_algorithm = args.det_algorithm + self.use_onnx = args.use_onnx + pre_process_list = [ + { + "DetResizeForTest": { + "limit_side_len": args.det_limit_side_len, + "limit_type": args.det_limit_type, + } + }, + { + "NormalizeImage": { + "std": [0.229, 0.224, 0.225], + "mean": [0.485, 0.456, 0.406], + "scale": "1./255.", + "order": "hwc", + } + }, + {"ToCHWImage": None}, + {"KeepKeys": {"keep_keys": ["image", "shape"]}}, + ] + postprocess_params = {} + if self.det_algorithm == "DB": + postprocess_params["name"] = "DBPostProcess" + postprocess_params["thresh"] = args.det_db_thresh + postprocess_params["box_thresh"] = args.det_db_box_thresh + postprocess_params["max_candidates"] = 1000 + postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio + postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode + elif self.det_algorithm == "EAST": + postprocess_params["name"] = "EASTPostProcess" + postprocess_params["score_thresh"] = args.det_east_score_thresh + postprocess_params["cover_thresh"] = args.det_east_cover_thresh + postprocess_params["nms_thresh"] = args.det_east_nms_thresh + elif self.det_algorithm == "SAST": + pre_process_list[0] = { + "DetResizeForTest": {"resize_long": args.det_limit_side_len} + } + postprocess_params["name"] = "SASTPostProcess" + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: + postprocess_params["sample_pts_num"] = 6 + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["sample_pts_num"] = 2 + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": + postprocess_params["name"] = "PSEPostProcess" + postprocess_params["thresh"] = args.det_pse_thresh + postprocess_params["box_thresh"] = args.det_pse_box_thresh + postprocess_params["min_area"] = args.det_pse_min_area + postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["scale"] = args.det_pse_scale + self.det_pse_box_type = args.det_pse_box_type + elif self.det_algorithm == "FCE": + pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}} + postprocess_params["name"] = "FCEPostProcess" + postprocess_params["scales"] = args.scales + postprocess_params["alpha"] = args.alpha + postprocess_params["beta"] = args.beta + postprocess_params["fourier_degree"] = args.fourier_degree + postprocess_params["box_type"] = args.det_fce_box_type + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + ( + self.predictor, + self.input_tensor, + self.output_tensors, + self.config, + ) = utility.create_predictor(args, "det") + + if self.use_onnx: + img_h, img_w = self.input_tensor.shape[2:] + if img_h is not None and img_w is not None and img_h > 0 and img_w > 0: + pre_process_list[0] = { + "DetResizeForTest": {"image_shape": [img_h, img_w]} + } + self.preprocess_op = create_operators(pre_process_list) + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + ori_im = img.copy() + data = {"image": img} + + st = time.time() + + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + if self.det_algorithm == "EAST": + preds["f_geo"] = outputs[0] + preds["f_score"] = outputs[1] + elif self.det_algorithm == "SAST": + preds["f_border"] = outputs[0] + preds["f_score"] = outputs[1] + preds["f_tco"] = outputs[2] + preds["f_tvo"] = outputs[3] + elif self.det_algorithm in ["DB", "PSE"]: + preds["maps"] = outputs[0] + elif self.det_algorithm == "FCE": + for i, output in enumerate(outputs): + preds["level_{}".format(i)] = output + else: + raise NotImplementedError + + # self.predictor.try_shrink_memory() + post_result = self.postprocess_op(preds, shape_list) + dt_boxes = post_result[0]["points"] + if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( + self.det_algorithm in ["PSE", "FCE"] + and self.postprocess_op.box_type == "poly" + ): + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + + et = time.time() + return dt_boxes, et - st + + +if __name__ == "__main__": + args = utility.parse_args() + image_file_list = ["images/y.png"] + text_detector = TextDetector(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(2): + res = text_detector(img) + + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + + save_results = [] + for image_file in image_file_list: + img = cv2.imread(image_file) + + for _ in range(10): + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st + print(elapse * 1000) + if count > 0: + total_time += elapse + count += 1 + save_pred = ( + os.path.basename(image_file) + + "\t" + + str(json.dumps([x.tolist() for x in dt_boxes])) + + "\n" + ) + save_results.append(save_pred) + src_im = utility.draw_text_det_res(dt_boxes, image_file) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) + + with open(os.path.join(draw_img_save, "det_results.txt"), "w") as f: + f.writelines(save_results) + f.close() diff --git a/ocr/inference.py b/ocr/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ed30bbf97423751c3691551539ffbbcc2031d3 --- /dev/null +++ b/ocr/inference.py @@ -0,0 +1,68 @@ +import os + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import time +import requests +from io import BytesIO + +import utility +from detector import * +from recognizer import * + +# Global Detector and Recognizer +args = utility.parse_args() +text_recognizer = TextRecognizer(args) +text_detector = TextDetector(args) + + +def apply_ocr(img): + # Detect text regions + dt_boxes, _ = text_detector(img) + + boxes = [] + for box in dt_boxes: + p1, p2, p3, p4 = box + x1 = min(p1[0], p2[0], p3[0], p4[0]) + y1 = min(p1[1], p2[1], p3[1], p4[1]) + x2 = max(p1[0], p2[0], p3[0], p4[0]) + y2 = max(p1[1], p2[1], p3[1], p4[1]) + boxes.append([x1, y1, x2, y2]) + + # Recognize text + img_list = [] + for i in range(len(boxes)): + x1, y1, x2, y2 = map(int, boxes[i]) + img_list.append(img.copy()[y1:y2, x1:x2]) + img_list.reverse() + + rec_res, _ = text_recognizer(img_list) + + # Postprocess + total_text = "" + table = dict() + for i in range(len(rec_res)): + table[i] = { + "text": rec_res[i][0], + } + total_text += rec_res[i][0] + " " + + total_text = total_text.strip() + return total_text + + +def main(): + image_url = "https://i.ibb.co/kQvHGjj/aewrg.png" + response = requests.get(image_url) + img = np.array(Image.open(BytesIO(response.content)).convert("RGB")) + + t0 = time.time() + epoch = 1 + for _ in range(epoch): + ocr_text = apply_ocr(img) + print("Elapsed time:", (time.time() - t0) * 1000 / epoch, "ms") + print("Output:", ocr_text) + + +if __name__ == "__main__": + main() diff --git a/ocr/postprocess/__init__.py b/ocr/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92895892cfde7940299802e664bf48e2a91aecf9 --- /dev/null +++ b/ocr/postprocess/__init__.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy + +__all__ = ["build_post_process"] + +from .cls_postprocess import ClsPostProcess +from .db_postprocess import DBPostProcess, DistillationDBPostProcess +from .east_postprocess import EASTPostProcess +from .fce_postprocess import FCEPostProcess +from .pg_postprocess import PGPostProcess +from .rec_postprocess import ( + AttnLabelDecode, + CTCLabelDecode, + DistillationCTCLabelDecode, + NRTRLabelDecode, + PRENLabelDecode, + SARLabelDecode, + SEEDLabelDecode, + SRNLabelDecode, + TableLabelDecode, +) +from .sast_postprocess import SASTPostProcess +from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess +from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess + + +def build_post_process(config, global_config=None): + support_dict = [ + "DBPostProcess", + "EASTPostProcess", + "SASTPostProcess", + "FCEPostProcess", + "CTCLabelDecode", + "AttnLabelDecode", + "ClsPostProcess", + "SRNLabelDecode", + "PGPostProcess", + "DistillationCTCLabelDecode", + "TableLabelDecode", + "DistillationDBPostProcess", + "NRTRLabelDecode", + "SARLabelDecode", + "SEEDLabelDecode", + "VQASerTokenLayoutLMPostProcess", + "VQAReTokenLayoutLMPostProcess", + "PRENLabelDecode", + "DistillationSARLabelDecode", + ] + + if config["name"] == "PSEPostProcess": + from .pse_postprocess import PSEPostProcess + + support_dict.append("PSEPostProcess") + + config = copy.deepcopy(config) + module_name = config.pop("name") + if module_name == "None": + return + if global_config is not None: + config.update(global_config) + assert module_name in support_dict, Exception( + "post process only support {}".format(support_dict) + ) + module_class = eval(module_name)(**config) + return module_class diff --git a/ocr/postprocess/cls_postprocess.py b/ocr/postprocess/cls_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbebbcefaeac3fa0e530d806df0fa85b626144a --- /dev/null +++ b/ocr/postprocess/cls_postprocess.py @@ -0,0 +1,30 @@ +import paddle + + +class ClsPostProcess(object): + """Convert between text-label and text-index""" + + def __init__(self, label_list=None, key=None, **kwargs): + super(ClsPostProcess, self).__init__() + self.label_list = label_list + self.key = key + + def __call__(self, preds, label=None, *args, **kwargs): + if self.key is not None: + preds = preds[self.key] + + label_list = self.label_list + if label_list is None: + label_list = {idx: idx for idx in range(preds.shape[-1])} + + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + pred_idxs = preds.argmax(axis=1) + decode_out = [ + (label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs) + ] + if label is None: + return decode_out + label = [(label_list[idx], 1.0) for idx in label] + return decode_out, label diff --git a/ocr/postprocess/db_postprocess.py b/ocr/postprocess/db_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..2219b3f217c4c4a96ec848d33b6c154ccc26e4c8 --- /dev/null +++ b/ocr/postprocess/db_postprocess.py @@ -0,0 +1,207 @@ +from __future__ import absolute_import, division, print_function + +import cv2 +import numpy as np +import paddle +import pyclipper +from shapely.geometry import Polygon + + +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__( + self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + **kwargs + ): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", + "fast", + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours( + (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE + ) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height + ) + boxes.append(box.astype(np.int16)) + scores.append(score) + return np.array(boxes, dtype=np.int16), scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + """ + box_score_fast: use bbox mean score as the mean score + """ + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def box_score_slow(self, bitmap, contour): + """ + box_score_slow: use polyon mean score as the mean score + """ + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def __call__(self, outs_dict, shape_list): + pred = outs_dict["maps"] + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel, + ) + else: + mask = segmentation[batch_index] + boxes, scores = self.boxes_from_bitmap( + pred[batch_index], mask, src_w, src_h + ) + + boxes_batch.append({"points": boxes}) + return boxes_batch + + +class DistillationDBPostProcess(object): + def __init__( + self, + model_name=["student"], + key=None, + thresh=0.3, + box_thresh=0.6, + max_candidates=1000, + unclip_ratio=1.5, + use_dilation=False, + score_mode="fast", + **kwargs + ): + self.model_name = model_name + self.key = key + self.post_process = DBPostProcess( + thresh=thresh, + box_thresh=box_thresh, + max_candidates=max_candidates, + unclip_ratio=unclip_ratio, + use_dilation=use_dilation, + score_mode=score_mode, + ) + + def __call__(self, predicts, shape_list): + results = {} + for k in self.model_name: + results[k] = self.post_process(predicts[k], shape_list=shape_list) + return results diff --git a/ocr/postprocess/east_postprocess.py b/ocr/postprocess/east_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..94ca238deafb0435295b8a68853d967ae1507a95 --- /dev/null +++ b/ocr/postprocess/east_postprocess.py @@ -0,0 +1,122 @@ +from __future__ import absolute_import, division, print_function + +import cv2 +import numpy as np +import paddle + +from .locality_aware_nms import nms_locality + + +class EASTPostProcess(object): + """ + The post process for EAST. + """ + + def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs): + + self.score_thresh = score_thresh + self.cover_thresh = cover_thresh + self.nms_thresh = nms_thresh + + def restore_rectangle_quad(self, origin, geometry): + """ + Restore rectangle from quadrangle. + """ + # quad + origin_concat = np.concatenate( + (origin, origin, origin, origin), axis=1 + ) # (n, 8) + pred_quads = origin_concat - geometry + pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2) + return pred_quads + + def detect( + self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2 + ): + """ + restore text boxes from score map and geo map + """ + + score_map = score_map[0] + geo_map = np.swapaxes(geo_map, 1, 0) + geo_map = np.swapaxes(geo_map, 1, 2) + # filter the score map + xy_text = np.argwhere(score_map > score_thresh) + if len(xy_text) == 0: + return [] + # sort the text boxes via the y axis + xy_text = xy_text[np.argsort(xy_text[:, 0])] + # restore quad proposals + text_box_restored = self.restore_rectangle_quad( + xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :] + ) + boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) + boxes[:, :8] = text_box_restored.reshape((-1, 8)) + boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] + + try: + import lanms + + boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) + except: + print( + "you should install lanms by pip3 install lanms-nova to speed up nms_locality" + ) + boxes = nms_locality(boxes.astype(np.float64), nms_thresh) + if boxes.shape[0] == 0: + return [] + # Here we filter some low score boxes by the average score map, + # this is different from the orginal paper. + for i, box in enumerate(boxes): + mask = np.zeros_like(score_map, dtype=np.uint8) + cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) + boxes[i, 8] = cv2.mean(score_map, mask)[0] + boxes = boxes[boxes[:, 8] > cover_thresh] + return boxes + + def sort_poly(self, p): + """ + Sort polygons. + """ + min_axis = np.argmin(np.sum(p, axis=1)) + p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] + if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): + return p + else: + return p[[0, 3, 2, 1]] + + def __call__(self, outs_dict, shape_list): + score_list = outs_dict["f_score"] + geo_list = outs_dict["f_geo"] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + geo_list = geo_list.numpy() + img_num = len(shape_list) + dt_boxes_list = [] + for ino in range(img_num): + score = score_list[ino] + geo = geo_list[ino] + boxes = self.detect( + score_map=score, + geo_map=geo, + score_thresh=self.score_thresh, + cover_thresh=self.cover_thresh, + nms_thresh=self.nms_thresh, + ) + boxes_norm = [] + if len(boxes) > 0: + h, w = score.shape[1:] + src_h, src_w, ratio_h, ratio_w = shape_list[ino] + boxes = boxes[:, :8].reshape((-1, 4, 2)) + boxes[:, :, 0] /= ratio_w + boxes[:, :, 1] /= ratio_h + for i_box, box in enumerate(boxes): + box = self.sort_poly(box.astype(np.int32)) + if ( + np.linalg.norm(box[0] - box[1]) < 5 + or np.linalg.norm(box[3] - box[0]) < 5 + ): + continue + boxes_norm.append(box) + dt_boxes_list.append({"points": np.array(boxes_norm)}) + return dt_boxes_list diff --git a/ocr/postprocess/extract_textpoint_fast.py b/ocr/postprocess/extract_textpoint_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..977dce96d1bd9c90a6774fa5ea9ae8b6c56fa567 --- /dev/null +++ b/ocr/postprocess/extract_textpoint_fast.py @@ -0,0 +1,464 @@ +from __future__ import absolute_import, division, print_function + +from itertools import groupby + +import cv2 +import numpy as np +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos + ) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] + probs_seq = logits_seq + labels = np.argmax(probs_seq, axis=1) + dst_str = [k for k, v_ in groupby(labels) if k != C - 1] + detal = len(gather_info) // (pts_num - 1) + keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1] + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6): + """ + CTC decoder using multiple processes. + """ + decoder_str = [] + decoder_xys = [] + for gather_info in gather_info_list: + if len(gather_info) < pts_num: + continue + dst_str, xys_list = instance_ctc_greedy_decoder( + gather_info, logits_map, pts_num=pts_num + ) + dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str]) + if len(dst_str_readable) < 2: + continue + decoder_str.append(dst_str_readable) + decoder_xys.append(xys_list) + return decoder_str, decoder_xys + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction + ) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction + ) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len :, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = ( + np.round(left_start + left_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = ( + np.round(right_start + right_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len :, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = ( + np.round(left_start + left_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = ( + np.round(right_start + right_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + +def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0): + ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = ( + -shrink_ratio_of_width + * np.linalg.norm(left_quad[0] - left_quad[3]) + / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + ) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], + poly[point_num // 2 - 1], + poly[point_num // 2], + poly[point_num // 2 + 1], + ], + dtype=np.float32, + ) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm( + right_quad[0] - right_quad[3] + ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def restore_poly( + instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set +): + poly_list = [] + keep_str_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(keep_str) < 2: + print("--> too short, {}".format(keep_str)) + continue + + offset_expand = 1.0 + if valid_set == "totaltext": + offset_expand = 1.2 + + point_pair_list = [] + for y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) * offset_expand + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = ( + (ori_yx + offset)[:, ::-1] + * 4.0 + / np.array([ratio_w, ratio_h]).reshape(-1, 2) + ) + point_pair_list.append(point_pair) + + detected_poly = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2 + ) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + + keep_str_list.append(keep_str) + if valid_set == "partvgg": + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif valid_set == "totaltext": + poly_list.append(detected_poly) + else: + print("--> Not supported format.") + exit(-1) + return poly_list, keep_str_list + + +def generate_pivot_list_fast( + p_score, p_char_maps, f_direction, Lexicon_Table, score_thresh=0.5 +): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map.astype(np.uint8)) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8 + ) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + if len(pos_list) < 3: + continue + + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map + ) + all_pos_yxs.append(pos_list_sorted) + + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decoded_str, keep_yxs_list = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table + ) + return keep_yxs_list, decoded_str + + +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction + ) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction + ) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point diff --git a/ocr/postprocess/extract_textpoint_slow.py b/ocr/postprocess/extract_textpoint_slow.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5a902fb25817f5b17d224ee389c2c4325b2017 --- /dev/null +++ b/ocr/postprocess/extract_textpoint_slow.py @@ -0,0 +1,608 @@ +from __future__ import absolute_import, division, print_function + +import math +from itertools import groupby + +import cv2 +import numpy as np +from skimage.morphology._skeletonize import thin + + +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + +def point_pair2poly(point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + pair_length_list = [] + for point_pair in point_pair_list: + pair_length = np.linalg.norm(point_pair[0] - point_pair[1]) + pair_length_list.append(pair_length) + pair_length_list = np.array(pair_length_list) + pair_info = ( + pair_length_list.max(), + pair_length_list.min(), + pair_length_list.mean(), + ) + + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2), pair_info + + +def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + +def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = ( + -shrink_ratio_of_width + * np.linalg.norm(left_quad[0] - left_quad[3]) + / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + ) + left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], + poly[point_num // 2 - 1], + poly[point_num // 2], + poly[point_num // 2 + 1], + ], + dtype=np.float32, + ) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm( + right_quad[0] - right_quad[3] + ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos + ) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, logits_map, keep_blank_in_idxs=True): + """ + gather_info: [[x, y], [x, y] ...] + logits_map: H x W X (n_chars + 1) + """ + _, _, C = logits_map.shape + ys, xs = zip(*gather_info) + logits_seq = logits_map[list(ys), list(xs)] # n x 96 + probs_seq = softmax(logits_seq) + dst_str, keep_idx_list = ctc_greedy_decoder( + probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs + ) + keep_gather_list = [gather_info[idx] for idx in keep_idx_list] + return dst_str, keep_gather_list + + +def ctc_decoder_for_image(gather_info_list, logits_map, keep_blank_in_idxs=True): + """ + CTC decoder using multiple processes. + """ + decoder_results = [] + for gather_info in gather_info_list: + res = instance_ctc_greedy_decoder( + gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs + ) + decoder_results.append(res) + return decoder_results + + +def sort_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list, point_direction): + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 2) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction + ) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction + ) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point, np.array(sorted_direction) + + +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len :, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = ( + np.round(left_start + left_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = ( + np.round(right_start + right_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + binary_tcl_map: h x w + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len :, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + max_append_num = 2 * append_num + + left_list = [] + right_list = [] + for i in range(max_append_num): + ly, lx = ( + np.round(left_start + left_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ly < h and lx < w and (ly, lx) not in left_list: + if binary_tcl_map[ly, lx] > 0.5: + left_list.append((ly, lx)) + else: + break + + for i in range(max_append_num): + ry, rx = ( + np.round(right_start + right_step * (i + 1)) + .flatten() + .astype("int32") + .tolist() + ) + if ry < h and rx < w and (ry, rx) not in right_list: + if binary_tcl_map[ry, rx] > 0.5: + right_list.append((ry, rx)) + else: + break + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + +def generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_expand=True, + is_backbone=False, + image_id=0, +): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8 + ) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + pred_strs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + + if is_expand: + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map + ) + else: + pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True + ) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + pred_strs.append(decoded_str) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return pred_strs, instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_horizontal( + p_score, p_char_maps, f_direction, score_thresh=0.5, is_backbone=False, image_id=0 +): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map_bi = (p_score > score_thresh) * 1.0 + instance_count, instance_label_map = cv2.connectedComponents( + p_tcl_map_bi.astype(np.uint8), connectivity=8 + ) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 5: + continue + + # add rule here + main_direction = extract_main_direction(pos_list, f_direction) # y x + reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x + is_h_angle = abs(np.sum(main_direction * reference_directin)) < math.cos( + math.pi / 180 * 70 + ) + + point_yxs = np.array(pos_list) + max_y, max_x = np.max(point_yxs, axis=0) + min_y, min_x = np.min(point_yxs, axis=0) + is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x) + + pos_list_final = [] + if is_h_len: + xs = np.unique(xs) + for x in xs: + ys = instance_label_map[:, x].copy().reshape((-1,)) + y = int(np.where(ys == instance_id)[0].mean()) + pos_list_final.append((y, x)) + else: + ys = np.unique(ys) + for y in ys: + xs = instance_label_map[y, :].copy().reshape((-1,)) + x = int(np.where(xs == instance_id)[0].mean()) + pos_list_final.append((y, x)) + + pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True + ) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_slow( + p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0, +): + """ + Warp all the function together. + """ + if is_curved: + return generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_expand=True, + is_backbone=is_backbone, + image_id=image_id, + ) + else: + return generate_pivot_list_horizontal( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_backbone=is_backbone, + image_id=image_id, + ) + + +# for refine module +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction + ) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction + ) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point + + +def generate_pivot_list_tt_inference( + p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0, +): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8 + ) + + # get TCL Instance + all_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map + ) + pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) + all_pos_yxs.append(pos_list_sorted_with_id) + return all_pos_yxs diff --git a/ocr/postprocess/fce_postprocess.py b/ocr/postprocess/fce_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8714e4b16ae63c1add49642e49613a00cdb9f9 --- /dev/null +++ b/ocr/postprocess/fce_postprocess.py @@ -0,0 +1,234 @@ +import cv2 +import numpy as np +import paddle +from numpy.fft import ifft + +from .poly_nms import * + + +def fill_hole(input_mask): + h, w = input_mask.shape + canvas = np.zeros((h + 2, w + 2), np.uint8) + canvas[1 : h + 1, 1 : w + 1] = input_mask.copy() + + mask = np.zeros((h + 4, w + 4), np.uint8) + + cv2.floodFill(canvas, mask, (0, 0), 1) + canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool) + + return ~canvas | input_mask + + +def fourier2poly(fourier_coeff, num_reconstr_points=50): + """Inverse Fourier transform + Args: + fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1), + with n and k being candidates number and Fourier degree + respectively. + num_reconstr_points (int): Number of reconstructed polygon points. + Returns: + Polygons (ndarray): The reconstructed polygons shaped (n, n') + """ + + a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype="complex") + k = (len(fourier_coeff[0]) - 1) // 2 + + a[:, 0 : k + 1] = fourier_coeff[:, k:] + a[:, -k:] = fourier_coeff[:, :k] + + poly_complex = ifft(a) * num_reconstr_points + polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2)) + polygon[:, :, 0] = poly_complex.real + polygon[:, :, 1] = poly_complex.imag + return polygon.astype("int32").reshape((len(fourier_coeff), -1)) + + +class FCEPostProcess(object): + """ + The post process for FCENet. + """ + + def __init__( + self, + scales, + fourier_degree=5, + num_reconstr_points=50, + decoding_type="fcenet", + score_thr=0.3, + nms_thr=0.1, + alpha=1.0, + beta=1.0, + box_type="poly", + **kwargs + ): + + self.scales = scales + self.fourier_degree = fourier_degree + self.num_reconstr_points = num_reconstr_points + self.decoding_type = decoding_type + self.score_thr = score_thr + self.nms_thr = nms_thr + self.alpha = alpha + self.beta = beta + self.box_type = box_type + + def __call__(self, preds, shape_list): + score_maps = [] + for key, value in preds.items(): + if isinstance(value, paddle.Tensor): + value = value.numpy() + cls_res = value[:, :4, :, :] + reg_res = value[:, 4:, :, :] + score_maps.append([cls_res, reg_res]) + + return self.get_boundary(score_maps, shape_list) + + def resize_boundary(self, boundaries, scale_factor): + """Rescale boundaries via scale_factor. + + Args: + boundaries (list[list[float]]): The boundary list. Each boundary + with size 2k+1 with k>=4. + scale_factor(ndarray): The scale factor of size (4,). + + Returns: + boundaries (list[list[float]]): The scaled boundaries. + """ + boxes = [] + scores = [] + for b in boundaries: + sz = len(b) + valid_boundary(b, True) + scores.append(b[-1]) + b = ( + ( + np.array(b[: sz - 1]) + * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1)) + ) + .flatten() + .tolist() + ) + boxes.append(np.array(b).reshape([-1, 2])) + + return np.array(boxes, dtype=np.float32), scores + + def get_boundary(self, score_maps, shape_list): + assert len(score_maps) == len(self.scales) + boundaries = [] + for idx, score_map in enumerate(score_maps): + scale = self.scales[idx] + boundaries = boundaries + self._get_boundary_single(score_map, scale) + + # nms + boundaries = poly_nms(boundaries, self.nms_thr) + boundaries, scores = self.resize_boundary( + boundaries, (1 / shape_list[0, 2:]).tolist()[::-1] + ) + + boxes_batch = [dict(points=boundaries, scores=scores)] + return boxes_batch + + def _get_boundary_single(self, score_map, scale): + assert len(score_map) == 2 + assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 + + return self.fcenet_decode( + preds=score_map, + fourier_degree=self.fourier_degree, + num_reconstr_points=self.num_reconstr_points, + scale=scale, + alpha=self.alpha, + beta=self.beta, + box_type=self.box_type, + score_thr=self.score_thr, + nms_thr=self.nms_thr, + ) + + def fcenet_decode( + self, + preds, + fourier_degree, + num_reconstr_points, + scale, + alpha=1.0, + beta=2.0, + box_type="poly", + score_thr=0.3, + nms_thr=0.1, + ): + """Decoding predictions of FCENet to instances. + + Args: + preds (list(Tensor)): The head output tensors. + fourier_degree (int): The maximum Fourier transform degree k. + num_reconstr_points (int): The points number of the polygon + reconstructed from predicted Fourier coefficients. + scale (int): The down-sample scale of the prediction. + alpha (float) : The parameter to calculate final scores. Score_{final} + = (Score_{text region} ^ alpha) + * (Score_{text center region}^ beta) + beta (float) : The parameter to calculate final score. + box_type (str): Boundary encoding type 'poly' or 'quad'. + score_thr (float) : The threshold used to filter out the final + candidates. + nms_thr (float) : The threshold of nms. + + Returns: + boundaries (list[list[float]]): The instance boundary and confidence + list. + """ + assert isinstance(preds, list) + assert len(preds) == 2 + assert box_type in ["poly", "quad"] + + cls_pred = preds[0][0] + tr_pred = cls_pred[0:2] + tcl_pred = cls_pred[2:] + + reg_pred = preds[1][0].transpose([1, 2, 0]) + x_pred = reg_pred[:, :, : 2 * fourier_degree + 1] + y_pred = reg_pred[:, :, 2 * fourier_degree + 1 :] + + score_pred = (tr_pred[1] ** alpha) * (tcl_pred[1] ** beta) + tr_pred_mask = (score_pred) > score_thr + tr_mask = fill_hole(tr_pred_mask) + + tr_contours, _ = cv2.findContours( + tr_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) # opencv4 + + mask = np.zeros_like(tr_mask) + boundaries = [] + for cont in tr_contours: + deal_map = mask.copy().astype(np.int8) + cv2.drawContours(deal_map, [cont], -1, 1, -1) + + score_map = score_pred * deal_map + score_mask = score_map > 0 + xy_text = np.argwhere(score_mask) + dxy = xy_text[:, 1] + xy_text[:, 0] * 1j + + x, y = x_pred[score_mask], y_pred[score_mask] + c = x + y * 1j + c[:, fourier_degree] = c[:, fourier_degree] + dxy + c *= scale + + polygons = fourier2poly(c, num_reconstr_points) + score = score_map[score_mask].reshape(-1, 1) + polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr) + + boundaries = boundaries + polygons + + boundaries = poly_nms(boundaries, nms_thr) + + if box_type == "quad": + new_boundaries = [] + for boundary in boundaries: + poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32) + score = boundary[-1] + points = cv2.boxPoints(cv2.minAreaRect(poly)) + points = np.int0(points) + new_boundaries.append(points.reshape(-1).tolist() + [score]) + boundaries = new_boundaries + + return boundaries diff --git a/ocr/postprocess/locality_aware_nms.py b/ocr/postprocess/locality_aware_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd9814cb853f3a5e2a88168b5be557b7f99a5fb --- /dev/null +++ b/ocr/postprocess/locality_aware_nms.py @@ -0,0 +1,198 @@ +""" +Locality aware nms. +This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py +""" + +import numpy as np +from shapely.geometry import Polygon + + +def intersection(g, p): + """ + Intersection. + """ + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + g = g.buffer(0) + p = p.buffer(0) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter / union + + +def intersection_iog(g, p): + """ + Intersection_iog. + """ + g = Polygon(g[:8].reshape((4, 2))) + p = Polygon(p[:8].reshape((4, 2))) + if not g.is_valid or not p.is_valid: + return 0 + inter = Polygon(g).intersection(Polygon(p)).area + # union = g.area + p.area - inter + union = p.area + if union == 0: + print("p_area is very small") + return 0 + else: + return inter / union + + +def weighted_merge(g, p): + """ + Weighted merge. + """ + g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8]) + g[8] = g[8] + p[8] + return g + + +def standard_nms(S, thres): + """ + Standard nms. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return S[keep] + + +def standard_nms_inds(S, thres): + """ + Standard nms, retun inds. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return keep + + +def nms(S, thres): + """ + nms. + """ + order = np.argsort(S[:, 8])[::-1] + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) + + inds = np.where(ovr <= thres)[0] + order = order[inds + 1] + + return keep + + +def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2): + """ + soft_nms + :para boxes_in, N x 9 (coords + score) + :para threshould, eliminate cases min score(0.001) + :para Nt_thres, iou_threshi + :para sigma, gaussian weght + :method, linear or gaussian + """ + boxes = boxes_in.copy() + N = boxes.shape[0] + if N is None or N < 1: + return np.array([]) + pos, maxpos = 0, 0 + weight = 0.0 + inds = np.arange(N) + tbox, sbox = boxes[0].copy(), boxes[0].copy() + for i in range(N): + maxscore = boxes[i, 8] + maxpos = i + tbox = boxes[i].copy() + ti = inds[i] + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 8]: + maxscore = boxes[pos, 8] + maxpos = pos + pos = pos + 1 + # add max box as a detection + boxes[i, :] = boxes[maxpos, :] + inds[i] = inds[maxpos] + # swap + boxes[maxpos, :] = tbox + inds[maxpos] = ti + tbox = boxes[i].copy() + pos = i + 1 + # NMS iteration + while pos < N: + sbox = boxes[pos].copy() + ts_iou_val = intersection(tbox, sbox) + if ts_iou_val > 0: + if method == 1: + if ts_iou_val > Nt_thres: + weight = 1 - ts_iou_val + else: + weight = 1 + elif method == 2: + weight = np.exp(-1.0 * ts_iou_val**2 / sigma) + else: + if ts_iou_val > Nt_thres: + weight = 0 + else: + weight = 1 + boxes[pos, 8] = weight * boxes[pos, 8] + # if box score falls below thresold, discard the box by + # swaping last box update N + if boxes[pos, 8] < threshold: + boxes[pos, :] = boxes[N - 1, :] + inds[pos] = inds[N - 1] + N = N - 1 + pos = pos - 1 + pos = pos + 1 + + return boxes[:N] + + +def nms_locality(polys, thres=0.3): + """ + locality aware nms of EAST + :param polys: a N*9 numpy array. first 8 coordinates, then prob + :return: boxes after nms + """ + S = [] + p = None + for g in polys: + if p is not None and intersection(g, p) > thres: + p = weighted_merge(g, p) + else: + if p is not None: + S.append(p) + p = g + if p is not None: + S.append(p) + + if len(S) == 0: + return np.array([]) + return standard_nms(np.array(S), thres) + + +if __name__ == "__main__": + # 343,350,448,135,474,143,369,359 + print(Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])).area) diff --git a/ocr/postprocess/pg_postprocess.py b/ocr/postprocess/pg_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..a344ac7b987446713b38bcc517f7273daf2c945f --- /dev/null +++ b/ocr/postprocess/pg_postprocess.py @@ -0,0 +1,189 @@ +from __future__ import absolute_import, division, print_function + +import os +import sys + +import paddle + +from .extract_textpoint_fast import * +from .extract_textpoint_slow import * + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, "..")) + + +class PGNet_PostProcess(object): + # two different post-process + def __init__( + self, character_dict_path, valid_set, score_thresh, outs_dict, shape_list + ): + self.Lexicon_Table = get_dict(character_dict_path) + self.valid_set = valid_set + self.score_thresh = score_thresh + self.outs_dict = outs_dict + self.shape_list = shape_list + + def pg_postprocess_fast(self): + p_score = self.outs_dict["f_score"] + p_border = self.outs_dict["f_border"] + p_char = self.outs_dict["f_char"] + p_direction = self.outs_dict["f_direction"] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + instance_yxs_list, seq_strs = generate_pivot_list_fast( + p_score, + p_char, + p_direction, + self.Lexicon_Table, + score_thresh=self.score_thresh, + ) + poly_list, keep_str_list = restore_poly( + instance_yxs_list, + seq_strs, + p_border, + ratio_w, + ratio_h, + src_w, + src_h, + self.valid_set, + ) + data = { + "points": poly_list, + "texts": keep_str_list, + } + return data + + def pg_postprocess_slow(self): + p_score = self.outs_dict["f_score"] + p_border = self.outs_dict["f_border"] + p_char = self.outs_dict["f_char"] + p_direction = self.outs_dict["f_direction"] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() + else: + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + src_h, src_w, ratio_h, ratio_w = self.shape_list[0] + is_curved = self.valid_set == "totaltext" + char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow( + p_score, + p_char, + p_direction, + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved, + ) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + yx_center_line.append(yx_center_line[-1]) + + offset_expand = 1.0 + if self.valid_set == "totaltext": + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm(offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0 + ) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = ( + (ori_yx + offset)[:, ::-1] + * 4.0 + / np.array([ratio_w, ratio_h]).reshape(-1, 2) + ) + point_pair_list.append(point_pair) + + all_point_list.append( + [int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))] + ) + all_point_pair_list.append(point_pair.round().astype(np.int32).tolist()) + + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2 + ) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + continue + + keep_str_list.append(keep_str) + detected_poly = np.round(detected_poly).astype("int32") + if self.valid_set == "partvgg": + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], : + ] + poly_list.append(detected_poly) + elif self.valid_set == "totaltext": + poly_list.append(detected_poly) + else: + print("--> Not supported format.") + exit(-1) + data = { + "points": poly_list, + "texts": keep_str_list, + } + return data + + +class PGPostProcess(object): + """ + The post process for PGNet. + """ + + def __init__(self, character_dict_path, valid_set, score_thresh, mode, **kwargs): + self.character_dict_path = character_dict_path + self.valid_set = valid_set + self.score_thresh = score_thresh + self.mode = mode + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def __call__(self, outs_dict, shape_list): + post = PGNet_PostProcess( + self.character_dict_path, + self.valid_set, + self.score_thresh, + outs_dict, + shape_list, + ) + if self.mode == "fast": + data = post.pg_postprocess_fast() + else: + data = post.pg_postprocess_slow() + return data diff --git a/ocr/postprocess/poly_nms.py b/ocr/postprocess/poly_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebb92e720632ba43be751d671d775936e5c1d74 --- /dev/null +++ b/ocr/postprocess/poly_nms.py @@ -0,0 +1,132 @@ +import numpy as np +from shapely.geometry import Polygon + + +def points2polygon(points): + """Convert k points to 1 polygon. + + Args: + points (ndarray or list): A ndarray or a list of shape (2k) + that indicates k points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(points, list): + points = np.array(points) + + assert isinstance(points, np.ndarray) + assert (points.size % 2 == 0) and (points.size >= 8) + + point_mat = points.reshape([-1, 2]) + return Polygon(point_mat) + + +def poly_intersection(poly_det, poly_gt, buffer=0.0001): + """Calculate the intersection area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + intersection_area (float): The intersection area between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + + if buffer == 0: + poly_inter = poly_det & poly_gt + else: + poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer) + return poly_inter.area, poly_inter + + +def poly_union(poly_det, poly_gt): + """Calculate the union area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + union_area (float): The union area between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + + area_det = poly_det.area + area_gt = poly_gt.area + area_inters, _ = poly_intersection(poly_det, poly_gt) + return area_det + area_gt - area_inters + + +def valid_boundary(x, with_score=True): + num = len(x) + if num < 8: + return False + if num % 2 == 0 and (not with_score): + return True + if num % 2 == 1 and with_score: + return True + + return False + + +def boundary_iou(src, target): + """Calculate the IOU between two boundaries. + + Args: + src (list): Source boundary. + target (list): Target boundary. + + Returns: + iou (float): The iou between two boundaries. + """ + assert valid_boundary(src, False) + assert valid_boundary(target, False) + src_poly = points2polygon(src) + target_poly = points2polygon(target) + + return poly_iou(src_poly, target_poly) + + +def poly_iou(poly_det, poly_gt): + """Calculate the IOU between two polygons. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + iou (float): The IOU between two polygons. + """ + assert isinstance(poly_det, Polygon) + assert isinstance(poly_gt, Polygon) + area_inters, _ = poly_intersection(poly_det, poly_gt) + area_union = poly_union(poly_det, poly_gt) + if area_union == 0: + return 0.0 + return area_inters / area_union + + +def poly_nms(polygons, threshold): + assert isinstance(polygons, list) + + polygons = np.array(sorted(polygons, key=lambda x: x[-1])) + + keep_poly = [] + index = [i for i in range(polygons.shape[0])] + + while len(index) > 0: + keep_poly.append(polygons[index[-1]].tolist()) + A = polygons[index[-1]][:-1] + index = np.delete(index, -1) + iou_list = np.zeros((len(index),)) + for i in range(len(index)): + B = polygons[index[i]][:-1] + iou_list[i] = boundary_iou(A, B) + remove_index = np.where(iou_list > threshold) + index = np.delete(index, remove_index) + + return keep_poly diff --git a/ocr/postprocess/pse_postprocess/__init__.py b/ocr/postprocess/pse_postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df8f0deabbb1a9f09a8db2849ef16e90ae1c061f --- /dev/null +++ b/ocr/postprocess/pse_postprocess/__init__.py @@ -0,0 +1 @@ +from .pse_postprocess import PSEPostProcess diff --git a/ocr/postprocess/pse_postprocess/pse/__init__.py b/ocr/postprocess/pse_postprocess/pse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6d6f601e79df62c26c5da6adf42c06b340cee3 --- /dev/null +++ b/ocr/postprocess/pse_postprocess/pse/__init__.py @@ -0,0 +1,20 @@ +import os +import subprocess +import sys + +python_path = sys.executable + +ori_path = os.getcwd() +os.chdir("ppocr/postprocess/pse_postprocess/pse") +if ( + subprocess.call("{} setup.py build_ext --inplace".format(python_path), shell=True) + != 0 +): + raise RuntimeError( + "Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+".format( + os.path.dirname(os.path.realpath(__file__)) + ) + ) +os.chdir(ori_path) + +from .pse import pse diff --git a/ocr/postprocess/pse_postprocess/pse/pse.pyx b/ocr/postprocess/pse_postprocess/pse/pse.pyx new file mode 100644 index 0000000000000000000000000000000000000000..40acbeec7ada0d75a8efa0ac328c59caf0eb9c14 --- /dev/null +++ b/ocr/postprocess/pse_postprocess/pse/pse.pyx @@ -0,0 +1,72 @@ + +import cv2 +import numpy as np + +cimport cython +cimport libcpp +cimport libcpp.pair +cimport libcpp.queue +cimport numpy as np +from libcpp.pair cimport * +from libcpp.queue cimport * + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, + np.ndarray[np.int32_t, ndim=2] label, + int kernel_num, + int label_num, + float min_area=0): + cdef np.ndarray[np.int32_t, ndim=2] pred + pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) + + for label_idx in range(1, label_num): + if np.sum(label == label_idx) < min_area: + label[label == label_idx] = 0 + + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ + queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() + cdef np.int16_t* dx = [-1, 1, 0, 0] + cdef np.int16_t* dy = [0, 0, -1, 1] + cdef np.int16_t tmpx, tmpy + + points = np.array(np.where(label > 0)).transpose((1, 0)) + for point_idx in range(points.shape[0]): + tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = label[tmpx, tmpy] + + cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur + cdef int cur_label + for kernel_idx in range(kernel_num - 1, -1, -1): + while not que.empty(): + cur = que.front() + que.pop() + cur_label = pred[cur.first, cur.second] + + is_edge = True + for j in range(4): + tmpx = cur.first + dx[j] + tmpy = cur.second + dy[j] + if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: + continue + if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: + continue + + que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) + pred[tmpx, tmpy] = cur_label + is_edge = False + if is_edge: + nxt_que.push(cur) + + que, nxt_que = nxt_que, que + + return pred + +def pse(kernels, min_area): + kernel_num = kernels.shape[0] + label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) + return _pse(kernels[:-1], label, kernel_num, label_num, min_area) \ No newline at end of file diff --git a/ocr/postprocess/pse_postprocess/pse/setup.py b/ocr/postprocess/pse_postprocess/pse/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..65eb77413c377c1a1461a356cc8d841f8597f0ea --- /dev/null +++ b/ocr/postprocess/pse_postprocess/pse/setup.py @@ -0,0 +1,19 @@ +from distutils.core import Extension, setup + +import numpy +from Cython.Build import cythonize + +setup( + ext_modules=cythonize( + Extension( + "pse", + sources=["pse.pyx"], + language="c++", + include_dirs=[numpy.get_include()], + library_dirs=[], + libraries=[], + extra_compile_args=["-O3"], + extra_link_args=[], + ) + ) +) diff --git a/ocr/postprocess/pse_postprocess/pse_postprocess.py b/ocr/postprocess/pse_postprocess/pse_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..7a07830ac2df19d667bdef3994c57fa6fef2cadb --- /dev/null +++ b/ocr/postprocess/pse_postprocess/pse_postprocess.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import, division, print_function + +import cv2 +import numpy as np +import paddle +from paddle.nn import functional as F + +from .pse import pse + + +class PSEPostProcess(object): + """ + The post process for PSE. + """ + + def __init__( + self, + thresh=0.5, + box_thresh=0.85, + min_area=16, + box_type="quad", + scale=4, + **kwargs + ): + assert box_type in ["quad", "poly"], "Only quad and poly is supported" + self.thresh = thresh + self.box_thresh = box_thresh + self.min_area = min_area + self.box_type = box_type + self.scale = scale + + def __call__(self, outs_dict, shape_list): + pred = outs_dict["maps"] + if not isinstance(pred, paddle.Tensor): + pred = paddle.to_tensor(pred) + pred = F.interpolate(pred, scale_factor=4 // self.scale, mode="bilinear") + + score = F.sigmoid(pred[:, 0, :, :]) + + kernels = (pred > self.thresh).astype("float32") + text_mask = kernels[:, 0, :, :] + kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask + + score = score.numpy() + kernels = kernels.numpy().astype(np.uint8) + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + boxes, scores = self.boxes_from_bitmap( + score[batch_index], kernels[batch_index], shape_list[batch_index] + ) + + boxes_batch.append({"points": boxes, "scores": scores}) + return boxes_batch + + def boxes_from_bitmap(self, score, kernels, shape): + label = pse(kernels, self.min_area) + return self.generate_box(score, label, shape) + + def generate_box(self, score, label, shape): + src_h, src_w, ratio_h, ratio_w = shape + label_num = np.max(label) + 1 + + boxes = [] + scores = [] + for i in range(1, label_num): + ind = label == i + points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1] + + if points.shape[0] < self.min_area: + label[ind] = 0 + continue + + score_i = np.mean(score[ind]) + if score_i < self.box_thresh: + label[ind] = 0 + continue + + if self.box_type == "quad": + rect = cv2.minAreaRect(points) + bbox = cv2.boxPoints(rect) + elif self.box_type == "poly": + box_height = np.max(points[:, 1]) + 10 + box_width = np.max(points[:, 0]) + 10 + + mask = np.zeros((box_height, box_width), np.uint8) + mask[points[:, 1], points[:, 0]] = 255 + + contours, _ = cv2.findContours( + mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + bbox = np.squeeze(contours[0], 1) + else: + raise NotImplementedError + + bbox[:, 0] = np.clip(np.round(bbox[:, 0] / ratio_w), 0, src_w) + bbox[:, 1] = np.clip(np.round(bbox[:, 1] / ratio_h), 0, src_h) + boxes.append(bbox) + scores.append(score_i) + return boxes, scores diff --git a/ocr/postprocess/rec_postprocess.py b/ocr/postprocess/rec_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ef86f973465949119da70e968b613b6c0e8902 --- /dev/null +++ b/ocr/postprocess/rec_postprocess.py @@ -0,0 +1,731 @@ +import re + +import numpy as np +import paddle +from paddle.nn import functional as F + + +class BaseRecLabelDecode(object): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False): + self.beg_str = "sos" + self.end_str = "eos" + + self.character_str = [] + if character_dict_path is None: + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + else: + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + self.character_str.append(line) + if use_space_char: + self.character_str.append(" ") + dict_character = list(self.character_str) + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + + char_list = [ + self.character[text_id] for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] + + text = "".join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple) or isinstance(preds, list): + preds = preds[-1] + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + if label is None: + return text + label = self.decode(label) + return text, label + + def add_special_char(self, dict_character): + dict_character = ["blank"] + dict_character + return dict_character + + +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__( + self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs + ): + super(DistillationCTCLabelDecode, self).__init__( + character_dict_path, use_space_char + ) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred["ctc"] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class NRTRLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=True, **kwargs): + super(NRTRLabelDecode, self).__init__(character_dict_path, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + + if len(preds) == 2: + preds_id = preds[0] + preds_prob = preds[1] + if isinstance(preds_id, paddle.Tensor): + preds_id = preds_id.numpy() + if isinstance(preds_prob, paddle.Tensor): + preds_prob = preds_prob.numpy() + if preds_id[0][0] == 2: + preds_idx = preds_id[:, 1:] + preds_prob = preds_prob[:, 1:] + else: + preds_idx = preds_id + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + else: + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label[:, 1:]) + return text, label + + def add_special_char(self, dict_character): + dict_character = ["blank", "", "", ""] + dict_character + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == 3: # end + break + try: + char_list.append(self.character[int(text_index[batch_idx][idx])]) + except: + continue + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = "".join(char_list) + result_list.append((text.lower(), np.mean(conf_list).tolist())) + return result_list + + +class AttnLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(AttnLabelDecode, self).__init__(character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): + continue + char_list.append(self.character[int(text_index[batch_idx][idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = "".join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + +class SEEDLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(SEEDLabelDecode, self).__init__(character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + self.padding_str = "padding" + self.end_str = "eos" + self.unknown = "unknown" + dict_character = dict_character + [self.end_str, self.padding_str, self.unknown] + return dict_character + + def get_ignored_tokens(self): + end_idx = self.get_beg_end_flag_idx("eos") + return [end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "sos": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "eos": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + [end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): + continue + char_list.append(self.character[int(text_index[batch_idx][idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = "".join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + """ + text = self.decode(text) + if label is None: + return text + else: + label = self.decode(label, is_remove_duplicate=False) + return text, label + """ + preds_idx = preds["rec_pred"] + if isinstance(preds_idx, paddle.Tensor): + preds_idx = preds_idx.numpy() + if "rec_pred_scores" in preds: + preds_idx = preds["rec_pred"] + preds_prob = preds["rec_pred_scores"] + else: + preds_idx = preds["rec_pred"].argmax(axis=2) + preds_prob = preds["rec_pred"].max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + +class SRNLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(SRNLabelDecode, self).__init__(character_dict_path, use_space_char) + self.max_text_length = kwargs.get("max_text_length", 25) + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds["predict"] + char_num = len(self.character_str) + 2 + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) + + preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) + + text = self.decode(preds_idx, preds_prob) + + if label is None: + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + return text + label = self.decode(label) + return text, label + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): + continue + char_list.append(self.character[int(text_index[batch_idx][idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = "".join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + +class TableLabelDecode(object): + """ """ + + def __init__(self, character_dict_path, **kwargs): + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + self.dict_idx_character = {} + for i, char in enumerate(list_character): + self.dict_idx_character[i] = char + self.dict_character[char] = i + self.dict_elem = {} + self.dict_idx_elem = {} + for i, elem in enumerate(list_elem): + self.dict_idx_elem[i] = elem + self.dict_elem[elem] = i + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode("utf-8").strip("\n").strip("\r\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode("utf-8").strip("\n").strip("\r\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode("utf-8").strip("\n").strip("\r\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def __call__(self, preds): + structure_probs = preds["structure_probs"] + loc_preds = preds["loc_preds"] + if isinstance(structure_probs, paddle.Tensor): + structure_probs = structure_probs.numpy() + if isinstance(loc_preds, paddle.Tensor): + loc_preds = loc_preds.numpy() + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + ( + structure_str, + structure_pos, + result_score_list, + result_elem_idx_list, + ) = self.decode(structure_idx, structure_probs, "elem") + res_html_code_list = [] + res_loc_list = [] + batch_num = len(structure_str) + for bno in range(batch_num): + res_loc = [] + for sno in range(len(structure_str[bno])): + text = structure_str[bno][sno] + if text in ["", " 0 and tmp_elem_idx == end_idx: + break + if tmp_elem_idx in ignored_tokens: + continue + + char_list.append(current_dict[tmp_elem_idx]) + elem_pos_list.append(idx) + score_list.append(structure_probs[batch_idx, idx]) + elem_idx_list.append(tmp_elem_idx) + result_list.append(char_list) + result_pos_list.append(elem_pos_list) + result_score_list.append(score_list) + result_elem_idx_list.append(elem_idx_list) + return result_list, result_pos_list, result_score_list, result_elem_idx_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = self.dict_character[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_character[self.end_str] + else: + assert False, ( + "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end + ) + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = self.dict_elem[self.beg_str] + elif beg_or_end == "end": + idx = self.dict_elem[self.end_str] + else: + assert False, ( + "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end + ) + else: + assert False, "Unsupport type %s in char_or_elem" % char_or_elem + return idx + + +class SARLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(SARLabelDecode, self).__init__(character_dict_path, use_space_char) + + self.rm_symbol = kwargs.get("rm_symbol", False) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """convert text-index into text-label.""" + result_list = [] + ignored_tokens = self.get_ignored_tokens() + + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(self.end_idx): + if text_prob is None and idx == 0: + continue + else: + break + if is_remove_duplicate: + # only for predict + if ( + idx > 0 + and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] + ): + continue + char_list.append(self.character[int(text_index[batch_idx][idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = "".join(char_list) + if self.rm_symbol: + comp = re.compile("[^A-Z^a-z^0-9^\u4e00-\u9fa5]") + text = text.lower() + text = comp.sub("", text) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def get_ignored_tokens(self): + return [self.padding_idx] + + +class DistillationSARLabelDecode(SARLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__( + self, + character_dict_path=None, + use_space_char=False, + model_name=["student"], + key=None, + multi_head=False, + **kwargs + ): + super(DistillationSARLabelDecode, self).__init__( + character_dict_path, use_space_char + ) + if not isinstance(model_name, list): + model_name = [model_name] + self.model_name = model_name + + self.key = key + self.multi_head = multi_head + + def __call__(self, preds, label=None, *args, **kwargs): + output = dict() + for name in self.model_name: + pred = preds[name] + if self.key is not None: + pred = pred[self.key] + if self.multi_head and isinstance(pred, dict): + pred = pred["sar"] + output[name] = super().__call__(pred, label=label, *args, **kwargs) + return output + + +class PRENLabelDecode(BaseRecLabelDecode): + """Convert between text-label and text-index""" + + def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): + super(PRENLabelDecode, self).__init__(character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + padding_str = "" # 0 + end_str = "" # 1 + unknown_str = "" # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def decode(self, text_index, text_prob=None): + """convert text-index into text-label.""" + result_list = [] + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] == self.end_idx: + break + if text_index[batch_idx][idx] in [self.padding_idx, self.unknown_idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = "".join(char_list) + if len(text) > 0: + result_list.append((text, np.mean(conf_list).tolist())) + else: + # here confidence of empty recog result is 1 + result_list.append(("", 1)) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label) + return text, label diff --git a/ocr/postprocess/sast_postprocess.py b/ocr/postprocess/sast_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..52304bdcfa16b88c68cde18a239bcb15e63dd146 --- /dev/null +++ b/ocr/postprocess/sast_postprocess.py @@ -0,0 +1,355 @@ +from __future__ import absolute_import, division, print_function + +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, "..")) + +import time + +import cv2 +import numpy as np +import paddle + +from .locality_aware_nms import nms_locality + + +class SASTPostProcess(object): + """ + The post process for SAST. + """ + + def __init__( + self, + score_thresh=0.5, + nms_thresh=0.2, + sample_pts_num=2, + shrink_ratio_of_width=0.3, + expand_scale=1.0, + tcl_map_thresh=0.5, + **kwargs + ): + + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.sample_pts_num = sample_pts_num + self.shrink_ratio_of_width = shrink_ratio_of_width + self.expand_scale = expand_scale + self.tcl_map_thresh = tcl_map_thresh + + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + + def point_pair2poly(self, point_pair_list): + """ + Transfer vertical point_pairs into poly point in clockwise. + """ + # constract poly + point_num = len(point_pair_list) * 2 + point_list = [0] * point_num + for idx, point_pair in enumerate(point_pair_list): + point_list[idx] = point_pair[0] + point_list[point_num - 1 - idx] = point_pair[1] + return np.array(point_list).reshape(-1, 2) + + def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32 + ) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): + """ + expand poly along width. + """ + point_num = poly.shape[0] + left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) + left_ratio = ( + -shrink_ratio_of_width + * np.linalg.norm(left_quad[0] - left_quad[3]) + / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) + ) + left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) + right_quad = np.array( + [ + poly[point_num // 2 - 2], + poly[point_num // 2 - 1], + poly[point_num // 2], + poly[point_num // 2 + 1], + ], + dtype=np.float32, + ) + right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm( + right_quad[0] - right_quad[3] + ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) + right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) + poly[0] = left_quad_expand[0] + poly[-1] = left_quad_expand[-1] + poly[point_num // 2 - 1] = right_quad_expand[1] + poly[point_num // 2] = right_quad_expand[2] + return poly + + def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): + """Restore quad.""" + xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) + xy_text = xy_text[:, ::-1] # (n, 2) + + # Sort the text boxes via the y axis + xy_text = xy_text[np.argsort(xy_text[:, 1])] + + scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + scores = scores[:, np.newaxis] + + # Restore + point_num = int(tvo_map.shape[-1] / 2) + assert point_num == 4 + tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] + xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) + quads = xy_text_tile - tvo_map + + return scores, quads, xy_text + + def quad_area(self, quad): + """ + compute area of a quad. + """ + edge = [ + (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), + (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), + (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), + (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]), + ] + return np.sum(edge) / 2.0 + + def nms(self, dets): + if self.is_python35: + import lanms + + dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh) + else: + dets = nms_locality(dets, self.nms_thresh) + return dets + + def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): + """ + Cluster pixels in tcl_map based on quads. + """ + instance_count = quads.shape[0] + 1 # contain background + instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) + if instance_count == 1: + return instance_count, instance_label_map + + # predict text center + xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) + n = xy_text.shape[0] + xy_text = xy_text[:, ::-1] # (n, 2) + tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) + pred_tc = xy_text - tco + + # get gt text center + m = quads.shape[0] + gt_tc = np.mean(quads, axis=1) # (m, 2) + + pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) + gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) + dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) + xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) + + instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign + return instance_count, instance_label_map + + def estimate_sample_pts_num(self, quad, xy_text): + """ + Estimate sample points number. + """ + eh = ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) + ) / 2.0 + ew = ( + np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]) + ) / 2.0 + + dense_sample_pts_num = max(2, int(ew)) + dense_xy_center_line = xy_text[ + np.linspace( + 0, + xy_text.shape[0] - 1, + dense_sample_pts_num, + endpoint=True, + dtype=np.float32, + ).astype(np.int32) + ] + + dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] + estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) + + sample_pts_num = max(2, int(estimate_arc_len / eh)) + return sample_pts_num + + def detect_sast( + self, + tcl_map, + tvo_map, + tbo_map, + tco_map, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=0.3, + tcl_map_thresh=0.5, + offset_expand=1.0, + out_strid=4.0, + ): + """ + first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys + """ + # restore quad + scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) + dets = np.hstack((quads, scores)).astype(np.float32, copy=False) + dets = self.nms(dets) + if dets.shape[0] == 0: + return [] + quads = dets[:, :-1].reshape(-1, 4, 2) + + # Compute quad area + quad_areas = [] + for quad in quads: + quad_areas.append(-self.quad_area(quad)) + + # instance segmentation + # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) + instance_count, instance_label_map = self.cluster_by_quads_tco( + tcl_map, tcl_map_thresh, quads, tco_map + ) + + # restore single poly with tcl instance. + poly_list = [] + for instance_idx in range(1, instance_count): + xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1] + quad = quads[instance_idx - 1] + q_area = quad_areas[instance_idx - 1] + if q_area < 5: + continue + + # + len1 = float(np.linalg.norm(quad[0] - quad[1])) + len2 = float(np.linalg.norm(quad[1] - quad[2])) + min_len = min(len1, len2) + if min_len < 3: + continue + + # filter small CC + if xy_text.shape[0] <= 0: + continue + + # filter low confidence instance + xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] + if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: + # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: + continue + + # sort xy_text + left_center_pt = np.array( + [[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]] + ) # (1, 2) + right_center_pt = np.array( + [[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]] + ) # (1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / ( + np.linalg.norm(right_center_pt - left_center_pt) + 1e-6 + ) + proj_value = np.sum(xy_text * proj_unit_vec, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # Sample pts in tcl map + if self.sample_pts_num == 0: + sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) + else: + sample_pts_num = self.sample_pts_num + xy_center_line = xy_text[ + np.linspace( + 0, + xy_text.shape[0] - 1, + sample_pts_num, + endpoint=True, + dtype=np.float32, + ).astype(np.int32) + ] + + point_pair_list = [] + for x, y in xy_center_line: + # get corresponding offset + offset = tbo_map[y, x, :].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm(offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0 + ) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + # original point + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = ( + (ori_yx + offset)[:, ::-1] + * out_strid + / np.array([ratio_w, ratio_h]).reshape(-1, 2) + ) + point_pair_list.append(point_pair) + + # ndarry: (x, 2), expand poly along width + detected_poly = self.point_pair2poly(point_pair_list) + detected_poly = self.expand_poly_along_width( + detected_poly, shrink_ratio_of_width + ) + detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) + poly_list.append(detected_poly) + + return poly_list + + def __call__(self, outs_dict, shape_list): + score_list = outs_dict["f_score"] + border_list = outs_dict["f_border"] + tvo_list = outs_dict["f_tvo"] + tco_list = outs_dict["f_tco"] + if isinstance(score_list, paddle.Tensor): + score_list = score_list.numpy() + border_list = border_list.numpy() + tvo_list = tvo_list.numpy() + tco_list = tco_list.numpy() + + img_num = len(shape_list) + poly_lists = [] + for ino in range(img_num): + p_score = score_list[ino].transpose((1, 2, 0)) + p_border = border_list[ino].transpose((1, 2, 0)) + p_tvo = tvo_list[ino].transpose((1, 2, 0)) + p_tco = tco_list[ino].transpose((1, 2, 0)) + src_h, src_w, ratio_h, ratio_w = shape_list[ino] + + poly_list = self.detect_sast( + p_score, + p_tvo, + p_border, + p_tco, + ratio_w, + ratio_h, + src_w, + src_h, + shrink_ratio_of_width=self.shrink_ratio_of_width, + tcl_map_thresh=self.tcl_map_thresh, + offset_expand=self.expand_scale, + ) + poly_lists.append({"points": np.array(poly_list)}) + + return poly_lists diff --git a/ocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ocr/postprocess/vqa_token_re_layoutlm_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3fc5f4841983f7390d0d375e1ed042e7fda9cf --- /dev/null +++ b/ocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -0,0 +1,36 @@ +class VQAReTokenLayoutLMPostProcess(object): + """Convert between text-label and text-index""" + + def __init__(self, **kwargs): + super(VQAReTokenLayoutLMPostProcess, self).__init__() + + def __call__(self, preds, label=None, *args, **kwargs): + if label is not None: + return self._metric(preds, label) + else: + return self._infer(preds, *args, **kwargs) + + def _metric(self, preds, label): + return preds["pred_relations"], label[6], label[5] + + def _infer(self, preds, *args, **kwargs): + ser_results = kwargs["ser_results"] + entity_idx_dict_batch = kwargs["entity_idx_dict_batch"] + pred_relations = preds["pred_relations"] + + # merge relations and ocr info + results = [] + for pred_relation, ser_result, entity_idx_dict in zip( + pred_relations, ser_results, entity_idx_dict_batch + ): + result = [] + used_tail_id = [] + for relation in pred_relation: + if relation["tail_id"] in used_tail_id: + continue + used_tail_id.append(relation["tail_id"]) + ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]] + ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]] + result.append((ocr_info_head, ocr_info_tail)) + results.append(result) + return results diff --git a/ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b888b0ccabcf2830aa2d4dde16d3199a7b643f25 --- /dev/null +++ b/ocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -0,0 +1,96 @@ +import numpy as np +import paddle + + +def load_vqa_bio_label_maps(label_map_path): + with open(label_map_path, "r", encoding="utf-8") as fin: + lines = fin.readlines() + lines = [line.strip() for line in lines] + if "O" not in lines: + lines.insert(0, "O") + labels = [] + for line in lines: + if line == "O": + labels.append("O") + else: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label: idx for idx, label in enumerate(labels)} + id2label_map = {idx: label for idx, label in enumerate(labels)} + return label2id_map, id2label_map + + +class VQASerTokenLayoutLMPostProcess(object): + """Convert between text-label and text-index""" + + def __init__(self, class_path, **kwargs): + super(VQASerTokenLayoutLMPostProcess, self).__init__() + label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) + + self.label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + self.label2id_map_for_draw[key] = label2id_map[key] + + self.id2label_map_for_show = dict() + for key in self.label2id_map_for_draw: + val = self.label2id_map_for_draw[key] + if key == "O": + self.id2label_map_for_show[val] = key + if key.startswith("B-") or key.startswith("I-"): + self.id2label_map_for_show[val] = key[2:] + else: + self.id2label_map_for_show[val] = key + + def __call__(self, preds, batch=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + if batch is not None: + return self._metric(preds, batch[1]) + else: + return self._infer(preds, **kwargs) + + def _metric(self, preds, label): + pred_idxs = preds.argmax(axis=2) + decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + + for i in range(pred_idxs.shape[0]): + for j in range(pred_idxs.shape[1]): + if label[i, j] != -100: + label_decode_out_list[i].append(self.id2label_map[label[i, j]]) + decode_out_list[i].append(self.id2label_map[pred_idxs[i, j]]) + return decode_out_list, label_decode_out_list + + def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): + results = [] + + for pred, attention_mask, segment_offset_id, ocr_info in zip( + preds, attention_masks, segment_offset_ids, ocr_infos + ): + pred = np.argmax(pred, axis=1) + pred = [self.id2label_map[idx] for idx in pred] + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = pred[start_id:end_id] + curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] + results.append(ocr_info) + return results diff --git a/ocr/ppocr/__init__.py b/ocr/ppocr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ocr/ppocr/data/__init__.py b/ocr/ppocr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d263a79b571b91dfb562d28799c5ff86db6c52 --- /dev/null +++ b/ocr/ppocr/data/__init__.py @@ -0,0 +1,79 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import signal +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, "../.."))) + +import copy + +from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler + +from .imaug import create_operators, transform + +__all__ = ["build_dataloader", "transform", "create_operators"] + + +def term_mp(sig_num, frame): + """kill all child processes""" + pid = os.getpid() + pgid = os.getpgid(os.getpid()) + print("main proc {} exit, kill process group " "{}".format(pid, pgid)) + os.killpg(pgid, signal.SIGKILL) + + +def build_dataloader(config, mode, device, logger, seed=None): + config = copy.deepcopy(config) + + support_dict = ["SimpleDataSet", "LMDBDataSet", "PGDataSet", "PubTabDataSet"] + module_name = config[mode]["dataset"]["name"] + assert module_name in support_dict, Exception( + "DataSet only support {}".format(support_dict) + ) + assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test." + + dataset = eval(module_name)(config, mode, logger, seed) + loader_config = config[mode]["loader"] + batch_size = loader_config["batch_size_per_card"] + drop_last = loader_config["drop_last"] + shuffle = loader_config["shuffle"] + num_workers = loader_config["num_workers"] + if "use_shared_memory" in loader_config.keys(): + use_shared_memory = loader_config["use_shared_memory"] + else: + use_shared_memory = True + + if mode == "Train": + # Distribute data to multiple cards + batch_sampler = DistributedBatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + else: + # Distribute data to single card + batch_sampler = BatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) + + if "collate_fn" in loader_config: + from . import collate_fn + + collate_fn = getattr(collate_fn, loader_config["collate_fn"])() + else: + collate_fn = None + data_loader = DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + places=device, + num_workers=num_workers, + return_list=True, + use_shared_memory=use_shared_memory, + collate_fn=collate_fn, + ) + + # support exit using ctrl+c + signal.signal(signal.SIGINT, term_mp) + signal.signal(signal.SIGTERM, term_mp) + + return data_loader diff --git a/ocr/ppocr/data/collate_fn.py b/ocr/ppocr/data/collate_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..d3793e6065de27b7509b1d5d0f4941d5c895f005 --- /dev/null +++ b/ocr/ppocr/data/collate_fn.py @@ -0,0 +1,59 @@ +import numbers +from collections import defaultdict + +import numpy as np +import paddle + + +class DictCollator(object): + """ + data batch + """ + + def __call__(self, batch): + # todo:support batch operators + data_dict = defaultdict(list) + to_tensor_keys = [] + for sample in batch: + for k, v in sample.items(): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if k not in to_tensor_keys: + to_tensor_keys.append(k) + data_dict[k].append(v) + for k in to_tensor_keys: + data_dict[k] = paddle.to_tensor(data_dict[k]) + return data_dict + + +class ListCollator(object): + """ + data batch + """ + + def __call__(self, batch): + # todo:support batch operators + data_dict = defaultdict(list) + to_tensor_idxs = [] + for sample in batch: + for idx, v in enumerate(sample): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if idx not in to_tensor_idxs: + to_tensor_idxs.append(idx) + data_dict[idx].append(v) + for idx in to_tensor_idxs: + data_dict[idx] = paddle.to_tensor(data_dict[idx]) + return list(data_dict.values()) + + +class SSLRotateCollate(object): + """ + bach: [ + [(4*3xH*W), (4,)] + [(4*3xH*W), (4,)] + ... + ] + """ + + def __call__(self, batch): + output = [np.concatenate(d, axis=0) for d in zip(*batch)] + return output diff --git a/ocr/ppocr/data/imaug/ColorJitter.py b/ocr/ppocr/data/imaug/ColorJitter.py new file mode 100644 index 0000000000000000000000000000000000000000..192ac74d29416546c45463db1f9b89ab1c3aa36e --- /dev/null +++ b/ocr/ppocr/data/imaug/ColorJitter.py @@ -0,0 +1,14 @@ +from paddle.vision.transforms import ColorJitter as pp_ColorJitter + +__all__ = ["ColorJitter"] + + +class ColorJitter(object): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs): + self.aug = pp_ColorJitter(brightness, contrast, saturation, hue) + + def __call__(self, data): + image = data["image"] + image = self.aug(image) + data["image"] = image + return data diff --git a/ocr/ppocr/data/imaug/__init__.py b/ocr/ppocr/data/imaug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7adefa4c54109a45dd956a2ceb1592d723bae58d --- /dev/null +++ b/ocr/ppocr/data/imaug/__init__.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +from .ColorJitter import ColorJitter +from .copy_paste import CopyPaste +from .east_process import * +from .fce_aug import * +from .fce_targets import FCENetTargets +from .gen_table_mask import * +from .iaa_augment import IaaAugment +from .label_ops import * +from .make_border_map import MakeBorderMap +from .make_pse_gt import MakePseGt +from .make_shrink_map import MakeShrinkMap +from .operators import * +from .pg_process import * +from .randaugment import RandAugment +from .random_crop_data import EastRandomCropData, RandomCropImgMask +from .rec_img_aug import ( + ClsResizeImg, + NRTRRecResizeImg, + PRENResizeImg, + RecAug, + RecConAug, + RecResizeImg, + SARRecResizeImg, + SRNRecResizeImg, +) +from .sast_process import * +from .ssl_img_aug import SSLRotateResize +from .vqa import * + + +def transform(data, ops=None): + """transform""" + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +def create_operators(op_param_list, global_config=None): + """ + create operators based on the config + + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance(op_param_list, list), "operator config should be a list" + ops = [] + for operator in op_param_list: + assert isinstance(operator, dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + if global_config is not None: + param.update(global_config) + op = eval(op_name)(**param) + ops.append(op) + return ops diff --git a/ocr/ppocr/data/imaug/copy_paste.py b/ocr/ppocr/data/imaug/copy_paste.py new file mode 100644 index 0000000000000000000000000000000000000000..9607fd7d2d4e0a491246d6df011fb91837fa7b8d --- /dev/null +++ b/ocr/ppocr/data/imaug/copy_paste.py @@ -0,0 +1,167 @@ +import os +import random + +import cv2 +import numpy as np +from PIL import Image +from shapely.geometry import Polygon + +from ppocr.data.imaug.iaa_augment import IaaAugment +from ppocr.data.imaug.random_crop_data import is_poly_outside_rect +from utility import get_rotate_crop_image + + +class CopyPaste(object): + def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): + self.ext_data_num = 1 + self.objects_paste_ratio = objects_paste_ratio + self.limit_paste = limit_paste + augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}] + self.aug = IaaAugment(augmenter_args) + + def __call__(self, data): + point_num = data["polys"].shape[1] + src_img = data["image"] + src_polys = data["polys"].tolist() + src_texts = data["texts"] + src_ignores = data["ignore_tags"].tolist() + ext_data = data["ext_data"][0] + ext_image = ext_data["image"] + ext_polys = ext_data["polys"] + ext_texts = ext_data["texts"] + ext_ignores = ext_data["ignore_tags"] + + indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] + select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) + + random.shuffle(indexs) + select_idxs = indexs[:select_num] + select_polys = ext_polys[select_idxs] + select_ignores = ext_ignores[select_idxs] + + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) + src_img = Image.fromarray(src_img).convert("RGBA") + for idx, poly, tag in zip(select_idxs, select_polys, select_ignores): + box_img = get_rotate_crop_image(ext_image, poly) + + src_img, box = self.paste_img(src_img, box_img, src_polys) + if box is not None: + box = box.tolist() + for _ in range(len(box), point_num): + box.append(box[-1]) + src_polys.append(box) + src_texts.append(ext_texts[idx]) + src_ignores.append(tag) + src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) + h, w = src_img.shape[:2] + src_polys = np.array(src_polys) + src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) + src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) + data["image"] = src_img + data["polys"] = src_polys + data["texts"] = src_texts + data["ignore_tags"] = np.array(src_ignores) + return data + + def paste_img(self, src_img, box_img, src_polys): + box_img_pil = Image.fromarray(box_img).convert("RGBA") + src_w, src_h = src_img.size + box_w, box_h = box_img_pil.size + + angle = np.random.randint(0, 360) + box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) + box = rotate_bbox(box_img, box, angle)[0] + box_img_pil = box_img_pil.rotate(angle, expand=1) + box_w, box_h = box_img_pil.width, box_img_pil.height + if src_w - box_w < 0 or src_h - box_h < 0: + return src_img, None + + paste_x, paste_y = self.select_coord( + src_polys, box, src_w - box_w, src_h - box_h + ) + if paste_x is None: + return src_img, None + box[:, 0] += paste_x + box[:, 1] += paste_y + r, g, b, A = box_img_pil.split() + src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) + + return src_img, box + + def select_coord(self, src_polys, box, endx, endy): + if self.limit_paste: + xmin, ymin, xmax, ymax = ( + box[:, 0].min(), + box[:, 1].min(), + box[:, 0].max(), + box[:, 1].max(), + ) + for _ in range(50): + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + xmin1 = xmin + paste_x + xmax1 = xmax + paste_x + ymin1 = ymin + paste_y + ymax1 = ymax + paste_y + + num_poly_in_rect = 0 + for poly in src_polys: + if not is_poly_outside_rect( + poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1 + ): + num_poly_in_rect += 1 + break + if num_poly_in_rect == 0: + return paste_x, paste_y + return None, None + else: + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + return paste_x, paste_y + + +def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + +def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + +def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + +def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ + w = img.shape[1] + h = img.shape[0] + + rangle = np.deg2rad(angle) + nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w) + nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w) + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) + rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + + # ---------------------- rotate box ---------------------- + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + return np.array(rot_text_polys, dtype=np.float32) diff --git a/ocr/ppocr/data/imaug/east_process.py b/ocr/ppocr/data/imaug/east_process.py new file mode 100644 index 0000000000000000000000000000000000000000..11eb5584abc31fda3521ee4052307921d072e534 --- /dev/null +++ b/ocr/ppocr/data/imaug/east_process.py @@ -0,0 +1,427 @@ +import math + +import cv2 +import numpy as np + +__all__ = ["EASTProcessTrain"] + + +class EASTProcessTrain(object): + def __init__( + self, + image_shape=[512, 512], + background_ratio=0.125, + min_crop_side_ratio=0.1, + min_text_size=10, + **kwargs + ): + self.input_size = image_shape[1] + self.random_scale = np.array([0.5, 1, 2.0, 3.0]) + self.background_ratio = background_ratio + self.min_crop_side_ratio = min_crop_side_ratio + self.min_text_size = min_text_size + + def preprocess(self, im): + input_size = self.input_size + im_shape = im.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + im_scale = float(input_size) / float(im_size_max) + im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale) + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + # im = im[:, :, ::-1].astype(np.float32) + im = im / 255 + im -= img_mean + im /= img_std + new_h, new_w, _ = im.shape + im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32) + im_padded[:new_h, :new_w, :] = im + im_padded = im_padded.transpose((2, 0, 1)) + im_padded = im_padded[np.newaxis, :] + return im_padded, im_scale + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if 0.333 < rand_degree_ratio < 0.666: + rand_degree_cnt = 2 + elif rand_degree_ratio > 0.666: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): + sx, sy = wordBB[j][0], wordBB[j][1] + dx = ( + math.cos(rot_angle) * (sx - cx) + - math.sin(rot_angle) * (sy - cy) + + ncx + ) + dy = ( + math.sin(rot_angle) * (sx - cx) + + math.cos(rot_angle) * (sy - cy) + + ncy + ) + poly.append([dx, dy]) + dst_polys.append(poly) + dst_polys = np.array(dst_polys, dtype=np.float32) + return dst_im, dst_polys + + def polygon_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [ + (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]), + ] + return np.sum(edge) / 2.0 + + def check_and_validate_polys(self, polys, tags, img_height, img_width): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + h, w = img_height, img_width + if polys.shape[0] == 0: + return polys + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + for poly, tag in zip(polys, tags): + p_area = self.polygon_area(poly) + # invalid poly + if abs(p_area) < 1: + continue + if p_area > 0: + #'poly in wrong direction' + if not tag: + tag = True # reversed cases should be ignore + poly = poly[(0, 3, 2, 1), :] + validated_polys.append(poly) + validated_tags.append(tag) + return np.array(validated_polys), np.array(validated_tags) + + def draw_img_polys(self, img, polys): + if len(img.shape) == 4: + img = np.squeeze(img, axis=0) + if img.shape[0] == 3: + img = img.transpose((1, 2, 0)) + img[:, :, 2] += 123.68 + img[:, :, 1] += 116.78 + img[:, :, 0] += 103.94 + cv2.imwrite("tmp.jpg", img) + img = cv2.imread("tmp.jpg") + for box in polys: + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2) + import random + + ino = random.randint(0, 100) + cv2.imwrite("tmp_%d.jpg" % ino, img) + return + + def shrink_poly(self, poly, r): + """ + fit a poly inside the origin poly, maybe bugs here... + used for generate the score map + :param poly: the text poly + :param r: r in the paper + :return: the shrinked poly + """ + # shrink ratio + R = 0.3 + # find the longer pair + dist0 = np.linalg.norm(poly[0] - poly[1]) + dist1 = np.linalg.norm(poly[2] - poly[3]) + dist2 = np.linalg.norm(poly[0] - poly[3]) + dist3 = np.linalg.norm(poly[1] - poly[2]) + if dist0 + dist1 > dist2 + dist3: + # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + ## p0, p3 + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + else: + ## p0, p3 + # print poly + theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) + poly[0][0] += R * r[0] * np.sin(theta) + poly[0][1] += R * r[0] * np.cos(theta) + poly[3][0] -= R * r[3] * np.sin(theta) + poly[3][1] -= R * r[3] * np.cos(theta) + ## p1, p2 + theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) + poly[1][0] += R * r[1] * np.sin(theta) + poly[1][1] += R * r[1] * np.cos(theta) + poly[2][0] -= R * r[2] * np.sin(theta) + poly[2][1] -= R * r[2] * np.cos(theta) + ## p0, p1 + theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) + poly[0][0] += R * r[0] * np.cos(theta) + poly[0][1] += R * r[0] * np.sin(theta) + poly[1][0] -= R * r[1] * np.cos(theta) + poly[1][1] -= R * r[1] * np.sin(theta) + ## p2, p3 + theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) + poly[3][0] += R * r[3] * np.cos(theta) + poly[3][1] += R * r[3] * np.sin(theta) + poly[2][0] -= R * r[2] * np.cos(theta) + poly[2][1] -= R * r[2] * np.sin(theta) + return poly + + def generate_quad(self, im_size, polys, tags): + """ + Generate quadrangle. + """ + h, w = im_size + poly_mask = np.zeros((h, w), dtype=np.uint8) + score_map = np.zeros((h, w), dtype=np.uint8) + # (x1, y1, ..., x4, y4, short_edge_norm) + geo_map = np.zeros((h, w, 9), dtype=np.float32) + # mask used during traning, to ignore some hard areas + training_mask = np.ones((h, w), dtype=np.uint8) + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + r = [None, None, None, None] + for i in range(4): + dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4]) + dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4]) + r[i] = min(dist1, dist2) + # score map + shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[ + np.newaxis, :, : + ] + cv2.fillPoly(score_map, shrinked_poly, 1) + cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) + # if the poly is too small, then ignore it during training + poly_h = min( + np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]) + ) + poly_w = min( + np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]) + ) + if min(poly_h, poly_w) < self.min_text_size: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + + if tag: + cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) + + xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) + # geo map. + y_in_poly = xy_in_poly[:, 0] + x_in_poly = xy_in_poly[:, 1] + poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w) + poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h) + for pno in range(4): + geo_channel_beg = pno * 2 + geo_map[y_in_poly, x_in_poly, geo_channel_beg] = ( + x_in_poly - poly[pno, 0] + ) + geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = ( + y_in_poly - poly[pno, 1] + ) + geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0) + return score_map, geo_map, training_mask + + def crop_area(self, im, polys, tags, crop_background=False, max_tries=50): + """ + make random crop from the input image + :param im: + :param polys: + :param tags: + :param crop_background: + :param max_tries: + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w : maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h : maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags + + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if ( + xmax - xmin < self.min_crop_side_ratio * w + or ymax - ymin < self.min_crop_side_ratio * h + ): + # area too small + continue + if polys.shape[0] != 0: + poly_axis_in_area = ( + (polys[:, :, 0] >= xmin) + & (polys[:, :, 0] <= xmax) + & (polys[:, :, 1] >= ymin) + & (polys[:, :, 1] <= ymax) + ) + selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + + if len(selected_polys) == 0: + # no text in this area + if crop_background: + im = im[ymin : ymax + 1, xmin : xmax + 1, :] + polys = [] + tags = [] + return im, polys, tags + else: + continue + + im = im[ymin : ymax + 1, xmin : xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags + return im, polys, tags + + def crop_background_infor(self, im, text_polys, text_tags): + im, text_polys, text_tags = self.crop_area( + im, text_polys, text_tags, crop_background=True + ) + + if len(text_polys) > 0: + return None + # pad and resize image + input_size = self.input_size + im, ratio = self.preprocess(im) + score_map = np.zeros((input_size, input_size), dtype=np.float32) + geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32) + training_mask = np.ones((input_size, input_size), dtype=np.float32) + return im, score_map, geo_map, training_mask + + def crop_foreground_infor(self, im, text_polys, text_tags): + im, text_polys, text_tags = self.crop_area( + im, text_polys, text_tags, crop_background=False + ) + + if text_polys.shape[0] == 0: + return None + # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + # pad and resize image + input_size = self.input_size + im, ratio = self.preprocess(im) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + _, _, new_h, new_w = im.shape + # print(im.shape) + # self.draw_img_polys(im, text_polys) + score_map, geo_map, training_mask = self.generate_quad( + (new_h, new_w), text_polys, text_tags + ) + return im, score_map, geo_map, training_mask + + def __call__(self, data): + im = data["image"] + text_polys = data["polys"] + text_tags = data["ignore_tags"] + if im is None: + return None + if text_polys.shape[0] == 0: + return None + + # add rotate cases + if np.random.rand() < 0.5: + im, text_polys = self.rotate_im_poly(im, text_polys) + h, w, _ = im.shape + text_polys, text_tags = self.check_and_validate_polys( + text_polys, text_tags, h, w + ) + if text_polys.shape[0] == 0: + return None + + # random scale this image + rd_scale = np.random.choice(self.random_scale) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + if np.random.rand() < self.background_ratio: + outs = self.crop_background_infor(im, text_polys, text_tags) + else: + outs = self.crop_foreground_infor(im, text_polys, text_tags) + + if outs is None: + return None + im, score_map, geo_map, training_mask = outs + score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32) + geo_map = np.swapaxes(geo_map, 1, 2) + geo_map = np.swapaxes(geo_map, 1, 0) + geo_map = geo_map[:, ::4, ::4].astype(np.float32) + training_mask = training_mask[np.newaxis, ::4, ::4] + training_mask = training_mask.astype(np.float32) + + data["image"] = im[0] + data["score_map"] = score_map + data["geo_map"] = geo_map + data["training_mask"] = training_mask + return data diff --git a/ocr/ppocr/data/imaug/fce_aug.py b/ocr/ppocr/data/imaug/fce_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..e44e095321cfff0ddfafc2661eace3d0b5457ec7 --- /dev/null +++ b/ocr/ppocr/data/imaug/fce_aug.py @@ -0,0 +1,563 @@ +import math +import os + +import cv2 +import numpy as np +from PIL import Image, ImageDraw +from shapely.geometry import Polygon + +from postprocess.poly_nms import poly_intersection + + +class RandomScaling: + def __init__(self, size=800, scale=(3.0 / 4, 5.0 / 2), **kwargs): + """Random scale the image while keeping aspect. + + Args: + size (int) : Base size before scaling. + scale (tuple(float)) : The range of scaling. + """ + assert isinstance(size, int) + assert isinstance(scale, float) or isinstance(scale, tuple) + self.size = size + self.scale = scale if isinstance(scale, tuple) else (1 - scale, 1 + scale) + + def __call__(self, data): + image = data["image"] + text_polys = data["polys"] + h, w, _ = image.shape + + aspect_ratio = np.random.uniform(min(self.scale), max(self.scale)) + scales = self.size * 1.0 / max(h, w) * aspect_ratio + scales = np.array([scales, scales]) + out_size = (int(h * scales[1]), int(w * scales[0])) + image = cv2.resize(image, out_size[::-1]) + + data["image"] = image + text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1] + text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0] + data["polys"] = text_polys + + return data + + +class RandomCropFlip: + def __init__( + self, pad_ratio=0.1, crop_ratio=0.5, iter_num=1, min_area_ratio=0.2, **kwargs + ): + """Random crop and flip a patch of the image. + + Args: + crop_ratio (float): The ratio of cropping. + iter_num (int): Number of operations. + min_area_ratio (float): Minimal area ratio between cropped patch + and original image. + """ + assert isinstance(crop_ratio, float) + assert isinstance(iter_num, int) + assert isinstance(min_area_ratio, float) + + self.pad_ratio = pad_ratio + self.epsilon = 1e-2 + self.crop_ratio = crop_ratio + self.iter_num = iter_num + self.min_area_ratio = min_area_ratio + + def __call__(self, results): + for i in range(self.iter_num): + results = self.random_crop_flip(results) + + return results + + def random_crop_flip(self, results): + image = results["image"] + polygons = results["polys"] + ignore_tags = results["ignore_tags"] + if len(polygons) == 0: + return results + + if np.random.random() >= self.crop_ratio: + return results + + h, w, _ = image.shape + area = h * w + pad_h = int(h * self.pad_ratio) + pad_w = int(w * self.pad_ratio) + h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h, pad_w) + if len(h_axis) == 0 or len(w_axis) == 0: + return results + + attempt = 0 + while attempt < 50: + attempt += 1 + polys_keep = [] + polys_new = [] + ignore_tags_keep = [] + ignore_tags_new = [] + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio: + # area too small + continue + + pts = np.stack( + [[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]] + ).T.astype(np.int32) + pp = Polygon(pts) + fail_flag = False + for polygon, ignore_tag in zip(polygons, ignore_tags): + ppi = Polygon(polygon.reshape(-1, 2)) + ppiou, _ = poly_intersection(ppi, pp, buffer=0) + if ( + np.abs(ppiou - float(ppi.area)) > self.epsilon + and np.abs(ppiou) > self.epsilon + ): + fail_flag = True + break + elif np.abs(ppiou - float(ppi.area)) < self.epsilon: + polys_new.append(polygon) + ignore_tags_new.append(ignore_tag) + else: + polys_keep.append(polygon) + ignore_tags_keep.append(ignore_tag) + + if fail_flag: + continue + else: + break + + cropped = image[ymin:ymax, xmin:xmax, :] + select_type = np.random.randint(3) + if select_type == 0: + img = np.ascontiguousarray(cropped[:, ::-1]) + elif select_type == 1: + img = np.ascontiguousarray(cropped[::-1, :]) + else: + img = np.ascontiguousarray(cropped[::-1, ::-1]) + image[ymin:ymax, xmin:xmax, :] = img + results["img"] = image + + if len(polys_new) != 0: + height, width, _ = cropped.shape + if select_type == 0: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + polys_new[idx] = poly + elif select_type == 1: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = poly + else: + for idx, polygon in enumerate(polys_new): + poly = polygon.reshape(-1, 2) + poly[:, 0] = width - poly[:, 0] + 2 * xmin + poly[:, 1] = height - poly[:, 1] + 2 * ymin + polys_new[idx] = poly + polygons = polys_keep + polys_new + ignore_tags = ignore_tags_keep + ignore_tags_new + results["polys"] = np.array(polygons) + results["ignore_tags"] = ignore_tags + + return results + + def generate_crop_target(self, image, all_polys, pad_h, pad_w): + """Generate crop target and make sure not to crop the polygon + instances. + + Args: + image (ndarray): The image waited to be crop. + all_polys (list[list[ndarray]]): All polygons including ground + truth polygons and ground truth ignored polygons. + pad_h (int): Padding length of height. + pad_w (int): Padding length of width. + Returns: + h_axis (ndarray): Vertical cropping range. + w_axis (ndarray): Horizontal cropping range. + """ + h, w, _ = image.shape + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + + text_polys = [] + for polygon in all_polys: + rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2)) + box = cv2.boxPoints(rect) + box = np.int0(box) + text_polys.append([box[0], box[1], box[2], box[3]]) + + polys = np.array(text_polys, dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w : maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h : maxy + pad_h] = 1 + + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + return h_axis, w_axis + + +class RandomCropPolyInstances: + """Randomly crop images and make sure to contain at least one intact + instance.""" + + def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs): + super().__init__() + self.crop_ratio = crop_ratio + self.min_side_ratio = min_side_ratio + + def sample_valid_start_end(self, valid_array, min_len, max_start, min_end): + + assert isinstance(min_len, int) + assert len(valid_array) > min_len + + start_array = valid_array.copy() + max_start = min(len(start_array) - min_len, max_start) + start_array[max_start:] = 0 + start_array[0] = 1 + diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + start = np.random.randint(region_starts[region_ind], region_ends[region_ind]) + + end_array = valid_array.copy() + min_end = max(start + min_len, min_end) + end_array[:min_end] = 0 + end_array[-1] = 1 + diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + end = np.random.randint(region_starts[region_ind], region_ends[region_ind]) + return start, end + + def sample_crop_box(self, img_size, results): + """Generate crop box and make sure not to crop the polygon instances. + + Args: + img_size (tuple(int)): The image size (h, w). + results (dict): The results dict. + """ + + assert isinstance(img_size, tuple) + h, w = img_size[:2] + + key_masks = results["polys"] + + x_valid_array = np.ones(w, dtype=np.int32) + y_valid_array = np.ones(h, dtype=np.int32) + + selected_mask = key_masks[np.random.randint(0, len(key_masks))] + selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32) + max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0) + min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1) + max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) + min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) + + for mask in key_masks: + mask = mask.reshape((-1, 2)).astype(np.int32) + clip_x = np.clip(mask[:, 0], 0, w - 1) + clip_y = np.clip(mask[:, 1], 0, h - 1) + min_x, max_x = np.min(clip_x), np.max(clip_x) + min_y, max_y = np.min(clip_y), np.max(clip_y) + + x_valid_array[min_x - 2 : max_x + 3] = 0 + y_valid_array[min_y - 2 : max_y + 3] = 0 + + min_w = int(w * self.min_side_ratio) + min_h = int(h * self.min_side_ratio) + + x1, x2 = self.sample_valid_start_end( + x_valid_array, min_w, max_x_start, min_x_end + ) + y1, y2 = self.sample_valid_start_end( + y_valid_array, min_h, max_y_start, min_y_end + ) + + return np.array([x1, y1, x2, y2]) + + def crop_img(self, img, bbox): + assert img.ndim == 3 + h, w, _ = img.shape + assert 0 <= bbox[1] < bbox[3] <= h + assert 0 <= bbox[0] < bbox[2] <= w + return img[bbox[1] : bbox[3], bbox[0] : bbox[2]] + + def __call__(self, results): + image = results["image"] + polygons = results["polys"] + ignore_tags = results["ignore_tags"] + if len(polygons) < 1: + return results + + if np.random.random_sample() < self.crop_ratio: + + crop_box = self.sample_crop_box(image.shape, results) + img = self.crop_img(image, crop_box) + results["image"] = img + # crop and filter masks + x1, y1, x2, y2 = crop_box + w = max(x2 - x1, 1) + h = max(y2 - y1, 1) + polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1 + polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1 + + valid_masks_list = [] + valid_tags_list = [] + for ind, polygon in enumerate(polygons): + if ( + (polygon[:, ::2] > -4).all() + and (polygon[:, ::2] < w + 4).all() + and (polygon[:, 1::2] > -4).all() + and (polygon[:, 1::2] < h + 4).all() + ): + polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w) + polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h) + valid_masks_list.append(polygon) + valid_tags_list.append(ignore_tags[ind]) + + results["polys"] = np.array(valid_masks_list) + results["ignore_tags"] = valid_tags_list + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +class RandomRotatePolyInstances: + def __init__( + self, + rotate_ratio=0.5, + max_angle=10, + pad_with_fixed_color=False, + pad_value=(0, 0, 0), + **kwargs + ): + """Randomly rotate images and polygon masks. + + Args: + rotate_ratio (float): The ratio of samples to operate rotation. + max_angle (int): The maximum rotation angle. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rotated image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def rotate(self, center, points, theta, center_shift=(0, 0)): + # rotate points. + (center_x, center_y) = center + center_y = -center_y + x, y = points[:, ::2], points[:, 1::2] + y = -y + + theta = theta / 180 * math.pi + cos = math.cos(theta) + sin = math.sin(theta) + + x = x - center_x + y = y - center_y + + _x = center_x + x * cos - y * sin + center_shift[0] + _y = -(center_y + x * sin + y * cos) + center_shift[1] + + points[:, ::2], points[:, 1::2] = _x, _y + return points + + def cal_canvas_size(self, ori_size, degree): + assert isinstance(ori_size, tuple) + angle = degree * math.pi / 180.0 + h, w = ori_size[:2] + + cos = math.cos(angle) + sin = math.sin(angle) + canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos)) + canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin)) + + canvas_size = (canvas_h, canvas_w) + return canvas_size + + def sample_angle(self, max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + def rotate_img(self, img, angle, canvas_size): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2) + rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2) + + if self.pad_with_fixed_color: + target_img = cv2.warpAffine( + img, + rotation_matrix, + (canvas_size[1], canvas_size[0]), + flags=cv2.INTER_NEAREST, + borderValue=self.pad_value, + ) + else: + mask = np.zeros_like(img) + (h_ind, w_ind) = ( + np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8), + ) + img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)] + img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0])) + + mask = cv2.warpAffine( + mask, + rotation_matrix, + (canvas_size[1], canvas_size[0]), + borderValue=[1, 1, 1], + ) + target_img = cv2.warpAffine( + img, + rotation_matrix, + (canvas_size[1], canvas_size[0]), + borderValue=[0, 0, 0], + ) + target_img = target_img + img_cut * mask + + return target_img + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + image = results["image"] + polygons = results["polys"] + h, w = image.shape[:2] + + angle = self.sample_angle(self.max_angle) + canvas_size = self.cal_canvas_size((h, w), angle) + center_shift = ( + int((canvas_size[1] - w) / 2), + int((canvas_size[0] - h) / 2), + ) + image = self.rotate_img(image, angle, canvas_size) + results["image"] = image + # rotate polygons + rotated_masks = [] + for mask in polygons: + rotated_mask = self.rotate((w / 2, h / 2), mask, angle, center_shift) + rotated_masks.append(rotated_mask) + results["polys"] = np.array(rotated_masks) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +class SquareResizePad: + def __init__( + self, + target_size, + pad_ratio=0.6, + pad_with_fixed_color=False, + pad_value=(0, 0, 0), + **kwargs + ): + """Resize or pad images to be square shape. + + Args: + target_size (int): The target size of square shaped image. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rescales image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + assert isinstance(target_size, int) + assert isinstance(pad_ratio, float) + assert isinstance(pad_with_fixed_color, bool) + assert isinstance(pad_value, tuple) + + self.target_size = target_size + self.pad_ratio = pad_ratio + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def resize_img(self, img, keep_ratio=True): + h, w, _ = img.shape + if keep_ratio: + t_h = self.target_size if h >= w else int(h * self.target_size / w) + t_w = self.target_size if h <= w else int(w * self.target_size / h) + else: + t_h = t_w = self.target_size + img = cv2.resize(img, (t_w, t_h)) + return img, (t_h, t_w) + + def square_pad(self, img): + h, w = img.shape[:2] + if h == w: + return img, (0, 0) + pad_size = max(h, w) + if self.pad_with_fixed_color: + expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8) + expand_img[:] = self.pad_value + else: + (h_ind, w_ind) = ( + np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8), + ) + img_cut = img[h_ind : (h_ind + h // 9), w_ind : (w_ind + w // 9)] + expand_img = cv2.resize(img_cut, (pad_size, pad_size)) + if h > w: + y0, x0 = 0, (h - w) // 2 + else: + y0, x0 = (w - h) // 2, 0 + expand_img[y0 : y0 + h, x0 : x0 + w] = img + offset = (x0, y0) + + return expand_img, offset + + def square_pad_mask(self, points, offset): + x0, y0 = offset + pad_points = points.copy() + pad_points[::2] = pad_points[::2] + x0 + pad_points[1::2] = pad_points[1::2] + y0 + return pad_points + + def __call__(self, results): + image = results["image"] + polygons = results["polys"] + h, w = image.shape[:2] + + if np.random.random_sample() < self.pad_ratio: + image, out_size = self.resize_img(image, keep_ratio=True) + image, offset = self.square_pad(image) + else: + image, out_size = self.resize_img(image, keep_ratio=False) + offset = (0, 0) + results["image"] = image + try: + polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[1] / w + offset[0] + polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[0] / h + offset[1] + except: + pass + results["polys"] = polygons + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/ocr/ppocr/data/imaug/fce_targets.py b/ocr/ppocr/data/imaug/fce_targets.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0d7eb5a8b9ea8753dee5d26156b5fae93604e6 --- /dev/null +++ b/ocr/ppocr/data/imaug/fce_targets.py @@ -0,0 +1,671 @@ +import cv2 +import numpy as np +from numpy.fft import fft +from numpy.linalg import norm + + +def vector_slope(vec): + assert len(vec) == 2 + return abs(vec[1] / (vec[0] + 1e-8)) + + +class FCENetTargets: + """Generate the ground truth targets of FCENet: Fourier Contour Embedding + for Arbitrary-Shaped Text Detection. + + [https://arxiv.org/abs/2104.10442] + + Args: + fourier_degree (int): The maximum Fourier transform degree k. + resample_step (float): The step size for resampling the text center + line (TCL). It's better not to exceed half of the minimum width. + center_region_shrink_ratio (float): The shrink ratio of text center + region. + level_size_divisors (tuple(int)): The downsample ratio on each level. + level_proportion_range (tuple(tuple(int))): The range of text sizes + assigned to each level. + """ + + def __init__( + self, + fourier_degree=5, + resample_step=4.0, + center_region_shrink_ratio=0.3, + level_size_divisors=(8, 16, 32), + level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)), + orientation_thr=2.0, + **kwargs + ): + + super().__init__() + assert isinstance(level_size_divisors, tuple) + assert isinstance(level_proportion_range, tuple) + assert len(level_size_divisors) == len(level_proportion_range) + self.fourier_degree = fourier_degree + self.resample_step = resample_step + self.center_region_shrink_ratio = center_region_shrink_ratio + self.level_size_divisors = level_size_divisors + self.level_proportion_range = level_proportion_range + + self.orientation_thr = orientation_thr + + def vector_angle(self, vec1, vec2): + if vec1.ndim > 1: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8) + if vec2.ndim > 1: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8) + return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0)) + + def resample_line(self, line, n): + """Resample n points on a line. + + Args: + line (ndarray): The points composing a line. + n (int): The resampled points number. + + Returns: + resampled_line (ndarray): The points composing the resampled line. + """ + + assert line.ndim == 2 + assert line.shape[0] >= 2 + assert line.shape[1] == 2 + assert isinstance(n, int) + assert n > 0 + + length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)] + total_length = sum(length_list) + length_cumsum = np.cumsum([0.0] + length_list) + delta_length = total_length / (float(n) + 1e-8) + + current_edge_ind = 0 + resampled_line = [line[0]] + + for i in range(1, n): + current_line_len = i * delta_length + + while current_line_len >= length_cumsum[current_edge_ind + 1]: + current_edge_ind += 1 + current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind] + end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind] + current_point = ( + line[current_edge_ind] + + (line[current_edge_ind + 1] - line[current_edge_ind]) + * end_shift_ratio + ) + resampled_line.append(current_point) + + resampled_line.append(line[-1]) + resampled_line = np.array(resampled_line) + + return resampled_line + + def reorder_poly_edge(self, points): + """Get the respective points composing head edge, tail edge, top + sideline and bottom sideline. + + Args: + points (ndarray): The points composing a text polygon. + + Returns: + head_edge (ndarray): The two points composing the head edge of text + polygon. + tail_edge (ndarray): The two points composing the tail edge of text + polygon. + top_sideline (ndarray): The points composing top curved sideline of + text polygon. + bot_sideline (ndarray): The points composing bottom curved sideline + of text polygon. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + + head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr) + head_edge, tail_edge = points[head_inds], points[tail_inds] + + pad_points = np.vstack([points, points]) + if tail_inds[1] < 1: + tail_inds[1] = len(points) + sideline1 = pad_points[head_inds[1] : tail_inds[1]] + sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))] + sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0) + + if sideline_mean_shift[1] > 0: + top_sideline, bot_sideline = sideline2, sideline1 + else: + top_sideline, bot_sideline = sideline1, sideline2 + + return head_edge, tail_edge, top_sideline, bot_sideline + + def find_head_tail(self, points, orientation_thr): + """Find the head edge and tail edge of a text polygon. + + Args: + points (ndarray): The points composing a text polygon. + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + + Returns: + head_inds (list): The indexes of two points composing head edge. + tail_inds (list): The indexes of two points composing tail edge. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + assert isinstance(orientation_thr, float) + + if len(points) > 4: + pad_points = np.vstack([points, points[0]]) + edge_vec = pad_points[1:] - pad_points[:-1] + + theta_sum = [] + adjacent_vec_theta = [] + for i, edge_vec1 in enumerate(edge_vec): + adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]] + adjacent_edge_vec = edge_vec[adjacent_ind] + temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec)) + temp_adjacent_theta = self.vector_angle( + adjacent_edge_vec[0], adjacent_edge_vec[1] + ) + theta_sum.append(temp_theta_sum) + adjacent_vec_theta.append(temp_adjacent_theta) + theta_sum_score = np.array(theta_sum) / np.pi + adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi + poly_center = np.mean(points, axis=0) + edge_dist = np.maximum( + norm(pad_points[1:] - poly_center, axis=-1), + norm(pad_points[:-1] - poly_center, axis=-1), + ) + dist_score = edge_dist / np.max(edge_dist) + position_score = np.zeros(len(edge_vec)) + score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score + score += 0.35 * dist_score + if len(points) % 2 == 0: + position_score[(len(score) // 2 - 1)] += 1 + position_score[-1] += 1 + score += 0.1 * position_score + pad_score = np.concatenate([score, score]) + score_matrix = np.zeros((len(score), len(score) - 3)) + x = np.arange(len(score) - 3) / float(len(score) - 4) + gaussian = ( + 1.0 + / (np.sqrt(2.0 * np.pi) * 0.5) + * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2) + ) + gaussian = gaussian / np.max(gaussian) + for i in range(len(score)): + score_matrix[i, :] = ( + score[i] + + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3 + ) + + head_start, tail_increment = np.unravel_index( + score_matrix.argmax(), score_matrix.shape + ) + tail_start = (head_start + tail_increment + 2) % len(points) + head_end = (head_start + 1) % len(points) + tail_end = (tail_start + 1) % len(points) + + if head_end > tail_end: + head_start, tail_start = tail_start, head_start + head_end, tail_end = tail_end, head_end + head_inds = [head_start, head_end] + tail_inds = [tail_start, tail_end] + else: + if vector_slope(points[1] - points[0]) + vector_slope( + points[3] - points[2] + ) < vector_slope(points[2] - points[1]) + vector_slope( + points[0] - points[3] + ): + horizontal_edge_inds = [[0, 1], [2, 3]] + vertical_edge_inds = [[3, 0], [1, 2]] + else: + horizontal_edge_inds = [[3, 0], [1, 2]] + vertical_edge_inds = [[0, 1], [2, 3]] + + vertical_len_sum = norm( + points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]] + ) + norm( + points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]] + ) + horizontal_len_sum = norm( + points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]] + ) + norm( + points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]] + ) + + if vertical_len_sum > horizontal_len_sum * orientation_thr: + head_inds = horizontal_edge_inds[0] + tail_inds = horizontal_edge_inds[1] + else: + head_inds = vertical_edge_inds[0] + tail_inds = vertical_edge_inds[1] + + return head_inds, tail_inds + + def resample_sidelines(self, sideline1, sideline2, resample_step): + """Resample two sidelines to be of the same points number according to + step size. + + Args: + sideline1 (ndarray): The points composing a sideline of a text + polygon. + sideline2 (ndarray): The points composing another sideline of a + text polygon. + resample_step (float): The resampled step size. + + Returns: + resampled_line1 (ndarray): The resampled line 1. + resampled_line2 (ndarray): The resampled line 2. + """ + + assert sideline1.ndim == sideline2.ndim == 2 + assert sideline1.shape[1] == sideline2.shape[1] == 2 + assert sideline1.shape[0] >= 2 + assert sideline2.shape[0] >= 2 + assert isinstance(resample_step, float) + + length1 = sum( + [norm(sideline1[i + 1] - sideline1[i]) for i in range(len(sideline1) - 1)] + ) + length2 = sum( + [norm(sideline2[i + 1] - sideline2[i]) for i in range(len(sideline2) - 1)] + ) + + total_length = (length1 + length2) / 2 + resample_point_num = max(int(float(total_length) / resample_step), 1) + + resampled_line1 = self.resample_line(sideline1, resample_point_num) + resampled_line2 = self.resample_line(sideline2, resample_point_num) + + return resampled_line1, resampled_line2 + + def generate_center_region_mask(self, img_size, text_polys): + """Generate text center region mask. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_region_mask (ndarray): The text center region mask. + """ + + assert isinstance(img_size, tuple) + # assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_region_mask = np.zeros((h, w), np.uint8) + + center_region_boxes = [] + for poly in text_polys: + # assert len(poly) == 1 + polygon_points = poly.reshape(-1, 2) + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step + ) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + line_head_shrink_len = ( + norm(resampled_top_line[0] - resampled_bot_line[0]) / 4.0 + ) + line_tail_shrink_len = ( + norm(resampled_top_line[-1] - resampled_bot_line[-1]) / 4.0 + ) + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[ + head_shrink_num : len(center_line) - tail_shrink_num + ] + resampled_top_line = resampled_top_line[ + head_shrink_num : len(resampled_top_line) - tail_shrink_num + ] + resampled_bot_line = resampled_bot_line[ + head_shrink_num : len(resampled_bot_line) - tail_shrink_num + ] + + for i in range(0, len(center_line) - 1): + tl = ( + center_line[i] + + (resampled_top_line[i] - center_line[i]) + * self.center_region_shrink_ratio + ) + tr = ( + center_line[i + 1] + + (resampled_top_line[i + 1] - center_line[i + 1]) + * self.center_region_shrink_ratio + ) + br = ( + center_line[i + 1] + + (resampled_bot_line[i + 1] - center_line[i + 1]) + * self.center_region_shrink_ratio + ) + bl = ( + center_line[i] + + (resampled_bot_line[i] - center_line[i]) + * self.center_region_shrink_ratio + ) + current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32) + center_region_boxes.append(current_center_box) + + cv2.fillPoly(center_region_mask, center_region_boxes, 1) + return center_region_mask + + def resample_polygon(self, polygon, n=400): + """Resample one polygon with n points on its boundary. + + Args: + polygon (list[float]): The input polygon. + n (int): The number of resampled points. + Returns: + resampled_polygon (list[float]): The resampled polygon. + """ + length = [] + + for i in range(len(polygon)): + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5) + + total_length = sum(length) + n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n + n_on_each_line = n_on_each_line.astype(np.int32) + new_polygon = [] + + for i in range(len(polygon)): + num = n_on_each_line[i] + p1 = polygon[i] + if i == len(polygon) - 1: + p2 = polygon[0] + else: + p2 = polygon[i + 1] + + if num == 0: + continue + + dxdy = (p2 - p1) / num + for j in range(num): + point = p1 + dxdy * j + new_polygon.append(point) + + return np.array(new_polygon) + + def normalize_polygon(self, polygon): + """Normalize one polygon so that its start point is at right most. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon with start point at right. + """ + temp_polygon = polygon - polygon.mean(axis=0) + x = np.abs(temp_polygon[:, 0]) + y = temp_polygon[:, 1] + index_x = np.argsort(x) + index_y = np.argmin(y[index_x[:8]]) + index = index_x[index_y] + new_polygon = np.concatenate([polygon[index:], polygon[:index]]) + return new_polygon + + def poly2fourier(self, polygon, fourier_degree): + """Perform Fourier transformation to generate Fourier coefficients ck + from polygon. + + Args: + polygon (ndarray): An input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + c (ndarray(complex)): Fourier coefficients. + """ + points = polygon[:, 0] + polygon[:, 1] * 1j + c_fft = fft(points) / len(points) + c = np.hstack((c_fft[-fourier_degree:], c_fft[: fourier_degree + 1])) + return c + + def clockwise(self, c, fourier_degree): + """Make sure the polygon reconstructed from Fourier coefficients c in + the clockwise direction. + + Args: + polygon (list[float]): The origin polygon. + Returns: + new_polygon (lost[float]): The polygon in clockwise point order. + """ + if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]): + return c + elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]): + return c[::-1] + else: + if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]): + return c + else: + return c[::-1] + + def cal_fourier_signature(self, polygon, fourier_degree): + """Calculate Fourier signature from input polygon. + + Args: + polygon (ndarray): The input polygon. + fourier_degree (int): The maximum Fourier degree K. + Returns: + fourier_signature (ndarray): An array shaped (2k+1, 2) containing + real part and image part of 2k+1 Fourier coefficients. + """ + resampled_polygon = self.resample_polygon(polygon) + resampled_polygon = self.normalize_polygon(resampled_polygon) + + fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree) + fourier_coeff = self.clockwise(fourier_coeff, fourier_degree) + + real_part = np.real(fourier_coeff).reshape((-1, 1)) + image_part = np.imag(fourier_coeff).reshape((-1, 1)) + fourier_signature = np.hstack([real_part, image_part]) + + return fourier_signature + + def generate_fourier_maps(self, img_size, text_polys): + """Generate Fourier coefficient maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + fourier_real_map (ndarray): The Fourier coefficient real part maps. + fourier_image_map (ndarray): The Fourier coefficient image part + maps. + """ + + assert isinstance(img_size, tuple) + + h, w = img_size + k = self.fourier_degree + real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32) + + for poly in text_polys: + mask = np.zeros((h, w), dtype=np.uint8) + polygon = np.array(poly).reshape((1, -1, 2)) + cv2.fillPoly(mask, polygon.astype(np.int32), 1) + fourier_coeff = self.cal_fourier_signature(polygon[0], k) + for i in range(-k, k + 1): + if i != 0: + real_map[i + k, :, :] = ( + mask * fourier_coeff[i + k, 0] + + (1 - mask) * real_map[i + k, :, :] + ) + imag_map[i + k, :, :] = ( + mask * fourier_coeff[i + k, 1] + + (1 - mask) * imag_map[i + k, :, :] + ) + else: + yx = np.argwhere(mask > 0.5) + k_ind = np.ones((len(yx)), dtype=np.int64) * k + y, x = yx[:, 0], yx[:, 1] + real_map[k_ind, y, x] = fourier_coeff[k, 0] - x + imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y + + return real_map, imag_map + + def generate_text_region_mask(self, img_size, text_polys): + """Generate text center region mask and geometry attribute maps. + + Args: + img_size (tuple): The image size (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + text_region_mask (ndarray): The text region mask. + """ + + assert isinstance(img_size, tuple) + + h, w = img_size + text_region_mask = np.zeros((h, w), dtype=np.uint8) + + for poly in text_polys: + polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2)) + cv2.fillPoly(text_region_mask, polygon, 1) + + return text_region_mask + + def generate_effective_mask(self, mask_size: tuple, polygons_ignore): + """Generate effective mask by setting the ineffective regions to 0 and + effective regions to 1. + + Args: + mask_size (tuple): The mask size. + polygons_ignore (list[[ndarray]]: The list of ignored text + polygons. + + Returns: + mask (ndarray): The effective mask of (height, width). + """ + + mask = np.ones(mask_size, dtype=np.uint8) + + for poly in polygons_ignore: + instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2) + cv2.fillPoly(mask, instance, 0) + + return mask + + def generate_level_targets(self, img_size, text_polys, ignore_polys): + """Generate ground truth target on each level. + + Args: + img_size (list[int]): Shape of input image. + text_polys (list[list[ndarray]]): A list of ground truth polygons. + ignore_polys (list[list[ndarray]]): A list of ignored polygons. + Returns: + level_maps (list(ndarray)): A list of ground target on each level. + """ + h, w = img_size + lv_size_divs = self.level_size_divisors + lv_proportion_range = self.level_proportion_range + lv_text_polys = [[] for i in range(len(lv_size_divs))] + lv_ignore_polys = [[] for i in range(len(lv_size_divs))] + level_maps = [] + for poly in text_polys: + polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_text_polys[ind].append(poly / lv_size_divs[ind]) + + for ignore_poly in ignore_polys: + polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2)) + _, _, box_w, box_h = cv2.boundingRect(polygon) + proportion = max(box_h, box_w) / (h + 1e-8) + + for ind, proportion_range in enumerate(lv_proportion_range): + if proportion_range[0] < proportion < proportion_range[1]: + lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind]) + + for ind, size_divisor in enumerate(lv_size_divs): + current_level_maps = [] + level_img_size = (h // size_divisor, w // size_divisor) + + text_region = self.generate_text_region_mask( + level_img_size, lv_text_polys[ind] + )[None] + current_level_maps.append(text_region) + + center_region = self.generate_center_region_mask( + level_img_size, lv_text_polys[ind] + )[None] + current_level_maps.append(center_region) + + effective_mask = self.generate_effective_mask( + level_img_size, lv_ignore_polys[ind] + )[None] + current_level_maps.append(effective_mask) + + fourier_real_map, fourier_image_maps = self.generate_fourier_maps( + level_img_size, lv_text_polys[ind] + ) + current_level_maps.append(fourier_real_map) + current_level_maps.append(fourier_image_maps) + + level_maps.append(np.concatenate(current_level_maps)) + + return level_maps + + def generate_targets(self, results): + """Generate the ground truth targets for FCENet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + image = results["image"] + polygons = results["polys"] + ignore_tags = results["ignore_tags"] + h, w, _ = image.shape + + polygon_masks = [] + polygon_masks_ignore = [] + for tag, polygon in zip(ignore_tags, polygons): + if tag is True: + polygon_masks_ignore.append(polygon) + else: + polygon_masks.append(polygon) + + level_maps = self.generate_level_targets( + (h, w), polygon_masks, polygon_masks_ignore + ) + + mapping = { + "p3_maps": level_maps[0], + "p4_maps": level_maps[1], + "p5_maps": level_maps[2], + } + for key, value in mapping.items(): + results[key] = value + + return results + + def __call__(self, results): + results = self.generate_targets(results) + return results diff --git a/ocr/ppocr/data/imaug/gen_table_mask.py b/ocr/ppocr/data/imaug/gen_table_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7010d9b36b95d1912c73fe7a66e509edc88f58 --- /dev/null +++ b/ocr/ppocr/data/imaug/gen_table_mask.py @@ -0,0 +1,228 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import cv2 +import numpy as np + + +class GenTableMask(object): + """gen table mask""" + + def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): + self.shrink_h_max = 5 + self.shrink_w_max = 5 + self.mask_type = mask_type + + def projection(self, erosion, h, w, spilt_threshold=0): + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + return box_list, projection_map + + def projection_cx(self, box_img): + box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) + h, w = box_gray_img.shape + # 灰度图片进行二值化处理 + ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) + # 纵向腐蚀 + if h < w: + kernel = np.ones((2, 1), np.uint8) + erode = cv2.erode(thresh1, kernel, iterations=1) + else: + erode = thresh1 + # 水平膨胀 + kernel = np.ones((1, 5), np.uint8) + erosion = cv2.dilate(erode, kernel, iterations=1) + # 水平投影 + projection_map = np.ones_like(erosion) + project_val_array = [0 for _ in range(0, h)] + + for j in range(0, h): + for i in range(0, w): + if erosion[j, i] == 255: + project_val_array[j] += 1 + # 根据数组,获取切割点 + start_idx = 0 # 记录进入字符区的索引 + end_idx = 0 # 记录进入空白区域的索引 + in_text = False # 是否遍历到了字符区内 + box_list = [] + spilt_threshold = 0 + for i in range(len(project_val_array)): + if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 + in_text = True + start_idx = i + elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 + end_idx = i + in_text = False + if end_idx - start_idx <= 2: + continue + box_list.append((start_idx, end_idx + 1)) + + if in_text: + box_list.append((start_idx, h - 1)) + # 绘制投影直方图 + for j in range(0, h): + for i in range(0, project_val_array[j]): + projection_map[j, i] = 0 + split_bbox_list = [] + if len(box_list) > 1: + for i, (h_start, h_end) in enumerate(box_list): + if i == 0: + h_start = 0 + if i == len(box_list): + h_end = h + word_img = erosion[h_start : h_end + 1, :] + word_h, word_w = word_img.shape + w_split_list, w_projection_map = self.projection( + word_img.T, word_w, word_h + ) + w_start, w_end = w_split_list[0][0], w_split_list[-1][1] + if h_start > 0: + h_start -= 1 + h_end += 1 + word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :] + split_bbox_list.append([w_start, h_start, w_end, h_end]) + else: + split_bbox_list.append([0, 0, w, h]) + return split_bbox_list + + def shrink_bbox(self, bbox): + left, top, right, bottom = bbox + sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) + sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) + left_new = left + sh_w + right_new = right - sh_w + top_new = top + sh_h + bottom_new = bottom - sh_h + if left_new >= right_new: + left_new = left + right_new = right + if top_new >= bottom_new: + top_new = top + bottom_new = bottom + return [left_new, top_new, right_new, bottom_new] + + def __call__(self, data): + img = data["image"] + cells = data["cells"] + height, width = img.shape[0:2] + if self.mask_type == 1: + mask_img = np.zeros((height, width), dtype=np.float32) + else: + mask_img = np.zeros((height, width, 3), dtype=np.float32) + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]["bbox"] + left, top, right, bottom = bbox + box_img = img[top:bottom, left:right, :].copy() + split_bbox_list = self.projection_cx(box_img) + for sno in range(len(split_bbox_list)): + split_bbox_list[sno][0] += left + split_bbox_list[sno][1] += top + split_bbox_list[sno][2] += left + split_bbox_list[sno][3] += top + + for sno in range(len(split_bbox_list)): + left, top, right, bottom = split_bbox_list[sno] + left, top, right, bottom = self.shrink_bbox( + [left, top, right, bottom] + ) + if self.mask_type == 1: + mask_img[top:bottom, left:right] = 1.0 + data["mask_img"] = mask_img + else: + mask_img[top:bottom, left:right, :] = (255, 255, 255) + data["image"] = mask_img + return data + + +class ResizeTableImage(object): + def __init__(self, max_len, **kwargs): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + + def get_img_bbox(self, cells): + bbox_list = [] + if len(cells) == 0: + return bbox_list + cell_num = len(cells) + for cno in range(cell_num): + if "bbox" in cells[cno]: + bbox = cells[cno]["bbox"] + bbox_list.append(bbox) + return bbox_list + + def resize_img_table(self, img, bbox_list, max_len): + height, width = img.shape[0:2] + ratio = max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + img_new = cv2.resize(img, (resize_w, resize_h)) + bbox_list_new = [] + for bno in range(len(bbox_list)): + left, top, right, bottom = bbox_list[bno].copy() + left = int(left * ratio) + top = int(top * ratio) + right = int(right * ratio) + bottom = int(bottom * ratio) + bbox_list_new.append([left, top, right, bottom]) + return img_new, bbox_list_new + + def __call__(self, data): + img = data["image"] + if "cells" not in data: + cells = [] + else: + cells = data["cells"] + bbox_list = self.get_img_bbox(cells) + img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) + data["image"] = img_new + cell_num = len(cells) + bno = 0 + for cno in range(cell_num): + if "bbox" in data["cells"][cno]: + data["cells"][cno]["bbox"] = bbox_list_new[bno] + bno += 1 + data["max_len"] = self.max_len + return data + + +class PaddingTableImage(object): + def __init__(self, **kwargs): + super(PaddingTableImage, self).__init__() + + def __call__(self, data): + img = data["image"] + max_len = data["max_len"] + padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data["image"] = padding_img + return data diff --git a/ocr/ppocr/data/imaug/iaa_augment.py b/ocr/ppocr/data/imaug/iaa_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..07f02a3e9e850982721e5572822d62fb6b89f808 --- /dev/null +++ b/ocr/ppocr/data/imaug/iaa_augment.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import imgaug +import imgaug.augmenters as iaa +import numpy as np + + +class AugmenterBuilder(object): + def __init__(self): + pass + + def build(self, args, root=True): + if args is None or len(args) == 0: + return None + elif isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + else: + return getattr(iaa, args[0])( + *[self.to_tuple_if_list(a) for a in args[1:]] + ) + elif isinstance(args, dict): + cls = getattr(iaa, args["type"]) + return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()}) + else: + raise RuntimeError("unknown augmenter arg: " + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj + + +class IaaAugment: + def __init__(self, augmenter_args=None, **kwargs): + if augmenter_args is None: + augmenter_args = [ + {"type": "Fliplr", "args": {"p": 0.5}}, + {"type": "Affine", "args": {"rotate": [-10, 10]}}, + {"type": "Resize", "args": {"size": [0.5, 3]}}, + ] + self.augmenter = AugmenterBuilder().build(augmenter_args) + + def __call__(self, data): + image = data["image"] + shape = image.shape + + if self.augmenter: + aug = self.augmenter.to_deterministic() + data["image"] = aug.augment_image(image) + data = self.may_augment_annotation(aug, data, shape) + return data + + def may_augment_annotation(self, aug, data, shape): + if aug is None: + return data + + line_polys = [] + for poly in data["polys"]: + new_poly = self.may_augment_poly(aug, shape, poly) + line_polys.append(new_poly) + data["polys"] = np.array(line_polys) + return data + + def may_augment_poly(self, aug, img_shape, poly): + keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] + keypoints = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints, shape=img_shape)] + )[0].keypoints + poly = [(p.x, p.y) for p in keypoints] + return poly diff --git a/ocr/ppocr/data/imaug/label_ops.py b/ocr/ppocr/data/imaug/label_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..eff417307dac09cd4d4bc93c16c78df8fbe72088 --- /dev/null +++ b/ocr/ppocr/data/imaug/label_ops.py @@ -0,0 +1,1046 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy +import json + +import numpy as np +from shapely.geometry import LineString, Point, Polygon + + +class ClsLabelEncode(object): + def __init__(self, label_list, **kwargs): + self.label_list = label_list + + def __call__(self, data): + label = data["label"] + if label not in self.label_list: + return None + label = self.label_list.index(label) + data["label"] = label + return data + + +class DetLabelEncode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data["label"] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]["points"] + txt = label[bno]["transcription"] + boxes.append(box) + txts.append(txt) + if txt in ["*", "###"]: + txt_tags.append(True) + else: + txt_tags.append(False) + if len(boxes) == 0: + return None + boxes = self.expand_points_num(boxes) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + + data["polys"] = boxes + data["texts"] = txts + data["ignore_tags"] = txt_tags + return data + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + return rect + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +class BaseRecLabelEncode(object): + """Convert between text-label and text-index""" + + def __init__(self, max_text_length, character_dict_path=None, use_space_char=False): + + self.max_text_len = max_text_length + self.beg_str = "sos" + self.end_str = "eos" + self.lower = False + + if character_dict_path is None: + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + self.lower = True + else: + self.character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + self.character_str.append(line) + if use_space_char: + self.character_str.append(" ") + dict_character = list(self.character_str) + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def encode(self, text): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + + output: + text: concatenated text index for CTCLoss. + [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] + length: length of each text. [batch_size] + """ + if len(text) == 0 or len(text) > self.max_text_len: + return None + if self.lower: + text = text.lower() + text_list = [] + for char in text: + if char not in self.dict: + continue + text_list.append(self.dict[char]) + if len(text_list) == 0: + return None + return text_list + + +class NRTRLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + + super(NRTRLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data["length"] = np.array(len(text)) + text.insert(0, 2) + text.append(3) + text = text + [0] * (self.max_text_len - len(text)) + data["label"] = np.array(text) + return data + + def add_special_char(self, dict_character): + dict_character = ["blank", "", "", ""] + dict_character + return dict_character + + +class CTCLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(CTCLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + if text is None: + return None + data["length"] = np.array(len(text)) + text = text + [0] * (self.max_text_len - len(text)) + data["label"] = np.array(text) + + label = [0] * len(self.character) + for x in text: + label[x] += 1 + data["label_ace"] = np.array(label) + return data + + def add_special_char(self, dict_character): + dict_character = ["blank"] + dict_character + return dict_character + + +class E2ELabelEncodeTest(BaseRecLabelEncode): + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(E2ELabelEncodeTest, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def __call__(self, data): + import json + + padnum = len(self.dict) + label = data["label"] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]["points"] + txt = label[bno]["transcription"] + boxes.append(box) + txts.append(txt) + if txt in ["*", "###"]: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + data["polys"] = boxes + data["ignore_tags"] = txt_tags + temp_texts = [] + for text in txts: + text = text.lower() + text = self.encode(text) + if text is None: + return None + text = text + [padnum] * (self.max_text_len - len(text)) # use 36 to pad + temp_texts.append(text) + data["texts"] = np.array(temp_texts) + return data + + +class E2ELabelEncodeTrain(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + import json + + label = data["label"] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]["points"] + txt = label[bno]["transcription"] + boxes.append(box) + txts.append(txt) + if txt in ["*", "###"]: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + + data["polys"] = boxes + data["texts"] = txts + data["ignore_tags"] = txt_tags + return data + + +class KieLabelEncode(object): + def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): + super(KieLabelEncode, self).__init__() + self.dict = dict({"": 0}) + with open(character_dict_path, "r", encoding="utf-8") as fr: + idx = 1 + for line in fr: + char = line.strip() + self.dict[char] = idx + idx += 1 + self.norm = norm + self.directed = directed + + def compute_relation(self, boxes): + """Compute relation between every two boxes.""" + x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] + x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] + ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) + dxs = (x1s[:, 0][None] - x1s) / self.norm + dys = (y1s[:, 0][None] - y1s) / self.norm + xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs + whs = ws / hs + np.zeros_like(xhhs) + relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) + bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) + return relations, bboxes + + def pad_text_indices(self, text_inds): + """Pad text index to same length.""" + max_len = 300 + recoder_len = max([len(text_ind) for text_ind in text_inds]) + padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) + for idx, text_ind in enumerate(text_inds): + padded_text_inds[idx, : len(text_ind)] = np.array(text_ind) + return padded_text_inds, recoder_len + + def list_to_numpy(self, ann_infos): + """Convert bboxes, relations, texts and labels to ndarray.""" + boxes, text_inds = ann_infos["points"], ann_infos["text_inds"] + boxes = np.array(boxes, np.int32) + relations, bboxes = self.compute_relation(boxes) + + labels = ann_infos.get("labels", None) + if labels is not None: + labels = np.array(labels, np.int32) + edges = ann_infos.get("edges", None) + if edges is not None: + labels = labels[:, None] + edges = np.array(edges) + edges = (edges[:, None] == edges[None, :]).astype(np.int32) + if self.directed: + edges = (edges & labels == 1).astype(np.int32) + np.fill_diagonal(edges, -1) + labels = np.concatenate([labels, edges], -1) + padded_text_inds, recoder_len = self.pad_text_indices(text_inds) + max_num = 300 + temp_bboxes = np.zeros([max_num, 4]) + h, _ = bboxes.shape + temp_bboxes[:h, :] = bboxes + + temp_relations = np.zeros([max_num, max_num, 5]) + temp_relations[:h, :h, :] = relations + + temp_padded_text_inds = np.zeros([max_num, max_num]) + temp_padded_text_inds[:h, :] = padded_text_inds + + temp_labels = np.zeros([max_num, max_num]) + temp_labels[:h, : h + 1] = labels + + tag = np.array([h, recoder_len]) + return dict( + image=ann_infos["image"], + points=temp_bboxes, + relations=temp_relations, + texts=temp_padded_text_inds, + labels=temp_labels, + tag=tag, + ) + + def convert_canonical(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + + polygon = Polygon([(p.x, p.y) for p in points]) + min_x, min_y, _, _ = polygon.bounds + points_to_lefttop = [ + LineString([points[i], Point(min_x, min_y)]) for i in range(4) + ] + distances = np.array([line.length for line in points_to_lefttop]) + sort_dist_idx = np.argsort(distances) + lefttop_idx = sort_dist_idx[0] + + if lefttop_idx == 0: + point_orders = [0, 1, 2, 3] + elif lefttop_idx == 1: + point_orders = [1, 2, 3, 0] + elif lefttop_idx == 2: + point_orders = [2, 3, 0, 1] + else: + point_orders = [3, 0, 1, 2] + + sorted_points_x = [points_x[i] for i in point_orders] + sorted_points_y = [points_y[j] for j in point_orders] + + return sorted_points_x, sorted_points_y + + def sort_vertex(self, points_x, points_y): + + assert len(points_x) == 4 + assert len(points_y) == 4 + + x = np.array(points_x) + y = np.array(points_y) + center_x = np.sum(x) * 0.25 + center_y = np.sum(y) * 0.25 + + x_arr = np.array(x - center_x) + y_arr = np.array(y - center_y) + + angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi + sort_idx = np.argsort(angle) + + sorted_points_x, sorted_points_y = [], [] + for i in range(4): + sorted_points_x.append(points_x[sort_idx[i]]) + sorted_points_y.append(points_y[sort_idx[i]]) + + return self.convert_canonical(sorted_points_x, sorted_points_y) + + def __call__(self, data): + import json + + label = data["label"] + annotations = json.loads(label) + boxes, texts, text_inds, labels, edges = [], [], [], [], [] + for ann in annotations: + box = ann["points"] + x_list = [box[i][0] for i in range(4)] + y_list = [box[i][1] for i in range(4)] + sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list) + sorted_box = [] + for x, y in zip(sorted_x_list, sorted_y_list): + sorted_box.append(x) + sorted_box.append(y) + boxes.append(sorted_box) + text = ann["transcription"] + texts.append(ann["transcription"]) + text_ind = [self.dict[c] for c in text if c in self.dict] + text_inds.append(text_ind) + if "label" in ann.keys(): + labels.append(ann["label"]) + elif "key_cls" in ann.keys(): + labels.append(ann["key_cls"]) + else: + raise ValueError( + "Cannot found 'key_cls' in ann.keys(), please check your training annotation." + ) + edges.append(ann.get("edge", 0)) + ann_infos = dict( + image=data["image"], + points=boxes, + texts=texts, + text_inds=text_inds, + edges=edges, + labels=labels, + ) + + return self.list_to_numpy(ann_infos) + + +class AttnLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(AttnLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data["length"] = np.array(len(text)) + text = ( + [0] + + text + + [len(self.character) - 1] + + [0] * (self.max_text_len - len(text) - 2) + ) + data["label"] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + +class SEEDLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(SEEDLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def add_special_char(self, dict_character): + self.padding = "padding" + self.end_str = "eos" + self.unknown = "unknown" + dict_character = dict_character + [self.end_str, self.padding, self.unknown] + return dict_character + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + data["length"] = np.array(len(text)) + 1 # conclude eos + text = ( + text + + [len(self.character) - 3] + + [len(self.character) - 2] * (self.max_text_len - len(text) - 1) + ) + data["label"] = np.array(text) + return data + + +class SRNLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, + max_text_length=25, + character_dict_path=None, + use_space_char=False, + **kwargs + ): + super(SRNLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + char_num = len(self.character) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data["length"] = np.array(len(text)) + text = text + [char_num - 1] * (self.max_text_len - len(text)) + data["label"] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" % beg_or_end + return idx + + +class TableLabelEncode(object): + """Convert between text-label and text-index""" + + def __init__( + self, + max_text_length, + max_elem_length, + max_cell_num, + character_dict_path, + span_weight=1.0, + **kwargs + ): + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num + list_character, list_elem = self.load_char_elem_dict(character_dict_path) + list_character = self.add_special_char(list_character) + list_elem = self.add_special_char(list_elem) + self.dict_character = {} + for i, char in enumerate(list_character): + self.dict_character[char] = i + self.dict_elem = {} + for i, elem in enumerate(list_elem): + self.dict_elem[elem] = i + self.span_weight = span_weight + + def load_char_elem_dict(self, character_dict_path): + list_character = [] + list_elem = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + substr = lines[0].decode("utf-8").strip("\r\n").split("\t") + character_num = int(substr[0]) + elem_num = int(substr[1]) + for cno in range(1, 1 + character_num): + character = lines[cno].decode("utf-8").strip("\r\n") + list_character.append(character) + for eno in range(1 + character_num, 1 + character_num + elem_num): + elem = lines[eno].decode("utf-8").strip("\r\n") + list_elem.append(elem) + return list_character, list_elem + + def add_special_char(self, list_character): + self.beg_str = "sos" + self.end_str = "eos" + list_character = [self.beg_str] + list_character + [self.end_str] + return list_character + + def get_span_idx_list(self): + span_idx_list = [] + for elem in self.dict_elem: + if "span" in elem: + span_idx_list.append(self.dict_elem[elem]) + return span_idx_list + + def __call__(self, data): + cells = data["cells"] + structure = data["structure"]["tokens"] + structure = self.encode(structure, "elem") + if structure is None: + return None + elem_num = len(structure) + structure = [0] + structure + [len(self.dict_elem) - 1] + structure = structure + [0] * (self.max_elem_length + 2 - len(structure)) + structure = np.array(structure) + data["structure"] = structure + elem_char_idx1 = self.dict_elem[""] + elem_char_idx2 = self.dict_elem[" 0: + span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) + span_weight = min(max(span_weight, 1.0), self.span_weight) + for cno in range(len(cells)): + if "bbox" in cells[cno]: + bbox = cells[cno]["bbox"].copy() + bbox[0] = bbox[0] * 1.0 / img_width + bbox[1] = bbox[1] * 1.0 / img_height + bbox[2] = bbox[2] * 1.0 / img_width + bbox[3] = bbox[3] * 1.0 / img_height + td_idx = td_idx_list[cno] + bbox_list[td_idx] = bbox + bbox_list_mask[td_idx] = 1.0 + cand_span_idx = td_idx + 1 + if cand_span_idx < (self.max_elem_length + 2): + if structure[cand_span_idx] in span_idx_list: + structure_mask[cand_span_idx] = span_weight + + data["bbox_list"] = bbox_list + data["bbox_list_mask"] = bbox_list_mask + data["structure_mask"] = structure_mask + char_beg_idx = self.get_beg_end_flag_idx("beg", "char") + char_end_idx = self.get_beg_end_flag_idx("end", "char") + elem_beg_idx = self.get_beg_end_flag_idx("beg", "elem") + elem_end_idx = self.get_beg_end_flag_idx("end", "elem") + data["sp_tokens"] = np.array( + [ + char_beg_idx, + char_end_idx, + elem_beg_idx, + elem_end_idx, + elem_char_idx1, + elem_char_idx2, + self.max_text_length, + self.max_elem_length, + self.max_cell_num, + elem_num, + ] + ) + return data + + def encode(self, text, char_or_elem): + """convert text-label into text-index.""" + if char_or_elem == "char": + max_len = self.max_text_length + current_dict = self.dict_character + else: + max_len = self.max_elem_length + current_dict = self.dict_elem + if len(text) > max_len: + return None + if len(text) == 0: + if char_or_elem == "char": + return [self.dict_character["space"]] + else: + return None + text_list = [] + for char in text: + if char not in current_dict: + return None + text_list.append(current_dict[char]) + if len(text_list) == 0: + if char_or_elem == "char": + return [self.dict_character["space"]] + else: + return None + return text_list + + def get_ignored_tokens(self, char_or_elem): + beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) + end_idx = self.get_beg_end_flag_idx("end", char_or_elem) + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): + if char_or_elem == "char": + if beg_or_end == "beg": + idx = np.array(self.dict_character[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_character[self.end_str]) + else: + assert False, ( + "Unsupport type %s in get_beg_end_flag_idx of char" % beg_or_end + ) + elif char_or_elem == "elem": + if beg_or_end == "beg": + idx = np.array(self.dict_elem[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict_elem[self.end_str]) + else: + assert False, ( + "Unsupport type %s in get_beg_end_flag_idx of elem" % beg_or_end + ) + else: + assert False, "Unsupport type %s in char_or_elem" % char_or_elem + return idx + + +class SARLabelEncode(BaseRecLabelEncode): + """Convert between text-label and text-index""" + + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(SARLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def add_special_char(self, dict_character): + beg_end_str = "" + unknown_str = "" + padding_str = "" + dict_character = dict_character + [unknown_str] + self.unknown_idx = len(dict_character) - 1 + dict_character = dict_character + [beg_end_str] + self.start_idx = len(dict_character) - 1 + self.end_idx = len(dict_character) - 1 + dict_character = dict_character + [padding_str] + self.padding_idx = len(dict_character) - 1 + + return dict_character + + def __call__(self, data): + text = data["label"] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 1: + return None + data["length"] = np.array(len(text)) + target = [self.start_idx] + text + [self.end_idx] + padded_text = [self.padding_idx for _ in range(self.max_text_len)] + + padded_text[: len(target)] = target + data["label"] = np.array(padded_text) + return data + + def get_ignored_tokens(self): + return [self.padding_idx] + + +class PRENLabelEncode(BaseRecLabelEncode): + def __init__( + self, max_text_length, character_dict_path, use_space_char=False, **kwargs + ): + super(PRENLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + def add_special_char(self, dict_character): + padding_str = "" # 0 + end_str = "" # 1 + unknown_str = "" # 2 + + dict_character = [padding_str, end_str, unknown_str] + dict_character + self.padding_idx = 0 + self.end_idx = 1 + self.unknown_idx = 2 + + return dict_character + + def encode(self, text): + if len(text) == 0 or len(text) >= self.max_text_len: + return None + if self.lower: + text = text.lower() + text_list = [] + for char in text: + if char not in self.dict: + text_list.append(self.unknown_idx) + else: + text_list.append(self.dict[char]) + text_list.append(self.end_idx) + if len(text_list) < self.max_text_len: + text_list += [self.padding_idx] * (self.max_text_len - len(text_list)) + return text_list + + def __call__(self, data): + text = data["label"] + encoded_text = self.encode(text) + if encoded_text is None: + return None + data["label"] = np.array(encoded_text) + return data + + +class VQATokenLabelEncode(object): + """ + Label encode for NLP VQA methods + """ + + def __init__( + self, + class_path, + contains_re=False, + add_special_ids=False, + algorithm="LayoutXLM", + infer_mode=False, + ocr_engine=None, + **kwargs + ): + super(VQATokenLabelEncode, self).__init__() + from paddlenlp.transformers import ( + LayoutLMTokenizer, + LayoutLMv2Tokenizer, + LayoutXLMTokenizer, + ) + + from ppocr.utils.utility import load_vqa_bio_label_maps + + tokenizer_dict = { + "LayoutXLM": { + "class": LayoutXLMTokenizer, + "pretrained_model": "layoutxlm-base-uncased", + }, + "LayoutLM": { + "class": LayoutLMTokenizer, + "pretrained_model": "layoutlm-base-uncased", + }, + "LayoutLMv2": { + "class": LayoutLMv2Tokenizer, + "pretrained_model": "layoutlmv2-base-uncased", + }, + } + self.contains_re = contains_re + tokenizer_config = tokenizer_dict[algorithm] + self.tokenizer = tokenizer_config["class"].from_pretrained( + tokenizer_config["pretrained_model"] + ) + self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path) + self.add_special_ids = add_special_ids + self.infer_mode = infer_mode + self.ocr_engine = ocr_engine + + def __call__(self, data): + # load bbox and label info + ocr_info = self._load_ocr_info(data) + + height, width, _ = data["image"].shape + + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + segment_offset_id = [] + gt_label_list = [] + + entities = [] + + # for re + train_re = self.contains_re and not self.infer_mode + if train_re: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + + data["ocr_info"] = copy.deepcopy(ocr_info) + + for info in ocr_info: + if train_re: + # for re + if len(info["text"]) == 0: + empty_entity.add(info["id"]) + continue + id2label[info["id"]] = info["label"] + relations.extend([tuple(sorted(l)) for l in info["linking"]]) + # smooth_box + bbox = self._smooth_box(info["bbox"], height, width) + + text = info["text"] + encode_res = self.tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True + ) + + if not self.add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] + encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] + # parse label + if not self.infer_mode: + label = info["label"] + gt_label = self._parse_label(label, encode_res) + + # construct entities for re + if train_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() + entities.append( + { + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + } + ) + else: + entities.append( + { + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": "O", + } + ) + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + if not self.infer_mode: + gt_label_list.extend(gt_label) + + data["input_ids"] = input_ids_list + data["token_type_ids"] = token_type_ids_list + data["bbox"] = bbox_list + data["attention_mask"] = [1] * len(input_ids_list) + data["labels"] = gt_label_list + data["segment_offset_id"] = segment_offset_id + data["tokenizer_params"] = dict( + padding_side=self.tokenizer.padding_side, + pad_token_type_id=self.tokenizer.pad_token_type_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + data["entities"] = entities + + if train_re: + data["relations"] = relations + data["id2label"] = id2label + data["empty_entity"] = empty_entity + data["entity_id_to_index_map"] = entity_id_to_index_map + return data + + def _load_ocr_info(self, data): + def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + + if self.infer_mode: + ocr_result = self.ocr_engine.ocr(data["image"], cls=False) + ocr_info = [] + for res in ocr_result: + ocr_info.append( + { + "text": res[1][0], + "bbox": trans_poly_to_bbox(res[0]), + "poly": res[0], + } + ) + return ocr_info + else: + info = data["label"] + # read text info + info_dict = json.loads(info) + return info_dict["ocr_info"] + + def _smooth_box(self, bbox, height, width): + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + return bbox + + def _parse_label(self, label, encode_res): + gt_label = [] + if label.lower() == "other": + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend( + [self.label2id_map[("i-" + label).upper()]] + * (len(encode_res["input_ids"]) - 1) + ) + return gt_label + + +class MultiLabelEncode(BaseRecLabelEncode): + def __init__( + self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs + ): + super(MultiLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char + ) + + self.ctc_encode = CTCLabelEncode( + max_text_length, character_dict_path, use_space_char, **kwargs + ) + self.sar_encode = SARLabelEncode( + max_text_length, character_dict_path, use_space_char, **kwargs + ) + + def __call__(self, data): + + data_ctc = copy.deepcopy(data) + data_sar = copy.deepcopy(data) + data_out = dict() + data_out["img_path"] = data.get("img_path", None) + data_out["image"] = data["image"] + ctc = self.ctc_encode.__call__(data_ctc) + sar = self.sar_encode.__call__(data_sar) + if ctc is None or sar is None: + return None + data_out["label_ctc"] = ctc["label"] + data_out["label_sar"] = sar["label"] + data_out["length"] = ctc["length"] + return data_out diff --git a/ocr/ppocr/data/imaug/make_border_map.py b/ocr/ppocr/data/imaug/make_border_map.py new file mode 100644 index 0000000000000000000000000000000000000000..175462db6bb56ef8d68b3a2eff78f0476ff59d25 --- /dev/null +++ b/ocr/ppocr/data/imaug/make_border_map.py @@ -0,0 +1,155 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import cv2 +import numpy as np + +np.seterr(divide="ignore", invalid="ignore") +import warnings + +import pyclipper +from shapely.geometry import Polygon + +warnings.simplefilter("ignore") + +__all__ = ["MakeBorderMap"] + + +class MakeBorderMap(object): + def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs): + self.shrink_ratio = shrink_ratio + self.thresh_min = thresh_min + self.thresh_max = thresh_max + + def __call__(self, data): + + img = data["image"] + text_polys = data["polys"] + ignore_tags = data["ignore_tags"] + + canvas = np.zeros(img.shape[:2], dtype=np.float32) + mask = np.zeros(img.shape[:2], dtype=np.float32) + + for i in range(len(text_polys)): + if ignore_tags[i]: + continue + self.draw_border_map(text_polys[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + + data["threshold_map"] = canvas + data["threshold_mask"] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + if polygon_shape.area <= 0: + return + distance = ( + polygon_shape.area + * (1 - np.power(self.shrink_ratio, 2)) + / polygon_shape.length + ) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), (height, width) + ) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width) + ) + + distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self._distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax( + 1 + - distance_map[ + ymin_valid - ymin : ymax_valid - ymax + height, + xmin_valid - xmin : xmax_valid - xmax + width, + ], + canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1], + ) + + def _distance(self, xs, ys, point_1, point_2): + """ + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + """ + height, width = xs.shape[:2] + square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) + square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) + square_distance = np.square(point_1[0] - point_2[0]) + np.square( + point_1[1] - point_2[1] + ) + + cosin = (square_distance - square_distance_1 - square_distance_2) / ( + 2 * np.sqrt(square_distance_1 * square_distance_2) + ) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt( + square_distance_1 * square_distance_2 * square_sin / square_distance + ) + + result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[ + cosin < 0 + ] + # self.extend_line(point_1, point_2, result) + return result + + def extend_line(self, point_1, point_2, result, shrink_ratio): + ex_point_1 = ( + int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))), + int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))), + ) + cv2.line( + result, + tuple(ex_point_1), + tuple(point_1), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0, + ) + ex_point_2 = ( + int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))), + int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))), + ) + cv2.line( + result, + tuple(ex_point_2), + tuple(point_2), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0, + ) + return ex_point_1, ex_point_2 diff --git a/ocr/ppocr/data/imaug/make_pse_gt.py b/ocr/ppocr/data/imaug/make_pse_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..7a873852dd51e9f60f3bbf4f277e737ad6f054d4 --- /dev/null +++ b/ocr/ppocr/data/imaug/make_pse_gt.py @@ -0,0 +1,88 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +__all__ = ["MakePseGt"] + + +class MakePseGt(object): + def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs): + self.kernel_num = kernel_num + self.min_shrink_ratio = min_shrink_ratio + self.size = size + + def __call__(self, data): + + image = data["image"] + text_polys = data["polys"] + ignore_tags = data["ignore_tags"] + + h, w, _ = image.shape + short_edge = min(h, w) + if short_edge < self.size: + # keep short_size >= self.size + scale = self.size / short_edge + image = cv2.resize(image, dsize=None, fx=scale, fy=scale) + text_polys *= scale + + gt_kernels = [] + for i in range(1, self.kernel_num + 1): + # s1->sn, from big to small + rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i + text_kernel, ignore_tags = self.generate_kernel( + image.shape[0:2], rate, text_polys, ignore_tags + ) + gt_kernels.append(text_kernel) + + training_mask = np.ones(image.shape[0:2], dtype="uint8") + for i in range(text_polys.shape[0]): + if ignore_tags[i]: + cv2.fillPoly( + training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0 + ) + + gt_kernels = np.array(gt_kernels) + gt_kernels[gt_kernels > 0] = 1 + + data["image"] = image + data["polys"] = text_polys + data["gt_kernels"] = gt_kernels[0:] + data["gt_text"] = gt_kernels[0] + data["mask"] = training_mask.astype("float32") + return data + + def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None): + """ + Refer to part of the code: + https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py + """ + + h, w = img_size + text_kernel = np.zeros((h, w), dtype=np.float32) + for i, poly in enumerate(text_polys): + polygon = Polygon(poly) + distance = ( + polygon.area + * (1 - shrink_ratio * shrink_ratio) + / (polygon.length + 1e-6) + ) + subject = [tuple(l) for l in poly] + pco = pyclipper.PyclipperOffset() + pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = np.array(pco.Execute(-distance)) + + if len(shrinked) == 0 or shrinked.size == 0: + if ignore_tags is not None: + ignore_tags[i] = True + continue + try: + shrinked = np.array(shrinked[0]).reshape(-1, 2) + except: + if ignore_tags is not None: + ignore_tags[i] = True + continue + cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1) + return text_kernel, ignore_tags diff --git a/ocr/ppocr/data/imaug/make_shrink_map.py b/ocr/ppocr/data/imaug/make_shrink_map.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8992515edb47398428b63dbb85aa6cecea4f5a --- /dev/null +++ b/ocr/ppocr/data/imaug/make_shrink_map.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +__all__ = ["MakeShrinkMap"] + + +class MakeShrinkMap(object): + r""" + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + """ + + def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs): + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + + def __call__(self, data): + image = data["image"] + text_polys = data["polys"] + ignore_tags = data["ignore_tags"] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = [] + + # Increase the shrink ratio every time we get multiple polygon returned back + possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio) + np.append(possible_ratios, 1) + # print(possible_ratios) + for ratio in possible_ratios: + # print(f"Change shrink ratio to {ratio}") + distance = ( + polygon_shape.area + * (1 - np.power(ratio, 2)) + / polygon_shape.length + ) + shrinked = padding.Execute(-distance) + if len(shrinked) == 1: + break + + if shrinked == []: + cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + + for each_shirnk in shrinked: + shirnk = np.array(each_shirnk).reshape(-1, 2) + cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1) + + data["shrink_map"] = gt + data["shrink_mask"] = mask + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + """ + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + """ + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + """ + compute polygon area + """ + area = 0 + q = polygon[-1] + for p in polygon: + area += p[0] * q[1] - p[1] * q[0] + q = p + return area / 2.0 diff --git a/ocr/ppocr/data/imaug/operators.py b/ocr/ppocr/data/imaug/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..e91a2b9d9dd01f3f600c3caacb3bbc032cd0048b --- /dev/null +++ b/ocr/ppocr/data/imaug/operators.py @@ -0,0 +1,458 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import math +import sys + +import cv2 +import numpy as np +import six + + +class DecodeImage(object): + """decode image""" + + def __init__( + self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs + ): + self.img_mode = img_mode + self.channel_first = channel_first + self.ignore_orientation = ignore_orientation + + def __call__(self, data): + img = data["image"] + if six.PY2: + assert ( + type(img) is str and len(img) > 0 + ), "invalid input 'img' in DecodeImage" + else: + assert ( + type(img) is bytes and len(img) > 0 + ), "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype="uint8") + if self.ignore_orientation: + img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR) + else: + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == "GRAY": + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == "RGB": + assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data["image"] = img + return data + + +class NRTRDecodeImage(object): + """decode image""" + + def __init__(self, img_mode="RGB", channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data["image"] + if six.PY2: + assert ( + type(img) is str and len(img) > 0 + ), "invalid input 'img' in DecodeImage" + else: + assert ( + type(img) is bytes and len(img) > 0 + ), "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype="uint8") + + img = cv2.imdecode(img, 1) + + if img is None: + return None + if self.img_mode == "GRAY": + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == "RGB": + assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape) + img = img[:, :, ::-1] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if self.channel_first: + img = img.transpose((2, 0, 1)) + data["image"] = img + return data + + +class NormalizeImage(object): + """normalize image such as substract mean, divide std""" + + def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") + + def __call__(self, data): + img = data["image"] + from PIL import Image + + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """convert hwc image to chw image""" + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data["image"] + from PIL import Image + + if isinstance(img, Image.Image): + img = np.array(img) + data["image"] = img.transpose((2, 0, 1)) + return data + + +class Fasttext(object): + def __init__(self, path="None", **kwargs): + import fasttext + + self.fast_model = fasttext.load_model(path) + + def __call__(self, data): + label = data["label"] + fast_label = self.fast_model[label] + data["fast_label"] = fast_label + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class Pad(object): + def __init__(self, size=None, size_div=32, **kwargs): + if size is not None and not isinstance(size, (int, list, tuple)): + raise TypeError( + "Type of target_size is invalid. Now is {}".format(type(size)) + ) + if isinstance(size, int): + size = [size, size] + self.size = size + self.size_div = size_div + + def __call__(self, data): + + img = data["image"] + img_h, img_w = img.shape[0], img.shape[1] + if self.size: + resize_h2, resize_w2 = self.size + assert ( + img_h < resize_h2 and img_w < resize_w2 + ), "(h, w) of target size should be greater than (img_h, img_w)" + else: + resize_h2 = max( + int(math.ceil(img.shape[0] / self.size_div) * self.size_div), + self.size_div, + ) + resize_w2 = max( + int(math.ceil(img.shape[1] / self.size_div) * self.size_div), + self.size_div, + ) + img = cv2.copyMakeBorder( + img, + 0, + resize_h2 - img_h, + 0, + resize_w2 - img_w, + cv2.BORDER_CONSTANT, + value=0, + ) + data["image"] = img + return data + + +class Resize(object): + def __init__(self, size=(640, 640), **kwargs): + self.size = size + + def resize_image(self, img): + resize_h, resize_w = self.size + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + return img, [ratio_h, ratio_w] + + def __call__(self, data): + img = data["image"] + if "polys" in data: + text_polys = data["polys"] + + img_resize, [ratio_h, ratio_w] = self.resize_image(img) + if "polys" in data: + new_boxes = [] + for box in text_polys: + new_box = [] + for cord in box: + new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) + new_boxes.append(new_box) + data["polys"] = np.array(new_boxes, dtype=np.float32) + data["image"] = img_resize + return data + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + if "image_shape" in kwargs: + self.image_shape = kwargs["image_shape"] + self.resize_type = 1 + elif "limit_side_len" in kwargs: + self.limit_side_len = kwargs["limit_side_len"] + self.limit_type = kwargs.get("limit_type", "min") + elif "resize_long" in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get("resize_long", 960) + else: + self.limit_side_len = 736 + self.limit_type = "min" + + def __call__(self, data): + img = data["image"] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + data["image"] = img + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == "max": + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1.0 + elif self.limit_type == "min": + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1.0 + elif self.limit_type == "resize_long": + ratio = float(limit_side_len) / max(h, w) + else: + raise Exception("not support limit type, image ") + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] + + +class E2EResizeForTest(object): + def __init__(self, **kwargs): + super(E2EResizeForTest, self).__init__() + self.max_side_len = kwargs["max_side_len"] + self.valid_set = kwargs["valid_set"] + + def __call__(self, data): + img = data["image"] + src_h, src_w, _ = img.shape + if self.valid_set == "totaltext": + im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext( + img, max_side_len=self.max_side_len + ) + else: + im_resized, (ratio_h, ratio_w) = self.resize_image( + img, max_side_len=self.max_side_len + ) + data["image"] = im_resized + data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_for_totaltext(self, im, max_side_len=512): + + h, w, _ = im.shape + resize_w = w + resize_h = h + ratio = 1.25 + if h * ratio > max_side_len: + ratio = float(max_side_len) / resize_h + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + return im, (ratio_h, ratio_w) + + def resize_image(self, im, max_side_len=512): + """ + resize image to a size multiple of max_stride which is required by the network + :param im: the resized image + :param max_side_len: limit of max image size to avoid out of memory in gpu + :return: the resized image and the resize ratio + """ + h, w, _ = im.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(max_side_len) / resize_h + else: + ratio = float(max_side_len) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(im, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return im, (ratio_h, ratio_w) + + +class KieResize(object): + def __init__(self, **kwargs): + super(KieResize, self).__init__() + self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1] + + def __call__(self, data): + img = data["image"] + points = data["points"] + src_h, src_w, _ = img.shape + ( + im_resized, + scale_factor, + [ratio_h, ratio_w], + [new_h, new_w], + ) = self.resize_image(img) + resize_points = self.resize_boxes(img, points, scale_factor) + data["ori_image"] = img + data["ori_boxes"] = points + data["points"] = resize_points + data["image"] = im_resized + data["shape"] = np.array([new_h, new_w]) + return data + + def resize_image(self, img): + norm_img = np.zeros([1024, 1024, 3], dtype="float32") + scale = [512, 1024] + h, w = img.shape[:2] + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) + resize_w, resize_h = int(w * float(scale_factor) + 0.5), int( + h * float(scale_factor) + 0.5 + ) + max_stride = 32 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + im = cv2.resize(img, (resize_w, resize_h)) + new_h, new_w = im.shape[:2] + w_scale = new_w / w + h_scale = new_h / h + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + norm_img[:new_h, :new_w, :] = im + return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w] + + def resize_boxes(self, im, points, scale_factor): + points = points * scale_factor + img_shape = im.shape[:2] + points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1]) + points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0]) + return points diff --git a/ocr/ppocr/data/imaug/pg_process.py b/ocr/ppocr/data/imaug/pg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..63721f1a0c10128607edd773679d2f4bf1dad304 --- /dev/null +++ b/ocr/ppocr/data/imaug/pg_process.py @@ -0,0 +1,961 @@ +import math + +import cv2 +import numpy as np + +__all__ = ["PGProcessTrain"] + + +class PGProcessTrain(object): + def __init__( + self, + character_dict_path, + max_text_length, + max_text_nums, + tcl_len, + batch_size=14, + min_crop_size=24, + min_text_size=4, + max_text_size=512, + **kwargs + ): + self.tcl_len = tcl_len + self.max_text_length = max_text_length + self.max_text_nums = max_text_nums + self.batch_size = batch_size + self.min_crop_size = min_crop_size + self.min_text_size = min_text_size + self.max_text_size = max_text_size + self.Lexicon_Table = self.get_dict(character_dict_path) + self.pad_num = len(self.Lexicon_Table) + self.img_id = 0 + + def get_dict(self, character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [ + (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]), + ] + return np.sum(edge) / 2.0 + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + rect = cv2.minAreaRect( + poly.astype(np.int32) + ) # (center (x,y), (width, height), angle of rotation) + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = ( + np.linalg.norm(box[(i + 0) % 4] - poly[0]) + + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + ) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, im_size): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = im_size + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print("invalid poly") + continue + if p_area > 0: + if tag == False: + print("poly in wrong direction") + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm( + quad[3] - quad[2] + ) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm( + quad[1] - quad[2] + ) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags) + + def crop_area( + self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25 + ): + """ + make random crop from the input image + :param im: + :param polys: [b,4,2] + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w : maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h : maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags, txts + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size: + continue + if polys.shape[0] != 0: + poly_axis_in_area = ( + (polys[:, :, 0] >= xmin) + & (polys[:, :, 0] <= xmax) + & (polys[:, :, 1] >= ymin) + & (polys[:, :, 1] <= ymax) + ) + selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + return ( + im[ymin : ymax + 1, xmin : xmax + 1, :], + polys[selected_polys], + tags[selected_polys], + hv_tags[selected_polys], + txts, + ) + else: + continue + im = im[ymin : ymax + 1, xmin : xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + txts_tmp = [] + for selected_poly in selected_polys: + txts_tmp.append(txts[selected_poly]) + txts = txts_tmp + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags, txts + + return im, polys, tags, hv_tags, txts + + def fit_and_gather_tcl_points_v2( + self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3, + ): + """ + Find the center point of poly as key_points, then fit and gather. + """ + key_point_xys = [] + point_num = poly.shape[0] + for idx in range(point_num // 2): + center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0 + key_point_xys.append(center_point) + + tmp_image = np.zeros( + shape=( + max_h, + max_w, + ), + dtype="float32", + ) + cv2.polylines(tmp_image, [np.array(key_point_xys).astype("int32")], False, 1.0) + ys, xs = np.where(tmp_image > 0) + xy_text = np.array(list(zip(xs, ys)), dtype="float32") + + left_center_pt = ((min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) + right_center_pt = ((min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2) + proj_unit_vec = (right_center_pt - left_center_pt) / ( + np.linalg.norm(right_center_pt - left_center_pt) + 1e-6 + ) + proj_unit_vec_tile = np.tile(proj_unit_vec, (xy_text.shape[0], 1)) # (n, 2) + left_center_pt_tile = np.tile(left_center_pt, (xy_text.shape[0], 1)) # (n, 2) + xy_text_to_left_center = xy_text - left_center_pt_tile + proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1) + xy_text = xy_text[np.argsort(proj_value)] + + # convert to np and keep the num of point not greater then fixed_point_num + pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + if np.random.rand() < 0.2 and reference_height >= 3: + dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3 + random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape([keep, 1]) + pos_info += random_float + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + + def generate_direction_map(self, poly_quads, n_char, direction_map): + """ """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = ( + np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]) + ) / 2.0 + quad_h = ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1]) + ) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / n_char, 1.0) + average_height = max(sum(height_list) / len(height_list), 1.0) + k = 1 + for quad in poly_quads: + direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = ( + direct_vector_full + / (np.linalg.norm(direct_vector_full) + 1e-6) + * norm_width + ) + direction_label = tuple( + map(float, [direct_vector[0], direct_vector[1], 1.0 / average_height]) + ) + cv2.fillPoly( + direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label, + ) + k += 1 + return direction_map + + def calculate_average_height(self, poly_quads): + """ """ + height_list = [] + for quad in poly_quads: + quad_h = ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1]) + ) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def generate_tcl_ctc_label( + self, + h, + w, + polys, + tags, + text_strs, + ds_ratio, + tcl_ratio=0.3, + shrink_ratio_of_width=0.15, + ): + """ + Generate polygon. + """ + score_map_big = np.zeros( + ( + h, + w, + ), + dtype=np.float32, + ) + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros( + ( + h, + w, + ), + dtype=np.float32, + ) + score_label_map = np.zeros( + ( + h, + w, + ), + dtype=np.float32, + ) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones( + ( + h, + w, + ), + dtype=np.float32, + ) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3] + ).astype(np.float32) + + label_idx = 0 + score_label_map_text_label_list = [] + pos_list, pos_mask, label_list = [], [], [] + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2]) + ) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3]) + ) + + if ( + min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio + ): + continue + + if tag: + cv2.fillPoly( + training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15 + ) + else: + text_label = text_strs[poly_idx] + text_label = self.prepare_text_label(text_label, self.Lexicon_Table) + + text_label_index_list = [ + [self.Lexicon_Table.index(c_)] + for c_ in text_label + if c_ in self.Lexicon_Table + ] + if len(text_label_index_list) < 1: + continue + + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio, + ) + + cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) + cv2.fillPoly( + score_map_big, np.round(stcl_quads / ds_ratio).astype(np.int32), 1.0 + ) + + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), + 1.0, + ) + tbo_map = self.gen_quad_tbo( + poly_quads[quad_index[idx]], quad_mask, tbo_map + ) + + # score label map and score_label_map_text_label_list for refine + if label_idx == 0: + text_pos_list_ = [ + [len(self.Lexicon_Table)], + ] + score_label_map_text_label_list.append(text_pos_list_) + + label_idx += 1 + cv2.fillPoly( + score_label_map, np.round(poly_quads).astype(np.int32), label_idx + ) + score_label_map_text_label_list.append(text_label_index_list) + + # direction info, fix-me + n_char = len(text_label_index_list) + direction_map = self.generate_direction_map( + poly_quads, n_char, direction_map + ) + + # pos info + average_shrink_height = self.calculate_average_height(stcl_quads) + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height, + ) + + label_l = text_label_index_list + if len(text_label_index_list) < 2: + continue + + pos_list.append(pos_l) + pos_mask.append(pos_m) + label_list.append(label_l) + + # use big score_map for smooth tcl lines + score_map_big_resized = cv2.resize( + score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio + ) + score_map = np.array(score_map_big_resized > 1e-3, dtype="float32") + + return ( + score_map, + score_label_map, + tbo_map, + direction_map, + training_mask, + pos_list, + pos_mask, + label_list, + score_label_map_text_label_list, + ) + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6 + ) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect( + poly.astype(np.int32) + ) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = ( + np.linalg.norm(box[(i + 0) % 4] - poly[0]) + + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + ) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32 + ) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width( + self, quads, shrink_ratio_of_width, expand_height_ratio=1.0 + ): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio + + shrink_length = ( + min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width + ) + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1 + ) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio + ) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]] + ) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def prepare_text_label(self, label_str, Lexicon_Table): + """ + Prepare text lablel by given Lexicon_Table. + """ + if len(Lexicon_Table) == 36: + return label_str.lower() + else: + return label_str + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + print("Cross point does not exist") + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = ( + poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair + ) + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) + ) + quad_w = 0.5 * ( + np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]) + ) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append( + (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]] + ) + + return np.array(quad_list) + + def rotate_im_poly(self, im, text_polys): + """ + rotate image with 90 / 180 / 270 degre + """ + im_w, im_h = im.shape[1], im.shape[0] + dst_im = im.copy() + dst_polys = [] + rand_degree_ratio = np.random.rand() + rand_degree_cnt = 1 + if rand_degree_ratio > 0.5: + rand_degree_cnt = 3 + for i in range(rand_degree_cnt): + dst_im = np.rot90(dst_im) + rot_degree = -90 * rand_degree_cnt + rot_angle = rot_degree * math.pi / 180.0 + n_poly = text_polys.shape[0] + cx, cy = 0.5 * im_w, 0.5 * im_h + ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0] + for i in range(n_poly): + wordBB = text_polys[i] + poly = [] + for j in range(4): # 16->4 + sx, sy = wordBB[j][0], wordBB[j][1] + dx = ( + math.cos(rot_angle) * (sx - cx) + - math.sin(rot_angle) * (sy - cy) + + ncx + ) + dy = ( + math.sin(rot_angle) * (sx - cx) + + math.cos(rot_angle) * (sy - cy) + + ncy + ) + poly.append([dx, dy]) + dst_polys.append(poly) + return dst_im, np.array(dst_polys, dtype=np.float32) + + def __call__(self, data): + input_size = 512 + im = data["image"] + text_polys = data["polys"] + text_tags = data["ignore_tags"] + text_strs = data["texts"] + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w) + ) + if text_polys.shape[0] <= 0: + return None + # set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, text_polys, text_tags, hv_tags, text_strs, crop_background=False + ) + + if text_polys.shape[0] == 0: + return None + # # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + # resize image + std_ratio = float(input_size) / max(new_w, new_h) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + # add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + # add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + # add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < input_size * 0.5: + return None + im_padded = np.ones((input_size, input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = input_size - new_h + del_w = input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + ( + score_map, + score_label_map, + border_map, + direction_map, + training_mask, + pos_list, + pos_mask, + label_list, + score_label_map_text_label, + ) = self.generate_tcl_ctc_label( + input_size, input_size, text_polys, text_tags, text_strs, 0.25 + ) + if len(label_list) <= 0: # eliminate negative samples + return None + pos_list_temp = np.zeros([64, 3]) + pos_mask_temp = np.zeros([64, 1]) + label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num + + for i, label in enumerate(label_list): + n = len(label) + if n > self.max_text_length: + label_list[i] = label[: self.max_text_length] + continue + while n < self.max_text_length: + label.append([self.pad_num]) + n += 1 + + for i in range(len(label_list)): + label_list[i] = np.array(label_list[i]) + + if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums: + return None + for __ in range(self.max_text_nums - len(pos_list), 0, -1): + pos_list.append(pos_list_temp) + pos_mask.append(pos_mask_temp) + label_list.append(label_list_temp) + + if self.img_id == self.batch_size - 1: + self.img_id = 0 + else: + self.img_id += 1 + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= 255.0 * 0.229 + im_padded[:, :, 1] /= 255.0 * 0.224 + im_padded[:, :, 0] /= 255.0 * 0.225 + im_padded = im_padded.transpose((2, 0, 1)) + images = im_padded[::-1, :, :] + tcl_maps = score_map[np.newaxis, :, :] + tcl_label_maps = score_label_map[np.newaxis, :, :] + border_maps = border_map.transpose((2, 0, 1)) + direction_maps = direction_map.transpose((2, 0, 1)) + training_masks = training_mask[np.newaxis, :, :] + pos_list = np.array(pos_list) + pos_mask = np.array(pos_mask) + label_list = np.array(label_list) + data["images"] = images + data["tcl_maps"] = tcl_maps + data["tcl_label_maps"] = tcl_label_maps + data["border_maps"] = border_maps + data["direction_maps"] = direction_maps + data["training_masks"] = training_masks + data["label_list"] = label_list + data["pos_list"] = pos_list + data["pos_mask"] = pos_mask + return data diff --git a/ocr/ppocr/data/imaug/randaugment.py b/ocr/ppocr/data/imaug/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..46f1df7e82f4f8b80cb62db308bcbac141266b63 --- /dev/null +++ b/ocr/ppocr/data/imaug/randaugment.py @@ -0,0 +1,125 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import random + +import numpy as np +import six +from PIL import Image, ImageEnhance, ImageOps + + +class RawRandAugment(object): + def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128), **kwargs): + self.num_layers = num_layers + self.magnitude = magnitude + self.max_level = 10 + + abso_level = self.magnitude / self.max_level + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 150.0 / 331 * abso_level, + "translateY": 150.0 / 331 * abso_level, + "rotate": 30 * abso_level, + "color": 0.9 * abso_level, + "posterize": int(4.0 * abso_level), + "solarize": 256.0 * abso_level, + "contrast": 0.9 * abso_level, + "sharpness": 0.9 * abso_level, + "brightness": 0.9 * abso_level, + "autocontrast": 0, + "equalize": 0, + "invert": 0, + } + + # from https://stackoverflow.com/questions/5252170/ + # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite( + rot, Image.new("RGBA", rot.size, (128,) * 4), rot + ).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), + fillcolor=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), + fillcolor=fillcolor, + ), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1]) + ), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1]) + ), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1]) + ), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + def __call__(self, img): + avaiable_op_names = list(self.level_map.keys()) + for layer_num in range(self.num_layers): + op_name = np.random.choice(avaiable_op_names) + img = self.func[op_name](img, self.level_map[op_name]) + return img + + +class RandAugment(RawRandAugment): + """RandAugment wrapper to auto fit different img types""" + + def __init__(self, prob=0.5, *args, **kwargs): + self.prob = prob + if six.PY2: + super(RandAugment, self).__init__(*args, **kwargs) + else: + super().__init__(*args, **kwargs) + + def __call__(self, data): + if np.random.rand() > self.prob: + return data + img = data["image"] + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + + if six.PY2: + img = super(RandAugment, self).__call__(img) + else: + img = super().__call__(img) + + if isinstance(img, Image.Image): + img = np.asarray(img) + data["image"] = img + return data diff --git a/ocr/ppocr/data/imaug/random_crop_data.py b/ocr/ppocr/data/imaug/random_crop_data.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0903947784978a4439c0e4e2a121ce87e85cf8 --- /dev/null +++ b/ocr/ppocr/data/imaug/random_crop_data.py @@ -0,0 +1,218 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import random + +import cv2 +import numpy as np + + +def is_poly_in_rect(poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + +def is_poly_outside_rect(poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + +def split_regions(axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + +def random_select(axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + +def region_wise_random_select(regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + +def crop_area(im, text_polys, min_crop_side_ratio, max_tries): + h, w, _ = im.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in text_polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = split_regions(h_axis) + w_regions = split_regions(w_axis) + + for i in range(max_tries): + if len(w_regions) > 1: + xmin, xmax = region_wise_random_select(w_regions, w) + else: + xmin, xmax = random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = region_wise_random_select(h_regions, h) + else: + ymin, ymax = random_select(h_axis, h) + + if ( + xmax - xmin < min_crop_side_ratio * w + or ymax - ymin < min_crop_side_ratio * h + ): + # area too small + continue + num_poly_in_rect = 0 + for poly in text_polys: + if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h + + +class EastRandomCropData(object): + def __init__( + self, + size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1, + keep_ratio=True, + **kwargs + ): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.keep_ratio = keep_ratio + + def __call__(self, data): + img = data["image"] + text_polys = data["polys"] + ignore_tags = data["ignore_tags"] + texts = data["texts"] + all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = crop_area( + img, all_care_polys, self.min_crop_side_ratio, self.max_tries + ) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], (w, h) + ) + img = padimg + else: + img = cv2.resize( + img[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], + tuple(self.size), + ) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + for poly, text, tag in zip(text_polys, texts, ignore_tags): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + data["image"] = img + data["polys"] = np.array(text_polys_crop) + data["ignore_tags"] = ignore_tags_crop + data["texts"] = texts_crop + return data + + +class RandomCropImgMask(object): + def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs): + self.size = size + self.main_key = main_key + self.crop_keys = crop_keys + self.p = p + + def __call__(self, data): + image = data["image"] + + h, w = image.shape[0:2] + th, tw = self.size + if w == tw and h == th: + return data + + mask = data[self.main_key] + if np.max(mask) > 0 and random.random() > self.p: + # make sure to crop the text region + tl = np.min(np.where(mask > 0), axis=1) - (th, tw) + tl[tl < 0] = 0 + br = np.max(np.where(mask > 0), axis=1) - (th, tw) + br[br < 0] = 0 + + br[0] = min(br[0], h - th) + br[1] = min(br[1], w - tw) + + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + i = random.randint(0, h - th) if h - th > 0 else 0 + j = random.randint(0, w - tw) if w - tw > 0 else 0 + + # return i, j, th, tw + for k in data: + if k in self.crop_keys: + if len(data[k].shape) == 3: + if np.argmin(data[k].shape) == 0: + img = data[k][:, i : i + th, j : j + tw] + if img.shape[1] != img.shape[2]: + a = 1 + elif np.argmin(data[k].shape) == 2: + img = data[k][i : i + th, j : j + tw, :] + if img.shape[1] != img.shape[0]: + a = 1 + else: + img = data[k] + else: + img = data[k][i : i + th, j : j + tw] + if img.shape[0] != img.shape[1]: + a = 1 + data[k] = img + return data diff --git a/ocr/ppocr/data/imaug/rec_img_aug.py b/ocr/ppocr/data/imaug/rec_img_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..fea4f95c3b501a8cbcdebe2dc3d91c02766a8b62 --- /dev/null +++ b/ocr/ppocr/data/imaug/rec_img_aug.py @@ -0,0 +1,631 @@ +import math +import random + +import cv2 +import numpy as np +from PIL import Image + +from .text_image_aug import tia_distort, tia_perspective, tia_stretch + + +class RecAug(object): + def __init__(self, use_tia=True, aug_prob=0.4, **kwargs): + self.use_tia = use_tia + self.aug_prob = aug_prob + + def __call__(self, data): + img = data["image"] + img = warp(img, 10, self.use_tia, self.aug_prob) + data["image"] = img + return data + + +class RecConAug(object): + def __init__( + self, + prob=0.5, + image_shape=(32, 320, 3), + max_text_length=25, + ext_data_num=1, + **kwargs + ): + self.ext_data_num = ext_data_num + self.prob = prob + self.max_text_length = max_text_length + self.image_shape = image_shape + self.max_wh_ratio = self.image_shape[1] / self.image_shape[0] + + def merge_ext_data(self, data, ext_data): + ori_w = round( + data["image"].shape[1] / data["image"].shape[0] * self.image_shape[0] + ) + ext_w = round( + ext_data["image"].shape[1] + / ext_data["image"].shape[0] + * self.image_shape[0] + ) + data["image"] = cv2.resize(data["image"], (ori_w, self.image_shape[0])) + ext_data["image"] = cv2.resize(ext_data["image"], (ext_w, self.image_shape[0])) + data["image"] = np.concatenate([data["image"], ext_data["image"]], axis=1) + data["label"] += ext_data["label"] + return data + + def __call__(self, data): + rnd_num = random.random() + if rnd_num > self.prob: + return data + for idx, ext_data in enumerate(data["ext_data"]): + if len(data["label"]) + len(ext_data["label"]) > self.max_text_length: + break + concat_ratio = ( + data["image"].shape[1] / data["image"].shape[0] + + ext_data["image"].shape[1] / ext_data["image"].shape[0] + ) + if concat_ratio > self.max_wh_ratio: + break + data = self.merge_ext_data(data, ext_data) + data.pop("ext_data") + return data + + +class ClsResizeImg(object): + def __init__(self, image_shape, **kwargs): + self.image_shape = image_shape + + def __call__(self, data): + img = data["image"] + norm_img, _ = resize_norm_img(img, self.image_shape) + data["image"] = norm_img + return data + + +class NRTRRecResizeImg(object): + def __init__(self, image_shape, resize_type, padding=False, **kwargs): + self.image_shape = image_shape + self.resize_type = resize_type + self.padding = padding + + def __call__(self, data): + img = data["image"] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + image_shape = self.image_shape + if self.padding: + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + norm_img = np.expand_dims(resized_image, -1) + norm_img = norm_img.transpose((2, 0, 1)) + resized_image = norm_img.astype(np.float32) / 128.0 - 1.0 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + data["image"] = padding_im + return data + if self.resize_type == "PIL": + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize(self.image_shape, Image.ANTIALIAS) + img = np.array(img) + if self.resize_type == "OpenCV": + img = cv2.resize(img, self.image_shape) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + data["image"] = norm_img.astype(np.float32) / 128.0 - 1.0 + return data + + +class RecResizeImg(object): + def __init__( + self, + image_shape, + infer_mode=False, + character_dict_path="./ppocr/utils/ppocr_keys_v1.txt", + padding=True, + **kwargs + ): + self.image_shape = image_shape + self.infer_mode = infer_mode + self.character_dict_path = character_dict_path + self.padding = padding + + def __call__(self, data): + img = data["image"] + if self.infer_mode and self.character_dict_path is not None: + norm_img, valid_ratio = resize_norm_img_chinese(img, self.image_shape) + else: + norm_img, valid_ratio = resize_norm_img(img, self.image_shape, self.padding) + data["image"] = norm_img + data["valid_ratio"] = valid_ratio + return data + + +class SRNRecResizeImg(object): + def __init__(self, image_shape, num_heads, max_text_length, **kwargs): + self.image_shape = image_shape + self.num_heads = num_heads + self.max_text_length = max_text_length + + def __call__(self, data): + img = data["image"] + norm_img = resize_norm_img_srn(img, self.image_shape) + data["image"] = norm_img + [ + encoder_word_pos, + gsrm_word_pos, + gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2, + ] = srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length) + + data["encoder_word_pos"] = encoder_word_pos + data["gsrm_word_pos"] = gsrm_word_pos + data["gsrm_slf_attn_bias1"] = gsrm_slf_attn_bias1 + data["gsrm_slf_attn_bias2"] = gsrm_slf_attn_bias2 + return data + + +class SARRecResizeImg(object): + def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs): + self.image_shape = image_shape + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, data): + img = data["image"] + norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar( + img, self.image_shape, self.width_downsample_ratio + ) + data["image"] = norm_img + data["resized_shape"] = resize_shape + data["pad_shape"] = pad_shape + data["valid_ratio"] = valid_ratio + return data + + +class PRENResizeImg(object): + def __init__(self, image_shape, **kwargs): + """ + Accroding to original paper's realization, it's a hard resize method here. + So maybe you should optimize it to fit for your task better. + """ + self.dst_h, self.dst_w = image_shape + + def __call__(self, data): + img = data["image"] + resized_img = cv2.resize( + img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR + ) + resized_img = resized_img.transpose((2, 0, 1)) / 255 + resized_img -= 0.5 + resized_img /= 0.5 + data["image"] = resized_img.astype(np.float32) + return data + + +def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype("float32") + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + +def resize_norm_img(img, image_shape, padding=True): + imgC, imgH, imgW = image_shape + h = img.shape[0] + w = img.shape[1] + if not padding: + resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_w = imgW + else: + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype("float32") + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio + + +def resize_norm_img_chinese(img, image_shape): + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + max_wh_ratio = imgW * 1.0 / imgH + h, w = img.shape[0], img.shape[1] + ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, ratio) + imgW = int(imgH * max_wh_ratio) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype("float32") + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + valid_ratio = min(1.0, float(resized_w / imgW)) + return padding_im, valid_ratio + + +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0 : img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + +def srn_other_inputs(image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = ( + np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64") + ) + gsrm_word_pos = ( + np.array(range(0, max_text_length)) + .reshape((max_text_length, 1)) + .astype("int64") + ) + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [1, max_text_length, max_text_length] + ) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [1, max_text_length, max_text_length] + ) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [num_heads, 1, 1]) * [-1e9] + + return [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] + + +def flag(): + """ + flag + """ + return 1 if random.random() > 0.5000001 else -1 + + +def cvtColor(img): + """ + cvtColor + """ + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + delta = 0.001 * random.random() * flag() + hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) + new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + return new_img + + +def blur(img): + """ + blur + """ + h, w, _ = img.shape + if h > 10 and w > 10: + return cv2.GaussianBlur(img, (5, 5), 1) + else: + return img + + +def jitter(img): + """ + jitter + """ + w, h, _ = img.shape + if h > 10 and w > 10: + thres = min(w, h) + s = int(random.random() * thres * 0.01) + src_img = img.copy() + for i in range(s): + img[i:, i:, :] = src_img[: w - i, : h - i, :] + return img + else: + return img + + +def add_gasuss_noise(image, mean=0, var=0.1): + """ + Gasuss noise + """ + + noise = np.random.normal(mean, var**0.5, image.shape) + out = image + 0.5 * noise + out = np.clip(out, 0, 255) + out = np.uint8(out) + return out + + +def get_crop(image): + """ + random crop + """ + h, w, _ = image.shape + top_min = 1 + top_max = 8 + top_crop = int(random.randint(top_min, top_max)) + top_crop = min(top_crop, h - 1) + crop_img = image.copy() + ratio = random.randint(0, 1) + if ratio: + crop_img = crop_img[top_crop:h, :, :] + else: + crop_img = crop_img[0 : h - top_crop, :, :] + return crop_img + + +class Config: + """ + Config + """ + + def __init__(self, use_tia): + self.anglex = random.random() * 30 + self.angley = random.random() * 15 + self.anglez = random.random() * 10 + self.fov = 42 + self.r = 0 + self.shearx = random.random() * 0.3 + self.sheary = random.random() * 0.05 + self.borderMode = cv2.BORDER_REPLICATE + self.use_tia = use_tia + + def make(self, w, h, ang): + """ + make + """ + self.anglex = random.random() * 5 * flag() + self.angley = random.random() * 5 * flag() + self.anglez = -1 * random.random() * int(ang) * flag() + self.fov = 42 + self.r = 0 + self.shearx = 0 + self.sheary = 0 + self.borderMode = cv2.BORDER_REPLICATE + self.w = w + self.h = h + + self.perspective = self.use_tia + self.stretch = self.use_tia + self.distort = self.use_tia + + self.crop = True + self.affine = False + self.reverse = True + self.noise = True + self.jitter = True + self.blur = True + self.color = True + + +def rad(x): + """ + rad + """ + return x * np.pi / 180 + + +def get_warpR(config): + """ + get_warpR + """ + anglex, angley, anglez, fov, w, h, r = ( + config.anglex, + config.angley, + config.anglez, + config.fov, + config.w, + config.h, + config.r, + ) + if w > 69 and w < 112: + anglex = anglex * 1.5 + + z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2)) + # Homogeneous coordinate transformation matrix + rx = np.array( + [ + [1, 0, 0, 0], + [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], + [ + 0, + -np.sin(rad(anglex)), + np.cos(rad(anglex)), + 0, + ], + [0, 0, 0, 1], + ], + np.float32, + ) + ry = np.array( + [ + [np.cos(rad(angley)), 0, np.sin(rad(angley)), 0], + [0, 1, 0, 0], + [ + -np.sin(rad(angley)), + 0, + np.cos(rad(angley)), + 0, + ], + [0, 0, 0, 1], + ], + np.float32, + ) + rz = np.array( + [ + [np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0], + [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + np.float32, + ) + r = rx.dot(ry).dot(rz) + # generate 4 points + pcenter = np.array([h / 2, w / 2, 0, 0], np.float32) + p1 = np.array([0, 0, 0, 0], np.float32) - pcenter + p2 = np.array([w, 0, 0, 0], np.float32) - pcenter + p3 = np.array([0, h, 0, 0], np.float32) - pcenter + p4 = np.array([w, h, 0, 0], np.float32) - pcenter + dst1 = r.dot(p1) + dst2 = r.dot(p2) + dst3 = r.dot(p3) + dst4 = r.dot(p4) + list_dst = np.array([dst1, dst2, dst3, dst4]) + org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32) + dst = np.zeros((4, 2), np.float32) + # Project onto the image plane + dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0] + dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1] + + warpR = cv2.getPerspectiveTransform(org, dst) + + dst1, dst2, dst3, dst4 = dst + r1 = int(min(dst1[1], dst2[1])) + r2 = int(max(dst3[1], dst4[1])) + c1 = int(min(dst1[0], dst3[0])) + c2 = int(max(dst2[0], dst4[0])) + + try: + ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1)) + + dx = -c1 + dy = -r1 + T1 = np.float32([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0 / ratio]]) + ret = T1.dot(warpR) + except: + ratio = 1.0 + T1 = np.float32([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]]) + ret = T1 + return ret, (-r1, -c1), ratio, dst + + +def get_warpAffine(config): + """ + get_warpAffine + """ + anglez = config.anglez + rz = np.array( + [ + [np.cos(rad(anglez)), np.sin(rad(anglez)), 0], + [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0], + ], + np.float32, + ) + return rz + + +def warp(img, ang, use_tia=True, prob=0.4): + """ + warp + """ + h, w, _ = img.shape + config = Config(use_tia=use_tia) + config.make(w, h, ang) + new_img = img + + if config.distort: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_distort(new_img, random.randint(3, 6)) + + if config.stretch: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_stretch(new_img, random.randint(3, 6)) + + if config.perspective: + if random.random() <= prob: + new_img = tia_perspective(new_img) + + if config.crop: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = get_crop(new_img) + + if config.blur: + if random.random() <= prob: + new_img = blur(new_img) + if config.color: + if random.random() <= prob: + new_img = cvtColor(new_img) + if config.jitter: + new_img = jitter(new_img) + if config.noise: + if random.random() <= prob: + new_img = add_gasuss_noise(new_img) + if config.reverse: + if random.random() <= prob: + new_img = 255 - new_img + return new_img diff --git a/ocr/ppocr/data/imaug/sast_process.py b/ocr/ppocr/data/imaug/sast_process.py new file mode 100644 index 0000000000000000000000000000000000000000..939752402d4adbf5b85dd2a3ecb7f2b3cf1ecd7e --- /dev/null +++ b/ocr/ppocr/data/imaug/sast_process.py @@ -0,0 +1,792 @@ +import math + +import cv2 +import numpy as np + +__all__ = ["SASTProcessTrain"] + + +class SASTProcessTrain(object): + def __init__( + self, + image_shape=[512, 512], + min_crop_size=24, + min_crop_side_ratio=0.3, + min_text_size=10, + max_text_size=512, + **kwargs + ): + self.input_size = image_shape[1] + self.min_crop_size = min_crop_size + self.min_crop_side_ratio = min_crop_side_ratio + self.min_text_size = min_text_size + self.max_text_size = max_text_size + + def quad_area(self, poly): + """ + compute area of a polygon + :param poly: + :return: + """ + edge = [ + (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]), + ] + return np.sum(edge) / 2.0 + + def gen_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if True: + rect = cv2.minAreaRect( + poly.astype(np.int32) + ) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = ( + np.linalg.norm(box[(i + 0) % 4] - poly[0]) + + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + ) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad + + def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): + """ + check so that the text poly is in the same direction, + and also filter some invalid polygons + :param polys: + :param tags: + :return: + """ + (h, w) = xxx_todo_changeme + if polys.shape[0] == 0: + return polys, np.array([]), np.array([]) + polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) + polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) + + validated_polys = [] + validated_tags = [] + hv_tags = [] + for poly, tag in zip(polys, tags): + quad = self.gen_quad_from_poly(poly) + p_area = self.quad_area(quad) + if abs(p_area) < 1: + print("invalid poly") + continue + if p_area > 0: + if tag == False: + print("poly in wrong direction") + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :] + quad = quad[(0, 3, 2, 1), :] + + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm( + quad[3] - quad[2] + ) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm( + quad[1] - quad[2] + ) + hv_tag = 1 + + if len_w * 2.0 < len_h: + hv_tag = 0 + + validated_polys.append(poly) + validated_tags.append(tag) + hv_tags.append(hv_tag) + return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags) + + def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25): + """ + make random crop from the input image + :param im: + :param polys: + :param tags: + :param crop_background: + :param max_tries: 50 -> 25 + :return: + """ + h, w, _ = im.shape + pad_h = h // 10 + pad_w = w // 10 + h_array = np.zeros((h + pad_h * 2), dtype=np.int32) + w_array = np.zeros((w + pad_w * 2), dtype=np.int32) + for poly in polys: + poly = np.round(poly, decimals=0).astype(np.int32) + minx = np.min(poly[:, 0]) + maxx = np.max(poly[:, 0]) + w_array[minx + pad_w : maxx + pad_w] = 1 + miny = np.min(poly[:, 1]) + maxy = np.max(poly[:, 1]) + h_array[miny + pad_h : maxy + pad_h] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + if len(h_axis) == 0 or len(w_axis) == 0: + return im, polys, tags, hv_tags + for i in range(max_tries): + xx = np.random.choice(w_axis, size=2) + xmin = np.min(xx) - pad_w + xmax = np.max(xx) - pad_w + xmin = np.clip(xmin, 0, w - 1) + xmax = np.clip(xmax, 0, w - 1) + yy = np.random.choice(h_axis, size=2) + ymin = np.min(yy) - pad_h + ymax = np.max(yy) - pad_h + ymin = np.clip(ymin, 0, h - 1) + ymax = np.clip(ymax, 0, h - 1) + # if xmax - xmin < ARGS.min_crop_side_ratio * w or \ + # ymax - ymin < ARGS.min_crop_side_ratio * h: + if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size: + # area too small + continue + if polys.shape[0] != 0: + poly_axis_in_area = ( + (polys[:, :, 0] >= xmin) + & (polys[:, :, 0] <= xmax) + & (polys[:, :, 1] >= ymin) + & (polys[:, :, 1] <= ymax) + ) + selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + else: + selected_polys = [] + if len(selected_polys) == 0: + # no text in this area + if crop_background: + return ( + im[ymin : ymax + 1, xmin : xmax + 1, :], + polys[selected_polys], + tags[selected_polys], + hv_tags[selected_polys], + ) + else: + continue + im = im[ymin : ymax + 1, xmin : xmax + 1, :] + polys = polys[selected_polys] + tags = tags[selected_polys] + hv_tags = hv_tags[selected_polys] + polys[:, :, 0] -= xmin + polys[:, :, 1] -= ymin + return im, polys, tags, hv_tags + + return im, polys, tags, hv_tags + + def generate_direction_map(self, poly_quads, direction_map): + """ """ + width_list = [] + height_list = [] + for quad in poly_quads: + quad_w = ( + np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]) + ) / 2.0 + quad_h = ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1]) + ) / 2.0 + width_list.append(quad_w) + height_list.append(quad_h) + norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0) + average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0) + + for quad in poly_quads: + direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = ( + direct_vector_full + / (np.linalg.norm(direct_vector_full) + 1e-6) + * norm_width + ) + direction_label = tuple( + map( + float, + [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)], + ) + ) + cv2.fillPoly( + direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label, + ) + return direction_map + + def calculate_average_height(self, poly_quads): + """ """ + height_list = [] + for quad in poly_quads: + quad_h = ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1]) + ) / 2.0 + height_list.append(quad_h) + average_height = max(sum(height_list) / len(height_list), 1.0) + return average_height + + def generate_tcl_label( + self, hw, polys, tags, ds_ratio, tcl_ratio=0.3, shrink_ratio_of_width=0.15 + ): + """ + Generate polygon. + """ + h, w = hw + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + + score_map = np.zeros( + ( + h, + w, + ), + dtype=np.float32, + ) + tbo_map = np.zeros((h, w, 5), dtype=np.float32) + training_mask = np.ones( + ( + h, + w, + ), + dtype=np.float32, + ) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3] + ).astype(np.float32) + + for poly_idx, poly_tag in enumerate(zip(polys, tags)): + poly = poly_tag[0] + tag = poly_tag[1] + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2]) + ) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3]) + ) + + if ( + min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio + or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio + ): + continue + + if tag: + # continue + cv2.fillPoly( + training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15 + ) + else: + tcl_poly = self.poly2tcl(poly, tcl_ratio) + tcl_quads = self.poly2quads(tcl_poly) + poly_quads = self.poly2quads(poly) + # stcl map + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio, + ) + # generate tcl map + cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) + + # generate tbo map + for idx, quad in enumerate(stcl_quads): + quad_mask = np.zeros((h, w), dtype=np.float32) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), + 1.0, + ) + tbo_map = self.gen_quad_tbo( + poly_quads[quad_index[idx]], quad_mask, tbo_map + ) + return score_map, tbo_map, training_mask + + def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25): + """ + Generate tcl map, tvo map and tbo map. + """ + h, w = hw + h, w = int(h * ds_ratio), int(w * ds_ratio) + polys = polys * ds_ratio + poly_mask = np.zeros((h, w), dtype=np.float32) + + tvo_map = np.ones((9, h, w), dtype=np.float32) + tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1)) + tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T + poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32) + + # tco map + tco_map = np.ones((3, h, w), dtype=np.float32) + tco_map[0] = np.tile(np.arange(0, w), (h, 1)) + tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T + poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32) + + poly_short_edge_map = np.ones((h, w), dtype=np.float32) + + for poly, poly_tag in zip(polys, tags): + + if poly_tag == True: + continue + + # adjust point order for vertical poly + poly = self.adjust_point(poly) + + # generate min_area_quad + min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2]) + ) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3]) + ) + + # generate tcl map and text, 128 * 128 + tcl_poly = self.poly2tcl(poly, tcl_ratio) + + # generate poly_tv_xy_map + for idx in range(4): + cv2.fillPoly( + poly_tv_xy_map[2 * idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 0], 0), w)), + ) + cv2.fillPoly( + poly_tv_xy_map[2 * idx + 1], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 1], 0), h)), + ) + + # generate poly_tc_xy_map + for idx in range(2): + cv2.fillPoly( + poly_tc_xy_map[idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(center_point[idx]), + ) + + # generate poly_short_edge_map + cv2.fillPoly( + poly_short_edge_map, + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(max(min(min_area_quad_h, min_area_quad_w), 1.0)), + ) + + # generate poly_mask and training_mask + cv2.fillPoly( + poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1 + ) + + tvo_map *= poly_mask + tvo_map[:8] -= poly_tv_xy_map + tvo_map[-1] /= poly_short_edge_map + tvo_map = tvo_map.transpose((1, 2, 0)) + + tco_map *= poly_mask + tco_map[:2] -= poly_tc_xy_map + tco_map[-1] /= poly_short_edge_map + tco_map = tco_map.transpose((1, 2, 0)) + + return tvo_map, tco_map + + def adjust_point(self, poly): + """ + adjust point order. + """ + point_num = poly.shape[0] + if point_num == 4: + len_1 = np.linalg.norm(poly[0] - poly[1]) + len_2 = np.linalg.norm(poly[1] - poly[2]) + len_3 = np.linalg.norm(poly[2] - poly[3]) + len_4 = np.linalg.norm(poly[3] - poly[0]) + + if (len_1 + len_3) * 1.5 < (len_2 + len_4): + poly = poly[[1, 2, 3, 0], :] + + elif point_num > 4: + vector_1 = poly[0] - poly[1] + vector_2 = poly[1] - poly[2] + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6 + ) + theta = np.arccos(np.round(cos_theta, decimals=4)) + + if abs(theta) > (70 / 180 * math.pi): + index = list(range(1, point_num)) + [0] + poly = poly[np.array(index), :] + return poly + + def gen_min_area_quad_from_poly(self, poly): + """ + Generate min area quad from poly. + """ + point_num = poly.shape[0] + min_area_quad = np.zeros((4, 2), dtype=np.float32) + if point_num == 4: + min_area_quad = poly + center_point = np.sum(poly, axis=0) / 4 + else: + rect = cv2.minAreaRect( + poly.astype(np.int32) + ) # (center (x,y), (width, height), angle of rotation) + center_point = rect[0] + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = ( + np.linalg.norm(box[(i + 0) % 4] - poly[0]) + + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + ) + if dist < min_dist: + min_dist = dist + first_point_idx = i + + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] + + return min_area_quad, center_point + + def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0): + """ + Generate shrink_quad_along_width. + """ + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32 + ) + p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair + p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair + return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) + + def shrink_poly_along_width( + self, quads, shrink_ratio_of_width, expand_height_ratio=1.0 + ): + """ + shrink poly with given length. + """ + upper_edge_list = [] + + def get_cut_info(edge_len_list, cut_len): + for idx, edge_len in enumerate(edge_len_list): + cut_len -= edge_len + if cut_len <= 0.000001: + ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx] + return idx, ratio + + for quad in quads: + upper_edge_len = np.linalg.norm(quad[0] - quad[1]) + upper_edge_list.append(upper_edge_len) + + # length of left edge and right edge. + left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio + + shrink_length = ( + min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width + ) + # shrinking length + upper_len_left = shrink_length + upper_len_right = sum(upper_edge_list) - shrink_length + + left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1 + ) + right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio + ) + + out_quad_list = [] + if left_idx == right_idx: + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]] + ) + else: + out_quad_list.append(left_quad) + for idx in range(left_idx + 1, right_idx): + out_quad_list.append(quads[idx]) + out_quad_list.append(right_quad) + + return np.array(out_quad_list), list(range(left_idx, right_idx + 1)) + + def vector_angle(self, A, B): + """ + Calculate the angle between vector AB and x-axis positive direction. + """ + AB = np.array([B[1] - A[1], B[0] - A[0]]) + return np.arctan2(*AB) + + def theta_line_cross_point(self, theta, point): + """ + Calculate the line through given point and angle in ax + by + c =0 form. + """ + x, y = point + cos = np.cos(theta) + sin = np.sin(theta) + return [sin, -cos, cos * y - sin * x] + + def line_cross_two_point(self, A, B): + """ + Calculate the line through given point A and B in ax + by + c =0 form. + """ + angle = self.vector_angle(A, B) + return self.theta_line_cross_point(angle, A) + + def average_angle(self, poly): + """ + Calculate the average angle between left and right edge in given poly. + """ + p0, p1, p2, p3 = poly + angle30 = self.vector_angle(p3, p0) + angle21 = self.vector_angle(p2, p1) + return (angle30 + angle21) / 2 + + def line_cross_point(self, line1, line2): + """ + line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2 + """ + a1, b1, c1 = line1 + a2, b2, c2 = line2 + d = a1 * b2 - a2 * b1 + + if d == 0: + # print("line1", line1) + # print("line2", line2) + print("Cross point does not exist") + return np.array([0, 0], dtype=np.float32) + else: + x = (b1 * c2 - b2 * c1) / d + y = (a2 * c1 - a1 * c2) / d + + return np.array([x, y], dtype=np.float32) + + def quad2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. (4, 2) + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair + p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair + return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) + + def poly2tcl(self, poly, ratio): + """ + Generate center line by poly clock-wise point. + """ + ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + tcl_poly = np.zeros_like(poly) + point_num = poly.shape[0] + + for idx in range(point_num // 2): + point_pair = ( + poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair + ) + tcl_poly[idx] = point_pair[0] + tcl_poly[point_num - 1 - idx] = point_pair[1] + return tcl_poly + + def gen_quad_tbo(self, quad, tcl_mask, tbo_map): + """ + Generate tbo_map for give quad. + """ + # upper and lower line function: ax + by + c = 0; + up_line = self.line_cross_two_point(quad[0], quad[1]) + lower_line = self.line_cross_two_point(quad[3], quad[2]) + + quad_h = 0.5 * ( + np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) + ) + quad_w = 0.5 * ( + np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]) + ) + + # average angle of left and right line. + angle = self.average_angle(quad) + + xy_in_poly = np.argwhere(tcl_mask == 1) + for y, x in xy_in_poly: + point = (x, y) + line = self.theta_line_cross_point(angle, point) + cross_point_upper = self.line_cross_point(up_line, line) + cross_point_lower = self.line_cross_point(lower_line, line) + ##FIX, offset reverse + upper_offset_x, upper_offset_y = cross_point_upper - point + lower_offset_x, lower_offset_y = cross_point_lower - point + tbo_map[y, x, 0] = upper_offset_y + tbo_map[y, x, 1] = upper_offset_x + tbo_map[y, x, 2] = lower_offset_y + tbo_map[y, x, 3] = lower_offset_x + tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2 + return tbo_map + + def poly2quads(self, poly): + """ + Split poly into quads. + """ + quad_list = [] + point_num = poly.shape[0] + + # point pair + point_pair_list = [] + for idx in range(point_num // 2): + point_pair = [poly[idx], poly[point_num - 1 - idx]] + point_pair_list.append(point_pair) + + quad_num = point_num // 2 - 1 + for idx in range(quad_num): + # reshape and adjust to clock-wise + quad_list.append( + (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]] + ) + + return np.array(quad_list) + + def __call__(self, data): + im = data["image"] + text_polys = data["polys"] + text_tags = data["ignore_tags"] + if im is None: + return None + if text_polys.shape[0] == 0: + return None + + h, w, _ = im.shape + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w) + ) + + if text_polys.shape[0] == 0: + return None + + # set aspect ratio and keep area fix + asp_scales = np.arange(1.0, 1.55, 0.1) + asp_scale = np.random.choice(asp_scales) + + if np.random.rand() < 0.5: + asp_scale = 1.0 / asp_scale + asp_scale = math.sqrt(asp_scale) + + asp_wx = asp_scale + asp_hy = 1.0 / asp_scale + im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) + text_polys[:, :, 0] *= asp_wx + text_polys[:, :, 1] *= asp_hy + + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags = self.crop_area( + im, text_polys, text_tags, hv_tags, crop_background=False + ) + + if text_polys.shape[0] == 0: + return None + # continue for all ignore case + if np.sum((text_tags * 1.0)) >= text_tags.size: + return None + new_h, new_w, _ = im.shape + if (new_h is None) or (new_w is None): + return None + # resize image + std_ratio = float(self.input_size) / max(new_w, new_h) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + rz_scale = std_ratio * np.random.choice(rand_scales) + im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) + text_polys[:, :, 0] *= rz_scale + text_polys[:, :, 1] *= rz_scale + + # add gaussian blur + if np.random.rand() < 0.1 * 0.5: + ks = np.random.permutation(5)[0] + 1 + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + # add brighter + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 + np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + # add darker + if np.random.rand() < 0.1 * 0.5: + im = im * (1.0 - np.random.rand() * 0.5) + im = np.clip(im, 0.0, 255.0) + + # Padding the im to [input_size, input_size] + new_h, new_w, _ = im.shape + if min(new_w, new_h) < self.input_size * 0.5: + return None + + im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32) + im_padded[:, :, 2] = 0.485 * 255 + im_padded[:, :, 1] = 0.456 * 255 + im_padded[:, :, 0] = 0.406 * 255 + + # Random the start position + del_h = self.input_size - new_h + del_w = self.input_size - new_w + sh, sw = 0, 0 + if del_h > 1: + sh = int(np.random.rand() * del_h) + if del_w > 1: + sw = int(np.random.rand() * del_w) + + # Padding + im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy() + text_polys[:, :, 0] += sw + text_polys[:, :, 1] += sh + + score_map, border_map, training_mask = self.generate_tcl_label( + (self.input_size, self.input_size), text_polys, text_tags, 0.25 + ) + + # SAST head + tvo_map, tco_map = self.generate_tvo_and_tco( + (self.input_size, self.input_size), + text_polys, + text_tags, + tcl_ratio=0.3, + ds_ratio=0.25, + ) + # print("test--------tvo_map shape:", tvo_map.shape) + + im_padded[:, :, 2] -= 0.485 * 255 + im_padded[:, :, 1] -= 0.456 * 255 + im_padded[:, :, 0] -= 0.406 * 255 + im_padded[:, :, 2] /= 255.0 * 0.229 + im_padded[:, :, 1] /= 255.0 * 0.224 + im_padded[:, :, 0] /= 255.0 * 0.225 + im_padded = im_padded.transpose((2, 0, 1)) + + data["image"] = im_padded[::-1, :, :] + data["score_map"] = score_map[np.newaxis, :, :] + data["border_map"] = border_map.transpose((2, 0, 1)) + data["training_mask"] = training_mask[np.newaxis, :, :] + data["tvo_map"] = tvo_map.transpose((2, 0, 1)) + data["tco_map"] = tco_map.transpose((2, 0, 1)) + return data diff --git a/ocr/ppocr/data/imaug/ssl_img_aug.py b/ocr/ppocr/data/imaug/ssl_img_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..60613da770e2e691c771417aadf07d61c7e3e463 --- /dev/null +++ b/ocr/ppocr/data/imaug/ssl_img_aug.py @@ -0,0 +1,38 @@ +import cv2 +import numpy as np + +from .rec_img_aug import resize_norm_img + + +class SSLRotateResize(object): + def __init__( + self, image_shape, padding=False, select_all=True, mode="train", **kwargs + ): + self.image_shape = image_shape + self.padding = padding + self.select_all = select_all + self.mode = mode + + def __call__(self, data): + img = data["image"] + + data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + data["image_r180"] = cv2.rotate(data["image_r90"], cv2.ROTATE_90_CLOCKWISE) + data["image_r270"] = cv2.rotate(data["image_r180"], cv2.ROTATE_90_CLOCKWISE) + + images = [] + for key in ["image", "image_r90", "image_r180", "image_r270"]: + images.append( + resize_norm_img( + data.pop(key), image_shape=self.image_shape, padding=self.padding + )[0] + ) + data["image"] = np.stack(images, axis=0) + data["label"] = np.array(list(range(4))) + if not self.select_all: + data["image"] = data["image"][0::2] # just choose 0 and 180 + data["label"] = data["label"][0:2] # label needs to be continuous + if self.mode == "test": + data["image"] = data["image"][0] + data["label"] = data["label"][0] + return data diff --git a/ocr/ppocr/data/imaug/text_image_aug/__init__.py b/ocr/ppocr/data/imaug/text_image_aug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6932b51ed54f5f3200bf4a669b09a15989aacc26 --- /dev/null +++ b/ocr/ppocr/data/imaug/text_image_aug/__init__.py @@ -0,0 +1,3 @@ +from .augment import tia_distort, tia_perspective, tia_stretch + +__all__ = ["tia_distort", "tia_stretch", "tia_perspective"] diff --git a/ocr/ppocr/data/imaug/text_image_aug/augment.py b/ocr/ppocr/data/imaug/text_image_aug/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..cc53e15bd423f3a6761ca7950a0612de2131403a --- /dev/null +++ b/ocr/ppocr/data/imaug/text_image_aug/augment.py @@ -0,0 +1,106 @@ +import numpy as np + +from .warp_mls import WarpMLS + + +def tia_distort(src, segment=4): + img_h, img_w = src.shape[:2] + + cut = img_w // segment + thresh = cut // 3 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) + dst_pts.append([img_w - np.random.randint(thresh), np.random.randint(thresh)]) + dst_pts.append( + [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)] + ) + dst_pts.append([np.random.randint(thresh), img_h - np.random.randint(thresh)]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append( + [ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + np.random.randint(thresh) - half_thresh, + ] + ) + dst_pts.append( + [ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + img_h + np.random.randint(thresh) - half_thresh, + ] + ) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst + + +def tia_stretch(src, segment=4): + img_h, img_w = src.shape[:2] + + cut = img_w // segment + thresh = cut * 4 // 5 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, 0]) + dst_pts.append([img_w, 0]) + dst_pts.append([img_w, img_h]) + dst_pts.append([0, img_h]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + move = np.random.randint(thresh) - half_thresh + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append([cut * cut_idx + move, 0]) + dst_pts.append([cut * cut_idx + move, img_h]) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst + + +def tia_perspective(src): + img_h, img_w = src.shape[:2] + + thresh = img_h // 2 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, np.random.randint(thresh)]) + dst_pts.append([img_w, np.random.randint(thresh)]) + dst_pts.append([img_w, img_h - np.random.randint(thresh)]) + dst_pts.append([0, img_h - np.random.randint(thresh)]) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst diff --git a/ocr/ppocr/data/imaug/text_image_aug/warp_mls.py b/ocr/ppocr/data/imaug/text_image_aug/warp_mls.py new file mode 100644 index 0000000000000000000000000000000000000000..af93139aa56c2c34f05341d2ea75229415386021 --- /dev/null +++ b/ocr/ppocr/data/imaug/text_image_aug/warp_mls.py @@ -0,0 +1,169 @@ +import numpy as np + + +class WarpMLS: + def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.0): + self.src = src + self.src_pts = src_pts + self.dst_pts = dst_pts + self.pt_count = len(self.dst_pts) + self.dst_w = dst_w + self.dst_h = dst_h + self.trans_ratio = trans_ratio + self.grid_size = 100 + self.rdx = np.zeros((self.dst_h, self.dst_w)) + self.rdy = np.zeros((self.dst_h, self.dst_w)) + + @staticmethod + def __bilinear_interp(x, y, v11, v12, v21, v22): + return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * (1 - y) + v22 * y) * x + + def generate(self): + self.calc_delta() + return self.gen_img() + + def calc_delta(self): + w = np.zeros(self.pt_count, dtype=np.float32) + + if self.pt_count < 2: + return + + i = 0 + while 1: + if self.dst_w <= i < self.dst_w + self.grid_size - 1: + i = self.dst_w - 1 + elif i >= self.dst_w: + break + + j = 0 + while 1: + if self.dst_h <= j < self.dst_h + self.grid_size - 1: + j = self.dst_h - 1 + elif j >= self.dst_h: + break + + sw = 0 + swp = np.zeros(2, dtype=np.float32) + swq = np.zeros(2, dtype=np.float32) + new_pt = np.zeros(2, dtype=np.float32) + cur_pt = np.array([i, j], dtype=np.float32) + + k = 0 + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + break + + w[k] = 1.0 / ( + (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + + (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]) + ) + + sw += w[k] + swp = swp + w[k] * np.array(self.dst_pts[k]) + swq = swq + w[k] * np.array(self.src_pts[k]) + + if k == self.pt_count - 1: + pstar = 1 / sw * swp + qstar = 1 / sw * swq + + miu_s = 0 + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + continue + pt_i = self.dst_pts[k] - pstar + miu_s += w[k] * np.sum(pt_i * pt_i) + + cur_pt -= pstar + cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) + + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + continue + + pt_i = self.dst_pts[k] - pstar + pt_j = np.array([-pt_i[1], pt_i[0]]) + + tmp_pt = np.zeros(2, dtype=np.float32) + tmp_pt[0] = ( + np.sum(pt_i * cur_pt) * self.src_pts[k][0] + - np.sum(pt_j * cur_pt) * self.src_pts[k][1] + ) + tmp_pt[1] = ( + -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + + np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] + ) + tmp_pt *= w[k] / miu_s + new_pt += tmp_pt + + new_pt += qstar + else: + new_pt = self.src_pts[k] + + self.rdx[j, i] = new_pt[0] - i + self.rdy[j, i] = new_pt[1] - j + + j += self.grid_size + i += self.grid_size + + def gen_img(self): + src_h, src_w = self.src.shape[:2] + dst = np.zeros_like(self.src, dtype=np.float32) + + for i in np.arange(0, self.dst_h, self.grid_size): + for j in np.arange(0, self.dst_w, self.grid_size): + ni = i + self.grid_size + nj = j + self.grid_size + w = h = self.grid_size + if ni >= self.dst_h: + ni = self.dst_h - 1 + h = ni - i + 1 + if nj >= self.dst_w: + nj = self.dst_w - 1 + w = nj - j + 1 + + di = np.reshape(np.arange(h), (-1, 1)) + dj = np.reshape(np.arange(w), (1, -1)) + delta_x = self.__bilinear_interp( + di / h, + dj / w, + self.rdx[i, j], + self.rdx[i, nj], + self.rdx[ni, j], + self.rdx[ni, nj], + ) + delta_y = self.__bilinear_interp( + di / h, + dj / w, + self.rdy[i, j], + self.rdy[i, nj], + self.rdy[ni, j], + self.rdy[ni, nj], + ) + nx = j + dj + delta_x * self.trans_ratio + ny = i + di + delta_y * self.trans_ratio + nx = np.clip(nx, 0, src_w - 1) + ny = np.clip(ny, 0, src_h - 1) + nxi = np.array(np.floor(nx), dtype=np.int32) + nyi = np.array(np.floor(ny), dtype=np.int32) + nxi1 = np.array(np.ceil(nx), dtype=np.int32) + nyi1 = np.array(np.ceil(ny), dtype=np.int32) + + if len(self.src.shape) == 3: + x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) + y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) + else: + x = ny - nyi + y = nx - nxi + dst[i : i + h, j : j + w] = self.__bilinear_interp( + x, + y, + self.src[nyi, nxi], + self.src[nyi, nxi1], + self.src[nyi1, nxi], + self.src[nyi1, nxi1], + ) + + dst = np.clip(dst, 0, 255) + dst = np.array(dst, dtype=np.uint8) + + return dst diff --git a/ocr/ppocr/data/imaug/vqa/__init__.py b/ocr/ppocr/data/imaug/vqa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a569a152e262dc3c49c11120124c6a9bf198ce46 --- /dev/null +++ b/ocr/ppocr/data/imaug/vqa/__init__.py @@ -0,0 +1,3 @@ +from .token import VQAReTokenChunk, VQAReTokenRelation, VQASerTokenChunk, VQATokenPad + +__all__ = ["VQATokenPad", "VQASerTokenChunk", "VQAReTokenChunk", "VQAReTokenRelation"] diff --git a/ocr/ppocr/data/imaug/vqa/token/__init__.py b/ocr/ppocr/data/imaug/vqa/token/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fcb04f4b33bd196c86e659ff36f2adf0d1bcca --- /dev/null +++ b/ocr/ppocr/data/imaug/vqa/token/__init__.py @@ -0,0 +1,3 @@ +from .vqa_token_chunk import VQAReTokenChunk, VQASerTokenChunk +from .vqa_token_pad import VQATokenPad +from .vqa_token_relation import VQAReTokenRelation diff --git a/ocr/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/ocr/ppocr/data/imaug/vqa/token/vqa_token_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ce12ce47c3c37eeeb71381d896f7111f83120eb3 --- /dev/null +++ b/ocr/ppocr/data/imaug/vqa/token/vqa_token_chunk.py @@ -0,0 +1,120 @@ +from collections import defaultdict + + +class VQASerTokenChunk(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + encoded_inputs_all = [] + seq_len = len(data["input_ids"]) + for index in range(0, seq_len, self.max_seq_len): + chunk_beg = index + chunk_end = min(index + self.max_seq_len, seq_len) + encoded_inputs_example = {} + for key in data: + if key in [ + "label", + "input_ids", + "labels", + "token_type_ids", + "bbox", + "attention_mask", + ]: + if self.infer_mode and key == "labels": + encoded_inputs_example[key] = data[key] + else: + encoded_inputs_example[key] = data[key][chunk_beg:chunk_end] + else: + encoded_inputs_example[key] = data[key] + + encoded_inputs_all.append(encoded_inputs_example) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + +class VQAReTokenChunk(object): + def __init__( + self, max_seq_len=512, entities_labels=None, infer_mode=False, **kwargs + ): + self.max_seq_len = max_seq_len + self.entities_labels = ( + {"HEADER": 0, "QUESTION": 1, "ANSWER": 2} + if entities_labels is None + else entities_labels + ) + self.infer_mode = infer_mode + + def __call__(self, data): + # prepare data + entities = data.pop("entities") + relations = data.pop("relations") + encoded_inputs_all = [] + for index in range(0, len(data["input_ids"]), self.max_seq_len): + item = {} + for key in data: + if key in [ + "label", + "input_ids", + "labels", + "token_type_ids", + "bbox", + "attention_mask", + ]: + if self.infer_mode and key == "labels": + item[key] = data[key] + else: + item[key] = data[key][index : index + self.max_seq_len] + else: + item[key] = data[key] + # select entity in current chunk + entities_in_this_span = [] + global_to_local_map = {} # + for entity_id, entity in enumerate(entities): + if ( + index <= entity["start"] < index + self.max_seq_len + and index <= entity["end"] < index + self.max_seq_len + ): + entity["start"] = entity["start"] - index + entity["end"] = entity["end"] - index + global_to_local_map[entity_id] = len(entities_in_this_span) + entities_in_this_span.append(entity) + + # select relations in current chunk + relations_in_this_span = [] + for relation in relations: + if ( + index <= relation["start_index"] < index + self.max_seq_len + and index <= relation["end_index"] < index + self.max_seq_len + ): + relations_in_this_span.append( + { + "head": global_to_local_map[relation["head"]], + "tail": global_to_local_map[relation["tail"]], + "start_index": relation["start_index"] - index, + "end_index": relation["end_index"] - index, + } + ) + item.update( + { + "entities": self.reformat(entities_in_this_span), + "relations": self.reformat(relations_in_this_span), + } + ) + if len(item["entities"]) > 0: + item["entities"]["label"] = [ + self.entities_labels[x] for x in item["entities"]["label"] + ] + encoded_inputs_all.append(item) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + def reformat(self, data): + new_data = defaultdict(list) + for item in data: + for k, v in item.items(): + new_data[k].append(v) + return new_data diff --git a/ocr/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/ocr/ppocr/data/imaug/vqa/token/vqa_token_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..464cbde443fba56c45fc0aaa2dc61743c028265e --- /dev/null +++ b/ocr/ppocr/data/imaug/vqa/token/vqa_token_pad.py @@ -0,0 +1,104 @@ +import numpy as np +import paddle + + +class VQATokenPad(object): + def __init__( + self, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + truncation_strategy="longest_first", + return_overflowing_tokens=False, + return_special_tokens_mask=False, + infer_mode=False, + **kwargs + ): + self.max_seq_len = max_seq_len + self.pad_to_max_seq_len = max_seq_len + self.return_attention_mask = return_attention_mask + self.return_token_type_ids = return_token_type_ids + self.truncation_strategy = truncation_strategy + self.return_overflowing_tokens = return_overflowing_tokens + self.return_special_tokens_mask = return_special_tokens_mask + self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + self.infer_mode = infer_mode + + def __call__(self, data): + needs_to_be_padded = ( + self.pad_to_max_seq_len and len(data["input_ids"]) < self.max_seq_len + ) + + if needs_to_be_padded: + if "tokenizer_params" in data: + tokenizer_params = data.pop("tokenizer_params") + else: + tokenizer_params = dict( + padding_side="right", pad_token_type_id=0, pad_token_id=1 + ) + + difference = self.max_seq_len - len(data["input_ids"]) + if tokenizer_params["padding_side"] == "right": + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data["input_ids"]) + [ + 0 + ] * difference + if self.return_token_type_ids: + data["token_type_ids"] = ( + data["token_type_ids"] + + [tokenizer_params["pad_token_type_id"]] * difference + ) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = ( + data["special_tokens_mask"] + [1] * difference + ) + data["input_ids"] = ( + data["input_ids"] + [tokenizer_params["pad_token_id"]] * difference + ) + if not self.infer_mode: + data["labels"] = ( + data["labels"] + [self.pad_token_label_id] * difference + ) + data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference + elif tokenizer_params["padding_side"] == "left": + if self.return_attention_mask: + data["attention_mask"] = [0] * difference + [1] * len( + data["input_ids"] + ) + if self.return_token_type_ids: + data["token_type_ids"] = [ + tokenizer_params["pad_token_type_id"] + ] * difference + data["token_type_ids"] + if self.return_special_tokens_mask: + data["special_tokens_mask"] = [1] * difference + data[ + "special_tokens_mask" + ] + data["input_ids"] = [ + tokenizer_params["pad_token_id"] + ] * difference + data["input_ids"] + if not self.infer_mode: + data["labels"] = [self.pad_token_label_id] * difference + data[ + "labels" + ] + data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"] + else: + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data["input_ids"]) + + for key in data: + if key in [ + "input_ids", + "labels", + "token_type_ids", + "bbox", + "attention_mask", + ]: + if self.infer_mode: + if key != "labels": + length = min(len(data[key]), self.max_seq_len) + data[key] = data[key][:length] + else: + continue + data[key] = np.array(data[key], dtype="int64") + return data diff --git a/ocr/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/ocr/ppocr/data/imaug/vqa/token/vqa_token_relation.py new file mode 100644 index 0000000000000000000000000000000000000000..4446687c0aab29ba86d854703bfb56f8cb27f23e --- /dev/null +++ b/ocr/ppocr/data/imaug/vqa/token/vqa_token_relation.py @@ -0,0 +1,61 @@ +class VQAReTokenRelation(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + """ + build relations + """ + entities = data["entities"] + relations = data["relations"] + id2label = data.pop("id2label") + empty_entity = data.pop("empty_entity") + entity_id_to_index_map = data.pop("entity_id_to_index_map") + + relations = list(set(relations)) + relations = [ + rel + for rel in relations + if rel[0] not in empty_entity and rel[1] not in empty_entity + ] + kv_relations = [] + for rel in relations: + pair = [id2label[rel[0]], id2label[rel[1]]] + if pair == ["question", "answer"]: + kv_relations.append( + { + "head": entity_id_to_index_map[rel[0]], + "tail": entity_id_to_index_map[rel[1]], + } + ) + elif pair == ["answer", "question"]: + kv_relations.append( + { + "head": entity_id_to_index_map[rel[1]], + "tail": entity_id_to_index_map[rel[0]], + } + ) + else: + continue + relations = sorted( + [ + { + "head": rel["head"], + "tail": rel["tail"], + "start_index": self.get_relation_span(rel, entities)[0], + "end_index": self.get_relation_span(rel, entities)[1], + } + for rel in kv_relations + ], + key=lambda x: x["head"], + ) + + data["relations"] = relations + return data + + def get_relation_span(self, rel, entities): + bound = [] + for entity_index in [rel["head"], rel["tail"]]: + bound.append(entities[entity_index]["start"]) + bound.append(entities[entity_index]["end"]) + return min(bound), max(bound) diff --git a/ocr/ppocr/data/lmdb_dataset.py b/ocr/ppocr/data/lmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..63009c495973d96ce2765f666aff004876957bcb --- /dev/null +++ b/ocr/ppocr/data/lmdb_dataset.py @@ -0,0 +1,139 @@ +import os + +import cv2 +import lmdb +import numpy as np +from paddle.io import Dataset + +from .imaug import create_operators, transform + + +class LMDBDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(LMDBDataSet, self).__init__() + + global_config = config["Global"] + dataset_config = config[mode]["dataset"] + loader_config = config[mode]["loader"] + batch_size = loader_config["batch_size_per_card"] + data_dir = dataset_config["data_dir"] + self.do_shuffle = loader_config["shuffle"] + + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) + logger.info("Initialize indexs of datasets:%s" % data_dir) + self.data_idx_order_list = self.dataset_traversal() + if self.do_shuffle: + np.random.shuffle(self.data_idx_order_list) + self.ops = create_operators(dataset_config["transforms"], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2) + + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + for dirpath, dirnames, filenames in os.walk(data_dir + "/"): + if not dirnames: + env = lmdb.open( + dirpath, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + txn = env.begin(write=False) + num_samples = int(txn.get("num-samples".encode())) + lmdb_sets[dataset_idx] = { + "dirpath": dirpath, + "env": env, + "txn": txn, + "num_samples": num_samples, + } + dataset_idx += 1 + return lmdb_sets + + def dataset_traversal(self): + lmdb_num = len(self.lmdb_sets) + total_sample_num = 0 + for lno in range(lmdb_num): + total_sample_num += self.lmdb_sets[lno]["num_samples"] + data_idx_order_list = np.zeros((total_sample_num, 2)) + beg_idx = 0 + for lno in range(lmdb_num): + tmp_sample_num = self.lmdb_sets[lno]["num_samples"] + end_idx = beg_idx + tmp_sample_num + data_idx_order_list[beg_idx:end_idx, 0] = lno + data_idx_order_list[beg_idx:end_idx, 1] = list(range(tmp_sample_num)) + data_idx_order_list[beg_idx:end_idx, 1] += 1 + beg_idx = beg_idx + tmp_sample_num + return data_idx_order_list + + def get_img_data(self, value): + """get_img_data""" + if not value: + return None + imgdata = np.frombuffer(value, dtype="uint8") + if imgdata is None: + return None + imgori = cv2.imdecode(imgdata, 1) + if imgori is None: + return None + return imgori + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, "ext_data_num"): + ext_data_num = getattr(op, "ext_data_num") + break + load_data_ops = self.ops[: self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[ + np.random.randint(self.__len__()) + ] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]["txn"], file_idx + ) + if sample_info is None: + continue + img, label = sample_info + data = {"image": img, "label": label} + outs = transform(data, load_data_ops) + ext_data.append(data) + return ext_data + + def get_lmdb_sample_info(self, txn, index): + label_key = "label-%09d".encode() % index + label = txn.get(label_key) + if label is None: + return None + label = label.decode("utf-8") + img_key = "image-%09d".encode() % index + imgbuf = txn.get(img_key) + return imgbuf, label + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]["txn"], file_idx + ) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + img, label = sample_info + data = {"image": img, "label": label} + data["ext_data"] = self.get_ext_data() + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0] diff --git a/ocr/ppocr/data/pgnet_dataset.py b/ocr/ppocr/data/pgnet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1bcfe785eedad221aade9f52ea64cdb5de998b30 --- /dev/null +++ b/ocr/ppocr/data/pgnet_dataset.py @@ -0,0 +1,96 @@ +import os +import random + +import numpy as np +from paddle.io import Dataset + +from .imaug import create_operators, transform + + +class PGDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PGDataSet, self).__init__() + + self.logger = logger + self.seed = seed + self.mode = mode + global_config = config["Global"] + dataset_config = config[mode]["dataset"] + loader_config = config[mode]["loader"] + + self.delimiter = dataset_config.get("delimiter", "\t") + label_file_list = dataset_config.pop("label_file_list") + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + assert ( + len(ratio_list) == data_source_num + ), "The length of ratio_list should be the same as the file_list." + self.data_dir = dataset_config["data_dir"] + self.do_shuffle = loader_config["shuffle"] + + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + + self.ops = create_operators(dataset_config["transforms"], global_config) + + self.need_reset = True in [x < 1 for x in ratio_list] + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def get_image_info_list(self, file_list, ratio_list): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, file in enumerate(file_list): + with open(file, "rb") as f: + lines = f.readlines() + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_line = self.data_lines[file_idx] + img_id = 0 + try: + data_line = data_line.decode("utf-8") + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + if self.mode.lower() == "eval": + try: + img_id = int(data_line.split(".")[0][7:]) + except: + img_id = 0 + data = {"img_path": img_path, "label": label, "img_id": img_id} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data["img_path"], "rb") as f: + img = f.read() + data["image"] = img + outs = transform(data, self.ops) + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + self.data_idx_order_list[idx], e + ) + ) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ocr/ppocr/data/pubtab_dataset.py b/ocr/ppocr/data/pubtab_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..65916ab8d91a8d5b07f8c4ace4c58b7cab569669 --- /dev/null +++ b/ocr/ppocr/data/pubtab_dataset.py @@ -0,0 +1,98 @@ +import json +import os +import random + +import numpy as np +from paddle.io import Dataset + +from .imaug import create_operators, transform + + +class PubTabDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(PubTabDataSet, self).__init__() + self.logger = logger + + global_config = config["Global"] + dataset_config = config[mode]["dataset"] + loader_config = config[mode]["loader"] + + label_file_path = dataset_config.pop("label_file_path") + + self.data_dir = dataset_config["data_dir"] + self.do_shuffle = loader_config["shuffle"] + self.do_hard_select = False + if "hard_select" in loader_config: + self.do_hard_select = loader_config["hard_select"] + self.hard_prob = loader_config["hard_prob"] + if self.do_hard_select: + self.img_select_prob = self.load_hard_select_prob() + self.table_select_type = None + if "table_select_type" in loader_config: + self.table_select_type = loader_config["table_select_type"] + self.table_select_prob = loader_config["table_select_prob"] + + self.seed = seed + logger.info("Initialize indexs of datasets:%s" % label_file_path) + with open(label_file_path, "rb") as f: + self.data_lines = f.readlines() + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() + self.ops = create_operators(dataset_config["transforms"], global_config) + + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def shuffle_data_random(self): + if self.do_shuffle: + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def __getitem__(self, idx): + try: + data_line = self.data_lines[idx] + data_line = data_line.decode("utf-8").strip("\n") + info = json.loads(data_line) + file_name = info["filename"] + select_flag = True + if self.do_hard_select: + prob = self.img_select_prob[file_name] + if prob < random.uniform(0, 1): + select_flag = False + + if self.table_select_type: + structure = info["html"]["structure"]["tokens"].copy() + structure_str = "".join(structure) + table_type = "simple" + if "colspan" in structure_str or "rowspan" in structure_str: + table_type = "complex" + if table_type == "complex": + if self.table_select_prob < random.uniform(0, 1): + select_flag = False + + if select_flag: + cells = info["html"]["cells"].copy() + structure = info["html"]["structure"].copy() + img_path = os.path.join(self.data_dir, file_name) + data = {"img_path": img_path, "cells": cells, "structure": structure} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data["img_path"], "rb") as f: + img = f.read() + data["image"] = img + outs = transform(data, self.ops) + else: + outs = None + except Exception as e: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format(data_line, e) + ) + outs = None + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ocr/ppocr/data/simple_dataset.py b/ocr/ppocr/data/simple_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9a628c5239eb8ea647ec37ed6460156efb324691 --- /dev/null +++ b/ocr/ppocr/data/simple_dataset.py @@ -0,0 +1,142 @@ +import json +import os +import random +import traceback + +import numpy as np +from paddle.io import Dataset + +from .imaug import create_operators, transform + + +class SimpleDataSet(Dataset): + def __init__(self, config, mode, logger, seed=None): + super(SimpleDataSet, self).__init__() + self.logger = logger + self.mode = mode.lower() + + global_config = config["Global"] + dataset_config = config[mode]["dataset"] + loader_config = config[mode]["loader"] + + self.delimiter = dataset_config.get("delimiter", "\t") + label_file_list = dataset_config.pop("label_file_list") + data_source_num = len(label_file_list) + ratio_list = dataset_config.get("ratio_list", 1.0) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * int(data_source_num) + + assert ( + len(ratio_list) == data_source_num + ), "The length of ratio_list should be the same as the file_list." + self.data_dir = dataset_config["data_dir"] + self.do_shuffle = loader_config["shuffle"] + self.seed = seed + logger.info("Initialize indexs of datasets:%s" % label_file_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if self.mode == "train" and self.do_shuffle: + self.shuffle_data_random() + self.ops = create_operators(dataset_config["transforms"], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2) + self.need_reset = True in [x < 1 for x in ratio_list] + + def get_image_info_list(self, file_list, ratio_list): + if isinstance(file_list, str): + file_list = [file_list] + data_lines = [] + for idx, file in enumerate(file_list): + with open(file, "rb") as f: + lines = f.readlines() + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines + + def shuffle_data_random(self): + random.seed(self.seed) + random.shuffle(self.data_lines) + return + + def _try_parse_filename_list(self, file_name): + # multiple images -> one gt label + if len(file_name) > 0 and file_name[0] == "[": + try: + info = json.loads(file_name) + file_name = random.choice(info) + except: + pass + return file_name + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, "ext_data_num"): + ext_data_num = getattr(op, "ext_data_num") + break + load_data_ops = self.ops[: self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + file_idx = self.data_idx_order_list[np.random.randint(self.__len__())] + data_line = self.data_lines[file_idx] + data_line = data_line.decode("utf-8") + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {"img_path": img_path, "label": label} + if not os.path.exists(img_path): + continue + with open(data["img_path"], "rb") as f: + img = f.read() + data["image"] = img + data = transform(data, load_data_ops) + + if data is None: + continue + if "polys" in data.keys(): + if data["polys"].shape[1] != 4: + continue + ext_data.append(data) + return ext_data + + def __getitem__(self, idx): + file_idx = self.data_idx_order_list[idx] + data_line = self.data_lines[file_idx] + try: + data_line = data_line.decode("utf-8") + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {"img_path": img_path, "label": label} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) + with open(data["img_path"], "rb") as f: + img = f.read() + data["image"] = img + data["ext_data"] = self.get_ext_data() + outs = transform(data, self.ops) + except: + self.logger.error( + "When parsing line {}, error happened with msg: {}".format( + data_line, traceback.format_exc() + ) + ) + outs = None + if outs is None: + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = ( + np.random.randint(self.__len__()) + if self.mode == "train" + else (idx + 1) % self.__len__() + ) + return self.__getitem__(rnd_idx) + return outs + + def __len__(self): + return len(self.data_idx_order_list) diff --git a/ocr/ppocr/ppocr_keys_v1.txt b/ocr/ppocr/ppocr_keys_v1.txt new file mode 100644 index 0000000000000000000000000000000000000000..84b885d8352226e49b1d5d791b8f43a663e246aa --- /dev/null +++ b/ocr/ppocr/ppocr_keys_v1.txt @@ -0,0 +1,6623 @@ +' +疗 +绚 +诚 +娇 +溜 +题 +贿 +者 +廖 +更 +纳 +加 +奉 +公 +一 +就 +汴 +计 +与 +路 +房 +原 +妇 +2 +0 +8 +- +7 +其 +> +: +] +, +, +骑 +刈 +全 +消 +昏 +傈 +安 +久 +钟 +嗅 +不 +影 +处 +驽 +蜿 +资 +关 +椤 +地 +瘸 +专 +问 +忖 +票 +嫉 +炎 +韵 +要 +月 +田 +节 +陂 +鄙 +捌 +备 +拳 +伺 +眼 +网 +盎 +大 +傍 +心 +东 +愉 +汇 +蹿 +科 +每 +业 +里 +航 +晏 +字 +平 +录 +先 +1 +3 +彤 +鲶 +产 +稍 +督 +腴 +有 +象 +岳 +注 +绍 +在 +泺 +文 +定 +核 +名 +水 +过 +理 +让 +偷 +率 +等 +这 +发 +” +为 +含 +肥 +酉 +相 +鄱 +七 +编 +猥 +锛 +日 +镀 +蒂 +掰 +倒 +辆 +栾 +栗 +综 +涩 +州 +雌 +滑 +馀 +了 +机 +块 +司 +宰 +甙 +兴 +矽 +抚 +保 +用 +沧 +秩 +如 +收 +息 +滥 +页 +疑 +埠 +! +! +姥 +异 +橹 +钇 +向 +下 +跄 +的 +椴 +沫 +国 +绥 +獠 +报 +开 +民 +蜇 +何 +分 +凇 +长 +讥 +藏 +掏 +施 +羽 +中 +讲 +派 +嘟 +人 +提 +浼 +间 +世 +而 +古 +多 +倪 +唇 +饯 +控 +庚 +首 +赛 +蜓 +味 +断 +制 +觉 +技 +替 +艰 +溢 +潮 +夕 +钺 +外 +摘 +枋 +动 +双 +单 +啮 +户 +枇 +确 +锦 +曜 +杜 +或 +能 +效 +霜 +盒 +然 +侗 +电 +晁 +放 +步 +鹃 +新 +杖 +蜂 +吒 +濂 +瞬 +评 +总 +隍 +对 +独 +合 +也 +是 +府 +青 +天 +诲 +墙 +组 +滴 +级 +邀 +帘 +示 +已 +时 +骸 +仄 +泅 +和 +遨 +店 +雇 +疫 +持 +巍 +踮 +境 +只 +亨 +目 +鉴 +崤 +闲 +体 +泄 +杂 +作 +般 +轰 +化 +解 +迂 +诿 +蛭 +璀 +腾 +告 +版 +服 +省 +师 +小 +规 +程 +线 +海 +办 +引 +二 +桧 +牌 +砺 +洄 +裴 +修 +图 +痫 +胡 +许 +犊 +事 +郛 +基 +柴 +呼 +食 +研 +奶 +律 +蛋 +因 +葆 +察 +戏 +褒 +戒 +再 +李 +骁 +工 +貂 +油 +鹅 +章 +啄 +休 +场 +给 +睡 +纷 +豆 +器 +捎 +说 +敏 +学 +会 +浒 +设 +诊 +格 +廓 +查 +来 +霓 +室 +溆 +¢ +诡 +寥 +焕 +舜 +柒 +狐 +回 +戟 +砾 +厄 +实 +翩 +尿 +五 +入 +径 +惭 +喹 +股 +宇 +篝 +| +; +美 +期 +云 +九 +祺 +扮 +靠 +锝 +槌 +系 +企 +酰 +阊 +暂 +蚕 +忻 +豁 +本 +羹 +执 +条 +钦 +H +獒 +限 +进 +季 +楦 +于 +芘 +玖 +铋 +茯 +未 +答 +粘 +括 +样 +精 +欠 +矢 +甥 +帷 +嵩 +扣 +令 +仔 +风 +皈 +行 +支 +部 +蓉 +刮 +站 +蜡 +救 +钊 +汗 +松 +嫌 +成 +可 +. +鹤 +院 +从 +交 +政 +怕 +活 +调 +球 +局 +验 +髌 +第 +韫 +谗 +串 +到 +圆 +年 +米 +/ +* +友 +忿 +检 +区 +看 +自 +敢 +刃 +个 +兹 +弄 +流 +留 +同 +没 +齿 +星 +聆 +轼 +湖 +什 +三 +建 +蛔 +儿 +椋 +汕 +震 +颧 +鲤 +跟 +力 +情 +璺 +铨 +陪 +务 +指 +族 +训 +滦 +鄣 +濮 +扒 +商 +箱 +十 +召 +慷 +辗 +所 +莞 +管 +护 +臭 +横 +硒 +嗓 +接 +侦 +六 +露 +党 +馋 +驾 +剖 +高 +侬 +妪 +幂 +猗 +绺 +骐 +央 +酐 +孝 +筝 +课 +徇 +缰 +门 +男 +西 +项 +句 +谙 +瞒 +秃 +篇 +教 +碲 +罚 +声 +呐 +景 +前 +富 +嘴 +鳌 +稀 +免 +朋 +啬 +睐 +去 +赈 +鱼 +住 +肩 +愕 +速 +旁 +波 +厅 +健 +茼 +厥 +鲟 +谅 +投 +攸 +炔 +数 +方 +击 +呋 +谈 +绩 +别 +愫 +僚 +躬 +鹧 +胪 +炳 +招 +喇 +膨 +泵 +蹦 +毛 +结 +5 +4 +谱 +识 +陕 +粽 +婚 +拟 +构 +且 +搜 +任 +潘 +比 +郢 +妨 +醪 +陀 +桔 +碘 +扎 +选 +哈 +骷 +楷 +亿 +明 +缆 +脯 +监 +睫 +逻 +婵 +共 +赴 +淝 +凡 +惦 +及 +达 +揖 +谩 +澹 +减 +焰 +蛹 +番 +祁 +柏 +员 +禄 +怡 +峤 +龙 +白 +叽 +生 +闯 +起 +细 +装 +谕 +竟 +聚 +钙 +上 +导 +渊 +按 +艾 +辘 +挡 +耒 +盹 +饪 +臀 +记 +邮 +蕙 +受 +各 +医 +搂 +普 +滇 +朗 +茸 +带 +翻 +酚 +( +光 +堤 +墟 +蔷 +万 +幻 +〓 +瑙 +辈 +昧 +盏 +亘 +蛀 +吉 +铰 +请 +子 +假 +闻 +税 +井 +诩 +哨 +嫂 +好 +面 +琐 +校 +馊 +鬣 +缂 +营 +访 +炖 +占 +农 +缀 +否 +经 +钚 +棵 +趟 +张 +亟 +吏 +茶 +谨 +捻 +论 +迸 +堂 +玉 +信 +吧 +瞠 +乡 +姬 +寺 +咬 +溏 +苄 +皿 +意 +赉 +宝 +尔 +钰 +艺 +特 +唳 +踉 +都 +荣 +倚 +登 +荐 +丧 +奇 +涵 +批 +炭 +近 +符 +傩 +感 +道 +着 +菊 +虹 +仲 +众 +懈 +濯 +颞 +眺 +南 +释 +北 +缝 +标 +既 +茗 +整 +撼 +迤 +贲 +挎 +耱 +拒 +某 +妍 +卫 +哇 +英 +矶 +藩 +治 +他 +元 +领 +膜 +遮 +穗 +蛾 +飞 +荒 +棺 +劫 +么 +市 +火 +温 +拈 +棚 +洼 +转 +果 +奕 +卸 +迪 +伸 +泳 +斗 +邡 +侄 +涨 +屯 +萋 +胭 +氡 +崮 +枞 +惧 +冒 +彩 +斜 +手 +豚 +随 +旭 +淑 +妞 +形 +菌 +吲 +沱 +争 +驯 +歹 +挟 +兆 +柱 +传 +至 +包 +内 +响 +临 +红 +功 +弩 +衡 +寂 +禁 +老 +棍 +耆 +渍 +织 +害 +氵 +渑 +布 +载 +靥 +嗬 +虽 +苹 +咨 +娄 +库 +雉 +榜 +帜 +嘲 +套 +瑚 +亲 +簸 +欧 +边 +6 +腿 +旮 +抛 +吹 +瞳 +得 +镓 +梗 +厨 +继 +漾 +愣 +憨 +士 +策 +窑 +抑 +躯 +襟 +脏 +参 +贸 +言 +干 +绸 +鳄 +穷 +藜 +音 +折 +详 +) +举 +悍 +甸 +癌 +黎 +谴 +死 +罩 +迁 +寒 +驷 +袖 +媒 +蒋 +掘 +模 +纠 +恣 +观 +祖 +蛆 +碍 +位 +稿 +主 +澧 +跌 +筏 +京 +锏 +帝 +贴 +证 +糠 +才 +黄 +鲸 +略 +炯 +饱 +四 +出 +园 +犀 +牧 +容 +汉 +杆 +浈 +汰 +瑷 +造 +虫 +瘩 +怪 +驴 +济 +应 +花 +沣 +谔 +夙 +旅 +价 +矿 +以 +考 +s +u +呦 +晒 +巡 +茅 +准 +肟 +瓴 +詹 +仟 +褂 +译 +桌 +混 +宁 +怦 +郑 +抿 +些 +余 +鄂 +饴 +攒 +珑 +群 +阖 +岔 +琨 +藓 +预 +环 +洮 +岌 +宀 +杲 +瀵 +最 +常 +囡 +周 +踊 +女 +鼓 +袭 +喉 +简 +范 +薯 +遐 +疏 +粱 +黜 +禧 +法 +箔 +斤 +遥 +汝 +奥 +直 +贞 +撑 +置 +绱 +集 +她 +馅 +逗 +钧 +橱 +魉 +[ +恙 +躁 +唤 +9 +旺 +膘 +待 +脾 +惫 +购 +吗 +依 +盲 +度 +瘿 +蠖 +俾 +之 +镗 +拇 +鲵 +厝 +簧 +续 +款 +展 +啃 +表 +剔 +品 +钻 +腭 +损 +清 +锶 +统 +涌 +寸 +滨 +贪 +链 +吠 +冈 +伎 +迥 +咏 +吁 +览 +防 +迅 +失 +汾 +阔 +逵 +绀 +蔑 +列 +川 +凭 +努 +熨 +揪 +利 +俱 +绉 +抢 +鸨 +我 +即 +责 +膦 +易 +毓 +鹊 +刹 +玷 +岿 +空 +嘞 +绊 +排 +术 +估 +锷 +违 +们 +苟 +铜 +播 +肘 +件 +烫 +审 +鲂 +广 +像 +铌 +惰 +铟 +巳 +胍 +鲍 +康 +憧 +色 +恢 +想 +拷 +尤 +疳 +知 +S +Y +F +D +A +峄 +裕 +帮 +握 +搔 +氐 +氘 +难 +墒 +沮 +雨 +叁 +缥 +悴 +藐 +湫 +娟 +苑 +稠 +颛 +簇 +后 +阕 +闭 +蕤 +缚 +怎 +佞 +码 +嘤 +蔡 +痊 +舱 +螯 +帕 +赫 +昵 +升 +烬 +岫 +、 +疵 +蜻 +髁 +蕨 +隶 +烛 +械 +丑 +盂 +梁 +强 +鲛 +由 +拘 +揉 +劭 +龟 +撤 +钩 +呕 +孛 +费 +妻 +漂 +求 +阑 +崖 +秤 +甘 +通 +深 +补 +赃 +坎 +床 +啪 +承 +吼 +量 +暇 +钼 +烨 +阂 +擎 +脱 +逮 +称 +P +神 +属 +矗 +华 +届 +狍 +葑 +汹 +育 +患 +窒 +蛰 +佼 +静 +槎 +运 +鳗 +庆 +逝 +曼 +疱 +克 +代 +官 +此 +麸 +耧 +蚌 +晟 +例 +础 +榛 +副 +测 +唰 +缢 +迹 +灬 +霁 +身 +岁 +赭 +扛 +又 +菡 +乜 +雾 +板 +读 +陷 +徉 +贯 +郁 +虑 +变 +钓 +菜 +圾 +现 +琢 +式 +乐 +维 +渔 +浜 +左 +吾 +脑 +钡 +警 +T +啵 +拴 +偌 +漱 +湿 +硕 +止 +骼 +魄 +积 +燥 +联 +踢 +玛 +则 +窿 +见 +振 +畿 +送 +班 +钽 +您 +赵 +刨 +印 +讨 +踝 +籍 +谡 +舌 +崧 +汽 +蔽 +沪 +酥 +绒 +怖 +财 +帖 +肱 +私 +莎 +勋 +羔 +霸 +励 +哼 +帐 +将 +帅 +渠 +纪 +婴 +娩 +岭 +厘 +滕 +吻 +伤 +坝 +冠 +戊 +隆 +瘁 +介 +涧 +物 +黍 +并 +姗 +奢 +蹑 +掣 +垸 +锴 +命 +箍 +捉 +病 +辖 +琰 +眭 +迩 +艘 +绌 +繁 +寅 +若 +毋 +思 +诉 +类 +诈 +燮 +轲 +酮 +狂 +重 +反 +职 +筱 +县 +委 +磕 +绣 +奖 +晋 +濉 +志 +徽 +肠 +呈 +獐 +坻 +口 +片 +碰 +几 +村 +柿 +劳 +料 +获 +亩 +惕 +晕 +厌 +号 +罢 +池 +正 +鏖 +煨 +家 +棕 +复 +尝 +懋 +蜥 +锅 +岛 +扰 +队 +坠 +瘾 +钬 +@ +卧 +疣 +镇 +譬 +冰 +彷 +频 +黯 +据 +垄 +采 +八 +缪 +瘫 +型 +熹 +砰 +楠 +襁 +箐 +但 +嘶 +绳 +啤 +拍 +盥 +穆 +傲 +洗 +盯 +塘 +怔 +筛 +丿 +台 +恒 +喂 +葛 +永 +¥ +烟 +酒 +桦 +书 +砂 +蚝 +缉 +态 +瀚 +袄 +圳 +轻 +蛛 +超 +榧 +遛 +姒 +奘 +铮 +右 +荽 +望 +偻 +卡 +丶 +氰 +附 +做 +革 +索 +戚 +坨 +桷 +唁 +垅 +榻 +岐 +偎 +坛 +莨 +山 +殊 +微 +骇 +陈 +爨 +推 +嗝 +驹 +澡 +藁 +呤 +卤 +嘻 +糅 +逛 +侵 +郓 +酌 +德 +摇 +※ +鬃 +被 +慨 +殡 +羸 +昌 +泡 +戛 +鞋 +河 +宪 +沿 +玲 +鲨 +翅 +哽 +源 +铅 +语 +照 +邯 +址 +荃 +佬 +顺 +鸳 +町 +霭 +睾 +瓢 +夸 +椁 +晓 +酿 +痈 +咔 +侏 +券 +噎 +湍 +签 +嚷 +离 +午 +尚 +社 +锤 +背 +孟 +使 +浪 +缦 +潍 +鞅 +军 +姹 +驶 +笑 +鳟 +鲁 +》 +孽 +钜 +绿 +洱 +礴 +焯 +椰 +颖 +囔 +乌 +孔 +巴 +互 +性 +椽 +哞 +聘 +昨 +早 +暮 +胶 +炀 +隧 +低 +彗 +昝 +铁 +呓 +氽 +藉 +喔 +癖 +瑗 +姨 +权 +胱 +韦 +堑 +蜜 +酋 +楝 +砝 +毁 +靓 +歙 +锲 +究 +屋 +喳 +骨 +辨 +碑 +武 +鸠 +宫 +辜 +烊 +适 +坡 +殃 +培 +佩 +供 +走 +蜈 +迟 +翼 +况 +姣 +凛 +浔 +吃 +飘 +债 +犟 +金 +促 +苛 +崇 +坂 +莳 +畔 +绂 +兵 +蠕 +斋 +根 +砍 +亢 +欢 +恬 +崔 +剁 +餐 +榫 +快 +扶 +‖ +濒 +缠 +鳜 +当 +彭 +驭 +浦 +篮 +昀 +锆 +秸 +钳 +弋 +娣 +瞑 +夷 +龛 +苫 +拱 +致 +% +嵊 +障 +隐 +弑 +初 +娓 +抉 +汩 +累 +蓖 +" +唬 +助 +苓 +昙 +押 +毙 +破 +城 +郧 +逢 +嚏 +獭 +瞻 +溱 +婿 +赊 +跨 +恼 +璧 +萃 +姻 +貉 +灵 +炉 +密 +氛 +陶 +砸 +谬 +衔 +点 +琛 +沛 +枳 +层 +岱 +诺 +脍 +榈 +埂 +征 +冷 +裁 +打 +蹴 +素 +瘘 +逞 +蛐 +聊 +激 +腱 +萘 +踵 +飒 +蓟 +吆 +取 +咙 +簋 +涓 +矩 +曝 +挺 +揣 +座 +你 +史 +舵 +焱 +尘 +苏 +笈 +脚 +溉 +榨 +诵 +樊 +邓 +焊 +义 +庶 +儋 +蟋 +蒲 +赦 +呷 +杞 +诠 +豪 +还 +试 +颓 +茉 +太 +除 +紫 +逃 +痴 +草 +充 +鳕 +珉 +祗 +墨 +渭 +烩 +蘸 +慕 +璇 +镶 +穴 +嵘 +恶 +骂 +险 +绋 +幕 +碉 +肺 +戳 +刘 +潞 +秣 +纾 +潜 +銮 +洛 +须 +罘 +销 +瘪 +汞 +兮 +屉 +r +林 +厕 +质 +探 +划 +狸 +殚 +善 +煊 +烹 +〒 +锈 +逯 +宸 +辍 +泱 +柚 +袍 +远 +蹋 +嶙 +绝 +峥 +娥 +缍 +雀 +徵 +认 +镱 +谷 += +贩 +勉 +撩 +鄯 +斐 +洋 +非 +祚 +泾 +诒 +饿 +撬 +威 +晷 +搭 +芍 +锥 +笺 +蓦 +候 +琊 +档 +礁 +沼 +卵 +荠 +忑 +朝 +凹 +瑞 +头 +仪 +弧 +孵 +畏 +铆 +突 +衲 +车 +浩 +气 +茂 +悖 +厢 +枕 +酝 +戴 +湾 +邹 +飚 +攘 +锂 +写 +宵 +翁 +岷 +无 +喜 +丈 +挑 +嗟 +绛 +殉 +议 +槽 +具 +醇 +淞 +笃 +郴 +阅 +饼 +底 +壕 +砚 +弈 +询 +缕 +庹 +翟 +零 +筷 +暨 +舟 +闺 +甯 +撞 +麂 +茌 +蔼 +很 +珲 +捕 +棠 +角 +阉 +媛 +娲 +诽 +剿 +尉 +爵 +睬 +韩 +诰 +匣 +危 +糍 +镯 +立 +浏 +阳 +少 +盆 +舔 +擘 +匪 +申 +尬 +铣 +旯 +抖 +赘 +瓯 +居 +ˇ +哮 +游 +锭 +茏 +歌 +坏 +甚 +秒 +舞 +沙 +仗 +劲 +潺 +阿 +燧 +郭 +嗖 +霏 +忠 +材 +奂 +耐 +跺 +砀 +输 +岖 +媳 +氟 +极 +摆 +灿 +今 +扔 +腻 +枝 +奎 +药 +熄 +吨 +话 +q +额 +慑 +嘌 +协 +喀 +壳 +埭 +视 +著 +於 +愧 +陲 +翌 +峁 +颅 +佛 +腹 +聋 +侯 +咎 +叟 +秀 +颇 +存 +较 +罪 +哄 +岗 +扫 +栏 +钾 +羌 +己 +璨 +枭 +霉 +煌 +涸 +衿 +键 +镝 +益 +岢 +奏 +连 +夯 +睿 +冥 +均 +糖 +狞 +蹊 +稻 +爸 +刿 +胥 +煜 +丽 +肿 +璃 +掸 +跚 +灾 +垂 +樾 +濑 +乎 +莲 +窄 +犹 +撮 +战 +馄 +软 +络 +显 +鸢 +胸 +宾 +妲 +恕 +埔 +蝌 +份 +遇 +巧 +瞟 +粒 +恰 +剥 +桡 +博 +讯 +凯 +堇 +阶 +滤 +卖 +斌 +骚 +彬 +兑 +磺 +樱 +舷 +两 +娱 +福 +仃 +差 +找 +桁 +÷ +净 +把 +阴 +污 +戬 +雷 +碓 +蕲 +楚 +罡 +焖 +抽 +妫 +咒 +仑 +闱 +尽 +邑 +菁 +爱 +贷 +沥 +鞑 +牡 +嗉 +崴 +骤 +塌 +嗦 +订 +拮 +滓 +捡 +锻 +次 +坪 +杩 +臃 +箬 +融 +珂 +鹗 +宗 +枚 +降 +鸬 +妯 +阄 +堰 +盐 +毅 +必 +杨 +崃 +俺 +甬 +状 +莘 +货 +耸 +菱 +腼 +铸 +唏 +痤 +孚 +澳 +懒 +溅 +翘 +疙 +杷 +淼 +缙 +骰 +喊 +悉 +砻 +坷 +艇 +赁 +界 +谤 +纣 +宴 +晃 +茹 +归 +饭 +梢 +铡 +街 +抄 +肼 +鬟 +苯 +颂 +撷 +戈 +炒 +咆 +茭 +瘙 +负 +仰 +客 +琉 +铢 +封 +卑 +珥 +椿 +镧 +窨 +鬲 +寿 +御 +袤 +铃 +萎 +砖 +餮 +脒 +裳 +肪 +孕 +嫣 +馗 +嵇 +恳 +氯 +江 +石 +褶 +冢 +祸 +阻 +狈 +羞 +银 +靳 +透 +咳 +叼 +敷 +芷 +啥 +它 +瓤 +兰 +痘 +懊 +逑 +肌 +往 +捺 +坊 +甩 +呻 +〃 +沦 +忘 +膻 +祟 +菅 +剧 +崆 +智 +坯 +臧 +霍 +墅 +攻 +眯 +倘 +拢 +骠 +铐 +庭 +岙 +瓠 +′ +缺 +泥 +迢 +捶 +? +? +郏 +喙 +掷 +沌 +纯 +秘 +种 +听 +绘 +固 +螨 +团 +香 +盗 +妒 +埚 +蓝 +拖 +旱 +荞 +铀 +血 +遏 +汲 +辰 +叩 +拽 +幅 +硬 +惶 +桀 +漠 +措 +泼 +唑 +齐 +肾 +念 +酱 +虚 +屁 +耶 +旗 +砦 +闵 +婉 +馆 +拭 +绅 +韧 +忏 +窝 +醋 +葺 +顾 +辞 +倜 +堆 +辋 +逆 +玟 +贱 +疾 +董 +惘 +倌 +锕 +淘 +嘀 +莽 +俭 +笏 +绑 +鲷 +杈 +择 +蟀 +粥 +嗯 +驰 +逾 +案 +谪 +褓 +胫 +哩 +昕 +颚 +鲢 +绠 +躺 +鹄 +崂 +儒 +俨 +丝 +尕 +泌 +啊 +萸 +彰 +幺 +吟 +骄 +苣 +弦 +脊 +瑰 +〈 +诛 +镁 +析 +闪 +剪 +侧 +哟 +框 +螃 +守 +嬗 +燕 +狭 +铈 +缮 +概 +迳 +痧 +鲲 +俯 +售 +笼 +痣 +扉 +挖 +满 +咋 +援 +邱 +扇 +歪 +便 +玑 +绦 +峡 +蛇 +叨 +〖 +泽 +胃 +斓 +喋 +怂 +坟 +猪 +该 +蚬 +炕 +弥 +赞 +棣 +晔 +娠 +挲 +狡 +创 +疖 +铕 +镭 +稷 +挫 +弭 +啾 +翔 +粉 +履 +苘 +哦 +楼 +秕 +铂 +土 +锣 +瘟 +挣 +栉 +习 +享 +桢 +袅 +磨 +桂 +谦 +延 +坚 +蔚 +噗 +署 +谟 +猬 +钎 +恐 +嬉 +雒 +倦 +衅 +亏 +璩 +睹 +刻 +殿 +王 +算 +雕 +麻 +丘 +柯 +骆 +丸 +塍 +谚 +添 +鲈 +垓 +桎 +蚯 +芥 +予 +飕 +镦 +谌 +窗 +醚 +菀 +亮 +搪 +莺 +蒿 +羁 +足 +J +真 +轶 +悬 +衷 +靛 +翊 +掩 +哒 +炅 +掐 +冼 +妮 +l +谐 +稚 +荆 +擒 +犯 +陵 +虏 +浓 +崽 +刍 +陌 +傻 +孜 +千 +靖 +演 +矜 +钕 +煽 +杰 +酗 +渗 +伞 +栋 +俗 +泫 +戍 +罕 +沾 +疽 +灏 +煦 +芬 +磴 +叱 +阱 +榉 +湃 +蜀 +叉 +醒 +彪 +租 +郡 +篷 +屎 +良 +垢 +隗 +弱 +陨 +峪 +砷 +掴 +颁 +胎 +雯 +绵 +贬 +沐 +撵 +隘 +篙 +暖 +曹 +陡 +栓 +填 +臼 +彦 +瓶 +琪 +潼 +哪 +鸡 +摩 +啦 +俟 +锋 +域 +耻 +蔫 +疯 +纹 +撇 +毒 +绶 +痛 +酯 +忍 +爪 +赳 +歆 +嘹 +辕 +烈 +册 +朴 +钱 +吮 +毯 +癜 +娃 +谀 +邵 +厮 +炽 +璞 +邃 +丐 +追 +词 +瓒 +忆 +轧 +芫 +谯 +喷 +弟 +半 +冕 +裙 +掖 +墉 +绮 +寝 +苔 +势 +顷 +褥 +切 +衮 +君 +佳 +嫒 +蚩 +霞 +佚 +洙 +逊 +镖 +暹 +唛 +& +殒 +顶 +碗 +獗 +轭 +铺 +蛊 +废 +恹 +汨 +崩 +珍 +那 +杵 +曲 +纺 +夏 +薰 +傀 +闳 +淬 +姘 +舀 +拧 +卷 +楂 +恍 +讪 +厩 +寮 +篪 +赓 +乘 +灭 +盅 +鞣 +沟 +慎 +挂 +饺 +鼾 +杳 +树 +缨 +丛 +絮 +娌 +臻 +嗳 +篡 +侩 +述 +衰 +矛 +圈 +蚜 +匕 +筹 +匿 +濞 +晨 +叶 +骋 +郝 +挚 +蚴 +滞 +增 +侍 +描 +瓣 +吖 +嫦 +蟒 +匾 +圣 +赌 +毡 +癞 +恺 +百 +曳 +需 +篓 +肮 +庖 +帏 +卿 +驿 +遗 +蹬 +鬓 +骡 +歉 +芎 +胳 +屐 +禽 +烦 +晌 +寄 +媾 +狄 +翡 +苒 +船 +廉 +终 +痞 +殇 +々 +畦 +饶 +改 +拆 +悻 +萄 +£ +瓿 +乃 +訾 +桅 +匮 +溧 +拥 +纱 +铍 +骗 +蕃 +龋 +缬 +父 +佐 +疚 +栎 +醍 +掳 +蓄 +x +惆 +颜 +鲆 +榆 +〔 +猎 +敌 +暴 +谥 +鲫 +贾 +罗 +玻 +缄 +扦 +芪 +癣 +落 +徒 +臾 +恿 +猩 +托 +邴 +肄 +牵 +春 +陛 +耀 +刊 +拓 +蓓 +邳 +堕 +寇 +枉 +淌 +啡 +湄 +兽 +酷 +萼 +碚 +濠 +萤 +夹 +旬 +戮 +梭 +琥 +椭 +昔 +勺 +蜊 +绐 +晚 +孺 +僵 +宣 +摄 +冽 +旨 +萌 +忙 +蚤 +眉 +噼 +蟑 +付 +契 +瓜 +悼 +颡 +壁 +曾 +窕 +颢 +澎 +仿 +俑 +浑 +嵌 +浣 +乍 +碌 +褪 +乱 +蔟 +隙 +玩 +剐 +葫 +箫 +纲 +围 +伐 +决 +伙 +漩 +瑟 +刑 +肓 +镳 +缓 +蹭 +氨 +皓 +典 +畲 +坍 +铑 +檐 +塑 +洞 +倬 +储 +胴 +淳 +戾 +吐 +灼 +惺 +妙 +毕 +珐 +缈 +虱 +盖 +羰 +鸿 +磅 +谓 +髅 +娴 +苴 +唷 +蚣 +霹 +抨 +贤 +唠 +犬 +誓 +逍 +庠 +逼 +麓 +籼 +釉 +呜 +碧 +秧 +氩 +摔 +霄 +穸 +纨 +辟 +妈 +映 +完 +牛 +缴 +嗷 +炊 +恩 +荔 +茆 +掉 +紊 +慌 +莓 +羟 +阙 +萁 +磐 +另 +蕹 +辱 +鳐 +湮 +吡 +吩 +唐 +睦 +垠 +舒 +圜 +冗 +瞿 +溺 +芾 +囱 +匠 +僳 +汐 +菩 +饬 +漓 +黑 +霰 +浸 +濡 +窥 +毂 +蒡 +兢 +驻 +鹉 +芮 +诙 +迫 +雳 +厂 +忐 +臆 +猴 +鸣 +蚪 +栈 +箕 +羡 +渐 +莆 +捍 +眈 +哓 +趴 +蹼 +埕 +嚣 +骛 +宏 +淄 +斑 +噜 +严 +瑛 +垃 +椎 +诱 +压 +庾 +绞 +焘 +廿 +抡 +迄 +棘 +夫 +纬 +锹 +眨 +瞌 +侠 +脐 +竞 +瀑 +孳 +骧 +遁 +姜 +颦 +荪 +滚 +萦 +伪 +逸 +粳 +爬 +锁 +矣 +役 +趣 +洒 +颔 +诏 +逐 +奸 +甭 +惠 +攀 +蹄 +泛 +尼 +拼 +阮 +鹰 +亚 +颈 +惑 +勒 +〉 +际 +肛 +爷 +刚 +钨 +丰 +养 +冶 +鲽 +辉 +蔻 +画 +覆 +皴 +妊 +麦 +返 +醉 +皂 +擀 +〗 +酶 +凑 +粹 +悟 +诀 +硖 +港 +卜 +z +杀 +涕 +± +舍 +铠 +抵 +弛 +段 +敝 +镐 +奠 +拂 +轴 +跛 +袱 +e +t +沉 +菇 +俎 +薪 +峦 +秭 +蟹 +历 +盟 +菠 +寡 +液 +肢 +喻 +染 +裱 +悱 +抱 +氙 +赤 +捅 +猛 +跑 +氮 +谣 +仁 +尺 +辊 +窍 +烙 +衍 +架 +擦 +倏 +璐 +瑁 +币 +楞 +胖 +夔 +趸 +邛 +惴 +饕 +虔 +蝎 +§ +哉 +贝 +宽 +辫 +炮 +扩 +饲 +籽 +魏 +菟 +锰 +伍 +猝 +末 +琳 +哚 +蛎 +邂 +呀 +姿 +鄞 +却 +歧 +仙 +恸 +椐 +森 +牒 +寤 +袒 +婆 +虢 +雅 +钉 +朵 +贼 +欲 +苞 +寰 +故 +龚 +坭 +嘘 +咫 +礼 +硷 +兀 +睢 +汶 +’ +铲 +烧 +绕 +诃 +浃 +钿 +哺 +柜 +讼 +颊 +璁 +腔 +洽 +咐 +脲 +簌 +筠 +镣 +玮 +鞠 +谁 +兼 +姆 +挥 +梯 +蝴 +谘 +漕 +刷 +躏 +宦 +弼 +b +垌 +劈 +麟 +莉 +揭 +笙 +渎 +仕 +嗤 +仓 +配 +怏 +抬 +错 +泯 +镊 +孰 +猿 +邪 +仍 +秋 +鼬 +壹 +歇 +吵 +炼 +< +尧 +射 +柬 +廷 +胧 +霾 +凳 +隋 +肚 +浮 +梦 +祥 +株 +堵 +退 +L +鹫 +跎 +凶 +毽 +荟 +炫 +栩 +玳 +甜 +沂 +鹿 +顽 +伯 +爹 +赔 +蛴 +徐 +匡 +欣 +狰 +缸 +雹 +蟆 +疤 +默 +沤 +啜 +痂 +衣 +禅 +w +i +h +辽 +葳 +黝 +钗 +停 +沽 +棒 +馨 +颌 +肉 +吴 +硫 +悯 +劾 +娈 +马 +啧 +吊 +悌 +镑 +峭 +帆 +瀣 +涉 +咸 +疸 +滋 +泣 +翦 +拙 +癸 +钥 +蜒 ++ +尾 +庄 +凝 +泉 +婢 +渴 +谊 +乞 +陆 +锉 +糊 +鸦 +淮 +I +B +N +晦 +弗 +乔 +庥 +葡 +尻 +席 +橡 +傣 +渣 +拿 +惩 +麋 +斛 +缃 +矮 +蛏 +岘 +鸽 +姐 +膏 +催 +奔 +镒 +喱 +蠡 +摧 +钯 +胤 +柠 +拐 +璋 +鸥 +卢 +荡 +倾 +^ +_ +珀 +逄 +萧 +塾 +掇 +贮 +笆 +聂 +圃 +冲 +嵬 +M +滔 +笕 +值 +炙 +偶 +蜱 +搐 +梆 +汪 +蔬 +腑 +鸯 +蹇 +敞 +绯 +仨 +祯 +谆 +梧 +糗 +鑫 +啸 +豺 +囹 +猾 +巢 +柄 +瀛 +筑 +踌 +沭 +暗 +苁 +鱿 +蹉 +脂 +蘖 +牢 +热 +木 +吸 +溃 +宠 +序 +泞 +偿 +拜 +檩 +厚 +朐 +毗 +螳 +吞 +媚 +朽 +担 +蝗 +橘 +畴 +祈 +糟 +盱 +隼 +郜 +惜 +珠 +裨 +铵 +焙 +琚 +唯 +咚 +噪 +骊 +丫 +滢 +勤 +棉 +呸 +咣 +淀 +隔 +蕾 +窈 +饨 +挨 +煅 +短 +匙 +粕 +镜 +赣 +撕 +墩 +酬 +馁 +豌 +颐 +抗 +酣 +氓 +佑 +搁 +哭 +递 +耷 +涡 +桃 +贻 +碣 +截 +瘦 +昭 +镌 +蔓 +氚 +甲 +猕 +蕴 +蓬 +散 +拾 +纛 +狼 +猷 +铎 +埋 +旖 +矾 +讳 +囊 +糜 +迈 +粟 +蚂 +紧 +鲳 +瘢 +栽 +稼 +羊 +锄 +斟 +睁 +桥 +瓮 +蹙 +祉 +醺 +鼻 +昱 +剃 +跳 +篱 +跷 +蒜 +翎 +宅 +晖 +嗑 +壑 +峻 +癫 +屏 +狠 +陋 +袜 +途 +憎 +祀 +莹 +滟 +佶 +溥 +臣 +约 +盛 +峰 +磁 +慵 +婪 +拦 +莅 +朕 +鹦 +粲 +裤 +哎 +疡 +嫖 +琵 +窟 +堪 +谛 +嘉 +儡 +鳝 +斩 +郾 +驸 +酊 +妄 +胜 +贺 +徙 +傅 +噌 +钢 +栅 +庇 +恋 +匝 +巯 +邈 +尸 +锚 +粗 +佟 +蛟 +薹 +纵 +蚊 +郅 +绢 +锐 +苗 +俞 +篆 +淆 +膀 +鲜 +煎 +诶 +秽 +寻 +涮 +刺 +怀 +噶 +巨 +褰 +魅 +灶 +灌 +桉 +藕 +谜 +舸 +薄 +搀 +恽 +借 +牯 +痉 +渥 +愿 +亓 +耘 +杠 +柩 +锔 +蚶 +钣 +珈 +喘 +蹒 +幽 +赐 +稗 +晤 +莱 +泔 +扯 +肯 +菪 +裆 +腩 +豉 +疆 +骜 +腐 +倭 +珏 +唔 +粮 +亡 +润 +慰 +伽 +橄 +玄 +誉 +醐 +胆 +龊 +粼 +塬 +陇 +彼 +削 +嗣 +绾 +芽 +妗 +垭 +瘴 +爽 +薏 +寨 +龈 +泠 +弹 +赢 +漪 +猫 +嘧 +涂 +恤 +圭 +茧 +烽 +屑 +痕 +巾 +赖 +荸 +凰 +腮 +畈 +亵 +蹲 +偃 +苇 +澜 +艮 +换 +骺 +烘 +苕 +梓 +颉 +肇 +哗 +悄 +氤 +涠 +葬 +屠 +鹭 +植 +竺 +佯 +诣 +鲇 +瘀 +鲅 +邦 +移 +滁 +冯 +耕 +癔 +戌 +茬 +沁 +巩 +悠 +湘 +洪 +痹 +锟 +循 +谋 +腕 +鳃 +钠 +捞 +焉 +迎 +碱 +伫 +急 +榷 +奈 +邝 +卯 +辄 +皲 +卟 +醛 +畹 +忧 +稳 +雄 +昼 +缩 +阈 +睑 +扌 +耗 +曦 +涅 +捏 +瞧 +邕 +淖 +漉 +铝 +耦 +禹 +湛 +喽 +莼 +琅 +诸 +苎 +纂 +硅 +始 +嗨 +傥 +燃 +臂 +赅 +嘈 +呆 +贵 +屹 +壮 +肋 +亍 +蚀 +卅 +豹 +腆 +邬 +迭 +浊 +} +童 +螂 +捐 +圩 +勐 +触 +寞 +汊 +壤 +荫 +膺 +渌 +芳 +懿 +遴 +螈 +泰 +蓼 +蛤 +茜 +舅 +枫 +朔 +膝 +眙 +避 +梅 +判 +鹜 +璜 +牍 +缅 +垫 +藻 +黔 +侥 +惚 +懂 +踩 +腰 +腈 +札 +丞 +唾 +慈 +顿 +摹 +荻 +琬 +~ +斧 +沈 +滂 +胁 +胀 +幄 +莜 +Z +匀 +鄄 +掌 +绰 +茎 +焚 +赋 +萱 +谑 +汁 +铒 +瞎 +夺 +蜗 +野 +娆 +冀 +弯 +篁 +懵 +灞 +隽 +芡 +脘 +俐 +辩 +芯 +掺 +喏 +膈 +蝈 +觐 +悚 +踹 +蔗 +熠 +鼠 +呵 +抓 +橼 +峨 +畜 +缔 +禾 +崭 +弃 +熊 +摒 +凸 +拗 +穹 +蒙 +抒 +祛 +劝 +闫 +扳 +阵 +醌 +踪 +喵 +侣 +搬 +仅 +荧 +赎 +蝾 +琦 +买 +婧 +瞄 +寓 +皎 +冻 +赝 +箩 +莫 +瞰 +郊 +笫 +姝 +筒 +枪 +遣 +煸 +袋 +舆 +痱 +涛 +母 +〇 +启 +践 +耙 +绲 +盘 +遂 +昊 +搞 +槿 +诬 +纰 +泓 +惨 +檬 +亻 +越 +C +o +憩 +熵 +祷 +钒 +暧 +塔 +阗 +胰 +咄 +娶 +魔 +琶 +钞 +邻 +扬 +杉 +殴 +咽 +弓 +〆 +髻 +】 +吭 +揽 +霆 +拄 +殖 +脆 +彻 +岩 +芝 +勃 +辣 +剌 +钝 +嘎 +甄 +佘 +皖 +伦 +授 +徕 +憔 +挪 +皇 +庞 +稔 +芜 +踏 +溴 +兖 +卒 +擢 +饥 +鳞 +煲 +‰ +账 +颗 +叻 +斯 +捧 +鳍 +琮 +讹 +蛙 +纽 +谭 +酸 +兔 +莒 +睇 +伟 +觑 +羲 +嗜 +宜 +褐 +旎 +辛 +卦 +诘 +筋 +鎏 +溪 +挛 +熔 +阜 +晰 +鳅 +丢 +奚 +灸 +呱 +献 +陉 +黛 +鸪 +甾 +萨 +疮 +拯 +洲 +疹 +辑 +叙 +恻 +谒 +允 +柔 +烂 +氏 +逅 +漆 +拎 +惋 +扈 +湟 +纭 +啕 +掬 +擞 +哥 +忽 +涤 +鸵 +靡 +郗 +瓷 +扁 +廊 +怨 +雏 +钮 +敦 +E +懦 +憋 +汀 +拚 +啉 +腌 +岸 +f +痼 +瞅 +尊 +咀 +眩 +飙 +忌 +仝 +迦 +熬 +毫 +胯 +篑 +茄 +腺 +凄 +舛 +碴 +锵 +诧 +羯 +後 +漏 +汤 +宓 +仞 +蚁 +壶 +谰 +皑 +铄 +棰 +罔 +辅 +晶 +苦 +牟 +闽 +\ +烃 +饮 +聿 +丙 +蛳 +朱 +煤 +涔 +鳖 +犁 +罐 +荼 +砒 +淦 +妤 +黏 +戎 +孑 +婕 +瑾 +戢 +钵 +枣 +捋 +砥 +衩 +狙 +桠 +稣 +阎 +肃 +梏 +诫 +孪 +昶 +婊 +衫 +嗔 +侃 +塞 +蜃 +樵 +峒 +貌 +屿 +欺 +缫 +阐 +栖 +诟 +珞 +荭 +吝 +萍 +嗽 +恂 +啻 +蜴 +磬 +峋 +俸 +豫 +谎 +徊 +镍 +韬 +魇 +晴 +U +囟 +猜 +蛮 +坐 +囿 +伴 +亭 +肝 +佗 +蝠 +妃 +胞 +滩 +榴 +氖 +垩 +苋 +砣 +扪 +馏 +姓 +轩 +厉 +夥 +侈 +禀 +垒 +岑 +赏 +钛 +辐 +痔 +披 +纸 +碳 +“ +坞 +蠓 +挤 +荥 +沅 +悔 +铧 +帼 +蒌 +蝇 +a +p +y +n +g +哀 +浆 +瑶 +凿 +桶 +馈 +皮 +奴 +苜 +佤 +伶 +晗 +铱 +炬 +优 +弊 +氢 +恃 +甫 +攥 +端 +锌 +灰 +稹 +炝 +曙 +邋 +亥 +眶 +碾 +拉 +萝 +绔 +捷 +浍 +腋 +姑 +菖 +凌 +涞 +麽 +锢 +桨 +潢 +绎 +镰 +殆 +锑 +渝 +铬 +困 +绽 +觎 +匈 +糙 +暑 +裹 +鸟 +盔 +肽 +迷 +綦 +『 +亳 +佝 +俘 +钴 +觇 +骥 +仆 +疝 +跪 +婶 +郯 +瀹 +唉 +脖 +踞 +针 +晾 +忒 +扼 +瞩 +叛 +椒 +疟 +嗡 +邗 +肆 +跆 +玫 +忡 +捣 +咧 +唆 +艄 +蘑 +潦 +笛 +阚 +沸 +泻 +掊 +菽 +贫 +斥 +髂 +孢 +镂 +赂 +麝 +鸾 +屡 +衬 +苷 +恪 +叠 +希 +粤 +爻 +喝 +茫 +惬 +郸 +绻 +庸 +撅 +碟 +宄 +妹 +膛 +叮 +饵 +崛 +嗲 +椅 +冤 +搅 +咕 +敛 +尹 +垦 +闷 +蝉 +霎 +勰 +败 +蓑 +泸 +肤 +鹌 +幌 +焦 +浠 +鞍 +刁 +舰 +乙 +竿 +裔 +。 +茵 +函 +伊 +兄 +丨 +娜 +匍 +謇 +莪 +宥 +似 +蝽 +翳 +酪 +翠 +粑 +薇 +祢 +骏 +赠 +叫 +Q +噤 +噻 +竖 +芗 +莠 +潭 +俊 +羿 +耜 +O +郫 +趁 +嗪 +囚 +蹶 +芒 +洁 +笋 +鹑 +敲 +硝 +啶 +堡 +渲 +揩 +』 +携 +宿 +遒 +颍 +扭 +棱 +割 +萜 +蔸 +葵 +琴 +捂 +饰 +衙 +耿 +掠 +募 +岂 +窖 +涟 +蔺 +瘤 +柞 +瞪 +怜 +匹 +距 +楔 +炜 +哆 +秦 +缎 +幼 +茁 +绪 +痨 +恨 +楸 +娅 +瓦 +桩 +雪 +嬴 +伏 +榔 +妥 +铿 +拌 +眠 +雍 +缇 +‘ +卓 +搓 +哌 +觞 +噩 +屈 +哧 +髓 +咦 +巅 +娑 +侑 +淫 +膳 +祝 +勾 +姊 +莴 +胄 +疃 +薛 +蜷 +胛 +巷 +芙 +芋 +熙 +闰 +勿 +窃 +狱 +剩 +钏 +幢 +陟 +铛 +慧 +靴 +耍 +k +浙 +浇 +飨 +惟 +绗 +祜 +澈 +啼 +咪 +磷 +摞 +诅 +郦 +抹 +跃 +壬 +吕 +肖 +琏 +颤 +尴 +剡 +抠 +凋 +赚 +泊 +津 +宕 +殷 +倔 +氲 +漫 +邺 +涎 +怠 +$ +垮 +荬 +遵 +俏 +叹 +噢 +饽 +蜘 +孙 +筵 +疼 +鞭 +羧 +牦 +箭 +潴 +c +眸 +祭 +髯 +啖 +坳 +愁 +芩 +驮 +倡 +巽 +穰 +沃 +胚 +怒 +凤 +槛 +剂 +趵 +嫁 +v +邢 +灯 +鄢 +桐 +睽 +檗 +锯 +槟 +婷 +嵋 +圻 +诗 +蕈 +颠 +遭 +痢 +芸 +怯 +馥 +竭 +锗 +徜 +恭 +遍 +籁 +剑 +嘱 +苡 +龄 +僧 +桑 +潸 +弘 +澶 +楹 +悲 +讫 +愤 +腥 +悸 +谍 +椹 +呢 +桓 +葭 +攫 +阀 +翰 +躲 +敖 +柑 +郎 +笨 +橇 +呃 +魁 +燎 +脓 +葩 +磋 +垛 +玺 +狮 +沓 +砜 +蕊 +锺 +罹 +蕉 +翱 +虐 +闾 +巫 +旦 +茱 +嬷 +枯 +鹏 +贡 +芹 +汛 +矫 +绁 +拣 +禺 +佃 +讣 +舫 +惯 +乳 +趋 +疲 +挽 +岚 +虾 +衾 +蠹 +蹂 +飓 +氦 +铖 +孩 +稞 +瑜 +壅 +掀 +勘 +妓 +畅 +髋 +W +庐 +牲 +蓿 +榕 +练 +垣 +唱 +邸 +菲 +昆 +婺 +穿 +绡 +麒 +蚱 +掂 +愚 +泷 +涪 +漳 +妩 +娉 +榄 +讷 +觅 +旧 +藤 +煮 +呛 +柳 +腓 +叭 +庵 +烷 +阡 +罂 +蜕 +擂 +猖 +咿 +媲 +脉 +【 +沏 +貅 +黠 +熏 +哲 +烁 +坦 +酵 +兜 +× +潇 +撒 +剽 +珩 +圹 +乾 +摸 +樟 +帽 +嗒 +襄 +魂 +轿 +憬 +锡 +〕 +喃 +皆 +咖 +隅 +脸 +残 +泮 +袂 +鹂 +珊 +囤 +捆 +咤 +误 +徨 +闹 +淙 +芊 +淋 +怆 +囗 +拨 +梳 +渤 +R +G +绨 +蚓 +婀 +幡 +狩 +麾 +谢 +唢 +裸 +旌 +伉 +纶 +裂 +驳 +砼 +咛 +澄 +樨 +蹈 +宙 +澍 +倍 +貔 +操 +勇 +蟠 +摈 +砧 +虬 +够 +缁 +悦 +藿 +撸 +艹 +摁 +淹 +豇 +虎 +榭 +ˉ +吱 +d +° +喧 +荀 +踱 +侮 +奋 +偕 +饷 +犍 +惮 +坑 +璎 +徘 +宛 +妆 +袈 +倩 +窦 +昂 +荏 +乖 +K +怅 +撰 +鳙 +牙 +袁 +酞 +X +痿 +琼 +闸 +雁 +趾 +荚 +虻 +涝 +《 +杏 +韭 +偈 +烤 +绫 +鞘 +卉 +症 +遢 +蓥 +诋 +杭 +荨 +匆 +竣 +簪 +辙 +敕 +虞 +丹 +缭 +咩 +黟 +m +淤 +瑕 +咂 +铉 +硼 +茨 +嶂 +痒 +畸 +敬 +涿 +粪 +窘 +熟 +叔 +嫔 +盾 +忱 +裘 +憾 +梵 +赡 +珙 +咯 +娘 +庙 +溯 +胺 +葱 +痪 +摊 +荷 +卞 +乒 +髦 +寐 +铭 +坩 +胗 +枷 +爆 +溟 +嚼 +羚 +砬 +轨 +惊 +挠 +罄 +竽 +菏 +氧 +浅 +楣 +盼 +枢 +炸 +阆 +杯 +谏 +噬 +淇 +渺 +俪 +秆 +墓 +泪 +跻 +砌 +痰 +垡 +渡 +耽 +釜 +讶 +鳎 +煞 +呗 +韶 +舶 +绷 +鹳 +缜 +旷 +铊 +皱 +龌 +檀 +霖 +奄 +槐 +艳 +蝶 +旋 +哝 +赶 +骞 +蚧 +腊 +盈 +丁 +` +蜚 +矸 +蝙 +睨 +嚓 +僻 +鬼 +醴 +夜 +彝 +磊 +笔 +拔 +栀 +糕 +厦 +邰 +纫 +逭 +纤 +眦 +膊 +馍 +躇 +烯 +蘼 +冬 +诤 +暄 +骶 +哑 +瘠 +」 +臊 +丕 +愈 +咱 +螺 +擅 +跋 +搏 +硪 +谄 +笠 +淡 +嘿 +骅 +谧 +鼎 +皋 +姚 +歼 +蠢 +驼 +耳 +胬 +挝 +涯 +狗 +蒽 +孓 +犷 +凉 +芦 +箴 +铤 +孤 +嘛 +坤 +V +茴 +朦 +挞 +尖 +橙 +诞 +搴 +碇 +洵 +浚 +帚 +蜍 +漯 +柘 +嚎 +讽 +芭 +荤 +咻 +祠 +秉 +跖 +埃 +吓 +糯 +眷 +馒 +惹 +娼 +鲑 +嫩 +讴 +轮 +瞥 +靶 +褚 +乏 +缤 +宋 +帧 +删 +驱 +碎 +扑 +俩 +俄 +偏 +涣 +竹 +噱 +皙 +佰 +渚 +唧 +斡 +# +镉 +刀 +崎 +筐 +佣 +夭 +贰 +肴 +峙 +哔 +艿 +匐 +牺 +镛 +缘 +仡 +嫡 +劣 +枸 +堀 +梨 +簿 +鸭 +蒸 +亦 +稽 +浴 +{ +衢 +束 +槲 +j +阁 +揍 +疥 +棋 +潋 +聪 +窜 +乓 +睛 +插 +冉 +阪 +苍 +搽 +「 +蟾 +螟 +幸 +仇 +樽 +撂 +慢 +跤 +幔 +俚 +淅 +覃 +觊 +溶 +妖 +帛 +侨 +曰 +妾 +泗 +· +: +瀘 +風 +Ë +( +) +∶ +紅 +紗 +瑭 +雲 +頭 +鶏 +財 +許 +• +¥ +樂 +焗 +麗 +— +; +滙 +東 +榮 +繪 +興 +… +門 +業 +π +楊 +國 +顧 +é +盤 +寳 +Λ +龍 +鳳 +島 +誌 +緣 +結 +銭 +萬 +勝 +祎 +璟 +優 +歡 +臨 +時 +購 += +★ +藍 +昇 +鐵 +觀 +勅 +農 +聲 +畫 +兿 +術 +發 +劉 +記 +專 +耑 +園 +書 +壴 +種 +Ο +● +褀 +號 +銀 +匯 +敟 +锘 +葉 +橪 +廣 +進 +蒄 +鑽 +阝 +祙 +貢 +鍋 +豊 +夬 +喆 +團 +閣 +開 +燁 +賓 +館 +酡 +沔 +順 ++ +硚 +劵 +饸 +陽 +車 +湓 +復 +萊 +氣 +軒 +華 +堃 +迮 +纟 +戶 +馬 +學 +裡 +電 +嶽 +獨 +マ +シ +サ +ジ +燘 +袪 +環 +❤ +臺 +灣 +専 +賣 +孖 +聖 +攝 +線 +▪ +α +傢 +俬 +夢 +達 +莊 +喬 +貝 +薩 +劍 +羅 +壓 +棛 +饦 +尃 +璈 +囍 +醫 +G +I +A +# +N +鷄 +髙 +嬰 +啓 +約 +隹 +潔 +賴 +藝 +~ +寶 +籣 +麺 +  +嶺 +√ +義 +網 +峩 +長 +∧ +魚 +機 +構 +② +鳯 +偉 +L +B +㙟 +畵 +鴿 +' +詩 +溝 +嚞 +屌 +藔 +佧 +玥 +蘭 +織 +1 +3 +9 +0 +7 +點 +砭 +鴨 +鋪 +銘 +廳 +弍 +‧ +創 +湯 +坶 +℃ +卩 +骝 +& +烜 +荘 +當 +潤 +扞 +係 +懷 +碶 +钅 +蚨 +讠 +☆ +叢 +爲 +埗 +涫 +塗 +→ +楽 +現 +鯨 +愛 +瑪 +鈺 +忄 +悶 +藥 +飾 +樓 +視 +孬 +ㆍ +燚 +苪 +師 +① +丼 +锽 +│ +韓 +標 +è +兒 +閏 +匋 +張 +漢 +Ü +髪 +會 +閑 +檔 +習 +裝 +の +峯 +菘 +輝 +И +雞 +釣 +億 +浐 +K +O +R +8 +H +E +P +T +W +D +S +C +M +F +姌 +饹 +» +晞 +廰 +ä +嵯 +鷹 +負 +飲 +絲 +冚 +楗 +澤 +綫 +區 +❋ +← +質 +靑 +揚 +③ +滬 +統 +産 +協 +﹑ +乸 +畐 +經 +運 +際 +洺 +岽 +為 +粵 +諾 +崋 +豐 +碁 +ɔ +V +2 +6 +齋 +誠 +訂 +´ +勑 +雙 +陳 +無 +í +泩 +媄 +夌 +刂 +i +c +t +o +r +a +嘢 +耄 +燴 +暃 +壽 +媽 +靈 +抻 +體 +唻 +É +冮 +甹 +鎮 +錦 +ʌ +蜛 +蠄 +尓 +駕 +戀 +飬 +逹 +倫 +貴 +極 +Я +Й +寬 +磚 +嶪 +郎 +職 +| +間 +n +d +剎 +伈 +課 +飛 +橋 +瘊 +№ +譜 +骓 +圗 +滘 +縣 +粿 +咅 +養 +濤 +彳 +® +% +Ⅱ +啰 +㴪 +見 +矞 +薬 +糁 +邨 +鲮 +顔 +罱 +З +選 +話 +贏 +氪 +俵 +競 +瑩 +繡 +枱 +β +綉 +á +獅 +爾 +™ +麵 +戋 +淩 +徳 +個 +劇 +場 +務 +簡 +寵 +h +實 +膠 +轱 +圖 +築 +嘣 +樹 +㸃 +營 +耵 +孫 +饃 +鄺 +飯 +麯 +遠 +輸 +坫 +孃 +乚 +閃 +鏢 +㎡ +題 +廠 +關 +↑ +爺 +將 +軍 +連 +篦 +覌 +參 +箸 +- +窠 +棽 +寕 +夀 +爰 +歐 +呙 +閥 +頡 +熱 +雎 +垟 +裟 +凬 +勁 +帑 +馕 +夆 +疌 +枼 +馮 +貨 +蒤 +樸 +彧 +旸 +靜 +龢 +暢 +㐱 +鳥 +珺 +鏡 +灡 +爭 +堷 +廚 +Ó +騰 +診 +┅ +蘇 +褔 +凱 +頂 +豕 +亞 +帥 +嘬 +⊥ +仺 +桖 +複 +饣 +絡 +穂 +顏 +棟 +納 +▏ +濟 +親 +設 +計 +攵 +埌 +烺 +ò +頤 +燦 +蓮 +撻 +節 +講 +濱 +濃 +娽 +洳 +朿 +燈 +鈴 +護 +膚 +铔 +過 +補 +Z +U +5 +4 +坋 +闿 +䖝 +餘 +缐 +铞 +貿 +铪 +桼 +趙 +鍊 +[ +㐂 +垚 +菓 +揸 +捲 +鐘 +滏 +𣇉 +爍 +輪 +燜 +鴻 +鮮 +動 +鹞 +鷗 +丄 +慶 +鉌 +翥 +飮 +腸 +⇋ +漁 +覺 +來 +熘 +昴 +翏 +鲱 +圧 +鄉 +萭 +頔 +爐 +嫚 +г +貭 +類 +聯 +幛 +輕 +訓 +鑒 +夋 +锨 +芃 +珣 +䝉 +扙 +嵐 +銷 +處 +ㄱ +語 +誘 +苝 +歸 +儀 +燒 +楿 +內 +粢 +葒 +奧 +麥 +礻 +滿 +蠔 +穵 +瞭 +態 +鱬 +榞 +硂 +鄭 +黃 +煙 +祐 +奓 +逺 +* +瑄 +獲 +聞 +薦 +讀 +這 +樣 +決 +問 +啟 +們 +執 +説 +轉 +單 +隨 +唘 +帶 +倉 +庫 +還 +贈 +尙 +皺 +■ +餅 +產 +○ +∈ +報 +狀 +楓 +賠 +琯 +嗮 +禮 +` +傳 +> +≤ +嗞 +Φ +≥ +換 +咭 +∣ +↓ +曬 +ε +応 +寫 +″ +終 +様 +純 +費 +療 +聨 +凍 +壐 +郵 +ü +黒 +∫ +製 +塊 +調 +軽 +確 +撃 +級 +馴 +Ⅲ +涇 +繹 +數 +碼 +證 +狒 +処 +劑 +< +晧 +賀 +衆 +] +櫥 +兩 +陰 +絶 +對 +鯉 +憶 +◎ +p +e +Y +蕒 +煖 +頓 +測 +試 +鼽 +僑 +碩 +妝 +帯 +≈ +鐡 +舖 +權 +喫 +倆 +ˋ +該 +悅 +ā +俫 +. +f +s +b +m +k +g +u +j +貼 +淨 +濕 +針 +適 +備 +l +/ +給 +謢 +強 +觸 +衛 +與 +⊙ +$ +緯 +變 +⑴ +⑵ +⑶ +㎏ +殺 +∩ +幚 +─ +價 +▲ +離 +ú +ó +飄 +烏 +関 +閟 +﹝ +﹞ +邏 +輯 +鍵 +驗 +訣 +導 +歷 +屆 +層 +▼ +儱 +錄 +熳 +ē +艦 +吋 +錶 +辧 +飼 +顯 +④ +禦 +販 +気 +対 +枰 +閩 +紀 +幹 +瞓 +貊 +淚 +△ +眞 +墊 +Ω +獻 +褲 +縫 +緑 +亜 +鉅 +餠 +{ +} +◆ +蘆 +薈 +█ +◇ +溫 +彈 +晳 +粧 +犸 +穩 +訊 +崬 +凖 +熥 +П +舊 +條 +紋 +圍 +Ⅳ +筆 +尷 +難 +雜 +錯 +綁 +識 +頰 +鎖 +艶 +□ +殁 +殼 +⑧ +├ +▕ +鵬 +ǐ +ō +ǒ +糝 +綱 +▎ +μ +盜 +饅 +醬 +籤 +蓋 +釀 +鹽 +據 +à +ɡ +辦 +◥ +彐 +┌ +婦 +獸 +鲩 +伱 +ī +蒟 +蒻 +齊 +袆 +腦 +寧 +凈 +妳 +煥 +詢 +偽 +謹 +啫 +鯽 +騷 +鱸 +損 +傷 +鎻 +髮 +買 +冏 +儥 +両 +﹢ +∞ +載 +喰 +z +羙 +悵 +燙 +曉 +員 +組 +徹 +艷 +痠 +鋼 +鼙 +縮 +細 +嚒 +爯 +≠ +維 +" +鱻 +壇 +厍 +帰 +浥 +犇 +薡 +軎 +² +應 +醜 +刪 +緻 +鶴 +賜 +噁 +軌 +尨 +镔 +鷺 +槗 +彌 +葚 +濛 +請 +溇 +緹 +賢 +訪 +獴 +瑅 +資 +縤 +陣 +蕟 +栢 +韻 +祼 +恁 +伢 +謝 +劃 +涑 +總 +衖 +踺 +砋 +凉 +籃 +駿 +苼 +瘋 +昽 +紡 +驊 +腎 +﹗ +響 +杋 +剛 +嚴 +禪 +歓 +槍 +傘 +檸 +檫 +炣 +勢 +鏜 +鎢 +銑 +尐 +減 +奪 +惡 +θ +僮 +婭 +臘 +ū +ì +殻 +鉄 +∑ +蛲 +焼 +緖 +續 +紹 +懮 \ No newline at end of file diff --git a/ocr/recognizer.py b/ocr/recognizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c6c7d7a7367a6e8f28420ad6dee4c1fb0f8427 --- /dev/null +++ b/ocr/recognizer.py @@ -0,0 +1,434 @@ +import os +import sys + +from PIL import Image + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) + +os.environ["FLAGS_allocator_strategy"] = "auto_growth" + +import math +import time + +import cv2 +import numpy as np + +import utility +from postprocess import build_post_process + + +def _check_image_file(path): + img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif"} + return any([path.lower().endswith(e) for e in img_end]) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif"} + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +def check_and_read_gif(img_path): + if os.path.basename(img_path)[-3:] in ["gif", "GIF"]: + gif = cv2.VideoCapture(img_path) + ret, frame = gif.read() + if not ret: + return None, False + if len(frame.shape) == 2 or frame.shape[-1] == 1: + frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) + imgvalue = frame[:, :, ::-1] + return imgvalue, True + return None, False + + +class TextRecognizer(object): + def __init__(self, args): + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm + postprocess_params = { + "name": "CTCLabelDecode", + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + } + if self.rec_algorithm == "SRN": + postprocess_params = { + "name": "SRNLabelDecode", + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + } + elif self.rec_algorithm == "RARE": + postprocess_params = { + "name": "AttnLabelDecode", + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + } + elif self.rec_algorithm == "NRTR": + postprocess_params = { + "name": "NRTRLabelDecode", + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + } + elif self.rec_algorithm == "SAR": + postprocess_params = { + "name": "SARLabelDecode", + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char, + } + self.postprocess_op = build_post_process(postprocess_params) + ( + self.predictor, + self.input_tensor, + self.output_tensors, + self.config, + ) = utility.create_predictor(args, "rec") + self.use_onnx = args.use_onnx + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == "NRTR": + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + image_pil = Image.fromarray(np.uint8(img)) + img = image_pil.resize([100, 32], Image.ANTIALIAS) + img = np.array(img) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + return norm_img.astype(np.float32) / 128.0 - 1.0 + + assert imgC == img.shape[2] + imgW = int((imgH * max_wh_ratio)) + if self.use_onnx: + w = self.input_tensor.shape[3:][0] + if w is not None and w > 0: + imgW = w + + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + if self.rec_algorithm == "RARE": + if resized_w > self.rec_image_shape[2]: + resized_w = self.rec_image_shape[2] + imgW = self.rec_image_shape[2] + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype("float32") + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def resize_norm_img_svtr(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize(img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype("float32") + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0 : img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = ( + np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype("int64") + ) + gsrm_word_pos = ( + np.array(range(0, max_text_length)) + .reshape((max_text_length, 1)) + .astype("int64") + ) + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length] + ) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]).astype( + "float32" + ) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length] + ) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]).astype( + "float32" + ) * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, + gsrm_word_pos, + gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2, + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [ + encoder_word_pos, + gsrm_word_pos, + gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2, + ] = self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return ( + norm_img, + encoder_word_pos, + gsrm_word_pos, + gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2, + ) + + def resize_norm_img_sar(self, img, image_shape, width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype("float32") + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + rec_res = [["", 0.0]] * img_num + batch_num = self.rec_batch_num + st = time.time() + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + imgC, imgH, imgW = self.rec_image_shape + max_wh_ratio = imgW / imgH + # max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + + if self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape + ) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) + elif self.rec_algorithm == "SRN": + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25 + ) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + elif self.rec_algorithm == "SVTR": + norm_img = self.resize_norm_img_svtr( + img_list[indices[ino]], self.rec_image_shape + ) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.resize_norm_img( + img_list[indices[ino]], max_wh_ratio + ) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + + if self.rec_algorithm == "SRN": + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, + encoder_word_pos_list, + gsrm_word_pos_list, + gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = {"predict": outputs[2]} + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = outputs[0] + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + else: + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if len(outputs) != 1: + preds = outputs + else: + preds = outputs[0] + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + return rec_res, time.time() - st + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_recognizer = TextRecognizer(args) + valid_image_file_list = [] + img_list = [] + + # warmup 2 times + if args.warmup: + img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8) + for i in range(2): + res = text_recognizer([img] * int(args.rec_batch_num)) + + for image_file in image_file_list: + img = cv2.imread(image_file) + valid_image_file_list.append(image_file) + img_list.append(img) + + for i in range(10): + t0 = time.time() + rec_res, _ = text_recognizer(img_list) + print((time.time() - t0) * 1000) + + for ino in range(len(img_list)): + print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/ocr/utility.py b/ocr/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..2460d8e2dd9a50192c522840cdc6dc8e4105e1f5 --- /dev/null +++ b/ocr/utility.py @@ -0,0 +1,614 @@ +import argparse +import math +import os +import platform + +import cv2 +import numpy as np +import paddle +from paddle import inference +from PIL import Image, ImageDraw, ImageFont + + +def str2bool(v): + return v.lower() in ("true", "t", "1") + + +def init_args(): + parser = argparse.ArgumentParser() + # params for prediction engine + parser.add_argument("--use_gpu", type=str2bool, default=False) + parser.add_argument("--use_xpu", type=str2bool, default=False) + parser.add_argument("--ir_optim", type=str2bool, default=False) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--min_subgraph_size", type=int, default=15) + parser.add_argument("--precision", type=str, default="fp32") + parser.add_argument("--gpu_mem", type=int, default=500) + + # params for text detector + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default="DB") + parser.add_argument("--det_model_dir", type=str, default="./ocr/ch_PP-OCRv3_det_infer/") + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default="max") + + # DB parmas + parser.add_argument("--det_db_thresh", type=float, default=0.1) + parser.add_argument("--det_db_box_thresh", type=float, default=0.1) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.7) + parser.add_argument("--max_batch_size", type=int, default=10) + parser.add_argument("--use_dilation", type=str2bool, default=True) + parser.add_argument("--det_db_score_mode", type=str, default="fast") + + # EAST parmas + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + + # SAST parmas + parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) + parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) + parser.add_argument("--det_sast_polygon", type=str2bool, default=False) + + # PSE parmas + parser.add_argument("--det_pse_thresh", type=float, default=0) + parser.add_argument("--det_pse_box_thresh", type=float, default=0.85) + parser.add_argument("--det_pse_min_area", type=float, default=16) + parser.add_argument("--det_pse_box_type", type=str, default="quad") + parser.add_argument("--det_pse_scale", type=int, default=1) + + # FCE parmas + parser.add_argument("--scales", type=list, default=[8, 16, 32]) + parser.add_argument("--alpha", type=float, default=1.0) + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--fourier_degree", type=int, default=5) + parser.add_argument("--det_fce_box_type", type=str, default="poly") + + # params for text recognizer + parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet") + parser.add_argument("--rec_model_dir", type=str, default="./ocr/ch_PP-OCRv3_rec_infer/") + parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320") + parser.add_argument("--rec_batch_num", type=int, default=6) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument( + "--rec_char_dict_path", type=str, default="./ocr/ppocr/ppocr_keys_v1.txt" + ) + parser.add_argument("--use_space_char", type=str2bool, default=True) + parser.add_argument("--drop_score", type=float, default=0.5) + + # params for text classifier + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + parser.add_argument("--cls_model_dir", type=str) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=["0", "180"]) + parser.add_argument("--cls_batch_num", type=int, default=6) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=str2bool, default=True) + parser.add_argument("--cpu_threads", type=int, default=10) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=False) + + # + parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results") + parser.add_argument("--save_crop_res", type=str2bool, default=False) + parser.add_argument("--crop_res_save_dir", type=str, default="./output") + + # multi-process + parser.add_argument("--use_mp", type=str2bool, default=False) + parser.add_argument("--total_process_num", type=int, default=1) + parser.add_argument("--process_id", type=int, default=0) + + parser.add_argument("--benchmark", type=str2bool, default=False) + parser.add_argument("--save_log_path", type=str, default="./log_output/") + + parser.add_argument("--use_onnx", type=str2bool, default=False) + return parser + + +def parse_args(): + parser = init_args() + return parser.parse_args() + + +def create_predictor(args, mode): + if mode == "det": + model_dir = args.det_model_dir + elif mode == "rec": + model_dir = args.rec_model_dir + + if args.use_onnx: + import onnxruntime as ort + + model_file_path = model_dir + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format(model_file_path)) + sess = ort.InferenceSession(model_file_path) + return sess, sess.get_inputs()[0], None, None + + else: + model_file_path = model_dir + "/inference.pdmodel" + params_file_path = model_dir + "/inference.pdiparams" + if not os.path.exists(model_file_path): + raise ValueError("not find model file path {}".format(model_file_path)) + if not os.path.exists(params_file_path): + raise ValueError("not find params file path {}".format(params_file_path)) + + config = inference.Config(model_file_path, params_file_path) + + if hasattr(args, "precision"): + if args.precision == "fp16" and args.use_tensorrt: + precision = inference.PrecisionType.Half + elif args.precision == "int8": + precision = inference.PrecisionType.Int8 + else: + precision = inference.PrecisionType.Float32 + else: + precision = inference.PrecisionType.Float32 + + if args.use_gpu: + gpu_id = get_infer_gpuid() + config.enable_use_gpu(args.gpu_mem, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=precision, + max_batch_size=args.max_batch_size, + min_subgraph_size=args.min_subgraph_size, + ) + # skip the minmum trt subgraph + use_dynamic_shape = True + if mode == "det": + min_input_shape = { + "x": [1, 3, 50, 50], + "conv2d_92.tmp_0": [1, 120, 20, 20], + "conv2d_91.tmp_0": [1, 24, 10, 10], + "conv2d_59.tmp_0": [1, 96, 20, 20], + "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10], + "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20], + "conv2d_124.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20], + "elementwise_add_7": [1, 56, 2, 2], + "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2], + } + max_input_shape = { + "x": [1, 3, 1536, 1536], + "conv2d_92.tmp_0": [1, 120, 400, 400], + "conv2d_91.tmp_0": [1, 24, 200, 200], + "conv2d_59.tmp_0": [1, 96, 400, 400], + "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200], + "conv2d_124.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400], + "elementwise_add_7": [1, 56, 400, 400], + "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400], + } + opt_input_shape = { + "x": [1, 3, 640, 640], + "conv2d_92.tmp_0": [1, 120, 160, 160], + "conv2d_91.tmp_0": [1, 24, 80, 80], + "conv2d_59.tmp_0": [1, 96, 160, 160], + "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80], + "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160], + "conv2d_124.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160], + "elementwise_add_7": [1, 56, 40, 40], + "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40], + } + min_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20], + "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20], + "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20], + } + max_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400], + "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400], + "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400], + } + opt_pact_shape = { + "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160], + "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160], + "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160], + } + min_input_shape.update(min_pact_shape) + max_input_shape.update(max_pact_shape) + opt_input_shape.update(opt_pact_shape) + elif mode == "rec": + if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]: + use_dynamic_shape = False + imgH = int(args.rec_image_shape.split(",")[-2]) + min_input_shape = {"x": [1, 3, imgH, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]} + opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]} + config.exp_disable_tensorrt_ops(["transpose2"]) + elif mode == "cls": + min_input_shape = {"x": [1, 3, 48, 10]} + max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]} + opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]} + else: + use_dynamic_shape = False + if use_dynamic_shape: + config.set_trt_dynamic_shape_info( + min_input_shape, max_input_shape, opt_input_shape + ) + + elif args.use_xpu: + config.enable_xpu(10 * 1024 * 1024) + else: + config.disable_gpu() + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if args.precision == "fp16": + config.enable_mkldnn_bfloat16() + # enable memory optim + config.enable_memory_optim() + config.disable_glog_info() + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + config.delete_pass("matmul_transpose_reshape_fuse_pass") + if mode == "table": + config.delete_pass("fc_fuse_pass") # not supported for table + config.switch_use_feed_fetch_ops(False) + config.switch_ir_optim(True) + + # create predictor + predictor = inference.create_predictor(config) + input_names = predictor.get_input_names() + for name in input_names: + input_tensor = predictor.get_input_handle(name) + output_tensors = get_output_tensors(args, mode, predictor) + return predictor, input_tensor, output_tensors, config + + +def get_output_tensors(args, mode, predictor): + output_names = predictor.get_output_names() + output_tensors = [] + if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]: + output_name = "softmax_0.tmp_0" + if output_name in output_names: + return [predictor.get_output_handle(output_name)] + else: + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + else: + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + return output_tensors + + +def get_infer_gpuid(): + sysstr = platform.system() + if sysstr == "Windows": + return 0 + + if not paddle.fluid.core.is_compiled_with_rocm(): + cmd = "env | grep CUDA_VISIBLE_DEVICES" + else: + cmd = "env | grep HIP_VISIBLE_DEVICES" + env_cuda = os.popen(cmd).readlines() + if len(env_cuda) == 0: + return 0 + else: + gpu_id = env_cuda[0].strip().split("=")[1] + return int(gpu_id[0]) + + +def draw_e2e_res(dt_boxes, strs, img_path): + src_im = cv2.imread(img_path) + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1, + ) + return src_im + + +def draw_text_det_res(dt_boxes, img_path): + src_im = cv2.imread(img_path) + for box in dt_boxes: + box = np.array(box).astype(np.int32).reshape(-1, 2) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + return src_im + + +def resize_img(img, input_size=600): + """ + resize img and limit the longest side of the image to input_size + """ + img = np.array(img) + im_shape = img.shape + im_size_max = np.max(im_shape[0:2]) + im_scale = float(input_size) / float(im_size_max) + img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale) + return img + + +def draw_ocr( + image, + boxes, + txts=None, + scores=None, + drop_score=0.5, + font_path="./doc/fonts/simfang.ttf", +): + """ + Visualize the results of OCR detection and recognition + args: + image(Image|array): RGB image + boxes(list): boxes with shape(N, 4, 2) + txts(list): the texts + scores(list): txxs corresponding scores + drop_score(float): only scores greater than drop_threshold will be visualized + font_path: the path of font which is used to draw text + return(array): + the visualized img + """ + if scores is None: + scores = [1] * len(boxes) + box_num = len(boxes) + for i in range(box_num): + if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])): + continue + box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64) + image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) + if txts is not None: + img = np.array(resize_img(image, input_size=600)) + txt_img = text_visual( + txts, + scores, + img_h=img.shape[0], + img_w=600, + threshold=drop_score, + font_path=font_path, + ) + img = np.concatenate([np.array(img), np.array(txt_img)], axis=1) + return img + return image + + +def draw_ocr_box_txt( + image, boxes, txts, scores=None, drop_score=0.5, font_path="./doc/simfang.ttf" +): + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + + import random + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + if scores is not None and scores[idx] < drop_score: + continue + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon(box, fill=color) + draw_right.polygon( + [ + box[0][0], + box[0][1], + box[1][0], + box[1][1], + box[2][0], + box[2][1], + box[3][0], + box[3][1], + ], + outline=color, + ) + box_height = math.sqrt( + (box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2 + ) + box_width = math.sqrt( + (box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2 + ) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + for c in txt: + char_size = font.getsize(c) + draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) + cur_y += char_size[1] + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + +def str_count(s): + """ + Count the number of Chinese characters, + a single English character and a single number + equal to half the length of Chinese characters. + args: + s(string): the input of string + return(int): + the number of Chinese characters + """ + import string + + count_zh = count_pu = 0 + s_len = len(s) + en_dg_count = 0 + for c in s: + if c in string.ascii_letters or c.isdigit() or c.isspace(): + en_dg_count += 1 + elif c.isalpha(): + count_zh += 1 + else: + count_pu += 1 + return s_len - math.ceil(en_dg_count / 2) + + +def text_visual( + texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf" +): + """ + create new blank img and draw txt on it + args: + texts(list): the text will be draw + scores(list|None): corresponding score of each txt + img_h(int): the height of blank img + img_w(int): the width of blank img + font_path: the path of font which is used to draw text + return(array): + """ + if scores is not None: + assert len(texts) == len( + scores + ), "The number of txts and corresponding scores must match" + + def create_blank_img(): + blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 + blank_img[:, img_w - 1 :] = 0 + blank_img = Image.fromarray(blank_img).convert("RGB") + draw_txt = ImageDraw.Draw(blank_img) + return blank_img, draw_txt + + blank_img, draw_txt = create_blank_img() + + font_size = 20 + txt_color = (0, 0, 0) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + + gap = font_size + 5 + txt_img_list = [] + count, index = 1, 0 + for idx, txt in enumerate(texts): + index += 1 + if scores[idx] < threshold or math.isnan(scores[idx]): + index -= 1 + continue + first_line = True + while str_count(txt) >= img_w // font_size - 4: + tmp = txt + txt = tmp[: img_w // font_size - 4] + if first_line: + new_txt = str(index) + ": " + txt + first_line = False + else: + new_txt = " " + txt + draw_txt.text((0, gap * count), new_txt, txt_color, font=font) + txt = tmp[img_w // font_size - 4 :] + if count >= img_h // gap - 1: + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + count += 1 + if first_line: + new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx]) + else: + new_txt = " " + txt + " " + "%.3f" % (scores[idx]) + draw_txt.text((0, gap * count), new_txt, txt_color, font=font) + # whether add new blank img or not + if count >= img_h // gap - 1 and idx + 1 < len(texts): + txt_img_list.append(np.array(blank_img)) + blank_img, draw_txt = create_blank_img() + count = 0 + count += 1 + txt_img_list.append(np.array(blank_img)) + if len(txt_img_list) == 1: + blank_img = np.array(txt_img_list[0]) + else: + blank_img = np.concatenate(txt_img_list, axis=1) + return np.array(blank_img) + + +def base64_to_cv2(b64str): + import base64 + + data = base64.b64decode(b64str.encode("utf8")) + data = np.frombuffer(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + + +def draw_boxes(image, boxes, scores=None, drop_score=0.5): + if scores is None: + scores = [1] * len(boxes) + for (box, score) in zip(boxes, scores): + if score < drop_score: + continue + box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) + image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) + return image + + +def get_rotate_crop_image(img, points): + """ + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + """ + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3]) + ) + ) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2]) + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, + (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC, + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def check_gpu(use_gpu): + if use_gpu and not paddle.is_compiled_with_cuda(): + use_gpu = False + return use_gpu diff --git a/requirements.txt b/requirements.txt index 66b1130082bb273a76c261bf6eb536fec748670c..8d6fd18cfb9886e66f43f7f9d9025730913f77c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,13 @@ openai Pillow -easyocr gradio -deta \ No newline at end of file +deta +paddlepaddle +opencv-python +Pillow +numpy==1.23.3 +pandas +imutils +Cython +imgaug +pyclipper \ No newline at end of file