File size: 5,232 Bytes
3f7cfab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import pytest
from tests.utils import wrap_test_forked, get_llama
from enums import DocumentChoices
@wrap_test_forked
def test_cli(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in our solar system" in all_generations[0]
@wrap_test_forked
def test_cli_langchain(monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
visible_langchain_modes=['UserData', 'MyData'],
document_choice=[DocumentChoices.All_Relevant.name],
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
assert "looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"what the cat is doing" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
prompt_type = get_llama()
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
prompt_type=prompt_type,
user_path=user_path,
visible_langchain_modes=['UserData', 'MyData'],
document_choice=[DocumentChoices.All_Relevant.name],
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
assert "The cat is sitting on a window seat and looking out the window" in all_generations[0] or \
"staring out the window at the city skyline" in all_generations[0] or \
"The cat is likely relaxing and enjoying" in all_generations[0] or \
"The cat is sitting on a window seat and looking out" in all_generations[0] or \
"cat in the image is" in all_generations[0]
@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
prompt_type = get_llama()
query = "Who are you?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
langchain_mode='Disabled',
prompt_type=prompt_type,
user_path=None,
visible_langchain_modes=[],
document_choice=[DocumentChoices.All_Relevant.name],
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
"how can I assist" in all_generations[0] or \
"am a virtual assistant" in all_generations[0]
@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
query = "What is the Earth?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False, score_model='None')
assert len(all_generations) == 1
assert "The Earth is a planet in the Solar System." in all_generations[0] or \
"The Earth is the third planet" in all_generations[0]
@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
from tests.utils import make_user_path_test
user_path = make_user_path_test()
query = "What is the cat doing?"
monkeypatch.setattr('builtins.input', lambda _: query)
from generate import main
all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
cli=True, cli_loop=False, score_model='None',
langchain_mode='UserData',
user_path=user_path,
visible_langchain_modes=['UserData', 'MyData'],
document_choice=[DocumentChoices.All_Relevant.name],
verbose=True)
print(all_generations)
assert len(all_generations) == 1
assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
assert "looking out the window" in all_generations[0] or "staring out the window at the city skyline" in \
all_generations[0]
|