|
import requests |
|
import sys |
|
from draw_confusion import draw, draw2 |
|
from tqdm import tqdm |
|
|
|
|
|
DEBUG = True |
|
API_URL = "/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2FMIT%2Fast-finetuned-audioset-10-10-0.4593%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> |
|
headers = {"Authorization": "Bearer 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"} |
|
|
|
|
|
|
|
|
|
def request_api(filename): |
|
with open(filename, "rb") as f: |
|
data = f.read() |
|
response = requests.post(API_URL, headers=headers, data=data) |
|
return response.json() |
|
|
|
|
|
|
|
def batch_request_api(file_uris): |
|
if DEBUG: print('batch_request_api') |
|
y_len = len(file_uris) |
|
y_true = [0] * y_len |
|
y_pred = [0] * y_len |
|
y_idx = 0 |
|
for input_file in tqdm(file_uris): |
|
res = request_api(input_file) |
|
|
|
|
|
|
|
first_label = str(res[0]['label']) |
|
first_score = res[0]['score'] |
|
|
|
|
|
|
|
y_true[y_idx] = first_label |
|
y_pred[y_idx] = round(first_score, 1) |
|
y_idx = y_idx + 1 |
|
|
|
return y_true, y_pred |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
if DEBUG: print('main, ' + str(sys.argv[1:])) |
|
if DEBUG: print('main, ' + str(len(sys.argv))) |
|
|
|
|
|
if len(sys.argv) < 2: |
|
print("用法:python x.py <文件或通配符>") |
|
sys.exit(1) |
|
|
|
if DEBUG: print('main, batch_request_api') |
|
y_true, y_pred = batch_request_api(sys.argv[1:]) |
|
if DEBUG: print('y_true = %s, y_pred = %s' % (str(y_true), str(y_pred))) |
|
draw2(y_true, y_pred) |
|
|