|
import os |
|
import subprocess |
|
import time |
|
from datetime import datetime |
|
|
|
import pytest |
|
|
|
from client_test import run_client_many |
|
from enums import PromptType |
|
from tests.test_langchain_units import have_openai_key |
|
from tests.utils import wrap_test_forked |
|
|
|
|
|
@pytest.mark.parametrize("base_model", |
|
['h2oai/h2ogpt-oig-oasst1-512-6_9b', |
|
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2', |
|
'llama', 'gptj'] |
|
) |
|
@pytest.mark.parametrize("force_langchain_evaluate", [False, True]) |
|
@pytest.mark.parametrize("do_langchain", [False, True]) |
|
@wrap_test_forked |
|
def test_gradio_inference_server(base_model, force_langchain_evaluate, do_langchain, |
|
prompt='Who are you?', stream_output=False, max_new_tokens=256, |
|
langchain_mode='Disabled', user_path=None, |
|
visible_langchain_modes=['UserData', 'MyData'], |
|
reverse_docs=True): |
|
if force_langchain_evaluate: |
|
langchain_mode = 'MyData' |
|
if do_langchain: |
|
langchain_mode = 'UserData' |
|
from tests.utils import make_user_path_test |
|
user_path = make_user_path_test() |
|
|
|
|
|
|
|
if base_model in ['h2oai/h2ogpt-oig-oasst1-512-6_9b', 'h2oai/h2ogpt-oasst1-512-12b']: |
|
prompt_type = PromptType.human_bot.name |
|
elif base_model in ['h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2']: |
|
prompt_type = PromptType.prompt_answer.name |
|
elif base_model in ['llama']: |
|
prompt_type = PromptType.wizard2.name |
|
elif base_model in ['gptj']: |
|
prompt_type = PromptType.gptj.name |
|
else: |
|
raise NotImplementedError(base_model) |
|
|
|
main_kwargs = dict(base_model=base_model, prompt_type=prompt_type, chat=True, |
|
stream_output=stream_output, gradio=True, num_beams=1, block_gradio_exit=False, |
|
max_new_tokens=max_new_tokens, |
|
langchain_mode=langchain_mode, user_path=user_path, |
|
visible_langchain_modes=visible_langchain_modes, |
|
reverse_docs=reverse_docs, |
|
force_langchain_evaluate=force_langchain_evaluate) |
|
|
|
|
|
inf_port = os.environ['GRADIO_SERVER_PORT'] = "7860" |
|
from generate import main |
|
main(**main_kwargs) |
|
|
|
|
|
client_port = os.environ['GRADIO_SERVER_PORT'] = "7861" |
|
from generate import main |
|
main(**main_kwargs, inference_server='http://127.0.0.1:%s' % inf_port) |
|
|
|
|
|
from client_test import run_client_chat |
|
os.environ['HOST'] = "http://127.0.0.1:%s" % client_port |
|
res_dict, client = run_client_chat(prompt=prompt, prompt_type=prompt_type, stream_output=stream_output, |
|
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode) |
|
assert res_dict['prompt'] == prompt |
|
assert res_dict['iinput'] == '' |
|
|
|
|
|
ret1, ret2, ret3, ret4, ret5, ret6, ret7 = run_client_many(prompt_type=None) |
|
if base_model == 'h2oai/h2ogpt-oig-oasst1-512-6_9b': |
|
assert 'h2oGPT' in ret1['response'] |
|
assert 'Birds' in ret2['response'] |
|
assert 'Birds' in ret3['response'] |
|
assert 'h2oGPT' in ret4['response'] |
|
assert 'h2oGPT' in ret5['response'] |
|
assert 'h2oGPT' in ret6['response'] |
|
assert 'h2oGPT' in ret7['response'] |
|
elif base_model == 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2': |
|
assert 'I am a language model trained' in ret1['response'] or \ |
|
'I am an AI language model developed by' in ret1['response'] or \ |
|
'I am a chatbot.' in ret1['response'] or \ |
|
'a chat-based assistant that can answer questions' in ret1['response'] or \ |
|
'I am an AI language model' in ret1['response'] or \ |
|
'I am an AI assistant.' in ret1['response'] |
|
assert 'Once upon a time' in ret2['response'] |
|
assert 'Once upon a time' in ret3['response'] |
|
assert 'I am a language model trained' in ret4['response'] or 'I am an AI language model developed by' in \ |
|
ret4['response'] or 'I am a chatbot.' in ret4['response'] or \ |
|
'a chat-based assistant that can answer questions' in ret4['response'] or \ |
|
'I am an AI language model' in ret4['response'] or \ |
|
'I am an AI assistant.' in ret4['response'] |
|
assert 'I am a language model trained' in ret5['response'] or 'I am an AI language model developed by' in \ |
|
ret5['response'] or 'I am a chatbot.' in ret5['response'] or \ |
|
'a chat-based assistant that can answer questions' in ret5['response'] or \ |
|
'I am an AI language model' in ret5['response'] or \ |
|
'I am an AI assistant.' in ret5['response'] |
|
assert 'I am a language model trained' in ret6['response'] or 'I am an AI language model developed by' in \ |
|
ret6['response'] or 'I am a chatbot.' in ret6['response'] or \ |
|
'a chat-based assistant that can answer questions' in ret6['response'] or \ |
|
'I am an AI language model' in ret6['response'] or \ |
|
'I am an AI assistant.' in ret6['response'] |
|
assert 'I am a language model trained' in ret7['response'] or 'I am an AI language model developed by' in \ |
|
ret7['response'] or 'I am a chatbot.' in ret7['response'] or \ |
|
'a chat-based assistant that can answer questions' in ret7['response'] or \ |
|
'I am an AI language model' in ret7['response'] or \ |
|
'I am an AI assistant.' in ret7['response'] |
|
elif base_model == 'llama': |
|
assert 'I am a bot.' in ret1['response'] or 'can I assist you today?' in ret1['response'] |
|
assert 'Birds' in ret2['response'] or 'Once upon a time' in ret2['response'] |
|
assert 'Birds' in ret3['response'] or 'Once upon a time' in ret3['response'] |
|
assert 'I am a bot.' in ret4['response'] or 'can I assist you today?' in ret4['response'] |
|
assert 'I am a bot.' in ret5['response'] or 'can I assist you today?' in ret5['response'] |
|
assert 'I am a bot.' in ret6['response'] or 'can I assist you today?' in ret6['response'] |
|
assert 'I am a bot.' in ret7['response'] or 'can I assist you today?' in ret7['response'] |
|
elif base_model == 'gptj': |
|
assert 'I am a bot.' in ret1['response'] or 'can I assist you today?' in ret1[ |
|
'response'] or 'a student at' in ret1['response'] or 'am a person who' in ret1['response'] or 'I am' in \ |
|
ret1['response'] or "I'm a student at" in ret1['response'] |
|
assert 'Birds' in ret2['response'] or 'Once upon a time' in ret2['response'] |
|
assert 'Birds' in ret3['response'] or 'Once upon a time' in ret3['response'] |
|
assert 'I am a bot.' in ret4['response'] or 'can I assist you today?' in ret4[ |
|
'response'] or 'a student at' in ret4['response'] or 'am a person who' in ret4['response'] or 'I am' in \ |
|
ret4['response'] or "I'm a student at" in ret4['response'] |
|
assert 'I am a bot.' in ret5['response'] or 'can I assist you today?' in ret5[ |
|
'response'] or 'a student at' in ret5['response'] or 'am a person who' in ret5['response'] or 'I am' in \ |
|
ret5['response'] or "I'm a student at" in ret5['response'] |
|
assert 'I am a bot.' in ret6['response'] or 'can I assist you today?' in ret6[ |
|
'response'] or 'a student at' in ret6['response'] or 'am a person who' in ret6['response'] or 'I am' in \ |
|
ret6['response'] or "I'm a student at" in ret6['response'] |
|
assert 'I am a bot.' in ret7['response'] or 'can I assist you today?' in ret7[ |
|
'response'] or 'a student at' in ret7['response'] or 'am a person who' in ret7['response'] or 'I am' in \ |
|
ret7['response'] or "I'm a student at" in ret7['response'] |
|
print("DONE", flush=True) |
|
|
|
|
|
def run_docker(inf_port, base_model): |
|
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_") |
|
msg = "Starting HF inference %s..." % datetime_str |
|
print(msg, flush=True) |
|
home_dir = os.path.expanduser('~') |
|
data_dir = '%s/.cache/huggingface/hub/' % home_dir |
|
cmd = ["docker"] + ['run', |
|
'--gpus', 'device=0', |
|
'--shm-size', '1g', |
|
'-e', 'TRANSFORMERS_CACHE="/.cache/"', |
|
'-p', '%s:80' % inf_port, |
|
'-v', '%s/.cache:/.cache/' % home_dir, |
|
'-v', '%s:/data' % data_dir, |
|
'ghcr.io/huggingface/text-generation-inference:0.8.2', |
|
'--model-id', base_model, |
|
'--max-input-length', '2048', |
|
'--max-total-tokens', '4096', |
|
'--max-stop-sequences', '6', |
|
] |
|
print(cmd, flush=True) |
|
p = subprocess.Popen(cmd, |
|
stdout=None, stderr=subprocess.STDOUT, |
|
) |
|
print("Done starting autoviz server", flush=True) |
|
return p.pid |
|
|
|
|
|
@pytest.mark.parametrize("base_model", |
|
|
|
|
|
['h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2'] |
|
) |
|
@pytest.mark.parametrize("force_langchain_evaluate", [False, True]) |
|
@pytest.mark.parametrize("do_langchain", [False, True]) |
|
@pytest.mark.parametrize("pass_prompt_type", [False, True, 'custom']) |
|
@pytest.mark.parametrize("do_model_lock", [False, True]) |
|
@wrap_test_forked |
|
def test_hf_inference_server(base_model, force_langchain_evaluate, do_langchain, pass_prompt_type, do_model_lock, |
|
prompt='Who are you?', stream_output=False, max_new_tokens=256, |
|
langchain_mode='Disabled', user_path=None, |
|
visible_langchain_modes=['UserData', 'MyData'], |
|
reverse_docs=True): |
|
|
|
inf_port = "6112" |
|
inference_server = 'http://127.0.0.1:%s' % inf_port |
|
inf_pid = run_docker(inf_port, base_model) |
|
time.sleep(60) |
|
|
|
if force_langchain_evaluate: |
|
langchain_mode = 'MyData' |
|
if do_langchain: |
|
langchain_mode = 'UserData' |
|
from tests.utils import make_user_path_test |
|
user_path = make_user_path_test() |
|
|
|
|
|
|
|
if base_model in ['h2oai/h2ogpt-oig-oasst1-512-6_9b', 'h2oai/h2ogpt-oasst1-512-12b']: |
|
prompt_type = PromptType.human_bot.name |
|
else: |
|
prompt_type = PromptType.prompt_answer.name |
|
if isinstance(pass_prompt_type, str): |
|
prompt_type = 'custom' |
|
prompt_dict = """{'promptA': None, 'promptB': None, 'PreInstruct': None, 'PreInput': None, 'PreResponse': None, 'terminate_response': [], 'chat_sep': '', 'chat_turn_sep': '', 'humanstr': None, 'botstr': None, 'generates_leading_space': False}""" |
|
else: |
|
prompt_dict = None |
|
if not pass_prompt_type: |
|
prompt_type = None |
|
if do_model_lock: |
|
model_lock = [{'inference_server': inference_server, 'base_model': base_model}] |
|
base_model = None |
|
inference_server = None |
|
else: |
|
model_lock = None |
|
main_kwargs = dict(base_model=base_model, |
|
prompt_type=prompt_type, |
|
prompt_dict=prompt_dict, |
|
chat=True, |
|
stream_output=stream_output, gradio=True, num_beams=1, block_gradio_exit=False, |
|
max_new_tokens=max_new_tokens, |
|
langchain_mode=langchain_mode, user_path=user_path, |
|
visible_langchain_modes=visible_langchain_modes, |
|
reverse_docs=reverse_docs, |
|
force_langchain_evaluate=force_langchain_evaluate, |
|
inference_server=inference_server, |
|
model_lock=model_lock) |
|
|
|
try: |
|
|
|
client_port = os.environ['GRADIO_SERVER_PORT'] = "7861" |
|
from generate import main |
|
main(**main_kwargs) |
|
|
|
|
|
from client_test import run_client_chat |
|
os.environ['HOST'] = "http://127.0.0.1:%s" % client_port |
|
res_dict, client = run_client_chat(prompt=prompt, prompt_type=prompt_type, |
|
stream_output=stream_output, |
|
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode, |
|
prompt_dict=prompt_dict) |
|
assert res_dict['prompt'] == prompt |
|
assert res_dict['iinput'] == '' |
|
|
|
|
|
ret1, ret2, ret3, ret4, ret5, ret6, ret7 = run_client_many(prompt_type=None) |
|
|
|
|
|
if isinstance(pass_prompt_type, str): |
|
assert 'year old student from the' in ret1['response'] |
|
assert 'bird' in ret2['response'] |
|
assert 'bird' in ret3['response'] |
|
assert 'year old student from the' in ret4['response'] |
|
assert 'year old student from the' in ret5['response'] |
|
assert 'year old student from the' in ret6['response'] |
|
assert 'year old student from the' in ret7['response'] |
|
elif base_model == 'h2oai/h2ogpt-oig-oasst1-512-6_9b': |
|
assert 'h2oGPT' in ret1['response'] |
|
assert 'Birds' in ret2['response'] |
|
assert 'Birds' in ret3['response'] |
|
assert 'h2oGPT' in ret4['response'] |
|
assert 'h2oGPT' in ret5['response'] |
|
assert 'h2oGPT' in ret6['response'] |
|
assert 'h2oGPT' in ret7['response'] |
|
else: |
|
assert 'I am a language model trained' in ret1['response'] or 'I am an AI language model developed by' in \ |
|
ret1['response'] or 'a chat-based assistant' in ret1['response'] or 'am a student' in ret1[ |
|
'response'] |
|
assert 'Once upon a time' in ret2['response'] |
|
assert 'Once upon a time' in ret3['response'] |
|
assert 'I am a language model trained' in ret4['response'] or 'I am an AI language model developed by' in \ |
|
ret4['response'] or 'a chat-based assistant' in ret4['response'] or 'am a student' in ret4[ |
|
'response'] |
|
assert 'I am a language model trained' in ret5['response'] or 'I am an AI language model developed by' in \ |
|
ret5['response'] or 'a chat-based assistant' in ret5['response'] or 'am a student' in ret5[ |
|
'response'] |
|
assert 'I am a language model trained' in ret6['response'] or 'I am an AI language model developed by' in \ |
|
ret6['response'] or 'a chat-based assistant' in ret6['response'] or 'am a student' in ret6[ |
|
'response'] |
|
assert 'I am a language model trained' in ret7['response'] or 'I am an AI language model developed by' in \ |
|
ret7['response'] or 'a chat-based assistant' in ret7['response'] or 'am a student' in ret7[ |
|
'response'] |
|
print("DONE", flush=True) |
|
finally: |
|
|
|
import signal |
|
try: |
|
os.kill(inf_pid, signal.SIGTERM) |
|
os.kill(inf_pid, signal.SIGKILL) |
|
except: |
|
pass |
|
|
|
os.system("docker ps | grep text-generation-inference | awk '{print $1}' | xargs docker stop ") |
|
|
|
|
|
@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run") |
|
@pytest.mark.parametrize("force_langchain_evaluate", [False, True]) |
|
@wrap_test_forked |
|
def test_openai_inference_server(force_langchain_evaluate, |
|
prompt='Who are you?', stream_output=False, max_new_tokens=256, |
|
base_model='gpt-3.5-turbo', |
|
langchain_mode='Disabled', user_path=None, |
|
visible_langchain_modes=['UserData', 'MyData'], |
|
reverse_docs=True): |
|
if force_langchain_evaluate: |
|
langchain_mode = 'MyData' |
|
|
|
main_kwargs = dict(base_model=base_model, chat=True, |
|
stream_output=stream_output, gradio=True, num_beams=1, block_gradio_exit=False, |
|
max_new_tokens=max_new_tokens, |
|
langchain_mode=langchain_mode, user_path=user_path, |
|
visible_langchain_modes=visible_langchain_modes, |
|
reverse_docs=reverse_docs) |
|
|
|
|
|
client_port = os.environ['GRADIO_SERVER_PORT'] = "7861" |
|
from generate import main |
|
main(**main_kwargs, inference_server='openai_chat') |
|
|
|
|
|
from client_test import run_client_chat |
|
os.environ['HOST'] = "http://127.0.0.1:%s" % client_port |
|
res_dict, client = run_client_chat(prompt=prompt, prompt_type='openai_chat', stream_output=stream_output, |
|
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode) |
|
assert res_dict['prompt'] == prompt |
|
assert res_dict['iinput'] == '' |
|
|
|
|
|
ret1, ret2, ret3, ret4, ret5, ret6, ret7 = run_client_many(prompt_type=None) |
|
assert 'I am an AI language model' in ret1['response'] |
|
assert 'Once upon a time, in a far-off land,' in ret2['response'] or 'Once upon a time' in ret2['response'] |
|
assert 'Once upon a time, in a far-off land,' in ret3['response'] or 'Once upon a time' in ret3['response'] |
|
assert 'I am an AI language model' in ret4['response'] |
|
assert 'I am an AI language model' in ret5['response'] |
|
assert 'I am an AI language model' in ret6['response'] |
|
assert 'I am an AI language model' in ret7['response'] |
|
print("DONE", flush=True) |
|
|