mod draw2 with color
Browse files- draw_confusion.py +20 -3
- requirements.txt +1 -0
- test_dreamtalk.py +34 -19
draw_confusion.py
CHANGED
@@ -3,23 +3,40 @@
|
|
3 |
import seaborn as sns
|
4 |
from sklearn.metrics import confusion_matrix
|
5 |
import matplotlib.pyplot as plt
|
|
|
6 |
|
7 |
|
8 |
def draw(y_true, y_pred):
|
|
|
|
|
9 |
sns.set()
|
10 |
f, ax=plt.subplots()
|
11 |
# y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
|
12 |
# y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
|
13 |
-
cm = confusion_matrix(y_true, y_pred, labels=[-1] *
|
14 |
print(cm) #打印出来看看
|
15 |
sns.heatmap(cm, annot=True, ax=ax) #画热力图
|
16 |
|
17 |
def draw2(y_true, y_pred):
|
18 |
min_len = min( len(y_true), len(y_pred) )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
for i in range(min_len):
|
|
|
20 |
for j in range(i):
|
21 |
print('\t', end='')
|
22 |
-
print(str(y_pred[i]))
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
3 |
import seaborn as sns
|
4 |
from sklearn.metrics import confusion_matrix
|
5 |
import matplotlib.pyplot as plt
|
6 |
+
from colorama import Fore,Back,Style
|
7 |
|
8 |
|
9 |
def draw(y_true, y_pred):
|
10 |
+
min_len = min( len(y_true), len(y_pred) )
|
11 |
+
|
12 |
sns.set()
|
13 |
f, ax=plt.subplots()
|
14 |
# y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
|
15 |
# y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
|
16 |
+
cm = confusion_matrix(y_true, y_pred, labels=[-1] * min_len)
|
17 |
print(cm) #打印出来看看
|
18 |
sns.heatmap(cm, annot=True, ax=ax) #画热力图
|
19 |
|
20 |
def draw2(y_true, y_pred):
|
21 |
min_len = min( len(y_true), len(y_pred) )
|
22 |
+
|
23 |
+
print('\t', end='')
|
24 |
+
for i in range(min_len):
|
25 |
+
y_true_format = str(y_true[i])[:3]
|
26 |
+
print('%s\t' % y_true_format, end='')
|
27 |
+
print('')
|
28 |
+
|
29 |
for i in range(min_len):
|
30 |
+
print(Fore.RESET + '%s\t' % str(i + 1), end='')
|
31 |
for j in range(i):
|
32 |
print('\t', end='')
|
|
|
33 |
|
34 |
+
# print with color
|
35 |
+
if y_pred[i] > 0.5:
|
36 |
+
print(Fore.GREEN + str(y_pred[i]))
|
37 |
+
else:
|
38 |
+
print(Fore.RED + str(y_pred[i]))
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
draw2( [1.0] * 13, [0.5] * 13 )
|
42 |
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
colorama
|
test_dreamtalk.py
CHANGED
@@ -1,43 +1,58 @@
|
|
1 |
import requests
|
2 |
import sys
|
3 |
from draw_confusion import draw, draw2
|
|
|
4 |
|
|
|
|
|
5 |
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
6 |
headers = {"Authorization": "Bearer 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"}
|
7 |
|
|
|
|
|
8 |
# filename = '1.flac'
|
9 |
-
def
|
10 |
with open(filename, "rb") as f:
|
11 |
data = f.read()
|
12 |
response = requests.post(API_URL, headers=headers, data=data)
|
13 |
return response.json()
|
14 |
|
15 |
-
# res = query(filename)
|
16 |
-
# print(str(res))
|
17 |
-
|
18 |
-
# 处理命令行
|
19 |
-
if __name__ == "__main__":
|
20 |
-
# 获取命令行参数
|
21 |
-
if len(sys.argv) < 2:
|
22 |
-
print("用法:python x.py <文件或通配符>")
|
23 |
-
sys.exit(1)
|
24 |
|
25 |
-
|
|
|
|
|
|
|
26 |
y_true = [0] * y_len
|
27 |
y_pred = [0] * y_len
|
28 |
y_idx = 0
|
29 |
-
for input_file in
|
30 |
-
res =
|
31 |
-
print('%s:' % str(input_file))
|
32 |
-
print('%s' % str(res[:3]))
|
33 |
|
34 |
first_label = str(res[0]['label'])
|
35 |
first_score = res[0]['score']
|
36 |
-
print(str(first_label))
|
37 |
-
print(str(first_score))
|
38 |
|
39 |
-
y_true[y_idx] =
|
40 |
y_pred[y_idx] = round(first_score, 1)
|
41 |
y_idx = y_idx + 1
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import sys
|
3 |
from draw_confusion import draw, draw2
|
4 |
+
from tqdm import tqdm
|
5 |
|
6 |
+
|
7 |
+
DEBUG = True #False
|
8 |
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
9 |
headers = {"Authorization": "Bearer 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"}
|
10 |
|
11 |
+
|
12 |
+
# 处理请求
|
13 |
# filename = '1.flac'
|
14 |
+
def request_api(filename):
|
15 |
with open(filename, "rb") as f:
|
16 |
data = f.read()
|
17 |
response = requests.post(API_URL, headers=headers, data=data)
|
18 |
return response.json()
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# 批量处理
|
22 |
+
def batch_request_api(file_uris):
|
23 |
+
if DEBUG: print('batch_request_api')
|
24 |
+
y_len = len(file_uris)
|
25 |
y_true = [0] * y_len
|
26 |
y_pred = [0] * y_len
|
27 |
y_idx = 0
|
28 |
+
for input_file in tqdm(file_uris):
|
29 |
+
res = request_api(input_file)
|
30 |
+
# print('%s %s:' % (str(y_idx), str(input_file)) )
|
31 |
+
# print('%s' % str(res[:3]))
|
32 |
|
33 |
first_label = str(res[0]['label'])
|
34 |
first_score = res[0]['score']
|
35 |
+
# print(str(first_label))
|
36 |
+
# print(str(first_score))
|
37 |
|
38 |
+
y_true[y_idx] = first_label
|
39 |
y_pred[y_idx] = round(first_score, 1)
|
40 |
y_idx = y_idx + 1
|
41 |
|
42 |
+
return y_true, y_pred
|
43 |
+
|
44 |
+
|
45 |
+
# 处理命令行
|
46 |
+
if __name__ == "__main__":
|
47 |
+
if DEBUG: print('main, ' + str(sys.argv[1:]))
|
48 |
+
if DEBUG: print('main, ' + str(len(sys.argv)))
|
49 |
+
|
50 |
+
# 获取命令行参数
|
51 |
+
if len(sys.argv) < 2:
|
52 |
+
print("用法:python x.py <文件或通配符>")
|
53 |
+
sys.exit(1)
|
54 |
+
|
55 |
+
if DEBUG: print('main, batch_request_api')
|
56 |
+
y_true, y_pred = batch_request_api(sys.argv[1:])
|
57 |
+
if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
|
58 |
+
draw2(y_true, y_pred)
|