winglian commited on
Commit
35af017
·
unverified ·
2 Parent(s): cb18856 a653392

Merge pull request #87 from OpenAccess-AI-Collective/add_prompter_tests

Browse files
.github/workflows/tests.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PyTest
2
+ on: push
3
+
4
+ jobs:
5
+ test:
6
+ runs-on: ubuntu-latest
7
+ timeout-minutes: 10
8
+
9
+ steps:
10
+ - name: Check out repository code
11
+ uses: actions/checkout@v2
12
+
13
+ - name: Setup Python
14
+ uses: actions/setup-python@v2
15
+ with:
16
+ python-version: "3.9"
17
+ cache: 'pip' # caching pip dependencies
18
+
19
+ - name: Install dependencies
20
+ run: |
21
+ pip install -e .
22
+ pip install -r requirements-tests.txt
23
+
24
+ - name: Run tests
25
+ run: |
26
+ pytest tests/
requirements-tests.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pytest
tests/test_prompters.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
4
+
5
+
6
+ class AlpacaPrompterTest(unittest.TestCase):
7
+ def test_prompt_style_w_none(self):
8
+ prompter = AlpacaPrompter(prompt_style=None)
9
+ res = next(prompter.build_prompt("tell me a joke"))
10
+ # just testing that it uses instruct style
11
+ assert "### Instruction:" in res
12
+
13
+ def test_prompt_style_w_instruct(self):
14
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
15
+ res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
16
+ assert "Below is an instruction" in res
17
+ assert "### Instruction:" in res
18
+ assert "### Input:" in res
19
+ assert "alpacas" in res
20
+ assert "### Response:" in res
21
+ assert "USER:" not in res
22
+ assert "ASSISTANT:" not in res
23
+ res = next(prompter.build_prompt("tell me a joke about the following"))
24
+ assert "Below is an instruction" in res
25
+ assert "### Instruction:" in res
26
+ assert "### Input:" not in res
27
+ assert "### Response:" in res
28
+ assert "USER:" not in res
29
+ assert "ASSISTANT:" not in res
30
+
31
+ def test_prompt_style_w_chat(self):
32
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
33
+ res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
34
+ assert "Below is an instruction" in res
35
+ assert "### Instruction:" not in res
36
+ assert "### Input:" not in res
37
+ assert "alpacas" in res
38
+ assert "### Response:" not in res
39
+ assert "USER:" in res
40
+ assert "ASSISTANT:" in res
41
+ res = next(prompter.build_prompt("tell me a joke about the following"))
42
+ assert "Below is an instruction" in res
43
+ assert "### Instruction:" not in res
44
+ assert "### Input:" not in res
45
+ assert "### Response:" not in res
46
+ assert "USER:" in res
47
+ assert "ASSISTANT:" in res
48
+
49
+