Spaces:
Running
Running
import argparse | |
import asyncio | |
import os | |
import platform | |
from typing import Literal | |
from calc_cost import calculate | |
from varco_arena_core.data_utils import load_all_data | |
from varco_arena_core.manager import Manager | |
if platform.system() == "Windows": | |
try: | |
import winloop | |
asyncio.set_event_loop_policy(winloop.EventLoopPolicy()) | |
except ImportError: | |
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
elif platform.system() == "Linux": | |
try: | |
import uvloop | |
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | |
except ImportError: | |
pass | |
def main( | |
input_str, | |
output_dir, | |
evaluation_model, | |
matching_method, | |
n_jobs: int = 8, | |
evalprompt: Literal[ | |
"llmbar_brief", | |
"llmbar", # general assistant eval | |
"translation_pair", # translation eval | |
"rag_pair_kr", # rag knowledge reflection eval | |
# "contextual_pair", # contextual visual-language instruction eval | |
] = "llmbar", | |
): | |
dataset_df = load_all_data(input_str) | |
if os.path.isfile(output_dir): | |
_output_dir = os.path.abspath(os.path.dirname(output_dir)) | |
print( | |
f"output directory '{output_dir}' is not a directory. we'll use '{_output_dir}' instead." | |
) | |
output_dir = _output_dir | |
else: | |
os.makedirs(output_dir, exist_ok=True) | |
# cost estimation | |
total_matches, total_toks_in, total_toks_out, total_costs = calculate( | |
dataset_df=dataset_df, | |
model_name=evaluation_model, | |
matching_method=matching_method, | |
evalprompt=evalprompt, | |
) | |
_doubleline = "=" * 50 | |
message = f"""--------------------------------------- | |
Judge LLM: {evaluation_model} | |
νκ° ν둬ννΈ: {evalprompt} | |
νκ° λ¦¬κ·Έ λ°©λ²: {matching_method} | |
μμ νκ° νμ : {total_matches:,} | |
μμ μ λ ₯ ν ν° : {total_toks_in:,} | |
μμ μΆλ ₯ ν ν° : {total_toks_out:,} (with x1.01 additional room) | |
--------------------------------------- | |
μμ λ°μ λΉμ© : ${total_costs:.3f} | |
{_doubleline}""" | |
print(message) | |
if args.calc_price_run: | |
return | |
# prompt user whether to continue | |
flag = input("[*] Run Varco Arena? (y/n) : ") | |
if not flag.lower() == "y" and not flag.lower() == "yes": | |
print("[-] Varco Arena Stopped") | |
return | |
manager = Manager( | |
dataset_df, | |
output_dir, | |
evaluation_model, | |
matching_method, | |
n_jobs=n_jobs, | |
evalprompt=evalprompt, | |
) | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(manager.async_run()) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-i", "--input", help="input file") | |
parser.add_argument("-o", "--output_dir", help="output file") | |
parser.add_argument( | |
"-e", "--evaluation_model", default="debug", help="evaluation model specifier" | |
) | |
parser.add_argument( | |
"-c", | |
"--calc_price_run", | |
action="store_true", | |
help="print out price calculations", | |
) | |
parser.add_argument( | |
"-m", | |
"--matching_method", | |
default="tournament", | |
choices=["tournament"], # , "league"], | |
help="matching method specifier", | |
) | |
parser.add_argument( | |
"-k", | |
"--openai_api_key", | |
default=None, | |
help='openai key to use / default: OpenAI API Key from your env variable "OPENAI_API_KEY"', | |
) | |
parser.add_argument( | |
"-u", | |
"--openai_url", | |
default="https://api.openai.com/v1", | |
help="OpenAI BASE URL", | |
) | |
# advanced options | |
parser.add_argument( | |
"-j", | |
"--n_jobs", | |
default=32, | |
type=int, | |
help="number of concurrency for asyncio (passed to async.semaphore @ manager.py)\nIf your job does not proceed, consider lowering this.", | |
) | |
parser.add_argument( | |
"-p", | |
"--evalprompt", | |
default="llmbar_brief", | |
choices=[ | |
"llmbar_brief", | |
"llmbar", | |
"translation_pair", | |
"rag_pair_kr", | |
# "contextual_pair", | |
], | |
) | |
parser.add_argument( | |
"-lr", | |
"--limit_requests", | |
default=7_680, | |
type=int, | |
help="limit number of requests per minute when using vLLM Server", | |
) | |
parser.add_argument( | |
"-lt", | |
"--limit_tokens", | |
default=15_728_640, | |
type=int, | |
help="limit number of tokens per minute when using vLLM Server", | |
) | |
args = parser.parse_args() | |
# clientλ openai key μΈν μ΄μλ‘ μμͺ½μμ import. μ¬κΈ°μ environ var λ‘ μ€μ | |
# base URL μ€μ | |
if not args.openai_url.startswith("https://") and not args.openai_url.startswith( | |
"http://" | |
): | |
args.openai_url = "http://" + args.openai_url | |
if not args.openai_url.endswith("/v1"): | |
args.openai_url += "/v1" | |
os.environ["OPENAI_BASE_URL"] = args.openai_url | |
# API Key μ€μ : μ£Όμ΄μ§κ² μμΌλ©΄ νκ²½λ³μ λμ μ°κ³ , μλλ©΄ νκ²½λ³μμμ μ°Ύλλ€ | |
if args.openai_api_key is None: | |
if os.getenv("OPENAI_API_KEY") is None: | |
raise ValueError( | |
"`--openai_api_key` or environment variable `OPENAI_API_KEY` is required" | |
) | |
else: | |
os.environ["OPENAI_API_KEY"] = args.openai_api_key | |
# limit μ€μ | |
os.environ["LIMIT_REQUESTS"] = str(args.limit_requests) | |
os.environ["LIMIT_TOKENS"] = str(args.limit_tokens) | |
main( | |
args.input, | |
args.output_dir, | |
args.evaluation_model, | |
args.matching_method, | |
n_jobs=args.n_jobs, | |
evalprompt=args.evalprompt, | |
) | |