|
import os, types |
|
import json |
|
import requests |
|
import time |
|
from typing import Callable, Optional |
|
from litellm.utils import ModelResponse, Usage |
|
import litellm |
|
import httpx |
|
from .prompt_templates.factory import prompt_factory, custom_prompt |
|
|
|
class ReplicateError(Exception): |
|
def __init__(self, status_code, message): |
|
self.status_code = status_code |
|
self.message = message |
|
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments") |
|
self.response = httpx.Response(status_code=status_code, request=self.request) |
|
super().__init__( |
|
self.message |
|
) |
|
|
|
class ReplicateConfig(): |
|
""" |
|
Reference: https://replicate.com/meta/llama-2-70b-chat/api |
|
- `prompt` (string): The prompt to send to the model. |
|
|
|
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`. |
|
|
|
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`. |
|
|
|
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`. |
|
|
|
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`. |
|
|
|
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`. |
|
|
|
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`. |
|
|
|
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'. |
|
|
|
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed. |
|
|
|
- `debug` (boolean): If set to `True`, it provides debugging output in logs. |
|
|
|
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. |
|
""" |
|
system_prompt: Optional[str]=None |
|
max_new_tokens: Optional[int]=None |
|
min_new_tokens: Optional[int]=None |
|
temperature: Optional[int]=None |
|
top_p: Optional[int]=None |
|
top_k: Optional[int]=None |
|
stop_sequences: Optional[str]=None |
|
seed: Optional[int]=None |
|
debug: Optional[bool]=None |
|
|
|
def __init__(self, |
|
system_prompt: Optional[str]=None, |
|
max_new_tokens: Optional[int]=None, |
|
min_new_tokens: Optional[int]=None, |
|
temperature: Optional[int]=None, |
|
top_p: Optional[int]=None, |
|
top_k: Optional[int]=None, |
|
stop_sequences: Optional[str]=None, |
|
seed: Optional[int]=None, |
|
debug: Optional[bool]=None) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != 'self' and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return {k: v for k, v in cls.__dict__.items() |
|
if not k.startswith('__') |
|
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) |
|
and v is not None} |
|
|
|
|
|
|
|
|
|
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): |
|
base_url = api_base |
|
if "deployments" in version_id: |
|
print_verbose("\nLiteLLM: Request to custom replicate deployment") |
|
version_id = version_id.replace("deployments/", "") |
|
base_url = f"https://api.replicate.com/v1/deployments/{version_id}" |
|
print_verbose(f"Deployment base URL: {base_url}\n") |
|
|
|
headers = { |
|
"Authorization": f"Token {api_token}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
initial_prediction_data = { |
|
"version": version_id, |
|
"input": input_data, |
|
} |
|
|
|
|
|
logging_obj.pre_call( |
|
input=input_data["prompt"], |
|
api_key="", |
|
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url}, |
|
) |
|
|
|
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) |
|
if response.status_code == 201: |
|
response_data = response.json() |
|
return response_data.get("urls", {}).get("get") |
|
else: |
|
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}") |
|
|
|
|
|
def handle_prediction_response(prediction_url, api_token, print_verbose): |
|
output_string = "" |
|
headers = { |
|
"Authorization": f"Token {api_token}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
status = "" |
|
logs = "" |
|
while True and (status not in ["succeeded", "failed", "canceled"]): |
|
print_verbose(f"replicate: polling endpoint: {prediction_url}") |
|
time.sleep(0.5) |
|
response = requests.get(prediction_url, headers=headers) |
|
if response.status_code == 200: |
|
response_data = response.json() |
|
if "output" in response_data: |
|
output_string = "".join(response_data['output']) |
|
print_verbose(f"Non-streamed output:{output_string}") |
|
status = response_data.get('status', None) |
|
logs = response_data.get("logs", "") |
|
if status == "failed": |
|
replicate_error = response_data.get("error", "") |
|
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}") |
|
else: |
|
|
|
print_verbose("Replicate: Failed to fetch prediction status and output.") |
|
return output_string, logs |
|
|
|
|
|
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): |
|
previous_output = "" |
|
output_string = "" |
|
|
|
headers = { |
|
"Authorization": f"Token {api_token}", |
|
"Content-Type": "application/json" |
|
} |
|
status = "" |
|
while True and (status not in ["succeeded", "failed", "canceled"]): |
|
time.sleep(0.5) |
|
print_verbose(f"replicate: polling endpoint: {prediction_url}") |
|
response = requests.get(prediction_url, headers=headers) |
|
if response.status_code == 200: |
|
response_data = response.json() |
|
status = response_data['status'] |
|
if "output" in response_data: |
|
output_string = "".join(response_data['output']) |
|
new_output = output_string[len(previous_output):] |
|
print_verbose(f"New chunk: {new_output}") |
|
yield {"output": new_output, "status": status} |
|
previous_output = output_string |
|
status = response_data['status'] |
|
if status == "failed": |
|
replicate_error = response_data.get("error", "") |
|
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}") |
|
else: |
|
|
|
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") |
|
|
|
|
|
|
|
def model_to_version_id(model): |
|
if ":" in model: |
|
split_model = model.split(":") |
|
return split_model[1] |
|
return model |
|
|
|
|
|
def completion( |
|
model: str, |
|
messages: list, |
|
api_base: str, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
logging_obj, |
|
api_key, |
|
encoding, |
|
custom_prompt_dict={}, |
|
optional_params=None, |
|
litellm_params=None, |
|
logger_fn=None, |
|
): |
|
|
|
version_id = model_to_version_id(model) |
|
|
|
config = litellm.ReplicateConfig.get_config() |
|
for k, v in config.items(): |
|
if k not in optional_params: |
|
optional_params[k] = v |
|
|
|
system_prompt = None |
|
if optional_params is not None and "supports_system_prompt" in optional_params: |
|
supports_sys_prompt = optional_params.pop("supports_system_prompt") |
|
else: |
|
supports_sys_prompt = False |
|
|
|
if supports_sys_prompt: |
|
for i in range(len(messages)): |
|
if messages[i]["role"] == "system": |
|
first_sys_message = messages.pop(i) |
|
system_prompt = first_sys_message["content"] |
|
break |
|
|
|
if model in custom_prompt_dict: |
|
|
|
model_prompt_details = custom_prompt_dict[model] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details.get("roles", {}), |
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), |
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), |
|
bos_token=model_prompt_details.get("bos_token", ""), |
|
eos_token=model_prompt_details.get("eos_token", ""), |
|
messages=messages, |
|
) |
|
else: |
|
prompt = prompt_factory(model=model, messages=messages) |
|
|
|
|
|
if system_prompt is not None: |
|
input_data = { |
|
"prompt": prompt, |
|
"system_prompt": system_prompt |
|
} |
|
|
|
else: |
|
input_data = { |
|
"prompt": prompt, |
|
**optional_params |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_response["created"] = int(time.time()) |
|
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) |
|
print_verbose(prediction_url) |
|
|
|
|
|
if "stream" in optional_params and optional_params["stream"] == True: |
|
print_verbose("streaming request") |
|
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) |
|
else: |
|
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose) |
|
model_response["ended"] = time.time() |
|
|
|
logging_obj.post_call( |
|
input=prompt, |
|
api_key="", |
|
original_response=result, |
|
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, }, |
|
) |
|
|
|
print_verbose(f"raw model_response: {result}") |
|
|
|
if len(result) == 0: |
|
result = " " |
|
|
|
|
|
if len(result) > 1: |
|
model_response["choices"][0]["message"]["content"] = result |
|
|
|
|
|
prompt_tokens = len(encoding.encode(prompt)) |
|
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", ""))) |
|
model_response["model"] = "replicate/" + model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens |
|
) |
|
model_response.usage = usage |
|
return model_response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|