purchasing_api / custom_llm.py
jonathanjordan21's picture
Create custom_llm.py
e944c71 verified
from typing import Any, List, Mapping, Optional
from langchain_core.language_models.llms import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from typing import Literal
import requests
class CustomLLM(LLM):
repo_id : str
api_token : str
model_type: Literal["text2text-generation", "text-generation"]
max_new_tokens: int = None
temperature: float = 0.001
timeout: float = None
top_p: float = None
top_k : int = None
repetition_penalty : float = None
stop : List[str] = []
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
headers = {"Authorization": f"Bearer {self.api_token}"}
API_URL = f"/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{self.repo_id}"
parameters_dict = {
'max_new_tokens': self.max_new_tokens,
'temperature': self.temperature,
'timeout': self.timeout,
'top_p': self.top_p,
'top_k': self.top_k,
'repetition_penalty': self.repetition_penalty,
'stop':self.stop
}
if self.model_type == 'text-generation':
parameters_dict["return_full_text"]=False
data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}}
data = requests.post(API_URL, headers=headers, json=data).json()
try:
return data[0]['generated_text']
except:
return data
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
'repo_id': self.repo_id,
'model_type':self.model_type,
'stop_sequences':self.stop,
'max_new_tokens': self.max_new_tokens,
'temperature': self.temperature,
'timeout': self.timeout,
'top_p': self.top_p,
'top_k': self.top_k,
'repetition_penalty': self.repetition_penalty
}