File size: 5,232 Bytes
3f7cfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import pytest

from tests.utils import wrap_test_forked, get_llama
from enums import DocumentChoices


@wrap_test_forked
def test_cli(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in our solar system" in all_generations[0]


@wrap_test_forked
def test_cli_langchain(monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='gptj', cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           user_path=user_path,
                           visible_langchain_modes=['UserData', 'MyData'],
                           document_choice=[DocumentChoices.All_Relevant.name],
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "what the cat is doing" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_langchain_llamacpp(monkeypatch):
    prompt_type = get_llama()

    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           prompt_type=prompt_type,
                           user_path=user_path,
                           visible_langchain_modes=['UserData', 'MyData'],
                           document_choice=[DocumentChoices.All_Relevant.name],
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "The cat is sitting on a window seat and looking out the window" in all_generations[0] or \
           "staring out the window at the city skyline" in all_generations[0] or \
           "The cat is likely relaxing and enjoying" in all_generations[0] or \
           "The cat is sitting on a window seat and looking out" in all_generations[0] or \
           "cat in the image is" in all_generations[0]


@pytest.mark.need_tokens
@wrap_test_forked
def test_cli_llamacpp(monkeypatch):
    prompt_type = get_llama()

    query = "Who are you?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='llama', cli=True, cli_loop=False, score_model='None',
                           langchain_mode='Disabled',
                           prompt_type=prompt_type,
                           user_path=None,
                           visible_langchain_modes=[],
                           document_choice=[DocumentChoices.All_Relevant.name],
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "I'm a software engineer with a passion for building scalable" in all_generations[0] or \
           "how can I assist" in all_generations[0] or \
           "am a virtual assistant" in all_generations[0]


@wrap_test_forked
def test_cli_h2ogpt(monkeypatch):
    query = "What is the Earth?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b', cli=True, cli_loop=False, score_model='None')

    assert len(all_generations) == 1
    assert "The Earth is a planet in the Solar System." in all_generations[0] or \
           "The Earth is the third planet" in all_generations[0]


@wrap_test_forked
def test_cli_langchain_h2ogpt(monkeypatch):
    from tests.utils import make_user_path_test
    user_path = make_user_path_test()

    query = "What is the cat doing?"
    monkeypatch.setattr('builtins.input', lambda _: query)

    from generate import main
    all_generations = main(base_model='h2oai/h2ogpt-oig-oasst1-512-6_9b',
                           cli=True, cli_loop=False, score_model='None',
                           langchain_mode='UserData',
                           user_path=user_path,
                           visible_langchain_modes=['UserData', 'MyData'],
                           document_choice=[DocumentChoices.All_Relevant.name],
                           verbose=True)

    print(all_generations)
    assert len(all_generations) == 1
    assert "pexels-evg-kowalievska-1170986_small.jpg" in all_generations[0]
    assert "looking out the window" in all_generations[0] or "staring out the window at the city skyline" in \
           all_generations[0]