thelou1s commited on
Commit
49a9586
·
1 Parent(s): 22c4a51

mod draw2 with color

Browse files
Files changed (3) hide show
  1. draw_confusion.py +20 -3
  2. requirements.txt +1 -0
  3. test_dreamtalk.py +34 -19
draw_confusion.py CHANGED
@@ -3,23 +3,40 @@
3
  import seaborn as sns
4
  from sklearn.metrics import confusion_matrix
5
  import matplotlib.pyplot as plt
 
6
 
7
 
8
  def draw(y_true, y_pred):
 
 
9
  sns.set()
10
  f, ax=plt.subplots()
11
  # y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
12
  # y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
13
- cm = confusion_matrix(y_true, y_pred, labels=[-1] * 13)
14
  print(cm) #打印出来看看
15
  sns.heatmap(cm, annot=True, ax=ax) #画热力图
16
 
17
  def draw2(y_true, y_pred):
18
  min_len = min( len(y_true), len(y_pred) )
 
 
 
 
 
 
 
19
  for i in range(min_len):
 
20
  for j in range(i):
21
  print('\t', end='')
22
- print(str(y_pred[i]))
23
 
24
- draw2( [1.0] * 13, [0.5] * 13 )
 
 
 
 
 
 
 
25
 
 
3
  import seaborn as sns
4
  from sklearn.metrics import confusion_matrix
5
  import matplotlib.pyplot as plt
6
+ from colorama import Fore,Back,Style
7
 
8
 
9
  def draw(y_true, y_pred):
10
+ min_len = min( len(y_true), len(y_pred) )
11
+
12
  sns.set()
13
  f, ax=plt.subplots()
14
  # y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
15
  # y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
16
+ cm = confusion_matrix(y_true, y_pred, labels=[-1] * min_len)
17
  print(cm) #打印出来看看
18
  sns.heatmap(cm, annot=True, ax=ax) #画热力图
19
 
20
  def draw2(y_true, y_pred):
21
  min_len = min( len(y_true), len(y_pred) )
22
+
23
+ print('\t', end='')
24
+ for i in range(min_len):
25
+ y_true_format = str(y_true[i])[:3]
26
+ print('%s\t' % y_true_format, end='')
27
+ print('')
28
+
29
  for i in range(min_len):
30
+ print(Fore.RESET + '%s\t' % str(i + 1), end='')
31
  for j in range(i):
32
  print('\t', end='')
 
33
 
34
+ # print with color
35
+ if y_pred[i] > 0.5:
36
+ print(Fore.GREEN + str(y_pred[i]))
37
+ else:
38
+ print(Fore.RED + str(y_pred[i]))
39
+
40
+ if __name__ == '__main__':
41
+ draw2( [1.0] * 13, [0.5] * 13 )
42
 
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ colorama
test_dreamtalk.py CHANGED
@@ -1,43 +1,58 @@
1
  import requests
2
  import sys
3
  from draw_confusion import draw, draw2
 
4
 
 
 
5
  API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
6
  headers = {"Authorization": "Bearer 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"}
7
 
 
 
8
  # filename = '1.flac'
9
- def query(filename):
10
  with open(filename, "rb") as f:
11
  data = f.read()
12
  response = requests.post(API_URL, headers=headers, data=data)
13
  return response.json()
14
 
15
- # res = query(filename)
16
- # print(str(res))
17
-
18
- # 处理命令行
19
- if __name__ == "__main__":
20
- # 获取命令行参数
21
- if len(sys.argv) < 2:
22
- print("用法:python x.py <文件或通配符>")
23
- sys.exit(1)
24
 
25
- y_len = 13 #len(sys.argv[1:]:)
 
 
 
26
  y_true = [0] * y_len
27
  y_pred = [0] * y_len
28
  y_idx = 0
29
- for input_file in sys.argv[1:]:
30
- res = query(input_file)
31
- print('%s:' % str(input_file))
32
- print('%s' % str(res[:3]))
33
 
34
  first_label = str(res[0]['label'])
35
  first_score = res[0]['score']
36
- print(str(first_label))
37
- print(str(first_score))
38
 
39
- y_true[y_idx] = 1.0
40
  y_pred[y_idx] = round(first_score, 1)
41
  y_idx = y_idx + 1
42
 
43
- draw2(y_true, y_pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import sys
3
  from draw_confusion import draw, draw2
4
+ from tqdm import tqdm
5
 
6
+
7
+ DEBUG = True #False
8
  API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593"
9
  headers = {"Authorization": "Bearer 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"}
10
 
11
+
12
+ # 处理请求
13
  # filename = '1.flac'
14
+ def request_api(filename):
15
  with open(filename, "rb") as f:
16
  data = f.read()
17
  response = requests.post(API_URL, headers=headers, data=data)
18
  return response.json()
19
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # 批量处理
22
+ def batch_request_api(file_uris):
23
+ if DEBUG: print('batch_request_api')
24
+ y_len = len(file_uris)
25
  y_true = [0] * y_len
26
  y_pred = [0] * y_len
27
  y_idx = 0
28
+ for input_file in tqdm(file_uris):
29
+ res = request_api(input_file)
30
+ # print('%s %s:' % (str(y_idx), str(input_file)) )
31
+ # print('%s' % str(res[:3]))
32
 
33
  first_label = str(res[0]['label'])
34
  first_score = res[0]['score']
35
+ # print(str(first_label))
36
+ # print(str(first_score))
37
 
38
+ y_true[y_idx] = first_label
39
  y_pred[y_idx] = round(first_score, 1)
40
  y_idx = y_idx + 1
41
 
42
+ return y_true, y_pred
43
+
44
+
45
+ # 处理命令行
46
+ if __name__ == "__main__":
47
+ if DEBUG: print('main, ' + str(sys.argv[1:]))
48
+ if DEBUG: print('main, ' + str(len(sys.argv)))
49
+
50
+ # 获取命令行参数
51
+ if len(sys.argv) < 2:
52
+ print("用法:python x.py <文件或通配符>")
53
+ sys.exit(1)
54
+
55
+ if DEBUG: print('main, batch_request_api')
56
+ y_true, y_pred = batch_request_api(sys.argv[1:])
57
+ if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred)))
58
+ draw2(y_true, y_pred)