chansung's picture
Update llama2.py
6aaddfa
raw
history blame
2.56 kB
import os
import json
import requests
import sseclient
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
class LLaMA2ChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
if context is None or context == "":
return ""
else:
return f"""<<SYS>>
{context}
<</SYS>>
"""
@classmethod
def prompt(cls, pingpong, truncate_size):
ping = pingpong.ping[:truncate_size]
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
return f"""[INST] {ping} [/INST] {pong}"""
class LLaMA2ChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = fmt.ctx(self.ctx)
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
async def gen_text(
prompt,
hf_model='meta-llama/Llama-2-70b-chat-hf',
hf_token=None,
parameters=None
):
if hf_token is None:
raise ValueError("Hugging Face Token is not set")
if parameters is None:
parameters = {
'max_new_tokens': 512,
'do_sample': True,
'return_full_text': False,
'temperature': 1.0,
'top_k': 50,
# 'top_p': 1.0,
'repetition_penalty': 1.2
}
url = f'/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{hf_model}'
headers={
'Authorization': f'Bearer {hf_token}',
'Content-type': 'application/json'
}
data = {
'inputs': prompt,
'stream': True,
'options': {
'use_cache': False,
},
'parameters': parameters
}
r = requests.post(
url,
headers=headers,
data=json.dumps(data),
stream=True
)
client = sseclient.SSEClient(r)
for event in client.events():
yield json.loads(event.data)['token']['text']