|
import subprocess |
|
import sys |
|
|
|
|
|
subprocess.check_call([ |
|
f"{sys.executable}", "-m", "pip", "install", |
|
"vllm @ https://github.com/vllm-project/vllm/releases/download/v0.6.1.post1/vllm-0.6.1.post1+cu118-cp310-cp310-manylinux1_x86_64.whl" |
|
]) |
|
|
|
import json |
|
import logging |
|
import os |
|
from vllm import LLM, SamplingParams |
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
def model_fn(model_dir, context= None): |
|
model = LLM( |
|
model=model_dir, |
|
trust_remote_code=True, |
|
gpu_memory_utilization=0.9, |
|
tensor_parallel_size=4 |
|
) |
|
return model |
|
def predict_fn(data, model , context= None): |
|
try: |
|
input_text = data.pop("inputs", data) |
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
chat_template = f"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
sampling_params = SamplingParams( |
|
temperature=parameters.get("temperature", 0.7), |
|
top_p=parameters.get("top_p", 0.9), |
|
max_new_tokens=parameters.get("max_new_tokens", 512), |
|
do_sample=True, |
|
stop_tokens=["<|im_end|>", "<|im_start|>"] |
|
) |
|
|
|
outputs = model.generate(chat_template, sampling_params) |
|
generated_text = outputs[0].outputs[0].text |
|
|
|
|
|
for stop_token in ["<|im_end|>", "<|im_start|>"]: |
|
if generated_text.endswith(stop_token): |
|
generated_text = generated_text[:-len(stop_token)].strip() |
|
|
|
return {"generated_text": generated_text} |
|
except Exception as e: |
|
logger.error(f"Exception during prediction: {e}") |
|
return {"error": str(e)} |
|
|
|
def input_fn(request_body, request_content_type,context= None): |
|
if request_content_type == "application/json": |
|
return json.loads(request_body) |
|
else: |
|
raise ValueError(f"Unsupported content type: {request_content_type}") |
|
|
|
def output_fn(prediction, accept , context= None) : |
|
return json.dumps(prediction) |