oussama commited on
Commit
44e3a91
·
1 Parent(s): 721bf42

Create new file

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from transformers import AutoModelForTokenClassification
7
+ from datasets.features import ClassLabel
8
+ from transformers import AutoProcessor
9
+ from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
10
+ import torch
11
+ from datasets import load_metric
12
+ from transformers import LayoutLMv3ForTokenClassification
13
+ from transformers.data.data_collator import default_data_collator
14
+
15
+
16
+ from transformers import AutoModelForTokenClassification
17
+ from datasets import load_dataset
18
+ from PIL import Image, ImageDraw, ImageFont
19
+
20
+
21
+ processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True)
22
+ model = AutoModelForTokenClassification.from_pretrained("oussama/Layoutlm_Form_information_extraction")
23
+
24
+
25
+
26
+ # load image example
27
+ dataset = load_dataset("darentang/generated", split="test")
28
+ Image.open(dataset[2]["image_path"]).convert("RGB").save("example1.png")
29
+ Image.open(dataset[1]["image_path"]).convert("RGB").save("example2.png")
30
+ Image.open(dataset[0]["image_path"]).convert("RGB").save("example3.png")
31
+ # define id2label, label2color
32
+ labels = dataset.features['ner_tags'].feature.names
33
+ id2label = {v: k for v, k in enumerate(labels)}
34
+ label2color = {'question':'blue', 'answer':'green', 'header':'orange', 'other':'violet'}
35
+
36
+
37
+ def unnormalize_box(bbox, width, height):
38
+ return [
39
+ width * (bbox[0] / 1000),
40
+ height * (bbox[1] / 1000),
41
+ width * (bbox[2] / 1000),
42
+ height * (bbox[3] / 1000),
43
+ ]
44
+
45
+
46
+ def iob_to_label(label):
47
+ return label
48
+
49
+
50
+
51
+ def process_image(image):
52
+
53
+ print(type(image))
54
+ width, height = image.size
55
+
56
+ # encode
57
+ encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
58
+ offset_mapping = encoding.pop('offset_mapping')
59
+
60
+ # forward pass
61
+ outputs = model(**encoding)
62
+
63
+ # get predictions
64
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
65
+ token_boxes = encoding.bbox.squeeze().tolist()
66
+
67
+ # only keep non-subword predictions
68
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
69
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
70
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
71
+
72
+ # draw predictions over the image
73
+ draw = ImageDraw.Draw(image)
74
+ font = ImageFont.load_default()
75
+ for prediction, box in zip(true_predictions, true_boxes):
76
+ predicted_label = iob_to_label(prediction)
77
+ draw.rectangle(box, outline=label2color[predicted_label])
78
+ draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font)
79
+
80
+ return image
81
+
82
+
83
+ title = "Extraction d'informations de factures en utilisant le modèle LayoutLMv3"
84
+ description = "J'utilise LayoutLMv3 de Microsoft formé sur un ensemble de données de factures pour prédire le nom de l'émetteur de factures, l'adresse de l'émetteur de factures, le code postal de l'émetteur de factures, la date d'échéance, la TPS, la date de facturation, le numéro de facture, le sous-total et le total. Pour l'utiliser, il suffit de télécharger une image ou d'utiliser l'exemple d'image ci-dessous. Les résultats seront affichés en quelques secondes."
85
+
86
+ article="<b>References</b><br>[1] Y. Xu et al., “LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking.” 2022. <a href='https://arxiv.org/abs/2204.08387'>Paper Link</a><br>[2] <a href='https://github.com/NielsRogge/Transformers-Tutorials/tree/master/LayoutLMv3'>LayoutLMv3 training and inference</a>"
87
+
88
+ examples =[['example1.png'],['example2.png'],['example3.png']]
89
+
90
+ css = """.output_image, .input_image {height: 600px !important}"""
91
+
92
+ iface = gr.Interface(fn=process_image,
93
+ inputs=gr.inputs.Image(type="pil"),
94
+ outputs=gr.outputs.Image(type="pil", label="annotated image"),
95
+ title=title,
96
+ description=description,
97
+ article=article,
98
+ examples=examples,
99
+ css=css,
100
+ analytics_enabled = True, enable_queue=True)
101
+
102
+ iface.launch(inline=False, share=False, debug=False)