Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import json | |
import sys | |
# pip install websocket-client | |
import websocket | |
class ModelClient(object): | |
def __init__(self, endpoint_url): | |
self.endpoint_url = endpoint_url | |
self.ws = None | |
self.model = None | |
def open_session(self, model, max_length): | |
self.ws = websocket.create_connection(self.endpoint_url, enable_multithread=True) | |
self.model = model | |
payload = { | |
"type": "open_inference_session", | |
"model": self.model, | |
"max_length": max_length, | |
} | |
self.ws.send(json.dumps(payload)) | |
assert json.loads(self.ws.recv())['ok'] == True | |
def is_session(self): | |
return self.ws != None | |
def close_session(self): | |
if self.ws: | |
self.ws.close() | |
self.ws = None | |
def generate(self, prompt, **kwargs): | |
try: | |
return self._generate(prompt, **kwargs) | |
except: | |
self.close_session() | |
raise | |
def _generate(self, prompt, **kwargs): | |
payload = { | |
"type": "generate", | |
"inputs": prompt, | |
"max_new_tokens": 1, | |
"do_sample": 0, | |
"temperature": 1, | |
"stop_sequence": "</s>" if "bloomz" in self.model else "\n\n", | |
} | |
payload = {**payload, **kwargs} | |
self.ws.send(json.dumps(payload)) | |
while True: | |
data = json.loads(self.ws.recv()) | |
if not data['ok']: | |
raise Exception(data['traceback']) | |
yield data['outputs'] | |
if data['stop']: | |
break | |
def main(): | |
#client = ModelClient("ws://localhost:8000/api/v2/generate") | |
client = ModelClient("wss://chat.petals.dev/api/v2/generate") | |
client.open_session("stabilityai/StableBeluga2", 128) | |
if len(sys.argv) > 1: | |
prompt = sys.argv[1] | |
# Bloomz variant uses </s> instead of \n\n as an eos token | |
if not prompt.endswith("\n\n"): | |
prompt += "\n\n" | |
else: | |
prompt = "The SQL command to extract all the users whose name starts with A is: \n\n" | |
print(f"Prompt: {prompt}") | |
# petals.client.routing.sequence_manager.MissingBlocksError | |
for out in client.generate(prompt, | |
do_sample=True, | |
temperature=0.75, | |
top_p=0.9): | |
print(out, end="", flush=True) | |
client.close_session() | |
if __name__ == '__main__': | |
main() | |