add simple confusion matrix
Browse files- draw_confusion.py +25 -0
- test2.py → test_dreamtalk.py +32 -3
draw_confusion.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://www.cnblogs.com/yexionglin/p/11432180.html
|
2 |
+
|
3 |
+
import seaborn as sns
|
4 |
+
from sklearn.metrics import confusion_matrix
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
def draw(y_true, y_pred):
|
9 |
+
sns.set()
|
10 |
+
f, ax=plt.subplots()
|
11 |
+
# y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
|
12 |
+
# y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
|
13 |
+
cm = confusion_matrix(y_true, y_pred, labels=[-1] * 13)
|
14 |
+
print(cm) #打印出来看看
|
15 |
+
sns.heatmap(cm, annot=True, ax=ax) #画热力图
|
16 |
+
|
17 |
+
def draw2(y_true, y_pred):
|
18 |
+
min_len = min( len(y_true), len(y_pred) )
|
19 |
+
for i in range(min_len):
|
20 |
+
for j in range(i):
|
21 |
+
print('\t', end='')
|
22 |
+
print(str(y_pred[i]))
|
23 |
+
|
24 |
+
draw2( [1.0] * 13, [0.5] * 13 )
|
25 |
+
|
test2.py → test_dreamtalk.py
RENAMED
@@ -1,14 +1,43 @@
|
|
1 |
import requests
|
|
|
|
|
2 |
|
3 |
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
4 |
headers = {"Authorization": "Bearer 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"}
|
5 |
|
6 |
-
filename = '1.flac'
|
7 |
def query(filename):
|
8 |
with open(filename, "rb") as f:
|
9 |
data = f.read()
|
10 |
response = requests.post(API_URL, headers=headers, data=data)
|
11 |
return response.json()
|
12 |
|
13 |
-
res = query(filename)
|
14 |
-
print(str(res))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
+
import sys
|
3 |
+
from draw_confusion import draw, draw2
|
4 |
|
5 |
API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
|
6 |
headers = {"Authorization": "Bearer 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"}
|
7 |
|
8 |
+
# filename = '1.flac'
|
9 |
def query(filename):
|
10 |
with open(filename, "rb") as f:
|
11 |
data = f.read()
|
12 |
response = requests.post(API_URL, headers=headers, data=data)
|
13 |
return response.json()
|
14 |
|
15 |
+
# res = query(filename)
|
16 |
+
# print(str(res))
|
17 |
+
|
18 |
+
# 处理命令行
|
19 |
+
if __name__ == "__main__":
|
20 |
+
# 获取命令行参数
|
21 |
+
if len(sys.argv) < 2:
|
22 |
+
print("用法:python x.py <文件或通配符>")
|
23 |
+
sys.exit(1)
|
24 |
+
|
25 |
+
y_len = 13 #len(sys.argv[1:]:)
|
26 |
+
y_true = [0] * y_len
|
27 |
+
y_pred = [0] * y_len
|
28 |
+
y_idx = 0
|
29 |
+
for input_file in sys.argv[1:]:
|
30 |
+
res = query(input_file)
|
31 |
+
print('%s:' % str(input_file))
|
32 |
+
print('%s' % str(res[:3]))
|
33 |
+
|
34 |
+
first_label = str(res[0]['label'])
|
35 |
+
first_score = res[0]['score']
|
36 |
+
print(str(first_label))
|
37 |
+
print(str(first_score))
|
38 |
+
|
39 |
+
y_true[y_idx] = 1.0
|
40 |
+
y_pred[y_idx] = round(first_score, 1)
|
41 |
+
y_idx = y_idx + 1
|
42 |
+
|
43 |
+
draw2(y_true, y_pred)
|