thelou1s commited on
Commit
22c4a51
·
1 Parent(s): 9154c40

add simple confusion matrix

Browse files
Files changed (2) hide show
  1. draw_confusion.py +25 -0
  2. test2.py → test_dreamtalk.py +32 -3
draw_confusion.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.cnblogs.com/yexionglin/p/11432180.html
2
+
3
+ import seaborn as sns
4
+ from sklearn.metrics import confusion_matrix
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ def draw(y_true, y_pred):
9
+ sns.set()
10
+ f, ax=plt.subplots()
11
+ # y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
12
+ # y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
13
+ cm = confusion_matrix(y_true, y_pred, labels=[-1] * 13)
14
+ print(cm) #打印出来看看
15
+ sns.heatmap(cm, annot=True, ax=ax) #画热力图
16
+
17
+ def draw2(y_true, y_pred):
18
+ min_len = min( len(y_true), len(y_pred) )
19
+ for i in range(min_len):
20
+ for j in range(i):
21
+ print('\t', end='')
22
+ print(str(y_pred[i]))
23
+
24
+ draw2( [1.0] * 13, [0.5] * 13 )
25
+
test2.py → test_dreamtalk.py RENAMED
@@ -1,14 +1,43 @@
1
  import requests
 
 
2
 
3
  API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
4
  headers = {"Authorization": "Bearer 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"}
5
 
6
- filename = '1.flac'
7
  def query(filename):
8
  with open(filename, "rb") as f:
9
  data = f.read()
10
  response = requests.post(API_URL, headers=headers, data=data)
11
  return response.json()
12
 
13
- res = query(filename)
14
- print(str(res))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
+ import sys
3
+ from draw_confusion import draw, draw2
4
 
5
  API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
6
  headers = {"Authorization": "Bearer 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"}
7
 
8
+ # filename = '1.flac'
9
  def query(filename):
10
  with open(filename, "rb") as f:
11
  data = f.read()
12
  response = requests.post(API_URL, headers=headers, data=data)
13
  return response.json()
14
 
15
+ # res = query(filename)
16
+ # print(str(res))
17
+
18
+ # 处理命令行
19
+ if __name__ == "__main__":
20
+ # 获取命令行参数
21
+ if len(sys.argv) < 2:
22
+ print("用法:python x.py <文件或通配符>")
23
+ sys.exit(1)
24
+
25
+ y_len = 13 #len(sys.argv[1:]:)
26
+ y_true = [0] * y_len
27
+ y_pred = [0] * y_len
28
+ y_idx = 0
29
+ for input_file in sys.argv[1:]:
30
+ res = query(input_file)
31
+ print('%s:' % str(input_file))
32
+ print('%s' % str(res[:3]))
33
+
34
+ first_label = str(res[0]['label'])
35
+ first_score = res[0]['score']
36
+ print(str(first_label))
37
+ print(str(first_score))
38
+
39
+ y_true[y_idx] = 1.0
40
+ y_pred[y_idx] = round(first_score, 1)
41
+ y_idx = y_idx + 1
42
+
43
+ draw2(y_true, y_pred)