Faizan Azizahmed Shaikh commited on
Commit
3bad857
1 Parent(s): 3c8df48

Initial Commit

Browse files
Files changed (3) hide show
  1. Chart_QnA.ipynb +110 -0
  2. app.py +63 -0
  3. requirements.txt +3 -0
Chart_QnA.ipynb ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a5831e0b-d99b-4f34-a65e-97f5d09f00ec",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# import required libraries\n",
11
+ "from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration\n",
12
+ "import gradio as gr"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "6881b277-9511-4460-a0aa-19b8d9e61fdf",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# pipeline function with default values\n",
23
+ "def query(image, user_question):\n",
24
+ " \"\"\"\n",
25
+ " image: single image or batch of images;\n",
26
+ " question: user prompt question;\n",
27
+ " \"\"\"\n",
28
+ " # select model from hugging face\n",
29
+ " model_name = \"google/deplot\"\n",
30
+ " # set preprocessor for current model\n",
31
+ " processor = Pix2StructProcessor.from_pretrained(model_name)\n",
32
+ " # load pre-trained model\n",
33
+ " model = Pix2StructForConditionalGeneration.from_pretrained(model_name)\n",
34
+ " # process the inputs for prediction\n",
35
+ " inputs = processor(images=image, text=user_question, return_tensors=\"pt\")\n",
36
+ " # save the results\n",
37
+ " predictions = model.generate(**inputs, max_new_tokens=512)\n",
38
+ " # save output\n",
39
+ " result = processor.decode(predictions[0], skip_special_tokens=True)\n",
40
+ " # process the results for output table\n",
41
+ " outs = [x.strip() for x in result.split(\"<0x0A>\")]\n",
42
+ " # create an empty list\n",
43
+ " nested = list()\n",
44
+ " # loop for splitting the data\n",
45
+ " for data in outs:\n",
46
+ " if \"|\" in data:\n",
47
+ " nested.append([x.strip() for x in data.split(\"|\")])\n",
48
+ " else:\n",
49
+ " nested.append(data)\n",
50
+ " # return the converted output\n",
51
+ " return nested"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "04526adc-1ce4-48c6-b635-13bf506ed862",
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "name": "stdout",
62
+ "output_type": "stream",
63
+ "text": [
64
+ "Using cache from 'C:\\Users\\faiza\\huggingface\\Group Project\\gradio_cached_examples\\14' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n",
65
+ "\n",
66
+ "Running on local URL: http://127.0.0.1:7860\n",
67
+ "\n",
68
+ "To create a public link, set `share=True` in `launch()`.\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "# Interface framework to customize the io page \n",
74
+ "ui = gr.Interface(title=\"Chart Q/A\",\n",
75
+ " fn=query,\n",
76
+ " inputs=[gr.Image(label=\"Upload Here\", type=\"pil\"), gr.Textbox(label=\"Question?\")],\n",
77
+ " outputs=\"list\",\n",
78
+ " examples=[[\"./samples/sample1.png\", \"Generate underlying data table of the figure\"], \n",
79
+ " [\"./samples/sample2.png\", \"Is the sum of all 4 places greater than Laos?\"]],\n",
80
+ " # [\"./samples/sample3.webp\", \"What are the 2020 net sales?\"]],\n",
81
+ " cache_examples=True,\n",
82
+ " allow_flagging='never')\n",
83
+ "\n",
84
+ "ui.queue(api_open=False)\n",
85
+ "ui.launch(inline=False, share=False, debug=True)"
86
+ ]
87
+ }
88
+ ],
89
+ "metadata": {
90
+ "kernelspec": {
91
+ "display_name": "Python 3 (ipykernel)",
92
+ "language": "python",
93
+ "name": "python3"
94
+ },
95
+ "language_info": {
96
+ "codemirror_mode": {
97
+ "name": "ipython",
98
+ "version": 3
99
+ },
100
+ "file_extension": ".py",
101
+ "mimetype": "text/x-python",
102
+ "name": "python",
103
+ "nbconvert_exporter": "python",
104
+ "pygments_lexer": "ipython3",
105
+ "version": "3.11.4"
106
+ }
107
+ },
108
+ "nbformat": 4,
109
+ "nbformat_minor": 5
110
+ }
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # import required libraries
8
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
9
+ import gradio as gr
10
+
11
+
12
+ # In[2]:
13
+
14
+
15
+ # pipeline function with default values
16
+ def query(image, user_question):
17
+ """
18
+ image: single image or batch of images;
19
+ question: user prompt question;
20
+ """
21
+ # select model from hugging face
22
+ model_name = "google/deplot"
23
+ # set preprocessor for current model
24
+ processor = Pix2StructProcessor.from_pretrained(model_name)
25
+ # load pre-trained model
26
+ model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
27
+ # process the inputs for prediction
28
+ inputs = processor(images=image, text=user_question, return_tensors="pt")
29
+ # save the results
30
+ predictions = model.generate(**inputs, max_new_tokens=512)
31
+ # save output
32
+ result = processor.decode(predictions[0], skip_special_tokens=True)
33
+ # process the results for output table
34
+ outs = [x.strip() for x in result.split("<0x0A>")]
35
+ # create an empty list
36
+ nested = list()
37
+ # loop for splitting the data
38
+ for data in outs:
39
+ if "|" in data:
40
+ nested.append([x.strip() for x in data.split("|")])
41
+ else:
42
+ nested.append(data)
43
+ # return the converted output
44
+ return nested
45
+
46
+
47
+ # In[ ]:
48
+
49
+
50
+ # Interface framework to customize the io page
51
+ ui = gr.Interface(title="Chart Q/A",
52
+ fn=query,
53
+ inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")],
54
+ outputs="list",
55
+ examples=[["./samples/sample1.png", "Generate underlying data table of the figure"],
56
+ ["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]],
57
+ # ["./samples/sample3.webp", "What are the 2020 net sales?"]],
58
+ cache_examples=True,
59
+ allow_flagging='never')
60
+
61
+ ui.queue(api_open=False)
62
+ ui.launch(inline=False, share=False, debug=True)
63
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==3.40.1
2
+ transformers==4.31.0
3
+ accelerate==0.21.0