|
from typing import Callable |
|
|
|
import hypothesis |
|
import pytest |
|
from hypothesis import strategies as st |
|
|
|
import tiktoken |
|
|
|
from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES |
|
|
|
|
|
def _common_prefix_len(a, b): |
|
i = 0 |
|
while i < len(a) and i < len(b) and a[i] == b[i]: |
|
i += 1 |
|
return i |
|
|
|
|
|
def _token_offsets_reference(enc, tokens): |
|
text = enc.decode(tokens, errors="strict") |
|
res = [] |
|
for i in range(len(tokens)): |
|
prefix = enc.decode(tokens[:i], errors="ignore") |
|
res.append(_common_prefix_len(text, prefix)) |
|
return res |
|
|
|
|
|
@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES) |
|
@hypothesis.given(data=st.data()) |
|
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) |
|
def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data): |
|
enc = make_enc() |
|
|
|
tokens_st = st.lists( |
|
st.integers(0, enc.n_vocab - 1).filter( |
|
lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values() |
|
), |
|
min_size=1, |
|
max_size=20, |
|
) |
|
tokens = data.draw(tokens_st) |
|
|
|
|
|
|
|
tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all") |
|
assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens) |
|
|
|
|
|
def test_basic_offsets(): |
|
enc = tiktoken.get_encoding("cl100k_base") |
|
|
|
prompt = "hello world" |
|
p, o = enc.decode_with_offsets(enc.encode(prompt)) |
|
assert p == prompt |
|
assert o == [0, 5] |
|
|
|
prompt = "hello world<|endoftext|> green cow" |
|
p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all")) |
|
assert p == prompt |
|
assert o == [0, 5, 11, 24, 30] |
|
|
|
prompt = "我非常渴望与人工智能一起工作" |
|
p, o = enc.decode_with_offsets(enc.encode(prompt)) |
|
assert p == prompt |
|
assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13] |
|
|
|
|
|
|
|
prompt = "நடிகர் சூர்யா" |
|
p, o = enc.decode_with_offsets(enc.encode(prompt)) |
|
assert p == prompt |
|
assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12] |
|
|
|
|
|
|
|
prompt = " Ġ除" |
|
p, o = enc.decode_with_offsets(enc.encode(prompt)) |
|
assert p == prompt |
|
assert o == [0, 1] |
|
|