Jae-Won Chung commited on
Commit
94f739f
·
1 Parent(s): 6bdaf0a

Make print_metrics.py recursive

Browse files
Files changed (1) hide show
  1. scripts/print_results.py +27 -11
scripts/print_results.py CHANGED
@@ -1,21 +1,37 @@
1
  import os
2
  import json
3
- from contextlib import suppress
4
 
5
  import tyro
6
 
7
 
8
- def main(data_dir: str) -> None:
9
- """Summarize the results collected for all models in the given directory."""
10
- model_names = os.listdir(data_dir)
11
- print(len(model_names), "models found")
 
 
 
 
 
 
 
12
 
13
- for i, model_name in enumerate(model_names):
14
- try:
15
- benchmark = json.load(open(f"{data_dir}/{model_name}/benchmark.json"))
16
- print(f"{i:2d} {len(benchmark):5d} results found for", model_name)
17
- except json.JSONDecodeError:
18
- print(f"{i:2d} [ERR] results found for {model_name}")
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  if __name__ == "__main__":
 
1
  import os
2
  import json
 
3
 
4
  import tyro
5
 
6
 
7
+ def main(data_dir: str, depth: int = 1) -> None:
8
+ """Summarize the results collected for all models in the given directory.
9
+
10
+ Args:
11
+ data_dir: The directory containing the results.
12
+ depth: The depth of the directory tree to search. When it's 1, the
13
+ script expects to fine model directories directly under `data_dir`.
14
+ (Default: 1)
15
+ """
16
+ if depth < 1:
17
+ raise ValueError("depth must be >= 1")
18
 
19
+ if depth == 1:
20
+ model_names = os.listdir(data_dir)
21
+ print(len(model_names), "models found in", data_dir)
22
+
23
+ for i, model_name in enumerate(model_names):
24
+ if not os.path.isdir(f"{data_dir}/{model_name}"):
25
+ continue
26
+ try:
27
+ benchmark = json.load(open(f"{data_dir}/{model_name}/benchmark.json"))
28
+ print(f"{i:2d} {len(benchmark):5d} results found for", model_name)
29
+ except json.JSONDecodeError:
30
+ print(f"{i:2d} [ERR] results found for {model_name}")
31
+
32
+ else:
33
+ for dir in os.listdir(data_dir):
34
+ main(f"{data_dir}/{dir}", depth - 1)
35
 
36
 
37
  if __name__ == "__main__":