john000z commited on
Commit
d00a15d
·
1 Parent(s): bb264c2

gitignore ipynb txt

Browse files
Files changed (2) hide show
  1. .gitignore +0 -0
  2. home.ipynb +305 -0
.gitignore ADDED
File without changes
home.ipynb ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Overwriting utools.py\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "%%writefile utools.py\n",
18
+ "import tflite_runtime.interpreter as tflite \n",
19
+ "import tflite_runtime\n",
20
+ "import numpy as np\n",
21
+ "ROWS_PER_FRAME=543\n",
22
+ "def load_relevant_data_subset(df):\n",
23
+ " data_columns = ['x', 'y', 'z']\n",
24
+ " data=df[data_columns]\n",
25
+ " n_frames = int(len(data) / ROWS_PER_FRAME)#单个文件的总帧数\n",
26
+ " data = data.values.reshape(n_frames, ROWS_PER_FRAME, len(data_columns))\n",
27
+ " return data.astype(np.float32)\n",
28
+ "\n",
29
+ "def mark_pred(model_path_1,aa):\n",
30
+ " interpreter = tflite.Interpreter(model_path_1)\n",
31
+ " found_signatures = list(interpreter.get_signature_list().keys())\n",
32
+ " prediction_fn = interpreter.get_signature_runner(\"serving_default\")\n",
33
+ " output_1 = prediction_fn(inputs=aa)\n",
34
+ " return output_1\n",
35
+ "\n",
36
+ "def softmax(x, axis=None):\n",
37
+ " x_exp = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
38
+ " return x_exp / np.sum(x_exp, axis=axis, keepdims=True)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": []
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 2,
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ "Overwriting model.py\n"
58
+ ]
59
+ }
60
+ ],
61
+ "source": [
62
+ "%%writefile model.py\n",
63
+ "import pandas as pd\n",
64
+ "import numpy as np\n",
65
+ "import os\n",
66
+ "import shutil\n",
67
+ "from datetime import datetime\n",
68
+ "from timeit import default_timer as timer\n",
69
+ "from utools import load_relevant_data_subset,mark_pred\n",
70
+ "from utools import softmax\n",
71
+ "import mediapipe as mp\n",
72
+ "import cv2\n",
73
+ "import json\n",
74
+ "N=3\n",
75
+ "\n",
76
+ "ROWS_PER_FRAME=543\n",
77
+ "with open('sign_to_prediction_index_map_cn.json', 'r') as f:\n",
78
+ " person_dict = json.load(f)\n",
79
+ "inverse_dict=dict([val,key] for key,val in person_dict.items())\n",
80
+ "\n",
81
+ "\n",
82
+ "def r_holistic(video_path):\n",
83
+ " mp_drawing = mp.solutions.drawing_utils\n",
84
+ " mp_drawing_styles = mp.solutions.drawing_styles\n",
85
+ " mp_holistic = mp.solutions.holistic\n",
86
+ " frame_number = 0\n",
87
+ " frame = []\n",
88
+ " type_ = []\n",
89
+ " index = []\n",
90
+ " x = []\n",
91
+ " y = []\n",
92
+ " z = []\n",
93
+ " cap=cv2.VideoCapture(video_path)\n",
94
+ " frame_width = int(cap.get(3))\n",
95
+ " frame_height = int(cap.get(4))\n",
96
+ " fps = int(cap.get(cv2.CAP_PROP_FPS))\n",
97
+ " frame_size = (frame_width, frame_height)\n",
98
+ " fourcc = cv2.VideoWriter_fourcc(*\"VP80\") #cv2.VideoWriter_fourcc('H.264')\n",
99
+ " output_video = \"output_recorded_holistic.webm\"\n",
100
+ " out = cv2.VideoWriter(output_video, fourcc, int(fps/N), frame_size)\n",
101
+ " with mp_holistic.Holistic(min_detection_confidence=0.5,min_tracking_confidence=0.5) as holistic:\n",
102
+ " n=0\n",
103
+ " while cap.isOpened():\n",
104
+ " frame_number+=1\n",
105
+ " n+=1\n",
106
+ " ret, image = cap.read()\n",
107
+ " if not ret:\n",
108
+ " break\n",
109
+ " if n%N==0:\n",
110
+ " image.flags.writeable = False\n",
111
+ " image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)\n",
112
+ " #mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=RGB_frame)\n",
113
+ " results = holistic.process(image)\n",
114
+ "\n",
115
+ " # Draw landmark annotation on the image.\n",
116
+ " image.flags.writeable = True\n",
117
+ " image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
118
+ " mp_drawing.draw_landmarks(\n",
119
+ " image,\n",
120
+ " results.face_landmarks,\n",
121
+ " mp_holistic.FACEMESH_CONTOURS,\n",
122
+ " landmark_drawing_spec=None,\n",
123
+ " connection_drawing_spec=mp_drawing_styles\n",
124
+ " .get_default_face_mesh_contours_style())\n",
125
+ " mp_drawing.draw_landmarks(\n",
126
+ " image,\n",
127
+ " results.pose_landmarks,\n",
128
+ " mp_holistic.POSE_CONNECTIONS,\n",
129
+ " landmark_drawing_spec=mp_drawing_styles\n",
130
+ " .get_default_pose_landmarks_style())\n",
131
+ " # Flip the image horizontally for a selfie-view display.\n",
132
+ " #if cv2.waitKey(5) & 0xFF == 27:\n",
133
+ " out.write(image)\n",
134
+ " \n",
135
+ " if(results.face_landmarks is None):\n",
136
+ " for i in range(468):\n",
137
+ " frame.append(frame_number)\n",
138
+ " type_.append(\"face\")\n",
139
+ " index.append(ind)\n",
140
+ " x.append(None)\n",
141
+ " y.append(None)\n",
142
+ " z.append(None)\n",
143
+ " else:\n",
144
+ " for ind,val in enumerate(results.face_landmarks.landmark):\n",
145
+ " frame.append(frame_number)\n",
146
+ " type_.append(\"face\")\n",
147
+ " index.append(ind)\n",
148
+ " x.append(val.x)\n",
149
+ " y.append(val.y)\n",
150
+ " z.append(val.z)\n",
151
+ " #left hand\n",
152
+ " if(results.left_hand_landmarks is None):\n",
153
+ " for i in range(21):\n",
154
+ " frame.append(frame_number)\n",
155
+ " type_.append(\"left_hand\")\n",
156
+ " index.append(ind)\n",
157
+ " x.append(None)\n",
158
+ " y.append(None)\n",
159
+ " z.append(None)\n",
160
+ " else:\n",
161
+ " for ind,val in enumerate(results.left_hand_landmarks.landmark):\n",
162
+ " frame.append(frame_number)\n",
163
+ " type_.append(\"left_hand\")\n",
164
+ " index.append(ind)\n",
165
+ " x.append(val.x)\n",
166
+ " y.append(val.y)\n",
167
+ " z.append(val.z)\n",
168
+ " #pose\n",
169
+ " if(results.pose_landmarks is None):\n",
170
+ " for i in range(33):\n",
171
+ " frame.append(frame_number)\n",
172
+ " type_.append(\"pose\")\n",
173
+ " index.append(ind)\n",
174
+ " x.append(None)\n",
175
+ " y.append(None)\n",
176
+ " z.append(None)\n",
177
+ " else:\n",
178
+ " for ind,val in enumerate(results.pose_landmarks.landmark):\n",
179
+ " frame.append(frame_number)\n",
180
+ " type_.append(\"pose\")\n",
181
+ " index.append(ind)\n",
182
+ " x.append(val.x)\n",
183
+ " y.append(val.y)\n",
184
+ " z.append(val.z)\n",
185
+ " #right hand\n",
186
+ " if(results.right_hand_landmarks is None):\n",
187
+ " for i in range(21):\n",
188
+ " frame.append(frame_number)\n",
189
+ " type_.append(\"right_hand\")\n",
190
+ " index.append(ind)\n",
191
+ " x.append(None)\n",
192
+ " y.append(None)\n",
193
+ " z.append(None)\n",
194
+ " else:\n",
195
+ " for ind,val in enumerate(results.right_hand_landmarks.landmark):\n",
196
+ " frame.append(frame_number)\n",
197
+ " type_.append(\"right_hand\")\n",
198
+ " index.append(ind)\n",
199
+ " x.append(val.x)\n",
200
+ " y.append(val.y)\n",
201
+ " z.append(val.z)\n",
202
+ " #break\n",
203
+ " cap.release()\n",
204
+ " out.release()\n",
205
+ " cv2.destroyAllWindows()\n",
206
+ " df1 = pd.DataFrame({\n",
207
+ " \"frame\" : frame,\n",
208
+ " \"type\" : type_,\n",
209
+ " \"landmark_index\" : index,\n",
210
+ " \"x\" : x,\n",
211
+ " \"y\" : y,\n",
212
+ " \"z\" : z\n",
213
+ " })\n",
214
+ " aa=load_relevant_data_subset(df1)\n",
215
+ " model_path_1='model_1.tflite'\n",
216
+ " model_path_2='model_2.tflite'\n",
217
+ " model_path_3='model_3.tflite'\n",
218
+ " #interpreter = tflite.Interpreter(model_path_1)\n",
219
+ " #found_signatures = list(interpreter.get_signature_list().keys())\n",
220
+ " #prediction_fn = interpreter.get_signature_runner(\"serving_default\")\n",
221
+ " output_1 = mark_pred(model_path_1,aa)\n",
222
+ " output_2 = mark_pred(model_path_2,aa)\n",
223
+ " output_3 = mark_pred(model_path_3,aa)\n",
224
+ " output=softmax(output_1['outputs'])+softmax(output_2['outputs'])+softmax(output_3['outputs'])\n",
225
+ " sign = output.argmax()\n",
226
+ " lb = inverse_dict.get(sign)\n",
227
+ " yield output_video,lb"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 3,
233
+ "metadata": {},
234
+ "outputs": [
235
+ {
236
+ "name": "stdout",
237
+ "output_type": "stream",
238
+ "text": [
239
+ "Overwriting app.py\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "%%writefile app.py\n",
245
+ "\n",
246
+ "import gradio as gr\n",
247
+ "from model import r_holistic\n",
248
+ "\n",
249
+ "title='手语动作分类'\n",
250
+ "description = \"此分类模型可以识别250个[ASL](https://www.lifeprint.com/)手语动作\\\n",
251
+ " 并将其转化为特定的标签, 标签列表见链接[sign_to_prediction_index_map.json](sign_to_prediction_index_map.json), \\\n",
252
+ " 大家可以使用示例视频进行测试, 也可以根据列表下载或模拟相应的手语视频测试输出.\\\n",
253
+ " \\n工作流程:\\\n",
254
+ " \\n 1. landmark提取, 我使用了[ MediaPipe Holistic Solution](https://ai.google.dev/edge/mediapipe/solutions/vision/holistic_landmarker)进行landmark提取.\\\n",
255
+ " \\n 2. 利用landmark进行手语识别, 我使用了自己搭建并训练的模型, 主体框架为cnn和transform,此模型在测试数据集上精度在90%以上.\"\n",
256
+ "\n",
257
+ "output_video_file = gr.Video(label=\"landmark输出\")\n",
258
+ "output_text=gr.Textbox(label=\"手语预测结果\")\n",
259
+ "slider_1=gr.Slider(0,1,label='detection_confidence')\n",
260
+ "slider_2=gr.Slider(0,1,label='tracking_confidence')\n",
261
+ "\n",
262
+ "iface = gr.Interface(\n",
263
+ " fn=r_holistic,\n",
264
+ " inputs=[gr.Video(sources=None, label=\"手语视频片段\")],\n",
265
+ " outputs= [output_video_file,output_text],\n",
266
+ " title=title, \n",
267
+ " description=description,\n",
268
+ " examples=['book.mp4','book2.mp4','chair1.mp4','chair2.mp4'],\n",
269
+ " #cache_examples=True,\n",
270
+ " ) #[\"hand-land-mark-video/01.mp4\",\"hand-land-mark-video/02.mp4\"]\n",
271
+ " \n",
272
+ "\n",
273
+ "iface.launch(share=True)\n"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": []
282
+ }
283
+ ],
284
+ "metadata": {
285
+ "kernelspec": {
286
+ "display_name": "myenv",
287
+ "language": "python",
288
+ "name": "python3"
289
+ },
290
+ "language_info": {
291
+ "codemirror_mode": {
292
+ "name": "ipython",
293
+ "version": 3
294
+ },
295
+ "file_extension": ".py",
296
+ "mimetype": "text/x-python",
297
+ "name": "python",
298
+ "nbconvert_exporter": "python",
299
+ "pygments_lexer": "ipython3",
300
+ "version": "3.10.6"
301
+ }
302
+ },
303
+ "nbformat": 4,
304
+ "nbformat_minor": 2
305
+ }