|
import ast |
|
import copy |
|
import functools |
|
import inspect |
|
import queue |
|
import sys |
|
import os |
|
import time |
|
import traceback |
|
import typing |
|
import warnings |
|
from datetime import datetime |
|
import requests |
|
from requests import ConnectTimeout, JSONDecodeError |
|
from urllib3.exceptions import ConnectTimeoutError, MaxRetryError, ConnectionError |
|
from requests.exceptions import ConnectionError as ConnectionError2 |
|
from requests.exceptions import ReadTimeout as ReadTimeout2 |
|
|
|
if os.path.dirname(os.path.abspath(__file__)) not in sys.path: |
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' |
|
os.environ['BITSANDBYTES_NOWELCOME'] = '1' |
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
|
|
|
|
|
max_cores = max(1, os.cpu_count() // 2) |
|
if os.getenv('NUMEXPR_MAX_THREADS') is None: |
|
os.environ['NUMEXPR_MAX_THREADS'] = str(min(8, max_cores)) |
|
if os.getenv('NUMEXPR_NUM_THREADS') is None: |
|
os.environ['NUMEXPR_NUM_THREADS'] = str(min(8, max_cores)) |
|
if os.getenv('OMP_NUM_THREADS') is None: |
|
os.environ['OMP_NUM_THREADS'] = str(min(8, max_cores)) |
|
if os.getenv('OPENBLAS_NUM_THREADS') is None: |
|
os.environ['OPENBLAS_NUM_THREADS'] = str(min(8, max_cores)) |
|
if os.getenv('DUCKDB_NUM_THREADS') is None: |
|
os.environ['DUCKDB_NUM_THREADS'] = str(min(4, max_cores)) |
|
if os.getenv('RAYON_RS_NUM_CPUS') is None: |
|
os.environ['RAYON_RS_NUM_CPUS'] = str(min(8, max_cores)) |
|
if os.getenv('RAYON_NUM_THREADS') is None: |
|
os.environ['RAYON_NUM_THREADS'] = str(min(8, max_cores)) |
|
|
|
import numpy as np |
|
from evaluate_params import eval_func_param_names, no_default_param_names, input_args_list |
|
from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, \ |
|
LangChainAction, LangChainAgent, DocumentChoice, LangChainTypes, super_source_prefix, \ |
|
super_source_postfix, t5_type, get_langchain_prompts, gr_to_lg, invalid_key_msg |
|
from loaders import get_loaders |
|
from utils import set_seed, clear_torch_cache, NullContext, wrapped_partial, EThread, get_githash, \ |
|
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, \ |
|
have_langchain, set_openai, cuda_vis_check, H2O_Fire, lg_to_gr, str_to_list, str_to_dict, get_token_count |
|
|
|
start_faulthandler() |
|
import_matplotlib() |
|
|
|
SEED = 1236 |
|
set_seed(SEED) |
|
|
|
from typing import Union |
|
|
|
import torch |
|
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer |
|
|
|
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt |
|
from stopping import get_stopping |
|
|
|
langchain_actions = [x.value for x in list(LangChainAction)] |
|
|
|
langchain_agents_list = [x.value for x in list(LangChainAgent)] |
|
|
|
|
|
def main( |
|
load_8bit: bool = False, |
|
load_4bit: bool = False, |
|
low_bit_mode: int = 1, |
|
load_half: bool = None, |
|
load_gptq: str = '', |
|
load_exllama: bool = False, |
|
use_safetensors: bool = False, |
|
revision: str = None, |
|
use_gpu_id: bool = True, |
|
base_model: str = '', |
|
tokenizer_base_model: str = '', |
|
lora_weights: str = "", |
|
gpu_id: int = 0, |
|
compile_model: bool = None, |
|
use_cache: bool = None, |
|
inference_server: str = "", |
|
prompt_type: Union[int, str] = None, |
|
prompt_dict: typing.Dict = None, |
|
system_prompt: str = '', |
|
|
|
|
|
llamacpp_dict: typing.Dict = dict(n_gpu_layers=100, use_mlock=True, n_batch=1024, n_gqa=0), |
|
model_path_llama: str = 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q8_0.bin', |
|
|
|
model_name_gptj: str = 'ggml-gpt4all-j-v1.3-groovy.bin', |
|
model_name_gpt4all_llama: str = 'ggml-wizardLM-7B.q4_2.bin', |
|
model_name_exllama_if_no_config: str = 'TheBloke/Nous-Hermes-Llama2-GPTQ', |
|
|
|
model_lock: typing.List[typing.Dict[str, str]] = None, |
|
model_lock_columns: int = None, |
|
fail_if_cannot_connect: bool = False, |
|
|
|
|
|
temperature: float = None, |
|
top_p: float = None, |
|
top_k: int = None, |
|
num_beams: int = None, |
|
repetition_penalty: float = None, |
|
num_return_sequences: int = None, |
|
do_sample: bool = None, |
|
max_new_tokens: int = None, |
|
min_new_tokens: int = None, |
|
early_stopping: Union[bool, str] = None, |
|
max_time: float = None, |
|
|
|
memory_restriction_level: int = None, |
|
debug: bool = False, |
|
save_dir: str = None, |
|
share: bool = False, |
|
local_files_only: bool = False, |
|
resume_download: bool = True, |
|
use_auth_token: Union[str, bool] = False, |
|
trust_remote_code: Union[str, bool] = True, |
|
rope_scaling: dict = None, |
|
max_seq_len: int = None, |
|
offload_folder: str = "offline_folder", |
|
|
|
src_lang: str = "English", |
|
tgt_lang: str = "Russian", |
|
|
|
prepare_offline_level: int = 0, |
|
cli: bool = False, |
|
cli_loop: bool = True, |
|
gradio: bool = True, |
|
gradio_offline_level: int = 0, |
|
server_name: str = "0.0.0.0", |
|
root_path: str = "", |
|
chat: bool = True, |
|
chat_conversation: typing.List[typing.Tuple[str, str]] = None, |
|
text_context_list: typing.List[str] = None, |
|
stream_output: bool = True, |
|
async_output: bool = True, |
|
num_async: int = 3, |
|
show_examples: bool = None, |
|
verbose: bool = False, |
|
h2ocolors: bool = True, |
|
dark: bool = False, |
|
height: int = 600, |
|
show_lora: bool = True, |
|
show_llama: bool = True, |
|
show_gpt4all: bool = False, |
|
login_mode_if_model0: bool = False, |
|
block_gradio_exit: bool = True, |
|
concurrency_count: int = 1, |
|
api_open: bool = False, |
|
allow_api: bool = True, |
|
input_lines: int = 1, |
|
gradio_size: str = None, |
|
show_copy_button: bool = True, |
|
large_file_count_mode: bool = False, |
|
pre_load_embedding_model: bool = True, |
|
|
|
auth: Union[typing.List[typing.Tuple[str, str]], str] = None, |
|
auth_filename: str = None, |
|
auth_access: str = 'open', |
|
auth_freeze: bool = False, |
|
auth_message: str = None, |
|
guest_name: str = "guest", |
|
enforce_h2ogpt_api_key: bool = None, |
|
h2ogpt_api_keys: Union[list, str] = [], |
|
h2ogpt_key: str = None, |
|
|
|
max_max_time=None, |
|
max_max_new_tokens=None, |
|
|
|
visible_models: list = None, |
|
visible_visible_models: bool = True, |
|
visible_submit_buttons: bool = True, |
|
visible_side_bar: bool = True, |
|
visible_doc_track: bool = True, |
|
visible_chat_tab: bool = True, |
|
visible_doc_selection_tab: bool = True, |
|
visible_doc_view_tab: bool = True, |
|
visible_chat_history_tab: bool = True, |
|
visible_expert_tab: bool = True, |
|
visible_models_tab: bool = True, |
|
visible_system_tab: bool = True, |
|
visible_tos_tab: bool = False, |
|
visible_login_tab: bool = True, |
|
visible_hosts_tab: bool = False, |
|
chat_tables: bool = False, |
|
visible_h2ogpt_header: bool = True, |
|
max_raw_chunks: int = None, |
|
|
|
sanitize_user_prompt: bool = False, |
|
sanitize_bot_response: bool = False, |
|
|
|
extra_model_options: typing.List[str] = [], |
|
extra_lora_options: typing.List[str] = [], |
|
extra_server_options: typing.List[str] = [], |
|
|
|
score_model: str = 'auto', |
|
|
|
eval_filename: str = None, |
|
eval_prompts_only_num: int = 0, |
|
eval_prompts_only_seed: int = 1234, |
|
eval_as_output: bool = False, |
|
|
|
langchain_mode: str = None, |
|
user_path: str = None, |
|
langchain_modes: list = [LangChainMode.USER_DATA.value, LangChainMode.MY_DATA.value, LangChainMode.LLM.value, |
|
LangChainMode.DISABLED.value], |
|
langchain_mode_paths: dict = {LangChainMode.USER_DATA.value: None}, |
|
langchain_mode_types: dict = {LangChainMode.USER_DATA.value: LangChainTypes.SHARED.value}, |
|
detect_user_path_changes_every_query: bool = False, |
|
|
|
langchain_action: str = LangChainAction.QUERY.value, |
|
langchain_agents: list = [], |
|
force_langchain_evaluate: bool = False, |
|
|
|
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value], |
|
visible_langchain_agents: list = langchain_agents_list.copy(), |
|
|
|
document_subset: str = DocumentSubset.Relevant.name, |
|
document_choice: list = [DocumentChoice.ALL.value], |
|
|
|
use_llm_if_no_docs: bool = True, |
|
load_db_if_exists: bool = True, |
|
keep_sources_in_context: bool = False, |
|
db_type: str = 'chroma', |
|
use_openai_embedding: bool = False, |
|
use_openai_model: bool = False, |
|
hf_embedding_model: str = None, |
|
migrate_embedding_model: str = False, |
|
auto_migrate_db: bool = False, |
|
cut_distance: float = 1.64, |
|
answer_with_sources: bool = True, |
|
append_sources_to_answer: bool = True, |
|
show_accordions: bool = True, |
|
top_k_docs_max_show: int = 10, |
|
show_link_in_sources: bool = True, |
|
pre_prompt_query: str = None, |
|
prompt_query: str = None, |
|
pre_prompt_summary: str = None, |
|
prompt_summary: str = None, |
|
add_chat_history_to_context: bool = True, |
|
add_search_to_context: bool = False, |
|
context: str = '', |
|
iinput: str = '', |
|
allow_upload_to_user_data: bool = True, |
|
reload_langchain_state: bool = True, |
|
allow_upload_to_my_data: bool = True, |
|
enable_url_upload: bool = True, |
|
enable_text_upload: bool = True, |
|
enable_sources_list: bool = True, |
|
chunk: bool = True, |
|
chunk_size: int = 512, |
|
top_k_docs: int = None, |
|
docs_ordering_type: str = 'reverse_ucurve_sort', |
|
min_max_new_tokens=256, |
|
auto_reduce_chunks: bool = True, |
|
max_chunks: int = 100, |
|
headsize: int = 50, |
|
n_jobs: int = -1, |
|
|
|
|
|
use_unstructured=True, |
|
use_playwright=False, |
|
use_selenium=False, |
|
|
|
|
|
use_pymupdf='auto', |
|
use_unstructured_pdf='auto', |
|
use_pypdf='auto', |
|
enable_pdf_ocr='auto', |
|
enable_pdf_doctr='auto', |
|
try_pdf_as_html='auto', |
|
|
|
|
|
enable_ocr=False, |
|
enable_doctr=False, |
|
enable_pix2struct=False, |
|
enable_captions=True, |
|
|
|
pre_load_caption_model: bool = False, |
|
caption_gpu: bool = True, |
|
captions_model: str = "Salesforce/blip-image-captioning-base", |
|
doctr_gpu: bool = True, |
|
|
|
|
|
jq_schema='.[]', |
|
|
|
max_quality: bool = False, |
|
|
|
enable_heap_analytics: bool = True, |
|
heap_app_id: str = "1680123994", |
|
): |
|
""" |
|
|
|
:param load_8bit: load model in 8-bit using bitsandbytes |
|
:param load_4bit: load model in 4-bit using bitsandbytes |
|
:param low_bit_mode: 0: no quantization config 1: change compute 2: nf4 3: double quant 4: 2 and 3 |
|
See: https://huggingface.co/docs/transformers/main_classes/quantization |
|
If using older bitsandbytes or transformers, 0 is required |
|
:param load_half: load model in float16 (None means auto, which means True unless t5 based model) |
|
otherwise specify bool |
|
:param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g |
|
:param load_exllama: whether to use exllama (only applicable to LLaMa1/2 models with 16-bit or GPTQ |
|
:param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version) |
|
:param revision: Which HF revision to use |
|
:param use_gpu_id: whether to control devices with gpu_id. If False, then spread across GPUs |
|
:param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab |
|
:param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model. |
|
:param lora_weights: LORA weights path/HF link |
|
:param gpu_id: if use_gpu_id, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1 |
|
:param compile_model Whether to compile the model |
|
:param use_cache: Whether to use caching in model (some models fail when multiple threads use) |
|
:param inference_server: Consume base_model as type of model at this address |
|
Address can be text-generation-server hosting that base_model |
|
e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b |
|
|
|
Or Address can be "openai_chat" or "openai" for OpenAI API |
|
Or Address can be "openai_azure_chat" or "openai_azure" for Azure OpenAI API |
|
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo |
|
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003 |
|
e.g. python generate.py --inference_server="openai_azure_chat:<deployment_name>:<baseurl>:<api_version>:<model_version>" --base_model=gpt-3.5-turbo |
|
e.g. python generate.py --inference_server="openai_azure:<deployment_name>:<baseurl>:<api_version>:<model_version>" --base_model=text-davinci-003 |
|
Optionals (Replace with None or just leave empty but keep :) |
|
<deployment_name> of some deployment name |
|
<baseurl>: e.g. "<endpoint>.openai.azure.com" for some <endpoint> without https:// |
|
<api_version> of some api, e.g. 2023-05-15 |
|
<model_version> e.g. 0613 |
|
|
|
Or Address can be for vLLM: |
|
Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint |
|
Note: vllm_chat not supported by vLLM project. |
|
|
|
Or Address can be replicate: |
|
Use: |
|
--inference_server=replicate:<model name string> will use a Replicate server, requiring a Replicate key. |
|
e.g. <model name string> looks like "a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5" |
|
|
|
Or Address can be for AWS SageMaker: |
|
Use: "sagemaker_chat:<endpoint name>" for chat models that AWS sets up as dialog |
|
Use: "sagemaker:<endpoint name>" for foundation models that AWS only text as inputs |
|
|
|
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model |
|
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True) |
|
:param system_prompt: Universal system prompt to use if model supports, like LLaMa2, regardless of prompt_type definition. |
|
Useful for langchain case to control behavior, or OpenAI and Replicate. |
|
If None, 'None', or 'auto', then for LLaMa or other models that internally have system_prompt, will use default for each model |
|
If '', then no system prompt (no empty template given to model either, just no system part added at all) |
|
If some string not in ['None', 'auto'], then use that as system prompt |
|
Default is '', no system_prompt, because often it hurts performance/accuracy |
|
|
|
:param llamacpp_dict: |
|
n_gpu_layers: for llama.cpp based models, number of GPU layers to offload (default is all by using large value) |
|
use_mlock: when using `llama.cpp` based CPU models, for computers with low system RAM or slow CPUs, recommended False |
|
n_batch: Can make smaller to 128 for slower low-memory CPU systems |
|
n_gqa: Required to be 8 for LLaMa 70B |
|
... etc. anything that could be passed to llama.cpp or GPT4All models |
|
e.g. python generate.py --base_model='llama' --prompt_type=llama2 --score_model=None --langchain_mode='UserData' --user_path=user_path --llamacpp_dict="{'n_gpu_layers':25,'n_batch':128}" |
|
:param model_path_llama: model path or URL (for auto-download) |
|
:param model_name_gptj: model path or URL (for auto-download) |
|
:param model_name_gpt4all_llama: model path or URL (for auto-download) |
|
:param model_name_exllama_if_no_config: exllama model's full path for model, tokenizer, generator for use when no HuggingFace config |
|
|
|
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models |
|
Only used if gradio = True |
|
List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict |
|
If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict |
|
Can specify model_lock instead of those items on CLI |
|
As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py. |
|
Also, tokenizer_base_model and lora_weights are optional. |
|
Also, inference_server is optional if loading model from local system. |
|
All models provided will automatically appear in compare model mode |
|
Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled |
|
:param model_lock_columns: How many columns to show if locking models (and so showing all at once) |
|
If None, then defaults to up to 3 |
|
if -1, then all goes into 1 row |
|
Maximum value is 4 due to non-dynamic gradio rendering elements |
|
:param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore. |
|
Useful when many endpoints and want to just see what works, but still have to wait for timeout. |
|
|
|
:param temperature: generation temperature |
|
:param top_p: generation top_p |
|
:param top_k: generation top_k |
|
:param num_beams: generation number of beams |
|
:param repetition_penalty: generation repetition penalty |
|
:param num_return_sequences: generation number of sequences (1 forced for chat) |
|
:param do_sample: generation sample |
|
:param max_new_tokens: generation max new tokens |
|
:param min_new_tokens: generation min tokens |
|
:param early_stopping: generation early stopping |
|
:param max_time: maximum time to allow for generation |
|
:param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case |
|
:param debug: enable debug mode |
|
:param save_dir: directory chat data is saved to |
|
:param share: whether to share the gradio app with sharable URL |
|
:param local_files_only: whether to only use local files instead of doing to HF for models |
|
:param resume_download: whether to resume downloads from HF for models |
|
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before) |
|
:param trust_remote_code: whether to use trust any code needed for HF model |
|
:param rope_scaling: |
|
For HF transformers model: scaling for rope-based models, e.g. --rope_scaling="{'type':'dynamic', 'factor':4}" |
|
For exllama model: --rope_scaling="{'alpha_value':4}" . This automatically scales max_seq_len for exllama |
|
:param max_seq_len: Manually set maximum sequence length for the LLM |
|
:param offload_folder: path for spilling model onto disk |
|
:param src_lang: source languages to include if doing translation (None = all) |
|
:param tgt_lang: target languages to include if doing translation (None = all) |
|
|
|
:param prepare_offline_level: |
|
Whether to just prepare for offline use, do not go into cli, eval, or gradio run modes |
|
0 : no prep |
|
1: prepare just h2oGPT with exact same setup as passed to CLI and ensure all artifacts for h2oGPT alone added to ~/.cache/ |
|
2: prepare h2oGPT + all inference servers so h2oGPT+inference servers can use the ~/.cache/ |
|
:param cli: whether to use CLI (non-gradio) interface. |
|
:param cli_loop: whether to loop for CLI (False usually only for testing) |
|
:param gradio: whether to enable gradio, or to enable benchmark mode |
|
:param gradio_offline_level: > 0, then change fonts so full offline |
|
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached |
|
== 2 means backend and frontend don't need internet to download any fonts. |
|
Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading. |
|
This option further disables google fonts for downloading, which is less intrusive than uploading, |
|
but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior. |
|
Also set --share=False to avoid sharing a gradio live link. |
|
:param server_name: IP to use. In linux 0.0.0.0 is good choice so exposed to outside host, else for only local use 127.0.0.1. |
|
For windows/MAC 0.0.0.0 or 127.0.0.1 will work, but may need to specify actual LAN IP address for other LAN clients to see. |
|
:param root_path: The root path (or "mount point") of the application, |
|
if it's not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy |
|
that forwards requests to the application. For example, if the application is served at "https://example.com/myapp", |
|
the `root_path` should be set to "/myapp". |
|
:param chat: whether to enable chat mode with chat history |
|
:param chat_conversation: list of tuples of (human, bot) conversation pre-appended to existing chat when using instruct/chat models |
|
Requires also add_chat_history_to_context = True |
|
It does *not* require chat=True, so works with nochat_api etc. |
|
:param text_context_list: List of strings to add to context for non-database version of document Q/A for faster handling via API etc. |
|
Forces LangChain code path and uses as many entries in list as possible given max_seq_len, with first assumed to be most relevant and to go near prompt. |
|
:param stream_output: whether to stream output |
|
:param async_output: Whether to do asyncio handling |
|
For summarization |
|
Applicable to HF TGI server |
|
Only if stream_output=False in CLI, UI, or API |
|
:param num_async: Number of simultaneously allowed asyncio calls to make for async_output |
|
Too many will overload inference server, too few will be too slow |
|
:param show_examples: whether to show clickable examples in gradio |
|
:param verbose: whether to show verbose prints |
|
:param h2ocolors: whether to use H2O.ai theme |
|
:param dark: whether to use dark mode for UI by default (still controlled in UI) |
|
:param height: height of chat window |
|
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand) |
|
:param show_llama: whether to show LLaMa.cpp/GPT4All options in UI (only likely useful if have weak GPUs) |
|
:param show_gpt4all: whether to show GPT4All models in UI (not often useful, llama.cpp models best) |
|
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped |
|
:param block_gradio_exit: whether to block gradio exit (used for testing) |
|
:param concurrency_count: gradio concurrency count (1 is optimal for LLMs) |
|
:param api_open: If False, don't let API calls skip gradio queue |
|
:param allow_api: whether to allow API calls at all to gradio server |
|
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit) |
|
:param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large". |
|
Small useful for many chatbots in model_lock mode |
|
:param show_copy_button: Whether to show copy button for chatbots |
|
:param large_file_count_mode: Whether to force manual update to UI of drop-downs, good idea if millions of chunks or documents |
|
:param pre_load_embedding_model: Whether to preload embedding model for shared use across DBs and users (multi-thread safe only) |
|
|
|
:param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...] |
|
e.g. --auth=[('jon','password')] with no spaces |
|
e.g. --auth="[('jon', 'password)())(')]" so any special characters can be used |
|
e.g. --auth=auth.json to specify persisted state file with name auth.json (auth_filename then not required) |
|
e.g. --auth='' will use default auth.json as file name for persisted state file (auth_filename then not required) |
|
e.g. --auth=None will use no auth, but still keep track of auth state, just not from logins |
|
:param auth_filename: |
|
Set auth filename, used only if --auth= was passed list of user/passwords |
|
:param auth_access: |
|
'open': Allow new users to be added |
|
'closed': Stick to existing users |
|
:param auth_freeze: whether freeze authentication based upon current file, no longer update file |
|
:param auth_message: Message to show if having users login, fixed if passed, else dynamic internally |
|
:param guest_name: guess name if using auth and have open access. |
|
If '', then no guest allowed even if open access, then all databases for each user always persisted |
|
:param enforce_h2ogpt_api_key: Whether to enforce h2oGPT token usage for API |
|
:param h2ogpt_api_keys: list of tokens allowed for API access or file accessed on demand for json of list of keys |
|
:param h2ogpt_key: E.g. can be set when accessing gradio h2oGPT server from local gradio h2oGPT server that acts as client to that inference server |
|
|
|
:param max_max_time: Maximum max_time for gradio slider |
|
:param max_max_new_tokens: Maximum max_new_tokens for gradio slider |
|
:param min_max_new_tokens: Minimum of max_new_tokens, when auto-scaling down to handle more docs/prompt, but still let generation have some tokens |
|
|
|
:param visible_models: Which models in model_lock list to show by default |
|
Takes integers of position in model_lock (model_states) list or strings of base_model names |
|
Ignored if model_lock not used |
|
For nochat API, this is single item within a list for model by name or by index in model_lock |
|
If None, then just use first model in model_lock list |
|
If model_lock not set, use model selected by CLI --base_model etc. |
|
|
|
:param visible_visible_models: Whether visible models drop-down is visible in UI |
|
:param visible_submit_buttons: whether submit buttons are visible when UI first comes up |
|
:param visible_side_bar: whether left side bar is visible when UI first comes up |
|
:param visible_doc_track: whether left side bar's document tracking is visible when UI first comes up |
|
:param visible_chat_tab: "" for chat tab |
|
:param visible_doc_selection_tab: "" for doc selection tab |
|
:param visible_doc_view_tab: "" for doc view tab |
|
:param visible_chat_history_tab: "" for chat history tab |
|
:param visible_expert_tab: "" for expert tab |
|
:param visible_models_tab: "" for models tab |
|
:param visible_system_tab: "" for system tab |
|
:param visible_tos_tab: "" for ToS tab |
|
:param visible_login_tab: "" for Login tab |
|
:param visible_hosts_tab: "" for hosts tab |
|
:param chat_tables: Just show Chat as block without tab (useful if want only chat view) |
|
:param visible_h2ogpt_header: Whether github stars, URL, logo, and QR code are visible |
|
:param max_raw_chunks: Maximum number of chunks to show in UI when asking for raw DB text from documents/collection |
|
|
|
:param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing) |
|
Requires optional packages: |
|
pip install alt-profanity-check==1.2.2 better-profanity==0.7.0 |
|
:param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow) |
|
:param extra_model_options: extra models to show in list in gradio |
|
:param extra_lora_options: extra LORA to show in list in gradio |
|
:param extra_server_options: extra servers to show in list in gradio |
|
:param score_model: which model to score responses |
|
None: no response scoring |
|
'auto': auto mode, '' (no model) for CPU or 1 GPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for >=2 GPUs, |
|
because on CPU takes too much compute just for scoring response |
|
:param eval_filename: json file to use for evaluation, if None is sharegpt |
|
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples |
|
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling |
|
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself |
|
|
|
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py. |
|
None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled |
|
If not passed, then chosen to be first langchain_modes, else langchain_mode->Disabled is set if no langchain_modes either |
|
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present. |
|
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode. |
|
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources |
|
:param langchain_modes: dbs to generate at launch to be ready for LLM |
|
Apart from additional user-defined collections, can include ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs'] |
|
But wiki_full is expensive and requires preparation |
|
To allow personal space only live in session, add 'MyData' to list |
|
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData'] |
|
If have own user modes, need to add these here or add in UI. |
|
:param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents |
|
E.g. "{'UserData2': 'userpath2'}" |
|
A disk path be None, e.g. --langchain_mode_paths="{'UserData2': None}" even if existing DB, to avoid new documents being added from that path, source links that are on disk still work. |
|
If `--user_path` was passed, that path is used for 'UserData' instead of the value in this dict |
|
:param langchain_mode_types: dict of langchain_mode keys and database types |
|
E.g. python generate.py --base_model=llama --langchain_modes=['TestData'] --langchain_mode_types="{'TestData':'shared'}" |
|
The type is attempted to be inferred if directory already exists, then don't have to pass this |
|
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes). |
|
Expensive for large number of files, so not done by default. By default only detect changes during db loading. |
|
|
|
:param langchain_action: Mode langchain operations in on documents. |
|
Query: Make query of document(s) |
|
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce |
|
Summarize_all: Summarize document(s) using entire document at once |
|
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary |
|
:param langchain_agents: Which agents to use |
|
'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env |
|
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing. |
|
|
|
:param visible_langchain_actions: Which actions to allow |
|
:param visible_langchain_agents: Which agents to allow |
|
|
|
:param document_subset: Default document choice when taking subset of collection |
|
:param document_choice: Chosen document(s) by internal name, 'All' means use all docs |
|
|
|
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom |
|
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db |
|
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually |
|
:param db_type: 'faiss' for in-memory |
|
'chroma' (for chroma >= 0.4) |
|
'chroma_old' (for chroma < 0.4) -- recommended for large collections |
|
'weaviate' for persisted on disk |
|
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db |
|
:param use_openai_model: Whether to use OpenAI model for use with vector db |
|
:param hf_embedding_model: Which HF embedding model to use for vector db |
|
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs |
|
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2" |
|
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl' |
|
We support automatically changing of embeddings for chroma, with a backup of db made if this is done |
|
:param migrate_embedding_model: whether to use hf_embedding_model embedding even if database already had an embedding set. |
|
used to migrate all embeddings to a new one, but will take time to re-embed. |
|
Default (False) is to use the prior embedding for existing databases, and only use hf_embedding_model for new databases |
|
If had old database without embedding saved, then hf_embedding_model is also used. |
|
:param auto_migrate_db: whether to automatically migrate any chroma<0.4 database from duckdb -> sqlite version |
|
:param cut_distance: Distance to cut off references with larger distances when showing references. |
|
1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references. |
|
For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references. |
|
:param answer_with_sources: Whether to determine (and return) sources |
|
:param append_sources_to_answer: Whether to place source information in chat response (ignored by LLM). Always disabled for API. |
|
:param show_accordions: whether to show accordion for document references in chatbot UI |
|
:param top_k_docs_max_show: Max number of docs to show in UI for sources |
|
If web search is enabled, then this is modified to be max(top_k_docs_max_show, number of links used in search) |
|
:param show_link_in_sources: Whether to show URL link to source document in references |
|
:param pre_prompt_query: prompt before documents to query, if None then use internal defaults |
|
:param prompt_query: prompt after documents to query, if None then use internal defaults |
|
:param pre_prompt_summary: prompt before documents to summarize, if None then use internal defaults |
|
:param prompt_summary: prompt after documents to summarize, if None then use internal defaults |
|
For summarize, normal to have empty query (nothing added in ask anything in UI or empty string in API) |
|
If pass query, template is "Focusing on %s, %s" % (query, prompt_summary) |
|
If pass query and iinput, template is "Focusing on %s, %s, %s" % (query, iinput, prompt_summary) |
|
:param add_chat_history_to_context: Include chat context when performing action |
|
Not supported yet for openai_chat when using document collection instead of LLM |
|
Also not supported when using CLI mode |
|
:param add_search_to_context: Include web search in context as augmented prompt |
|
:param context: Default context to use (for system pre-context in gradio UI) |
|
context comes before chat_conversation and any document Q/A from text_context_list |
|
:param iinput: Default input for instruction-based prompts |
|
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs) |
|
Ensure pass user_path for the files uploaded to be moved to this location for linking. |
|
:param reload_langchain_state: Whether to reload langchain_modes.pkl file that contains any new user collections. |
|
:param allow_upload_to_my_data: Whether to allow file uploads to update personal vector db |
|
:param enable_url_upload: Whether to allow upload from URL |
|
:param enable_text_upload: Whether to allow upload of text |
|
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db |
|
:param chunk: Whether to chunk data (True unless know data is already optimally chunked) |
|
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length |
|
:param top_k_docs: For langchain_action query: number of chunks to give LLM |
|
-1 : auto-fills context up to max_seq_len |
|
For langchain_action summarize: number of document parts, like pages for PDF. |
|
There's no such thing as chunks for summarization. |
|
-1 : auto-fills context up to max_seq_len |
|
:param docs_ordering_type: |
|
Type of ordering of docs. |
|
'best_first': Order by score so score is worst match near prompt |
|
'best_near_prompt' or 'reverse_sort' : reverse docs order so most relevant is closest to question. |
|
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too. |
|
But smaller 6_9 models fail to use newest context and can get stuck on old information. |
|
'' or None (i.e. default) or 'reverse_ucurve_sort' : Sort so most relevant is either near start or near end |
|
Best to avoid "lost in middle" as well as avoid hallucinating off starting content that LLM focuses on alot. |
|
:param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt |
|
:param max_chunks: If top_k_docs=-1, maximum number of chunks to allow |
|
:param headsize: Maximum number of characters for head of document document for UI to show |
|
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default) |
|
|
|
:param use_unstructured: Enable unstructured URL loader |
|
:param use_playwright: Enable PlayWright URL loader |
|
:param use_selenium: Enable Selenium URL loader |
|
|
|
:param use_pymupdf: enable PyMUPDF 'auto' means use first, use others if they are 'auto' if no result |
|
:param use_unstructured_pdf: enable Unstructured PDF loader, 'auto' means use if pymupdf fails to get doc result |
|
:param use_pypdf: enable PyPDF loader 'auto' means use if unstructured fails to get doc result |
|
:param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text. |
|
if enable_pdf_doctr == 'on' then don't do. |
|
'on' means always do OCR as additional parsing of same documents |
|
'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked) |
|
:param enable_pdf_doctr: Whether to support doctr on pdfs, 'auto' means use do if failed to get doc result so far |
|
:param try_pdf_as_html: Try "PDF" as if HTML file, in case web link has .pdf extension but really is just HTML |
|
|
|
:param enable_ocr: Whether to support OCR on images |
|
:param enable_doctr: Whether to support doctr on images (using OCR better than enable_ocr=True) |
|
:param enable_pix2struct: Whether to support pix2struct on images for captions |
|
:param enable_captions: Whether to support captions using BLIP for image files as documents, |
|
then preloads that model if pre_load_caption_model=True |
|
|
|
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader |
|
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context |
|
Recommended if using larger caption model |
|
:param captions_model: Which model to use for captions. |
|
captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable |
|
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state |
|
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state |
|
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions |
|
Disabled for CPU since BLIP requires CUDA |
|
:param caption_gpu: If support caption, then use GPU if exists |
|
|
|
:param doctr_gpu: If support doctr, then use GPU if exists |
|
|
|
:param jq_schema: control json loader |
|
By default '.[]' ingests everything in brute-force way, but better to match your schema |
|
See: https://python.langchain.com/docs/modules/data_connection/document_loaders/json#using-jsonloader |
|
|
|
:param max_quality: Choose maximum quality ingestion with all available parsers |
|
Pro: Catches document when some default parsers would fail |
|
Pro: Enables DocTR that has much better OCR than Tesseract |
|
Con: Fills DB with results from all parsers, so similarity search gives redundant results |
|
|
|
:param enable_heap_analytics: Toggle telemetry. |
|
:param heap_app_id: App ID for Heap, change to your ID. |
|
:return: |
|
""" |
|
if base_model is None: |
|
base_model = '' |
|
if tokenizer_base_model is None: |
|
tokenizer_base_model = '' |
|
if lora_weights is None: |
|
lora_weights = '' |
|
if inference_server is None: |
|
inference_server = '' |
|
|
|
|
|
model_lock = os.getenv('model_lock', str(model_lock)) |
|
model_lock = ast.literal_eval(model_lock) |
|
|
|
chat_conversation = str_to_list(chat_conversation) |
|
text_context_list = str_to_list(text_context_list) |
|
|
|
llamacpp_dict = str_to_dict(llamacpp_dict) |
|
|
|
llamacpp_dict['model_path_llama'] = model_path_llama |
|
llamacpp_dict['model_name_gptj'] = model_name_gptj |
|
llamacpp_dict['model_name_gpt4all_llama'] = model_name_gpt4all_llama |
|
llamacpp_dict['model_name_exllama_if_no_config'] = model_name_exllama_if_no_config |
|
|
|
if 'n_batch' not in llamacpp_dict: |
|
llamacpp_dict['n_batch'] = 128 |
|
if 'n_gpu_layers' not in llamacpp_dict: |
|
llamacpp_dict['n_gpu_layers'] = 100 |
|
if 'n_gqa' not in llamacpp_dict: |
|
llamacpp_dict['n_gqa'] = 0 |
|
|
|
if os.environ.get('SERPAPI_API_KEY') is None and LangChainAgent.SEARCH.value in visible_langchain_agents: |
|
visible_langchain_agents.remove(LangChainAgent.SEARCH.value) |
|
|
|
if model_lock: |
|
assert gradio, "model_lock only supported for gradio=True" |
|
assert not cli, "model_lock only supported for cli=False" |
|
assert not (not cli and not gradio), "model_lock only supported for eval (cli=gradio=False)" |
|
assert not base_model, "Don't specify model_lock and base_model" |
|
assert not tokenizer_base_model, "Don't specify model_lock and tokenizer_base_model" |
|
assert not lora_weights, "Don't specify model_lock and lora_weights" |
|
assert not inference_server, "Don't specify model_lock and inference_server" |
|
|
|
|
|
|
|
n_jobs = int(os.getenv('n_jobs', str(n_jobs))) |
|
is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0'))) |
|
is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0'))) |
|
is_public = is_hf or is_gpth2oai |
|
if is_public: |
|
visible_tos_tab = visible_hosts_tab = True |
|
if enforce_h2ogpt_api_key is None: |
|
enforce_h2ogpt_api_key = True |
|
else: |
|
if enforce_h2ogpt_api_key is None: |
|
enforce_h2ogpt_api_key = False |
|
if isinstance(h2ogpt_api_keys, str) and not os.path.isfile(h2ogpt_api_keys): |
|
h2ogpt_api_keys = str_to_list(h2ogpt_api_keys) |
|
if memory_restriction_level is None: |
|
memory_restriction_level = 2 if is_hf else 0 |
|
else: |
|
assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level |
|
if n_jobs == -1: |
|
|
|
n_jobs = max(1, os.cpu_count() // 2) |
|
if is_public and os.getenv('n_jobs') is None: |
|
n_jobs = min(n_jobs, max(1, min(os.cpu_count() // 2, 8))) |
|
admin_pass = os.getenv("ADMIN_PASS") |
|
|
|
|
|
raise_generate_gpu_exceptions = True |
|
|
|
rope_scaling = str_to_dict(rope_scaling) |
|
|
|
if isinstance(auth, str): |
|
if auth.strip().startswith('['): |
|
auth = str_to_list(auth) |
|
if isinstance(auth, str) and auth: |
|
auth_filename = auth |
|
if not auth_filename: |
|
auth_filename = "auth.json" |
|
assert isinstance(auth, (str, list, tuple, type(None))), "Unknown type %s for auth=%s" % (type(auth), auth) |
|
|
|
|
|
use_auth_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", use_auth_token) |
|
allow_upload_to_user_data = bool( |
|
int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data))))) |
|
allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data))))) |
|
height = int(os.environ.get("HEIGHT", height)) |
|
h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors))) |
|
|
|
|
|
|
|
langchain_modes = ast.literal_eval(os.environ.get("langchain_modes", str(langchain_modes))) |
|
if not isinstance(langchain_modes, list): |
|
langchain_modes = [] |
|
|
|
if LangChainMode.DISABLED.value not in langchain_modes: |
|
langchain_modes.append(LangChainMode.DISABLED.value) |
|
if not have_langchain: |
|
|
|
langchain_mode = LangChainMode.DISABLED.value |
|
langchain_modes = [langchain_mode] |
|
|
|
|
|
langchain_mode_paths = str_to_dict(langchain_mode_paths) |
|
langchain_mode_types = str_to_dict(langchain_mode_types) |
|
for lmode in [LangChainMode.GITHUB_H2OGPT.value, |
|
LangChainMode.H2O_DAI_DOCS.value, |
|
LangChainMode.WIKI.value, |
|
LangChainMode.WIKI_FULL.value, |
|
]: |
|
if lmode not in langchain_mode_types: |
|
langchain_mode_types[lmode] = 'shared' |
|
if lmode not in langchain_mode_paths: |
|
langchain_mode_types[lmode] = '' |
|
if user_path: |
|
user_path = makedirs(user_path, use_base=True) |
|
langchain_mode_paths['UserData'] = user_path |
|
langchain_mode_paths['UserData'] = LangChainTypes.SHARED.value |
|
|
|
if is_public: |
|
allow_upload_to_user_data = False |
|
if LangChainMode.USER_DATA.value in langchain_modes: |
|
langchain_modes.remove(LangChainMode.USER_DATA.value) |
|
if max_raw_chunks is None: |
|
max_raw_chunks = 30 if is_public else 1000000 |
|
|
|
|
|
if allow_upload_to_user_data: |
|
|
|
if user_path: |
|
langchain_mode_paths['UserData'] = user_path |
|
|
|
assert langchain_action in langchain_actions, "Invalid langchain_action %s not in %s" % ( |
|
langchain_action, langchain_actions) |
|
assert len( |
|
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents |
|
|
|
|
|
langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode) |
|
if have_langchain and langchain_mode is None: |
|
|
|
if LangChainMode.LLM.value in langchain_modes: |
|
langchain_mode = LangChainMode.LLM.value |
|
elif len(langchain_modes) >= 1: |
|
|
|
langchain_mode = langchain_modes[0] |
|
if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']: |
|
if verbose: |
|
print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True) |
|
elif allow_upload_to_my_data: |
|
if verbose: |
|
print("Auto set langchain_mode=%s. Could use MyData instead." |
|
" To allow UserData to pull files from disk," |
|
" set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode, |
|
flush=True) |
|
else: |
|
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes) |
|
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]: |
|
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.") |
|
if langchain_mode is None: |
|
|
|
langchain_mode = LangChainMode.DISABLED.value |
|
print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True) |
|
|
|
if langchain_mode not in langchain_modes: |
|
langchain_modes.append(langchain_mode) |
|
|
|
if is_public: |
|
allow_upload_to_user_data = False |
|
input_lines = 1 |
|
temperature = 0.2 if temperature is None else temperature |
|
top_p = 0.85 if top_p is None else top_p |
|
top_k = 70 if top_k is None else top_k |
|
if is_hf: |
|
do_sample = True if do_sample is None else do_sample |
|
top_k_docs = 3 if top_k_docs is None else top_k_docs |
|
else: |
|
|
|
do_sample = False if do_sample is None else do_sample |
|
top_k_docs = 4 if top_k_docs is None else top_k_docs |
|
|
|
if memory_restriction_level == 2: |
|
if not base_model and not inference_server and not model_lock: |
|
base_model = 'h2oai/h2ogpt-oasst1-512-12b' |
|
|
|
load_8bit = True |
|
load_4bit = False |
|
elif not inference_server: |
|
top_k_docs = 10 if top_k_docs is None else top_k_docs |
|
if memory_restriction_level >= 2: |
|
load_8bit = True |
|
load_4bit = False |
|
if hf_embedding_model is None: |
|
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" |
|
top_k_docs = 3 if top_k_docs is None else top_k_docs |
|
if top_k_docs is None: |
|
top_k_docs = 3 |
|
if is_public: |
|
if not max_time: |
|
max_time = 60 * 2 |
|
if not max_max_time: |
|
max_max_time = max_time |
|
if not max_new_tokens: |
|
max_new_tokens = 256 |
|
if not max_max_new_tokens: |
|
max_max_new_tokens = 512 |
|
else: |
|
if not max_max_time: |
|
max_max_time = 60 * 20 |
|
if not max_max_new_tokens: |
|
max_max_new_tokens = 1024 |
|
if is_hf: |
|
|
|
share = False |
|
if not max_time: |
|
max_time = 60 * 1 |
|
if not max_max_time: |
|
max_max_time = max_time |
|
|
|
save_dir = os.getenv('SAVE_DIR', save_dir) |
|
save_dir = makedirs(save_dir, exist_ok=True, tmp_ok=True, use_base=True) |
|
score_model = os.getenv('SCORE_MODEL', score_model) |
|
if str(score_model) == 'None': |
|
score_model = '' |
|
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count)) |
|
api_open = bool(int(os.getenv('API_OPEN', str(int(api_open))))) |
|
allow_api = bool(int(os.getenv('ALLOW_API', str(int(allow_api))))) |
|
|
|
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
|
n_gpus, gpu_ids = cuda_vis_check(n_gpus) |
|
|
|
if load_half is None and t5_type(base_model): |
|
load_half = False |
|
print("load_half=%s auto-set for %s to avoid bad generation" % (load_half, base_model), flush=True) |
|
|
|
if n_gpus == 0 or get_device() == "mps": |
|
|
|
|
|
if get_device() != "mps": |
|
print("No GPUs detected", flush=True) |
|
|
|
enable_captions = False |
|
gpu_id = None |
|
load_8bit = False |
|
load_4bit = False |
|
low_bit_mode = 1 |
|
if load_half is None: |
|
|
|
load_half = False |
|
load_gptq = '' |
|
load_exllama = False |
|
use_gpu_id = False |
|
if get_device() == "cuda": |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.enabled = False |
|
torch.set_default_dtype(torch.float32) |
|
if is_public and not inference_server and not model_lock: |
|
|
|
|
|
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model |
|
if hf_embedding_model is None: |
|
|
|
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" |
|
if score_model == 'auto': |
|
score_model = '' |
|
else: |
|
if load_half is None: |
|
load_half = True |
|
|
|
if score_model == 'auto': |
|
if n_gpus >= 2: |
|
|
|
score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2' |
|
else: |
|
score_model = '' |
|
if hf_embedding_model is None: |
|
|
|
hf_embedding_model = 'hkunlp/instructor-large' |
|
|
|
|
|
if base_model: |
|
model_lower = base_model.lower() |
|
elif model_lock: |
|
|
|
assert len(model_lock) > 0 and model_lock[0]['base_model'] |
|
model_lower = model_lock[0]['base_model'].lower() |
|
else: |
|
model_lower = '' |
|
if not gradio: |
|
|
|
stream_output = False |
|
|
|
chat = False |
|
|
|
first_para = False |
|
text_limit = None |
|
|
|
if compile_model is None: |
|
|
|
compile_model = not cli |
|
|
|
if offload_folder: |
|
offload_folder = makedirs(offload_folder, exist_ok=True, tmp_ok=True, use_base=True) |
|
|
|
|
|
caption_loader = None |
|
doctr_loader = None |
|
pix2struct_loader = None |
|
|
|
image_loaders_options0, image_loaders_options, \ |
|
pdf_loaders_options0, pdf_loaders_options, \ |
|
url_loaders_options0, url_loaders_options = lg_to_gr(**locals()) |
|
jq_schema0 = jq_schema |
|
|
|
image_loaders = image_loaders_options0 |
|
pdf_loaders = pdf_loaders_options0 |
|
url_loaders = url_loaders_options0 |
|
|
|
placeholder_instruction, placeholder_input, \ |
|
stream_output, show_examples, \ |
|
prompt_type, prompt_dict, \ |
|
temperature, top_p, top_k, num_beams, \ |
|
max_new_tokens, min_new_tokens, early_stopping, max_time, \ |
|
repetition_penalty, num_return_sequences, \ |
|
do_sample, \ |
|
src_lang, tgt_lang, \ |
|
examples, \ |
|
task_info = \ |
|
get_generate_params(model_lower, |
|
chat, |
|
stream_output, show_examples, |
|
prompt_type, prompt_dict, |
|
system_prompt, |
|
pre_prompt_query, prompt_query, |
|
pre_prompt_summary, prompt_summary, |
|
temperature, top_p, top_k, num_beams, |
|
max_new_tokens, min_new_tokens, early_stopping, max_time, |
|
repetition_penalty, num_return_sequences, |
|
do_sample, |
|
top_k_docs, |
|
chunk, |
|
chunk_size, |
|
image_loaders, |
|
pdf_loaders, |
|
url_loaders, |
|
jq_schema, |
|
docs_ordering_type, |
|
min_max_new_tokens, |
|
verbose, |
|
) |
|
|
|
git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH" |
|
locals_dict = locals() |
|
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()]) |
|
if verbose: |
|
print(f"Generating model with params:\n{locals_print}", flush=True) |
|
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), git_hash), flush=True) |
|
|
|
if langchain_mode != LangChainMode.DISABLED.value: |
|
|
|
from gpt_langchain import prep_langchain, get_some_dbs_from_hf, get_persist_directory |
|
if is_hf: |
|
get_some_dbs_from_hf() |
|
dbs = {} |
|
for langchain_mode1 in langchain_modes: |
|
langchain_type = langchain_mode_types.get(langchain_mode1, LangChainTypes.EITHER.value) |
|
if langchain_type == LangChainTypes.PERSONAL.value: |
|
|
|
continue |
|
persist_directory1, langchain_type = get_persist_directory(langchain_mode1, langchain_type=langchain_type) |
|
langchain_mode_types[langchain_mode1] = langchain_type |
|
if langchain_type == LangChainTypes.PERSONAL.value: |
|
|
|
continue |
|
try: |
|
db = prep_langchain(persist_directory1, |
|
load_db_if_exists, |
|
db_type, use_openai_embedding, |
|
langchain_mode1, langchain_mode_paths, langchain_mode_types, |
|
hf_embedding_model, |
|
migrate_embedding_model, |
|
auto_migrate_db, |
|
kwargs_make_db=locals(), |
|
verbose=verbose) |
|
finally: |
|
|
|
clear_torch_cache() |
|
dbs[langchain_mode1] = db |
|
|
|
dbs = {k: v for k, v in dbs.items() if v is not None} |
|
else: |
|
dbs = {} |
|
|
|
if os.environ.get("TEST_LANGCHAIN_IMPORT"): |
|
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" |
|
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" |
|
|
|
other_model_state_defaults = dict(load_8bit=load_8bit, load_4bit=load_4bit, low_bit_mode=low_bit_mode, |
|
load_half=load_half, |
|
load_gptq=load_gptq, load_exllama=load_exllama, use_safetensors=use_safetensors, |
|
revision=revision, use_gpu_id=use_gpu_id, gpu_id=gpu_id, |
|
compile_model=compile_model, |
|
use_cache=use_cache, |
|
llamacpp_dict=llamacpp_dict, model_path_llama=model_path_llama, |
|
model_name_gptj=model_name_gptj, |
|
model_name_gpt4all_llama=model_name_gpt4all_llama, |
|
model_name_exllama_if_no_config=model_name_exllama_if_no_config, |
|
) |
|
model_state_none = dict(model=None, tokenizer=None, device=None, |
|
base_model=None, tokenizer_base_model=None, lora_weights=None, |
|
inference_server=None, prompt_type=None, prompt_dict=None, |
|
visible_models=None, h2ogpt_key=None, |
|
) |
|
model_state_none.update(other_model_state_defaults) |
|
my_db_state0 = {LangChainMode.MY_DATA.value: [None, None, None]} |
|
selection_docs_state0 = dict(langchain_modes=langchain_modes, |
|
langchain_mode_paths=langchain_mode_paths, |
|
langchain_mode_types=langchain_mode_types) |
|
selection_docs_state = copy.deepcopy(selection_docs_state0) |
|
|
|
if cli or not gradio: |
|
|
|
model_name = base_model |
|
pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary = \ |
|
get_langchain_prompts(pre_prompt_query, prompt_query, |
|
pre_prompt_summary, prompt_summary, |
|
model_name, inference_server, |
|
model_path_llama) |
|
|
|
if cli: |
|
from cli import run_cli |
|
return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals())) |
|
elif not gradio: |
|
from eval import run_eval |
|
return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals())) |
|
elif gradio or prepare_offline_level > 0: |
|
|
|
from gradio_runner import go_gradio |
|
|
|
|
|
model_states = [] |
|
model_list = [dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights, |
|
inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict, |
|
visible_models=None, h2ogpt_key=None)] |
|
model_list[0].update(other_model_state_defaults) |
|
|
|
|
|
|
|
|
|
model_list0 = copy.deepcopy(model_list) |
|
model_state0 = model_state_none.copy() |
|
assert len(model_state_none) == len(model_state0) |
|
if model_lock: |
|
model_list = model_lock |
|
|
|
for model_dict in reversed(model_list): |
|
|
|
|
|
model_dict['base_model'] = model_dict.get('base_model', '') |
|
model_dict['tokenizer_base_model'] = model_dict.get('tokenizer_base_model', '') |
|
model_dict['lora_weights'] = model_dict.get('lora_weights', '') |
|
model_dict['inference_server'] = model_dict.get('inference_server', '') |
|
if prepare_offline_level >= 2: |
|
if 'openai' not in model_dict['inference_server'] and 'replicate' not in model_dict['inference_server']: |
|
|
|
model_dict['inference_server'] = '' |
|
prompt_type_infer = not model_dict.get('prompt_type') |
|
model_dict['prompt_type'] = model_dict.get('prompt_type', |
|
model_list0[0]['prompt_type']) |
|
|
|
for k in model_list0[0]: |
|
if k not in model_dict: |
|
model_dict[k] = model_list0[0][k] |
|
|
|
|
|
|
|
pre_prompt_query1, prompt_query1, pre_prompt_summary1, prompt_summary1 = ( |
|
get_langchain_prompts(pre_prompt_query, prompt_query, |
|
pre_prompt_summary, prompt_summary, |
|
model_dict['base_model'], |
|
model_dict['inference_server'], |
|
model_dict['model_path_llama'])) |
|
|
|
|
|
pre_prompt_query = pre_prompt_query or pre_prompt_query1 |
|
prompt_query = prompt_query or prompt_query1 |
|
pre_prompt_summary = pre_prompt_summary or pre_prompt_summary1 |
|
prompt_summary = prompt_summary or prompt_summary1 |
|
|
|
|
|
if prompt_type_infer: |
|
model_lower1 = model_dict['base_model'].lower() |
|
if model_lower1 in inv_prompt_type_to_model_lower: |
|
model_dict['prompt_type'] = inv_prompt_type_to_model_lower[model_lower1] |
|
model_dict['prompt_dict'], error0 = get_prompt(model_dict['prompt_type'], '', |
|
chat=False, context='', reduced=False, |
|
making_context=False, |
|
return_dict=True, |
|
system_prompt=system_prompt) |
|
else: |
|
model_dict['prompt_dict'] = prompt_dict |
|
else: |
|
model_dict['prompt_dict'] = prompt_dict |
|
model_dict['prompt_dict'] = model_dict.get('prompt_dict', model_dict['prompt_dict']) |
|
|
|
all_kwargs = locals().copy() |
|
all_kwargs.update(model_dict) |
|
if model_dict['base_model'] and not login_mode_if_model0: |
|
model0, tokenizer0, device = get_model(reward_type=False, |
|
**get_kwargs(get_model, exclude_names=['reward_type'], |
|
**all_kwargs)) |
|
else: |
|
|
|
model0, tokenizer0, device = None, None, None |
|
if model0 is None: |
|
if fail_if_cannot_connect: |
|
raise RuntimeError("Could not connect, see logs") |
|
|
|
if isinstance(model_lock, list): |
|
model_lock.remove(model_dict) |
|
continue |
|
model_state_trial = dict(model=model0, tokenizer=tokenizer0, device=device) |
|
model_state_trial.update(model_dict) |
|
diff_keys = set(list(model_state_none.keys())).symmetric_difference(model_state_trial.keys()) |
|
assert len(model_state_none) == len(model_state_trial), diff_keys |
|
print("Model %s" % model_dict, flush=True) |
|
if model_lock: |
|
|
|
model_states.insert(0, model_state_trial) |
|
|
|
model_state0 = model_state_trial.copy() |
|
else: |
|
model_state0 = model_state_trial.copy() |
|
assert len(model_state_none) == len(model_state0) |
|
|
|
visible_models = str_to_list(visible_models, allow_none=True) |
|
all_models = [x.get('base_model', xi) for xi, x in enumerate(model_states)] |
|
visible_models_state0 = [x.get('base_model', xi) for xi, x in enumerate(model_states) if |
|
visible_models is None or |
|
x.get('base_model', xi) in visible_models or |
|
xi in visible_models] |
|
|
|
|
|
|
|
|
|
if len(model_states) >= 1: |
|
max_seq_len = model_states[0]['tokenizer'].model_max_length |
|
|
|
|
|
all_kwargs = locals().copy() |
|
smodel, stokenizer, sdevice = get_score_model(reward_type=True, |
|
**get_kwargs(get_score_model, exclude_names=['reward_type'], |
|
**all_kwargs)) |
|
score_model_state0 = dict(model=smodel, tokenizer=stokenizer, device=sdevice, |
|
base_model=score_model, tokenizer_base_model='', lora_weights='', |
|
inference_server='', prompt_type='', prompt_dict='', |
|
visible_models=None, h2ogpt_key=None) |
|
|
|
if enable_captions: |
|
if pre_load_caption_model: |
|
from image_captions import H2OImageCaptionLoader |
|
caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model() |
|
else: |
|
caption_loader = 'gpu' if n_gpus > 0 and caption_gpu else 'cpu' |
|
else: |
|
caption_loader = False |
|
|
|
if pre_load_embedding_model and \ |
|
langchain_mode != LangChainMode.DISABLED.value and \ |
|
not use_openai_embedding: |
|
from src.gpt_langchain import get_embedding |
|
hf_embedding_model = dict(name=hf_embedding_model, |
|
model=get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model, |
|
preload=True)) |
|
if enable_doctr or enable_pdf_ocr in [True, 'auto', 'on']: |
|
doctr_loader = 'gpu' if n_gpus > 0 and doctr_gpu else 'cpu' |
|
else: |
|
doctr_loader = False |
|
|
|
|
|
go_gradio(**locals()) |
|
|
|
|
|
def get_config(base_model, |
|
use_auth_token=False, |
|
trust_remote_code=True, |
|
offload_folder=None, |
|
revision=None, |
|
rope_scaling=None, |
|
triton_attn=False, |
|
long_sequence=True, |
|
return_model=False, |
|
raise_exception=False, |
|
max_seq_len=None, |
|
verbose=False, |
|
): |
|
from accelerate import init_empty_weights |
|
with init_empty_weights(): |
|
from transformers import AutoConfig |
|
try: |
|
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
revision=revision, |
|
rope_scaling=rope_scaling if rope_scaling else None) |
|
except OSError as e: |
|
if raise_exception: |
|
raise |
|
if 'not a local folder and is not a valid model identifier listed on' in str( |
|
e) or '404 Client Error' in str(e) or "couldn't connect" in str(e): |
|
|
|
|
|
if max_seq_len is None and base_model.lower() in non_hf_types: |
|
print("Could not determine --max_seq_len, setting to 2048. Pass if not correct", flush=True) |
|
max_seq_len = 2048 |
|
|
|
return None, None, max_seq_len |
|
else: |
|
raise |
|
if triton_attn and 'mpt-' in base_model.lower(): |
|
config.attn_config['attn_impl'] = 'triton' |
|
if long_sequence: |
|
if 'mpt-7b-storywriter' in base_model.lower(): |
|
config.update({"max_seq_len": 83968}) |
|
if 'mosaicml/mpt-7b-chat' in base_model.lower(): |
|
config.update({"max_seq_len": 4096}) |
|
if 'mpt-30b' in base_model.lower(): |
|
config.update({"max_seq_len": 2 * 8192}) |
|
if return_model and \ |
|
issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())): |
|
model = AutoModel.from_config( |
|
config, |
|
trust_remote_code=trust_remote_code, |
|
) |
|
else: |
|
|
|
model = None |
|
if 'falcon' in base_model.lower(): |
|
config.use_cache = False |
|
|
|
|
|
if max_seq_len is not None: |
|
print("Overriding max_seq_len -> %d" % max_seq_len, flush=True) |
|
else: |
|
if hasattr(config, 'max_seq_len'): |
|
max_seq_len = int(config.max_seq_len) |
|
elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int): |
|
|
|
max_seq_len = config.max_position_embeddings |
|
if verbose: |
|
print("Used max_position_embeddings=%s as base model (pre-rope) max_seq_len." |
|
" If not desired, pass --max_seq_len and set to some integer value." % config.max_position_embeddings, |
|
flush=True) |
|
elif hasattr(config, 'n_ctx'): |
|
|
|
max_seq_len = int(config.n_ctx) |
|
else: |
|
print("Could not determine --max_seq_len, setting to 2048. Pass if not correct", flush=True) |
|
max_seq_len = 2048 |
|
|
|
|
|
|
|
|
|
if rope_scaling: |
|
if rope_scaling.get('factor'): |
|
|
|
max_seq_len *= rope_scaling.get('factor') |
|
elif rope_scaling.get('alpha_value'): |
|
|
|
|
|
max_seq_len *= rope_scaling.get('alpha_value') |
|
print("Automatically setting max_seq_len=%d for RoPE scaling" % max_seq_len, flush=True) |
|
|
|
return config, model, max_seq_len |
|
|
|
|
|
def get_non_lora_model(base_model, model_loader, load_half, |
|
load_gptq, |
|
load_exllama, |
|
use_safetensors, |
|
revision, |
|
model_kwargs, reward_type, |
|
config, model, |
|
gpu_id=0, |
|
): |
|
""" |
|
Ensure model gets on correct device |
|
""" |
|
|
|
if model is not None: |
|
|
|
|
|
|
|
from accelerate import infer_auto_device_map |
|
device_map = infer_auto_device_map( |
|
model, |
|
dtype=torch.float16 if load_half else torch.float32, |
|
) |
|
if hasattr(model, 'model'): |
|
device_map_model = infer_auto_device_map( |
|
model.model, |
|
dtype=torch.float16 if load_half else torch.float32, |
|
) |
|
device_map.update(device_map_model) |
|
else: |
|
device_map = "auto" |
|
|
|
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
|
n_gpus, gpu_ids = cuda_vis_check(n_gpus) |
|
|
|
if n_gpus > 0: |
|
if gpu_id >= 0: |
|
|
|
|
|
if reward_type: |
|
device_map = {'': n_gpus - 1} |
|
else: |
|
device_map = {'': min(n_gpus - 1, gpu_id)} |
|
if gpu_id == -1: |
|
device_map = {'': 'cuda'} |
|
else: |
|
device_map = {'': 'cpu'} |
|
model_kwargs['load_in_8bit'] = False |
|
model_kwargs['load_in_4bit'] = False |
|
print('device_map: %s' % device_map, flush=True) |
|
|
|
load_in_8bit = model_kwargs.get('load_in_8bit', False) |
|
load_in_4bit = model_kwargs.get('load_in_4bit', False) |
|
model_kwargs['device_map'] = device_map |
|
model_kwargs['use_safetensors'] = use_safetensors |
|
model_kwargs['revision'] = revision |
|
pop_unused_model_kwargs(model_kwargs) |
|
|
|
if load_exllama: |
|
model = model_loader |
|
elif load_gptq: |
|
if 'Llama-2-70B-chat-GPTQ' in base_model: |
|
model_kwargs.update(dict(inject_fused_attention=False)) |
|
model_kwargs.pop('torch_dtype', None) |
|
model_kwargs.pop('device_map') |
|
model = model_loader( |
|
model_name_or_path=base_model, |
|
model_basename=load_gptq, |
|
**model_kwargs, |
|
) |
|
elif load_in_8bit or load_in_4bit or not load_half: |
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs, |
|
) |
|
else: |
|
|
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs, |
|
) |
|
if not getattr(model, "is_quantized", False): |
|
model = model.half() |
|
return model |
|
|
|
|
|
def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False): |
|
inference_server, headers = get_hf_server(inference_server) |
|
|
|
from gradio_utils.grclient import GradioClient |
|
gr_client = None |
|
hf_client = None |
|
if headers is None: |
|
try: |
|
print("GR Client Begin: %s %s" % (inference_server, base_model), flush=True) |
|
|
|
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) |
|
gr_client = GradioClient(inference_server) |
|
print("GR Client End: %s" % inference_server, flush=True) |
|
except (OSError, ValueError) as e: |
|
|
|
gr_client = None |
|
print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(e)), flush=True) |
|
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, |
|
JSONDecodeError, ReadTimeout2, KeyError) as e: |
|
t, v, tb = sys.exc_info() |
|
ex = ''.join(traceback.format_exception(t, v, tb)) |
|
print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(ex)), flush=True) |
|
if raise_connection_exception: |
|
raise |
|
|
|
if gr_client is None: |
|
res = None |
|
from text_generation import Client as HFClient |
|
print("HF Client Begin: %s %s" % (inference_server, base_model)) |
|
try: |
|
hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) |
|
|
|
res = hf_client.generate('What?', max_new_tokens=1) |
|
hf_client = HFClient(inference_server, headers=headers, timeout=300) |
|
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, |
|
JSONDecodeError, ReadTimeout2, KeyError) as e: |
|
hf_client = None |
|
t, v, tb = sys.exc_info() |
|
ex = ''.join(traceback.format_exception(t, v, tb)) |
|
print("HF Client Failed %s %s: %s" % (inference_server, base_model, str(ex))) |
|
if raise_connection_exception: |
|
raise |
|
print("HF Client End: %s %s : %s" % (inference_server, base_model, res)) |
|
return inference_server, gr_client, hf_client |
|
|
|
|
|
def get_model( |
|
load_8bit: bool = False, |
|
load_4bit: bool = False, |
|
low_bit_mode: int = 1, |
|
load_half: bool = True, |
|
load_gptq: str = '', |
|
load_exllama: bool = False, |
|
use_safetensors: bool = False, |
|
revision: str = None, |
|
use_gpu_id: bool = True, |
|
base_model: str = '', |
|
inference_server: str = "", |
|
tokenizer_base_model: str = '', |
|
lora_weights: str = "", |
|
gpu_id: int = 0, |
|
n_jobs=None, |
|
|
|
reward_type: bool = None, |
|
local_files_only: bool = False, |
|
resume_download: bool = True, |
|
use_auth_token: Union[str, bool] = False, |
|
trust_remote_code: bool = True, |
|
offload_folder: str = None, |
|
rope_scaling: dict = None, |
|
max_seq_len: int = None, |
|
compile_model: bool = True, |
|
llamacpp_dict=None, |
|
|
|
verbose: bool = False, |
|
): |
|
""" |
|
|
|
:param load_8bit: load model in 8-bit, not supported by all models |
|
:param load_4bit: load model in 4-bit, not supported by all models |
|
:param low_bit_mode: See gen.py |
|
:param load_half: load model in 16-bit |
|
:param load_gptq: GPTQ model_basename |
|
:param load_exllama: whether to use exllama |
|
:param use_safetensors: use safetensors file |
|
:param revision: |
|
:param use_gpu_id: Use torch infer of optimal placement of layers on devices (for non-lora case) |
|
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches |
|
So it is not the default |
|
:param base_model: name/path of base model |
|
:param inference_server: whether base_model is hosted locally ('') or via http (url) |
|
:param tokenizer_base_model: name/path of tokenizer |
|
:param lora_weights: name/path |
|
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1) |
|
:param n_jobs: number of cores to use (e.g. for llama CPU model) |
|
:param reward_type: reward type model for sequence classification |
|
:param local_files_only: use local files instead of from HF |
|
:param resume_download: resume downloads from HF |
|
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo |
|
:param trust_remote_code: trust code needed by model |
|
:param offload_folder: offload folder |
|
:param rope_scaling: scaling for rope-based models, e.g. "{'type':'dynamic', 'factor':4}" |
|
:param max_seq_len: override for maximum sequence length for model |
|
:param max_seq_len: if set, use as max_seq_len for model |
|
:param compile_model: whether to compile torch model |
|
:param llamacpp_dict: dict of llama.cpp and GPT4All model options |
|
:param verbose: |
|
:return: |
|
""" |
|
print("Starting get_model: %s %s" % (base_model, inference_server), flush=True) |
|
|
|
triton_attn = False |
|
long_sequence = True |
|
config_kwargs = dict(use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
rope_scaling=rope_scaling, |
|
triton_attn=triton_attn, |
|
long_sequence=long_sequence, |
|
revision=revision, |
|
max_seq_len=max_seq_len, |
|
verbose=verbose) |
|
config, _, max_seq_len = get_config(base_model, **config_kwargs, raise_exception=False) |
|
|
|
if base_model in non_hf_types: |
|
assert config is None, "Expected config None for %s" % base_model |
|
|
|
llama_type_from_config = 'llama' in str(config).lower() |
|
llama_type_from_name = "llama" in base_model.lower() |
|
llama_type = llama_type_from_config or llama_type_from_name |
|
if "xgen" in base_model.lower() or 'llama2' in base_model.lower() or 'llama-2' in base_model.lower(): |
|
llama_type = False |
|
if llama_type: |
|
if verbose: |
|
print("Detected as llama type from" |
|
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True) |
|
|
|
model_name_exllama_if_no_config = '' if not llamacpp_dict else llamacpp_dict.get('model_name_exllama_if_no_config', |
|
'') |
|
model_loader, tokenizer_loader, conditional_type = ( |
|
get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, |
|
load_gptq=load_gptq, load_exllama=load_exllama, config=config, |
|
rope_scaling=rope_scaling, max_seq_len=max_seq_len, |
|
model_name_exllama_if_no_config=model_name_exllama_if_no_config)) |
|
|
|
tokenizer_kwargs = dict(local_files_only=local_files_only, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
revision=revision, |
|
padding_side='left', |
|
config=config, |
|
) |
|
if not tokenizer_base_model: |
|
tokenizer_base_model = base_model |
|
|
|
if load_exllama: |
|
tokenizer = tokenizer_loader |
|
elif config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str): |
|
if load_exllama: |
|
tokenizer = tokenizer_loader |
|
else: |
|
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs) |
|
|
|
|
|
|
|
set_model_max_len(max_seq_len, tokenizer, verbose=False) |
|
|
|
|
|
tokenizer.model_max_length = tokenizer.model_max_length - 50 |
|
else: |
|
tokenizer = None |
|
|
|
if isinstance(inference_server, str) and inference_server.startswith("http"): |
|
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server, |
|
base_model=base_model) |
|
client = gr_client or hf_client |
|
|
|
if tokenizer is None: |
|
|
|
if os.getenv("HARD_ASSERTS") and base_model not in non_hf_types: |
|
raise RuntimeError("Unexpected tokenizer=None") |
|
tokenizer = FakeTokenizer() |
|
return client, tokenizer, 'http' |
|
if isinstance(inference_server, str) and ( |
|
inference_server.startswith('openai') or |
|
inference_server.startswith('vllm') or |
|
inference_server.startswith('replicate') or |
|
inference_server.startswith('sagemaker') |
|
): |
|
if inference_server.startswith('openai'): |
|
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY" |
|
|
|
|
|
max_seq_len = model_token_mapping[base_model] |
|
if inference_server.startswith('replicate'): |
|
assert len(inference_server.split(':')) >= 3, "Expected replicate:model string, got %s" % inference_server |
|
assert os.getenv('REPLICATE_API_TOKEN'), "Set environment for REPLICATE_API_TOKEN" |
|
assert max_seq_len is not None, "Please pass --max_seq_len=<max_seq_len> for replicate models." |
|
try: |
|
import replicate as replicate_python |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import replicate python package. " |
|
"Please install it with `pip install replicate`." |
|
) |
|
if inference_server.startswith('sagemaker'): |
|
assert len( |
|
inference_server.split( |
|
':')) >= 3, "Expected sagemaker_chat:<endpoint name>:<region>, got %s" % inference_server |
|
assert os.getenv('AWS_ACCESS_KEY_ID'), "Set environment for AWS_ACCESS_KEY_ID" |
|
assert os.getenv('AWS_SECRET_ACCESS_KEY'), "Set environment for AWS_SECRET_ACCESS_KEY" |
|
|
|
|
|
if inference_server.startswith('openai') or tokenizer is None: |
|
|
|
tokenizer = FakeTokenizer(model_max_length=max_seq_len - 50) |
|
return inference_server, tokenizer, inference_server |
|
assert not inference_server, "Malformed inference_server=%s" % inference_server |
|
if base_model in non_hf_types: |
|
from gpt4all_llm import get_model_tokenizer_gpt4all |
|
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model, n_jobs=n_jobs, |
|
max_seq_len=max_seq_len, |
|
llamacpp_dict=llamacpp_dict) |
|
return model, tokenizer, device |
|
if load_exllama: |
|
return model_loader, tokenizer, 'cuda' |
|
|
|
|
|
return get_hf_model(load_8bit=load_8bit, |
|
load_4bit=load_4bit, |
|
low_bit_mode=low_bit_mode, |
|
load_half=load_half, |
|
load_gptq=load_gptq, |
|
use_safetensors=use_safetensors, |
|
revision=revision, |
|
use_gpu_id=use_gpu_id, |
|
base_model=base_model, |
|
tokenizer_base_model=tokenizer_base_model, |
|
lora_weights=lora_weights, |
|
gpu_id=gpu_id, |
|
|
|
reward_type=reward_type, |
|
local_files_only=local_files_only, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
rope_scaling=rope_scaling, |
|
compile_model=compile_model, |
|
|
|
llama_type=llama_type, |
|
config_kwargs=config_kwargs, |
|
tokenizer_kwargs=tokenizer_kwargs, |
|
|
|
verbose=verbose) |
|
|
|
|
|
def get_hf_model(load_8bit: bool = False, |
|
load_4bit: bool = False, |
|
low_bit_mode: int = 1, |
|
load_half: bool = True, |
|
load_gptq: str = '', |
|
use_safetensors: bool = False, |
|
revision: str = None, |
|
use_gpu_id: bool = True, |
|
base_model: str = '', |
|
tokenizer_base_model: str = '', |
|
lora_weights: str = "", |
|
gpu_id: int = 0, |
|
|
|
reward_type: bool = None, |
|
local_files_only: bool = False, |
|
resume_download: bool = True, |
|
use_auth_token: Union[str, bool] = False, |
|
trust_remote_code: bool = True, |
|
offload_folder: str = None, |
|
rope_scaling: dict = None, |
|
compile_model: bool = True, |
|
|
|
llama_type: bool = False, |
|
config_kwargs=None, |
|
tokenizer_kwargs=None, |
|
|
|
verbose: bool = False, |
|
): |
|
assert config_kwargs is not None |
|
assert tokenizer_kwargs is not None |
|
|
|
load_exllama = False |
|
|
|
if lora_weights is not None and lora_weights.strip(): |
|
if verbose: |
|
print("Get %s lora weights" % lora_weights, flush=True) |
|
device = get_device() |
|
|
|
if 'gpt2' in base_model.lower(): |
|
|
|
load_8bit = False |
|
load_4bit = False |
|
|
|
assert base_model.strip(), ( |
|
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)" |
|
) |
|
|
|
model_loader, tokenizer_loader, conditional_type = ( |
|
get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, |
|
load_gptq=load_gptq, load_exllama=load_exllama)) |
|
|
|
config, _, max_seq_len = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs) |
|
|
|
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str): |
|
if load_exllama: |
|
tokenizer = tokenizer_loader |
|
else: |
|
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, |
|
**tokenizer_kwargs) |
|
else: |
|
tokenizer = tokenizer_loader |
|
|
|
if isinstance(tokenizer, str): |
|
|
|
model = model_loader(tokenizer, |
|
model=base_model, |
|
device=0 if device == "cuda" else -1, |
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32) |
|
else: |
|
assert device in ["cuda", "cpu", "mps"], "Unsupported device %s" % device |
|
model_kwargs = dict(local_files_only=local_files_only, |
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
revision=revision, |
|
|
|
) |
|
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower(): |
|
if use_gpu_id and gpu_id is not None and gpu_id >= 0 and device == 'cuda': |
|
device_map = {"": gpu_id} |
|
else: |
|
device_map = "auto" |
|
model_kwargs.update(dict(load_in_8bit=load_8bit, |
|
load_in_4bit=load_4bit, |
|
device_map=device_map, |
|
)) |
|
if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0: |
|
|
|
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu")) |
|
|
|
if 'OpenAssistant/reward-model'.lower() in base_model.lower(): |
|
|
|
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'} |
|
model_kwargs.pop('torch_dtype', None) |
|
pop_unused_model_kwargs(model_kwargs) |
|
|
|
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
|
n_gpus, gpu_ids = cuda_vis_check(n_gpus) |
|
if low_bit_mode == 1 and n_gpus != 0: |
|
from transformers import BitsAndBytesConfig |
|
model_kwargs['quantization_config'] = BitsAndBytesConfig(bnb_4bit_compute_dtype=torch.bfloat16, |
|
load_in_4bit=load_4bit, |
|
load_in_8bit=load_8bit, |
|
) |
|
elif low_bit_mode == 2 and n_gpus != 0: |
|
from transformers import BitsAndBytesConfig |
|
model_kwargs['quantization_config'] = BitsAndBytesConfig(bnb_4bit_quant_type="nf4", |
|
load_in_4bit=load_4bit, |
|
load_in_8bit=load_8bit, |
|
) |
|
elif low_bit_mode == 3 and n_gpus != 0: |
|
from transformers import BitsAndBytesConfig |
|
model_kwargs['quantization_config'] = BitsAndBytesConfig(bnb_4bit_use_double_quant=True, |
|
load_in_4bit=load_4bit, |
|
load_in_8bit=load_8bit, |
|
) |
|
elif low_bit_mode == 4 and n_gpus != 0: |
|
from transformers import BitsAndBytesConfig |
|
model_kwargs['quantization_config'] = BitsAndBytesConfig(bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
load_in_4bit=load_4bit, |
|
load_in_8bit=load_8bit, |
|
) |
|
|
|
if not lora_weights: |
|
|
|
context = NullContext if load_gptq else torch.device |
|
with context(device): |
|
|
|
if use_gpu_id: |
|
config, model, max_seq_len = get_config(base_model, |
|
return_model=True, raise_exception=True, **config_kwargs) |
|
model = get_non_lora_model(base_model, model_loader, load_half, load_gptq, |
|
load_exllama, |
|
use_safetensors, |
|
revision, |
|
model_kwargs, reward_type, |
|
config, model, |
|
gpu_id=gpu_id, |
|
) |
|
else: |
|
config, _, max_seq_len = get_config(base_model, **config_kwargs) |
|
if load_half and not (load_8bit or load_4bit or load_gptq): |
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs) |
|
if not getattr(model, "is_quantized", False): |
|
model = model.half() |
|
else: |
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs) |
|
elif load_8bit or load_4bit: |
|
config, _, max_seq_len = get_config(base_model, **config_kwargs) |
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs |
|
) |
|
from peft import PeftModel |
|
model = PeftModel.from_pretrained( |
|
model, |
|
lora_weights, |
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
|
local_files_only=local_files_only, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
rope_scaling=rope_scaling, |
|
revision=revision, |
|
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, |
|
) |
|
else: |
|
with torch.device(device): |
|
config, _, max_seq_len = get_config(base_model, raise_exception=True, **config_kwargs) |
|
model = model_loader( |
|
base_model, |
|
config=config, |
|
**model_kwargs |
|
) |
|
from peft import PeftModel |
|
model = PeftModel.from_pretrained( |
|
model, |
|
lora_weights, |
|
torch_dtype=torch.float16 if device == 'cuda' else torch.float32, |
|
local_files_only=local_files_only, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=trust_remote_code, |
|
offload_folder=offload_folder, |
|
rope_scaling=rope_scaling, |
|
device_map="auto", |
|
) |
|
if load_half and not load_gptq: |
|
if not getattr(model, "is_quantized", False): |
|
model = model.half() |
|
|
|
|
|
if llama_type: |
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 |
|
model.config.bos_token_id = 1 |
|
model.config.eos_token_id = 2 |
|
if 'gpt2' in base_model.lower(): |
|
|
|
tokenizer.add_special_tokens({'bos_token': '<bos>', |
|
'eos_token': '<eos>', |
|
'pad_token': '<pad>'}) |
|
|
|
if not isinstance(tokenizer, str): |
|
model.eval() |
|
if torch.__version__ >= "2" and sys.platform != "win32" and compile_model: |
|
model = torch.compile(model) |
|
|
|
set_model_max_len(max_seq_len, tokenizer, verbose=False, reward_type=reward_type) |
|
|
|
|
|
model.conditional_type = conditional_type |
|
tokenizer.conditional_type = conditional_type |
|
|
|
return model, tokenizer, device |
|
|
|
|
|
def set_model_max_len(max_seq_len, tokenizer, verbose=False, reward_type=False): |
|
if reward_type: |
|
|
|
tokenizer.model_max_length = 512 |
|
return |
|
|
|
tokenizer.model_max_length = int(max_seq_len) |
|
if verbose: |
|
print("model_max_length=%s" % tokenizer.model_max_length, flush=True) |
|
|
|
if tokenizer.model_max_length > 100000000: |
|
tokenizer.model_max_length = 2048 |
|
|
|
|
|
def pop_unused_model_kwargs(model_kwargs): |
|
""" |
|
in-place pop unused kwargs that are not dependency-upgrade friendly |
|
no point passing in False, is default, and helps avoid needing to update requirements for new deps |
|
:param model_kwargs: |
|
:return: |
|
""" |
|
check_list = ['load_in_8bit', 'load_in_4bit'] |
|
for k in check_list: |
|
if k in model_kwargs and not model_kwargs[k]: |
|
model_kwargs.pop(k) |
|
|
|
|
|
def get_score_model(score_model: str = None, |
|
load_8bit: bool = False, |
|
load_4bit: bool = False, |
|
low_bit_mode=1, |
|
load_half: bool = True, |
|
load_gptq: str = '', |
|
load_exllama: bool = False, |
|
use_gpu_id: bool = True, |
|
base_model: str = '', |
|
inference_server: str = '', |
|
tokenizer_base_model: str = '', |
|
lora_weights: str = "", |
|
gpu_id: int = 0, |
|
n_jobs=None, |
|
|
|
reward_type: bool = None, |
|
local_files_only: bool = False, |
|
resume_download: bool = True, |
|
use_auth_token: Union[str, bool] = False, |
|
trust_remote_code: bool = True, |
|
offload_folder: str = None, |
|
rope_scaling: dict = None, |
|
compile_model: bool = True, |
|
llamacpp_dict: typing.Dict = None, |
|
|
|
verbose: bool = False, |
|
): |
|
if score_model is not None and score_model.strip(): |
|
load_8bit = False |
|
load_4bit = False |
|
low_bit_mode = 1 |
|
load_half = False |
|
load_gptq = '' |
|
load_exllama = False |
|
use_safetensors = False |
|
revision = None |
|
base_model = score_model.strip() |
|
tokenizer_base_model = '' |
|
lora_weights = '' |
|
inference_server = '' |
|
llama_type = False |
|
max_seq_len = None |
|
compile_model = False |
|
llamacpp_dict = {} |
|
smodel, stokenizer, sdevice = get_model(reward_type=True, |
|
**get_kwargs(get_model, exclude_names=['reward_type'], **locals())) |
|
else: |
|
smodel, stokenizer, sdevice = None, None, None |
|
return smodel, stokenizer, sdevice |
|
|
|
|
|
def evaluate_fake(*args, **kwargs): |
|
yield dict(response=invalid_key_msg, sources='') |
|
return |
|
|
|
|
|
def evaluate( |
|
model_state, |
|
my_db_state, |
|
selection_docs_state, |
|
requests_state, |
|
|
|
instruction, |
|
iinput, |
|
context, |
|
stream_output, |
|
prompt_type, |
|
prompt_dict, |
|
temperature, |
|
top_p, |
|
top_k, |
|
num_beams, |
|
max_new_tokens, |
|
min_new_tokens, |
|
early_stopping, |
|
max_time, |
|
repetition_penalty, |
|
num_return_sequences, |
|
do_sample, |
|
chat, |
|
instruction_nochat, |
|
iinput_nochat, |
|
langchain_mode, |
|
add_chat_history_to_context, |
|
langchain_action, |
|
langchain_agents, |
|
top_k_docs, |
|
chunk, |
|
chunk_size, |
|
document_subset, |
|
document_choice, |
|
pre_prompt_query, |
|
prompt_query, |
|
pre_prompt_summary, |
|
prompt_summary, |
|
system_prompt, |
|
|
|
image_loaders, |
|
pdf_loaders, |
|
url_loaders, |
|
jq_schema, |
|
visible_models, |
|
h2ogpt_key, |
|
add_search_to_context, |
|
chat_conversation, |
|
text_context_list, |
|
docs_ordering_type, |
|
min_max_new_tokens, |
|
|
|
|
|
captions_model=None, |
|
caption_loader=None, |
|
doctr_loader=None, |
|
pix2struct_loader=None, |
|
async_output=None, |
|
num_async=None, |
|
src_lang=None, |
|
tgt_lang=None, |
|
debug=False, |
|
concurrency_count=None, |
|
save_dir=None, |
|
sanitize_bot_response=False, |
|
model_state0=None, |
|
memory_restriction_level=None, |
|
max_max_new_tokens=None, |
|
is_public=None, |
|
max_max_time=None, |
|
raise_generate_gpu_exceptions=None, |
|
lora_weights=None, |
|
use_llm_if_no_docs=True, |
|
load_db_if_exists=True, |
|
dbs=None, |
|
detect_user_path_changes_every_query=None, |
|
use_openai_embedding=None, |
|
use_openai_model=None, |
|
hf_embedding_model=None, |
|
migrate_embedding_model=None, |
|
auto_migrate_db=None, |
|
cut_distance=None, |
|
db_type=None, |
|
n_jobs=None, |
|
first_para=None, |
|
text_limit=None, |
|
show_accordions=None, |
|
top_k_docs_max_show=None, |
|
show_link_in_sources=None, |
|
verbose=False, |
|
cli=False, |
|
use_cache=None, |
|
auto_reduce_chunks=None, |
|
max_chunks=None, |
|
headsize=None, |
|
model_lock=None, |
|
force_langchain_evaluate=None, |
|
model_state_none=None, |
|
load_exllama=None, |
|
answer_with_sources=None, |
|
append_sources_to_answer=None, |
|
image_loaders_options0=None, |
|
pdf_loaders_options0=None, |
|
url_loaders_options0=None, |
|
jq_schema0=None, |
|
keep_sources_in_context=None, |
|
): |
|
|
|
assert concurrency_count is not None |
|
assert memory_restriction_level is not None |
|
assert raise_generate_gpu_exceptions is not None |
|
assert use_openai_embedding is not None |
|
assert use_openai_model is not None |
|
assert hf_embedding_model is not None |
|
assert migrate_embedding_model is not None |
|
assert auto_migrate_db is not None |
|
assert db_type is not None |
|
assert top_k_docs is not None and isinstance(top_k_docs, int) |
|
assert chunk is not None and isinstance(chunk, bool) |
|
assert chunk_size is not None and isinstance(chunk_size, int) |
|
assert n_jobs is not None |
|
assert first_para is not None |
|
assert isinstance(add_chat_history_to_context, bool) |
|
assert isinstance(add_search_to_context, bool) |
|
assert load_exllama is not None |
|
|
|
if image_loaders is None: |
|
image_loaders = image_loaders_options0 |
|
if pdf_loaders is None: |
|
pdf_loaders = pdf_loaders_options0 |
|
if url_loaders is None: |
|
url_loaders = url_loaders_options0 |
|
if jq_schema is None: |
|
jq_schema = jq_schema0 |
|
if isinstance(langchain_agents, str): |
|
if langchain_agents.strip().startswith('['): |
|
|
|
langchain_agents = str_to_list(langchain_agents) |
|
else: |
|
|
|
langchain_agents = [langchain_agents] |
|
chat_conversation = str_to_list(chat_conversation) |
|
text_context_list = str_to_list(text_context_list) |
|
|
|
langchain_modes = selection_docs_state['langchain_modes'] |
|
langchain_mode_paths = selection_docs_state['langchain_mode_paths'] |
|
langchain_mode_types = selection_docs_state['langchain_mode_types'] |
|
|
|
if debug: |
|
locals_dict = locals().copy() |
|
locals_dict.pop('model_state', None) |
|
locals_dict.pop('model_state0', None) |
|
locals_dict.pop('model_states', None) |
|
print(locals_dict) |
|
|
|
no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" \ |
|
"Then start New Conversation" |
|
|
|
if model_state is None: |
|
model_state = model_state_none.copy() |
|
if model_state0 is None: |
|
|
|
model_state0 = model_state_none.copy() |
|
|
|
|
|
|
|
have_model_lock = model_lock is not None |
|
have_fresh_model = model_state['model'] not in [None, 'model', no_model_str] |
|
|
|
|
|
|
|
|
|
have_cli_model = model_state0['model'] not in [None, 'model', no_model_str] |
|
|
|
if have_fresh_model: |
|
|
|
if not have_model_lock: |
|
|
|
|
|
if model_state0['model'] and hasattr(model_state0['model'], 'cpu'): |
|
model_state0['model'].cpu() |
|
model_state0['model'] = None |
|
|
|
if model_state0['tokenizer']: |
|
model_state0['tokenizer'] = None |
|
clear_torch_cache() |
|
chosen_model_state = model_state |
|
elif have_cli_model: |
|
|
|
assert isinstance(model_state['model'], (type(None), str)) |
|
chosen_model_state = model_state0 |
|
else: |
|
raise AssertionError(no_model_msg) |
|
|
|
model = chosen_model_state['model'] |
|
tokenizer = chosen_model_state['tokenizer'] |
|
device = chosen_model_state['device'] |
|
base_model = chosen_model_state['base_model'] |
|
tokenizer_base_model = chosen_model_state['tokenizer_base_model'] |
|
lora_weights = chosen_model_state['lora_weights'] |
|
inference_server = chosen_model_state['inference_server'] |
|
visible_models = chosen_model_state['visible_models'] |
|
|
|
if chosen_model_state['h2ogpt_key'] is not None: |
|
h2ogpt_key = chosen_model_state['h2ogpt_key'] |
|
|
|
prompt_type = prompt_type or chosen_model_state['prompt_type'] |
|
prompt_dict = prompt_dict or chosen_model_state['prompt_dict'] |
|
|
|
if base_model is None: |
|
raise AssertionError(no_model_msg) |
|
|
|
assert base_model.strip(), no_model_msg |
|
assert model, "Model is missing" |
|
assert tokenizer, "Tokenizer is missing" |
|
|
|
|
|
if not chat: |
|
instruction = instruction_nochat |
|
iinput = iinput_nochat |
|
|
|
|
|
model_lower = base_model.lower() |
|
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': |
|
prompt_type = inv_prompt_type_to_model_lower[model_lower] |
|
if verbose: |
|
print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True) |
|
assert prompt_type is not None, "prompt_type was None" |
|
|
|
|
|
|
|
|
|
|
|
top_p = min(max(1e-3, top_p), 1.0 - 1e-3) |
|
top_k = min(max(1, int(top_k)), 100) |
|
temperature = min(max(0.01, temperature), 2.0) |
|
|
|
num_beams = 1 if stream_output else num_beams |
|
max_max_new_tokens = get_max_max_new_tokens(chosen_model_state, |
|
memory_restriction_level=memory_restriction_level, |
|
max_new_tokens=max_new_tokens, |
|
max_max_new_tokens=max_max_new_tokens) |
|
if min_max_new_tokens is None: |
|
|
|
min_max_new_tokens = 256 |
|
if docs_ordering_type is None: |
|
docs_ordering_type = 'reverse_ucurve_sort' |
|
model_max_length = get_model_max_length(chosen_model_state) |
|
max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens) |
|
min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens) |
|
max_time = min(max(0, max_time), max_max_time) |
|
repetition_penalty = min(max(0.01, repetition_penalty), 3.0) |
|
num_return_sequences = 1 if chat else min(max(1, int(num_return_sequences)), 10) |
|
min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) |
|
|
|
if is_public: |
|
total_tokens_for_docs = min(2 * model_max_length, 16384) |
|
else: |
|
total_tokens_for_docs = None |
|
top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs) |
|
chunk_size = min(max(128, int(chunk_size)), 2048) |
|
if not context: |
|
context = '' |
|
|
|
|
|
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output, |
|
system_prompt=system_prompt) |
|
|
|
|
|
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s not in %s" % (langchain_mode, langchain_modes) |
|
assert langchain_action in langchain_actions, "Invalid langchain_action %s not in %s" % ( |
|
langchain_action, langchain_actions) |
|
assert len( |
|
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents |
|
|
|
|
|
if langchain_mode != LangChainMode.DISABLED.value: |
|
from src.gpt_langchain import get_any_db |
|
db = get_any_db(my_db_state, langchain_mode, langchain_mode_paths, langchain_mode_types, |
|
dbs=dbs, |
|
load_db_if_exists=load_db_if_exists, |
|
db_type=db_type, |
|
use_openai_embedding=use_openai_embedding, |
|
hf_embedding_model=hf_embedding_model, |
|
migrate_embedding_model=migrate_embedding_model, |
|
auto_migrate_db=auto_migrate_db, |
|
for_sources_list=True, |
|
verbose=verbose, |
|
n_jobs=n_jobs, |
|
) |
|
else: |
|
db = None |
|
|
|
t_generate = time.time() |
|
langchain_only_model = base_model in non_hf_types or \ |
|
load_exllama or \ |
|
inference_server.startswith('replicate') or \ |
|
inference_server.startswith('sagemaker') or \ |
|
inference_server.startswith('openai_azure_chat') or \ |
|
inference_server.startswith('openai_azure') |
|
do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \ |
|
langchain_only_model or \ |
|
force_langchain_evaluate or \ |
|
len(text_context_list) > 0 |
|
|
|
if len(langchain_agents) > 0: |
|
do_langchain_path = True |
|
if add_search_to_context: |
|
|
|
do_langchain_path = True |
|
|
|
if do_langchain_path: |
|
text = '' |
|
sources = '' |
|
response = '' |
|
|
|
from gpt_langchain import run_qa_db |
|
gen_hyper_langchain = dict(do_sample=do_sample, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
top_k=top_k, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
min_new_tokens=min_new_tokens, |
|
max_new_tokens=max_new_tokens, |
|
early_stopping=early_stopping, |
|
max_time=max_time, |
|
num_return_sequences=num_return_sequences, |
|
) |
|
loaders_dict, captions_model = gr_to_lg(image_loaders, |
|
pdf_loaders, |
|
url_loaders, |
|
captions_model=captions_model, |
|
) |
|
loaders_dict.update(dict(captions_model=captions_model, |
|
caption_loader=caption_loader, |
|
doctr_loader=doctr_loader, |
|
pix2struct_loader=pix2struct_loader, |
|
jq_schema=jq_schema, |
|
)) |
|
data_point = dict(context=context, instruction=instruction, input=iinput) |
|
|
|
prompt_basic = prompter.generate_prompt(data_point, context_from_history=False) |
|
prompt = prompt_basic |
|
num_prompt_tokens = 0 |
|
for r in run_qa_db( |
|
inference_server=inference_server, |
|
model_name=base_model, model=model, tokenizer=tokenizer, |
|
langchain_only_model=langchain_only_model, |
|
async_output=async_output, |
|
num_async=num_async, |
|
prompter=prompter, |
|
use_llm_if_no_docs=use_llm_if_no_docs, |
|
load_db_if_exists=load_db_if_exists, |
|
db=db, |
|
langchain_mode_paths=langchain_mode_paths, |
|
langchain_mode_types=langchain_mode_types, |
|
detect_user_path_changes_every_query=detect_user_path_changes_every_query, |
|
cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance, |
|
answer_with_sources=answer_with_sources, |
|
append_sources_to_answer=append_sources_to_answer, |
|
add_chat_history_to_context=add_chat_history_to_context, |
|
add_search_to_context=add_search_to_context, |
|
keep_sources_in_context=keep_sources_in_context, |
|
memory_restriction_level=memory_restriction_level, |
|
system_prompt=system_prompt, |
|
use_openai_embedding=use_openai_embedding, |
|
use_openai_model=use_openai_model, |
|
hf_embedding_model=hf_embedding_model, |
|
migrate_embedding_model=migrate_embedding_model, |
|
auto_migrate_db=auto_migrate_db, |
|
first_para=first_para, |
|
text_limit=text_limit, |
|
show_accordions=show_accordions, |
|
top_k_docs_max_show=top_k_docs_max_show, |
|
show_link_in_sources=show_link_in_sources, |
|
|
|
|
|
query=instruction, |
|
iinput=iinput, |
|
context=context, |
|
stream_output=stream_output, |
|
chunk=chunk, |
|
chunk_size=chunk_size, |
|
|
|
**loaders_dict, |
|
|
|
langchain_mode=langchain_mode, |
|
langchain_action=langchain_action, |
|
langchain_agents=langchain_agents, |
|
document_subset=document_subset, |
|
document_choice=document_choice, |
|
top_k_docs=top_k_docs, |
|
prompt_type=prompt_type, |
|
prompt_dict=prompt_dict, |
|
pre_prompt_query=pre_prompt_query, |
|
prompt_query=prompt_query, |
|
pre_prompt_summary=pre_prompt_summary, |
|
prompt_summary=prompt_summary, |
|
text_context_list=text_context_list, |
|
chat_conversation=chat_conversation, |
|
visible_models=visible_models, |
|
h2ogpt_key=h2ogpt_key, |
|
docs_ordering_type=docs_ordering_type, |
|
min_max_new_tokens=min_max_new_tokens, |
|
|
|
**gen_hyper_langchain, |
|
|
|
db_type=db_type, |
|
n_jobs=n_jobs, |
|
verbose=verbose, |
|
cli=cli, |
|
sanitize_bot_response=sanitize_bot_response, |
|
|
|
lora_weights=lora_weights, |
|
|
|
auto_reduce_chunks=auto_reduce_chunks, |
|
max_chunks=max_chunks, |
|
total_tokens_for_docs=total_tokens_for_docs, |
|
headsize=headsize, |
|
): |
|
|
|
response = r['response'] |
|
sources = r['sources'] |
|
prompt = r['prompt'] |
|
num_prompt_tokens = r['num_prompt_tokens'] |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
if save_dir: |
|
|
|
extra_dict = gen_hyper_langchain.copy() |
|
extra_dict.update(prompt_type=prompt_type, |
|
inference_server=inference_server, |
|
langchain_mode=langchain_mode, |
|
langchain_action=langchain_action, |
|
langchain_agents=langchain_agents, |
|
document_subset=document_subset, |
|
document_choice=document_choice, |
|
chat_conversation=chat_conversation, |
|
add_search_to_context=add_search_to_context, |
|
num_prompt_tokens=num_prompt_tokens, |
|
instruction=instruction, |
|
iinput=iinput, |
|
context=context, |
|
t_generate=time.time() - t_generate, |
|
ntokens=None, |
|
tokens_persecond=None, |
|
) |
|
save_dict = dict(prompt=prompt, |
|
output=response, base_model=base_model, save_dir=save_dir, |
|
where_from='run_qa_db', |
|
extra_dict=extra_dict) |
|
yield dict(response=response, sources=sources, save_dict=save_dict) |
|
if verbose: |
|
print( |
|
'Post-Generate Langchain: %s decoded_output: %s' % |
|
(str(datetime.now()), len(response) if response else -1), |
|
flush=True) |
|
if response or sources or langchain_only_model: |
|
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
|
|
prompt, \ |
|
instruction, iinput, context, \ |
|
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \ |
|
chat_index, top_k_docs_trial, one_doc_size = \ |
|
get_limited_prompt(instruction, |
|
iinput, |
|
tokenizer, |
|
prompter=prompter, |
|
inference_server=inference_server, |
|
|
|
|
|
|
|
max_new_tokens=max_new_tokens, |
|
|
|
context=context, |
|
chat_conversation=chat_conversation, |
|
keep_sources_in_context=keep_sources_in_context, |
|
model_max_length=model_max_length, |
|
memory_restriction_level=memory_restriction_level, |
|
langchain_mode=langchain_mode, |
|
add_chat_history_to_context=add_chat_history_to_context, |
|
min_max_new_tokens=min_max_new_tokens, |
|
) |
|
|
|
if inference_server.startswith('vllm') or \ |
|
inference_server.startswith('openai') or \ |
|
inference_server.startswith('http'): |
|
if inference_server.startswith('vllm') or inference_server.startswith('openai'): |
|
assert not inference_server.startswith('openai_azure_chat'), "Not fo Azure, use langchain path" |
|
assert not inference_server.startswith('openai_azure'), "Not for Azure, use langchain path" |
|
openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server) |
|
where_from = inf_type |
|
|
|
terminate_response = prompter.terminate_response or [] |
|
stop_sequences = list(set(terminate_response + [prompter.PreResponse])) |
|
stop_sequences = [x for x in stop_sequences if x] |
|
|
|
max_new_tokens_openai = min(max_new_tokens, model_max_length - num_prompt_tokens) |
|
gen_server_kwargs = dict(temperature=temperature if do_sample else 0, |
|
max_tokens=max_new_tokens_openai, |
|
top_p=top_p if do_sample else 1, |
|
frequency_penalty=0, |
|
n=num_return_sequences, |
|
presence_penalty=1.07 - repetition_penalty + 0.6, |
|
) |
|
if inf_type == 'vllm' or inference_server == 'openai': |
|
responses = openai.Completion.create( |
|
model=base_model, |
|
prompt=prompt, |
|
**gen_server_kwargs, |
|
stop=stop_sequences, |
|
stream=stream_output, |
|
) |
|
text = '' |
|
sources = '' |
|
response = '' |
|
if not stream_output: |
|
text = responses['choices'][0]['text'] |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
collected_events = [] |
|
for event in responses: |
|
collected_events.append(event) |
|
event_text = event['choices'][0]['text'] |
|
text += event_text |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat': |
|
if inf_type == 'vllm_chat': |
|
raise NotImplementedError('%s not supported by vLLM' % inf_type) |
|
if system_prompt in [None, 'None', 'auto']: |
|
openai_system_prompt = "You are a helpful assistant." |
|
else: |
|
openai_system_prompt = system_prompt |
|
messages0 = [] |
|
if openai_system_prompt: |
|
messages0.append({"role": "system", "content": openai_system_prompt}) |
|
messages0.append({'role': 'user', 'content': prompt}) |
|
responses = openai.ChatCompletion.create( |
|
model=base_model, |
|
messages=messages0, |
|
stream=stream_output, |
|
**gen_server_kwargs, |
|
) |
|
text = "" |
|
sources = '' |
|
response = "" |
|
if not stream_output: |
|
text = responses["choices"][0]["message"]["content"] |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
for chunk in responses: |
|
delta = chunk["choices"][0]["delta"] |
|
if 'content' in delta: |
|
text += delta['content'] |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
raise RuntimeError("No such OpenAI mode: %s" % inference_server) |
|
elif inference_server.startswith('http'): |
|
inference_server, headers = get_hf_server(inference_server) |
|
from gradio_utils.grclient import GradioClient |
|
from text_generation import Client as HFClient |
|
if isinstance(model, GradioClient): |
|
gr_client = model |
|
hf_client = None |
|
elif isinstance(model, HFClient): |
|
gr_client = None |
|
hf_client = model |
|
else: |
|
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server, |
|
base_model=base_model) |
|
|
|
|
|
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) |
|
|
|
if gr_client is not None: |
|
|
|
|
|
|
|
chat_client = False |
|
where_from = "gr_client" |
|
client_langchain_mode = 'Disabled' |
|
client_add_chat_history_to_context = True |
|
client_add_search_to_context = False |
|
client_langchain_action = LangChainAction.QUERY.value |
|
client_langchain_agents = [] |
|
gen_server_kwargs = dict(temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=min_new_tokens, |
|
early_stopping=early_stopping, |
|
max_time=max_time, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=num_return_sequences, |
|
do_sample=do_sample, |
|
chat=chat_client, |
|
) |
|
|
|
if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value, |
|
str(PromptType.plain.value)]: |
|
|
|
|
|
gr_prompt_type = '' |
|
gr_prompt_dict = '' |
|
gr_prompt = prompt |
|
gr_context = '' |
|
gr_iinput = '' |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr_context = context |
|
gr_prompt = instruction |
|
gr_iinput = iinput |
|
gr_prompt_type = prompt_type |
|
gr_prompt_dict = prompt_dict |
|
client_kwargs = dict(instruction=gr_prompt if chat_client else '', |
|
iinput=gr_iinput, |
|
context=gr_context, |
|
|
|
|
|
stream_output=stream_output, |
|
|
|
**gen_server_kwargs, |
|
|
|
prompt_type=gr_prompt_type, |
|
prompt_dict=gr_prompt_dict, |
|
|
|
instruction_nochat=gr_prompt if not chat_client else '', |
|
iinput_nochat=gr_iinput, |
|
langchain_mode=client_langchain_mode, |
|
add_chat_history_to_context=client_add_chat_history_to_context, |
|
langchain_action=client_langchain_action, |
|
langchain_agents=client_langchain_agents, |
|
top_k_docs=top_k_docs, |
|
chunk=chunk, |
|
chunk_size=chunk_size, |
|
document_subset=DocumentSubset.Relevant.name, |
|
document_choice=[DocumentChoice.ALL.value], |
|
pre_prompt_query=pre_prompt_query, |
|
prompt_query=prompt_query, |
|
pre_prompt_summary=pre_prompt_summary, |
|
prompt_summary=prompt_summary, |
|
system_prompt=system_prompt, |
|
image_loaders=image_loaders, |
|
pdf_loaders=pdf_loaders, |
|
url_loaders=url_loaders, |
|
jq_schema=jq_schema, |
|
visible_models=visible_models, |
|
h2ogpt_key=h2ogpt_key, |
|
add_search_to_context=client_add_search_to_context, |
|
docs_ordering_type=None, |
|
min_max_new_tokens=min_max_new_tokens, |
|
) |
|
api_name = '/submit_nochat_api' |
|
response = '' |
|
text = '' |
|
sources = '' |
|
if not stream_output: |
|
res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) |
|
res_dict = ast.literal_eval(res) |
|
text = res_dict['response'] |
|
sources = res_dict['sources'] |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) |
|
res_dict = dict(response=text, sources=sources, save_dict=dict()) |
|
text0 = '' |
|
while not job.done(): |
|
if job.communicator.job.latest_status.code.name == 'FINISHED': |
|
break |
|
e = job.future._exception |
|
if e is not None: |
|
break |
|
outputs_list = job.communicator.job.outputs |
|
if outputs_list: |
|
res = job.communicator.job.outputs[-1] |
|
res_dict = ast.literal_eval(res) |
|
text = res_dict['response'] |
|
sources = res_dict['sources'] |
|
if gr_prompt_type == 'plain': |
|
|
|
prompt_and_text = text |
|
else: |
|
prompt_and_text = prompt + text |
|
response = prompter.get_response(prompt_and_text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
text_chunk = response[len(text0):] |
|
if not text_chunk: |
|
continue |
|
|
|
text0 = response |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
time.sleep(0.01) |
|
|
|
res_all = job.outputs() |
|
if len(res_all) > 0: |
|
res = res_all[-1] |
|
res_dict = ast.literal_eval(res) |
|
text = res_dict['response'] |
|
sources = res_dict['sources'] |
|
else: |
|
|
|
e = job.future._exception |
|
if e is not None: |
|
stre = str(e) |
|
strex = ''.join(traceback.format_tb(e.__traceback__)) |
|
else: |
|
stre = '' |
|
strex = '' |
|
|
|
print("Bad final response: %s %s %s %s %s: %s %s" % (base_model, inference_server, |
|
res_all, prompt, text, stre, strex), |
|
flush=True) |
|
if gr_prompt_type == 'plain': |
|
|
|
prompt_and_text = text |
|
else: |
|
prompt_and_text = prompt + text |
|
response = prompter.get_response(prompt_and_text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
elif hf_client: |
|
|
|
where_from = "hf_client" |
|
response = '' |
|
extra = '' |
|
sources = '' |
|
|
|
|
|
|
|
terminate_response = prompter.terminate_response or [] |
|
stop_sequences = list(set(terminate_response + [prompter.PreResponse])) |
|
stop_sequences = [x for x in stop_sequences if x] |
|
gen_server_kwargs = dict(do_sample=do_sample, |
|
max_new_tokens=max_new_tokens, |
|
|
|
repetition_penalty=repetition_penalty, |
|
return_full_text=False, |
|
seed=SEED, |
|
stop_sequences=stop_sequences, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
hf_client.timeout = max(300, max_time) |
|
if not stream_output: |
|
text = hf_client.generate(prompt, **gen_server_kwargs).generated_text |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
text = "" |
|
for responses in hf_client.generate_stream(prompt, **gen_server_kwargs): |
|
if not responses.token.special: |
|
|
|
text_chunk = responses.token.text |
|
text += text_chunk |
|
response = prompter.get_response(prompt + text, prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response) |
|
sources = '' |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
else: |
|
raise RuntimeError("Failed to get client: %s" % inference_server) |
|
else: |
|
raise RuntimeError("No such inference_server %s" % inference_server) |
|
|
|
if save_dir and text: |
|
|
|
extra_dict = gen_server_kwargs.copy() |
|
extra_dict.update(dict(inference_server=inference_server, num_prompt_tokens=num_prompt_tokens, |
|
t_generate=time.time() - t_generate, |
|
ntokens=None, |
|
tokens_persecond=None, |
|
)) |
|
save_dict = dict(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir, |
|
where_from=where_from, extra_dict=extra_dict) |
|
yield dict(response=response, sources=sources, save_dict=save_dict) |
|
return |
|
else: |
|
assert not inference_server, "inference_server=%s not supported" % inference_server |
|
|
|
if isinstance(tokenizer, str): |
|
|
|
if tokenizer == "summarization": |
|
key = 'summary_text' |
|
else: |
|
raise RuntimeError("No such task type %s" % tokenizer) |
|
|
|
sources = '' |
|
yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources=sources, save_dict=dict()) |
|
|
|
if 'mbart-' in base_model.lower(): |
|
assert src_lang is not None |
|
tokenizer.src_lang = languages_covered()[src_lang] |
|
|
|
stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, |
|
model_max_length=model_max_length, |
|
prompter=prompter) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
if debug and len(inputs["input_ids"]) > 0: |
|
print('input_ids length', len(inputs["input_ids"][0]), flush=True) |
|
input_ids = inputs["input_ids"].to(device) |
|
|
|
max_max_tokens = tokenizer.model_max_length |
|
max_input_tokens = max(0, int(max_max_tokens - min_new_tokens)) |
|
|
|
assert isinstance(max_input_tokens, int), "Bad type for max_input_tokens=%s %s" % ( |
|
max_input_tokens, type(max_input_tokens)) |
|
input_ids = input_ids[:, -max_input_tokens:] |
|
|
|
if use_cache is None: |
|
use_cache = False if 'falcon' in base_model else True |
|
gen_config_kwargs = dict(num_beams=num_beams, |
|
do_sample=do_sample, |
|
repetition_penalty=float(repetition_penalty), |
|
num_return_sequences=num_return_sequences, |
|
renormalize_logits=True, |
|
remove_invalid_values=True, |
|
use_cache=use_cache, |
|
) |
|
if do_sample: |
|
gen_config_kwargs.update(dict(temperature=float(temperature), |
|
top_p=float(top_p), |
|
top_k=top_k)) |
|
if True: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_ids = ['eos_token_id', 'pad_token_id', 'bos_token_id', 'cls_token_id', 'sep_token_id'] |
|
for token_id in token_ids: |
|
if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: |
|
gen_config_kwargs.update({token_id: getattr(tokenizer, token_id)}) |
|
generation_config = GenerationConfig(**gen_config_kwargs) |
|
|
|
gen_kwargs = dict(input_ids=input_ids, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=min_new_tokens, |
|
early_stopping=early_stopping, |
|
max_time=max_time, |
|
stopping_criteria=stopping_criteria, |
|
) |
|
if 'gpt2' in base_model.lower(): |
|
gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id)) |
|
elif 'mbart-' in base_model.lower(): |
|
assert tgt_lang is not None |
|
tgt_lang = languages_covered()[tgt_lang] |
|
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])) |
|
else: |
|
token_ids = ['eos_token_id', 'bos_token_id', 'pad_token_id'] |
|
for token_id in token_ids: |
|
if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: |
|
gen_kwargs.update({token_id: getattr(tokenizer, token_id)}) |
|
|
|
decoder_kwargs = dict(skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
|
|
decoder = functools.partial(tokenizer.decode, |
|
**decoder_kwargs |
|
) |
|
with torch.no_grad(): |
|
have_lora_weights = lora_weights not in [no_lora_str, '', None] |
|
context_class_cast = NullContext if device == 'cpu' or have_lora_weights or device == 'mps' else torch.autocast |
|
if t5_type(base_model): |
|
|
|
context_class_cast = NullContext |
|
with context_class_cast(device): |
|
|
|
|
|
|
|
|
|
context_class = NullContext |
|
if verbose: |
|
print('Pre-Generate: %s' % str(datetime.now()), flush=True) |
|
decoded_output = None |
|
response = '' |
|
with context_class("generate.lock"): |
|
if verbose: |
|
print('Generate: %s' % str(datetime.now()), flush=True) |
|
always_use_streaming_method = True |
|
if stream_output or always_use_streaming_method: |
|
skip_prompt = True |
|
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, |
|
**decoder_kwargs) |
|
gen_kwargs.update(dict(streamer=streamer)) |
|
target = wrapped_partial(generate_with_exceptions, model.generate, |
|
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions, |
|
**gen_kwargs) |
|
bucket = queue.Queue() |
|
thread = EThread(target=target, streamer=streamer, bucket=bucket) |
|
thread.start() |
|
ret = dict(response='', sources='', save_dict=dict()) |
|
outputs = "" |
|
sources = '' |
|
try: |
|
for new_text in streamer: |
|
if bucket.qsize() > 0 or thread.exc: |
|
thread.join() |
|
outputs += new_text |
|
response = prompter.get_response(outputs, prompt=None, |
|
only_new_text=True, |
|
sanitize_bot_response=sanitize_bot_response) |
|
ret = dict(response=response, sources=sources, save_dict=dict()) |
|
if stream_output: |
|
yield ret |
|
if not stream_output: |
|
yield ret |
|
except BaseException: |
|
|
|
if thread.exc: |
|
raise thread.exc |
|
raise |
|
finally: |
|
|
|
|
|
if not thread.exc: |
|
thread.join() |
|
|
|
if thread.exc: |
|
raise thread.exc |
|
decoded_output = outputs |
|
ntokens = len(outputs) // 4 |
|
else: |
|
|
|
input_ids_len = gen_kwargs['input_ids'][0].shape[0] |
|
try: |
|
outputs = model.generate(**gen_kwargs) |
|
finally: |
|
pass |
|
|
|
|
|
ntokens = sum([len(s) - input_ids_len for s in outputs.sequences]) if save_dir else -1 |
|
outputs = [decoder(s[input_ids_len:]) for s in outputs.sequences] |
|
sources = '' |
|
response = prompter.get_response(outputs, prompt=None, |
|
only_new_text=True, |
|
sanitize_bot_response=sanitize_bot_response) |
|
yield dict(response=response, sources=sources, save_dict=dict()) |
|
if outputs and len(outputs) >= 1: |
|
decoded_output = prompt + outputs[0] |
|
if save_dir and decoded_output: |
|
extra_dict = gen_config_kwargs.copy() |
|
extra_dict.update(dict(num_prompt_tokens=num_prompt_tokens, |
|
t_generate=time.time() - t_generate, |
|
ntokens=ntokens, |
|
tokens_persecond=ntokens / (time.time() - t_generate), |
|
)) |
|
save_dict = dict(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir, |
|
where_from="evaluate_%s" % str(stream_output), |
|
extra_dict=extra_dict) |
|
yield dict(response=response, sources=sources, save_dict=save_dict) |
|
if verbose: |
|
print('Post-Generate: %s decoded_output: %s' % ( |
|
str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True) |
|
|
|
|
|
inputs_list_names = list(inspect.signature(evaluate).parameters) |
|
state_names = input_args_list.copy() |
|
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names] |
|
|
|
|
|
def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048): |
|
|
|
|
|
|
|
|
|
if memory_restriction_level > 0: |
|
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 |
|
else: |
|
|
|
max_length_tokenize = model_max_length - 256 |
|
cutoff_len = max_length_tokenize * 4 |
|
output_smallest = 30 * 4 |
|
max_prompt_length = cutoff_len - output_smallest |
|
|
|
if for_context: |
|
|
|
max_prompt_length = max(64, int(max_prompt_length * 0.8)) |
|
|
|
return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length |
|
|
|
|
|
class H2OTextIteratorStreamer(TextIteratorStreamer): |
|
""" |
|
normally, timeout required for now to handle exceptions, else get() |
|
but with H2O version of TextIteratorStreamer, loop over block to handle |
|
""" |
|
|
|
def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None, |
|
block=True, **decode_kwargs): |
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs) |
|
self.text_queue = queue.Queue() |
|
self.stop_signal = None |
|
self.do_stop = False |
|
self.timeout = timeout |
|
self.block = block |
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False): |
|
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" |
|
self.text_queue.put(text, timeout=self.timeout) |
|
if stream_end: |
|
self.text_queue.put(self.stop_signal, timeout=self.timeout) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
while True: |
|
try: |
|
value = self.stop_signal |
|
if self.do_stop: |
|
print("hit stop", flush=True) |
|
|
|
self.clear_queue() |
|
self.do_stop = False |
|
raise StopIteration() |
|
|
|
value = self.text_queue.get(block=self.block, timeout=self.timeout) |
|
break |
|
except queue.Empty: |
|
time.sleep(0.01) |
|
if value == self.stop_signal: |
|
self.clear_queue() |
|
self.do_stop = False |
|
raise StopIteration() |
|
else: |
|
return value |
|
|
|
def clear_queue(self): |
|
|
|
with self.text_queue.mutex: |
|
self.text_queue.queue.clear() |
|
|
|
def put(self, value): |
|
""" |
|
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. |
|
# same as base class, except remove hack w.r.t. text.rfind(" ") that ruins LLaMa2 |
|
""" |
|
if len(value.shape) > 1 and value.shape[0] > 1: |
|
raise ValueError("TextStreamer only supports batch size 1") |
|
elif len(value.shape) > 1: |
|
value = value[0] |
|
|
|
if self.skip_prompt and self.next_tokens_are_prompt: |
|
self.next_tokens_are_prompt = False |
|
return |
|
|
|
|
|
self.token_cache.extend(value.tolist()) |
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) |
|
|
|
|
|
if text.endswith("\n"): |
|
printable_text = text[self.print_len:] |
|
self.token_cache = [] |
|
self.print_len = 0 |
|
|
|
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): |
|
printable_text = text[self.print_len:] |
|
self.print_len += len(printable_text) |
|
|
|
|
|
elif len(text) > 0 and text[-1] == '�': |
|
printable_text = text[self.print_len: text.rfind(" ") + 1] |
|
self.print_len += len(printable_text) |
|
else: |
|
printable_text = text[self.print_len:] |
|
self.print_len += len(printable_text) |
|
|
|
self.on_finalized_text(printable_text) |
|
|
|
|
|
def generate_with_exceptions(func, *args, raise_generate_gpu_exceptions=True, **kwargs): |
|
try: |
|
func(*args, **kwargs) |
|
except torch.cuda.OutOfMemoryError as e: |
|
print("GPU OOM 2: exception: %s" % str(e), |
|
flush=True) |
|
if 'input_ids' in kwargs: |
|
if kwargs['input_ids'] is not None: |
|
kwargs['input_ids'].cpu() |
|
kwargs['input_ids'] = None |
|
traceback.print_exc() |
|
clear_torch_cache() |
|
return |
|
except (Exception, RuntimeError) as e: |
|
if 'Expected all tensors to be on the same device' in str(e) or \ |
|
'expected scalar type Half but found Float' in str(e) or \ |
|
'probability tensor contains either' in str(e) or \ |
|
'cublasLt ran into an error!' in str(e) or \ |
|
'mat1 and mat2 shapes cannot be multiplied' in str(e): |
|
print( |
|
"GPU Error: exception: %s" % str(e), |
|
flush=True) |
|
traceback.print_exc() |
|
clear_torch_cache() |
|
if raise_generate_gpu_exceptions: |
|
raise |
|
return |
|
else: |
|
clear_torch_cache() |
|
if raise_generate_gpu_exceptions: |
|
raise |
|
|
|
|
|
def get_generate_params(model_lower, |
|
chat, |
|
stream_output, show_examples, |
|
prompt_type, prompt_dict, |
|
system_prompt, |
|
pre_prompt_query, prompt_query, |
|
pre_prompt_summary, prompt_summary, |
|
temperature, top_p, top_k, num_beams, |
|
max_new_tokens, min_new_tokens, early_stopping, max_time, |
|
repetition_penalty, num_return_sequences, |
|
do_sample, |
|
top_k_docs, chunk, chunk_size, |
|
image_loaders, |
|
pdf_loaders, |
|
url_loaders, |
|
jq_schema, |
|
docs_ordering_type, |
|
min_max_new_tokens, |
|
verbose, |
|
): |
|
use_defaults = False |
|
use_default_examples = True |
|
examples = [] |
|
task_info = 'LLM' |
|
if model_lower: |
|
print(f"Using Model {model_lower}", flush=True) |
|
else: |
|
if verbose: |
|
print("No model defined yet", flush=True) |
|
|
|
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0 |
|
early_stopping = early_stopping if early_stopping is not None else False |
|
max_time_defaults = 60 * 3 |
|
max_time = max_time if max_time is not None else max_time_defaults |
|
|
|
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': |
|
prompt_type = inv_prompt_type_to_model_lower[model_lower] |
|
if verbose: |
|
print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True) |
|
|
|
|
|
if show_examples is None: |
|
if chat: |
|
show_examples = False |
|
else: |
|
show_examples = True |
|
|
|
summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker? |
|
Philipp: Sure you can use the new Hugging Face Deep Learning Container. |
|
Jeff: ok. |
|
Jeff: and how can I get started? |
|
Jeff: where can I find documentation? |
|
Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face""" |
|
|
|
use_placeholder_instruction_as_example = False |
|
if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower: |
|
placeholder_instruction = summarize_example1 |
|
placeholder_input = "" |
|
use_defaults = True |
|
use_default_examples = False |
|
use_placeholder_instruction_as_example = True |
|
task_info = "Summarization" |
|
elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower: |
|
placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?" |
|
placeholder_input = "" |
|
use_defaults = True |
|
use_default_examples = True |
|
task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)" |
|
elif 'mbart-' in model_lower: |
|
placeholder_instruction = "The girl has long hair." |
|
placeholder_input = "" |
|
use_defaults = True |
|
use_default_examples = False |
|
use_placeholder_instruction_as_example = True |
|
elif 'gpt2' in model_lower: |
|
placeholder_instruction = "The sky is" |
|
placeholder_input = "" |
|
prompt_type = prompt_type or 'plain' |
|
use_default_examples = True |
|
use_placeholder_instruction_as_example = True |
|
task_info = "Auto-complete phrase, code, etc." |
|
use_defaults = True |
|
else: |
|
if chat: |
|
placeholder_instruction = "" |
|
else: |
|
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter." |
|
placeholder_input = "" |
|
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': |
|
prompt_type = inv_prompt_type_to_model_lower[model_lower] |
|
elif model_lower: |
|
|
|
prompt_type = prompt_type or 'plain' |
|
else: |
|
prompt_type = '' |
|
task_info = "No task" |
|
if prompt_type == 'instruct': |
|
task_info = "Answer question or follow imperative as instruction with optionally input." |
|
elif prompt_type == 'plain': |
|
task_info = "Auto-complete phrase, code, etc." |
|
elif prompt_type == 'human_bot': |
|
if chat: |
|
task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)" |
|
else: |
|
task_info = "Ask question/imperative (input concatenated with instruction)" |
|
|
|
|
|
prompt_type = prompt_type or 'plain' |
|
if use_defaults: |
|
temperature = 1.0 if temperature is None else temperature |
|
top_p = 1.0 if top_p is None else top_p |
|
top_k = 40 if top_k is None else top_k |
|
num_beams = num_beams or 1 |
|
max_new_tokens = max_new_tokens or 512 |
|
repetition_penalty = repetition_penalty or 1.07 |
|
num_return_sequences = min(num_beams, num_return_sequences or 1) |
|
do_sample = False if do_sample is None else do_sample |
|
else: |
|
temperature = 0.1 if temperature is None else temperature |
|
top_p = 0.75 if top_p is None else top_p |
|
top_k = 40 if top_k is None else top_k |
|
num_beams = num_beams or 1 |
|
max_new_tokens = max_new_tokens or 1024 |
|
repetition_penalty = repetition_penalty or 1.07 |
|
num_return_sequences = min(num_beams, num_return_sequences or 1) |
|
do_sample = False if do_sample is None else do_sample |
|
|
|
params_list = ["", |
|
stream_output, |
|
prompt_type, prompt_dict, |
|
temperature, top_p, top_k, num_beams, |
|
max_new_tokens, min_new_tokens, |
|
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample] |
|
|
|
if use_placeholder_instruction_as_example: |
|
examples += [[placeholder_instruction, ''] + params_list] |
|
|
|
if use_default_examples: |
|
examples += [ |
|
["Translate English to French", "Good morning"] + params_list, |
|
["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list, |
|
["Explain in detailed list, all the best practices for coding in python.", ''] + params_list, |
|
[ |
|
"Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.", |
|
''] + params_list, |
|
['Translate to German: My name is Arthur', ''] + params_list, |
|
["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list, |
|
['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.', |
|
''] + params_list, |
|
['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list, |
|
['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list, |
|
["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list, |
|
[ |
|
"Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?", |
|
''] + params_list, |
|
['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list, |
|
[ |
|
'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?', |
|
''] + params_list, |
|
["""def area_of_rectangle(a: float, b: float): |
|
\"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list, |
|
["""# a function in native python: |
|
def mean(a): |
|
return sum(a)/len(a) |
|
|
|
# the same function using numpy: |
|
import numpy as np |
|
def mean(a):""", ''] + params_list, |
|
["""X = np.random.randn(100, 100) |
|
y = np.random.randint(0, 1, 100) |
|
|
|
# fit random forest classifier with 20 estimators""", ''] + params_list, |
|
] |
|
|
|
examples += [ |
|
[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list] |
|
|
|
src_lang = "English" |
|
tgt_lang = "Russian" |
|
|
|
|
|
for example in examples: |
|
example += [chat, '', '', LangChainMode.DISABLED.value, True, |
|
LangChainAction.QUERY.value, [], |
|
top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, [], |
|
pre_prompt_query, prompt_query, |
|
pre_prompt_summary, prompt_summary, |
|
system_prompt, |
|
image_loaders, |
|
pdf_loaders, |
|
url_loaders, |
|
jq_schema, |
|
None, |
|
None, |
|
False, |
|
None, |
|
None, |
|
docs_ordering_type, |
|
min_max_new_tokens, |
|
] |
|
|
|
if not chat: |
|
example[eval_func_param_names.index('instruction_nochat')] = example[ |
|
eval_func_param_names.index('instruction')] |
|
example[eval_func_param_names.index('instruction')] = '' |
|
|
|
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')] |
|
example[eval_func_param_names.index('iinput')] = '' |
|
assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % ( |
|
len(example), len(eval_func_param_names)) |
|
|
|
if prompt_type == PromptType.custom.name and not prompt_dict: |
|
raise ValueError("Unexpected to get non-empty prompt_dict=%s for prompt_type=%s" % (prompt_dict, prompt_type)) |
|
|
|
|
|
prompt_dict, error0 = get_prompt(prompt_type, prompt_dict, |
|
chat=False, context='', reduced=False, making_context=False, return_dict=True, |
|
system_prompt=system_prompt) |
|
if error0: |
|
raise RuntimeError("Prompt wrong: %s" % error0) |
|
|
|
return placeholder_instruction, placeholder_input, \ |
|
stream_output, show_examples, \ |
|
prompt_type, prompt_dict, \ |
|
temperature, top_p, top_k, num_beams, \ |
|
max_new_tokens, min_new_tokens, early_stopping, max_time, \ |
|
repetition_penalty, num_return_sequences, \ |
|
do_sample, \ |
|
src_lang, tgt_lang, \ |
|
examples, \ |
|
task_info |
|
|
|
|
|
def languages_covered(): |
|
|
|
covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)""" |
|
covered = covered.split(', ') |
|
covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered} |
|
return covered |
|
|
|
|
|
def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len): |
|
question = question[-cutoff_len:] |
|
answer = answer[-cutoff_len:] |
|
|
|
inputs = stokenizer(question, answer, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=max_length_tokenize).to(smodel.device) |
|
try: |
|
score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0] |
|
except torch.cuda.OutOfMemoryError as e: |
|
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True) |
|
del inputs |
|
traceback.print_exc() |
|
clear_torch_cache() |
|
return 'Response Score: GPU OOM' |
|
except (Exception, RuntimeError) as e: |
|
if 'Expected all tensors to be on the same device' in str(e) or \ |
|
'expected scalar type Half but found Float' in str(e) or \ |
|
'probability tensor contains either' in str(e) or \ |
|
'cublasLt ran into an error!' in str(e) or \ |
|
'device-side assert triggered' in str(e): |
|
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), |
|
flush=True) |
|
traceback.print_exc() |
|
clear_torch_cache() |
|
return 'Response Score: GPU Error' |
|
else: |
|
raise |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
|
return score |
|
|
|
|
|
def check_locals(**kwargs): |
|
|
|
can_skip_because_locally_generated = no_default_param_names + [ |
|
|
|
'reward_type' |
|
] |
|
for k in eval_func_param_names: |
|
if k in can_skip_because_locally_generated: |
|
continue |
|
assert k in kwargs, "Missing %s" % k |
|
for k in inputs_kwargs_list: |
|
if k in can_skip_because_locally_generated: |
|
continue |
|
assert k in kwargs, "Missing %s" % k |
|
|
|
for k in list(inspect.signature(get_model).parameters): |
|
if k in can_skip_because_locally_generated: |
|
continue |
|
assert k in kwargs, "Missing %s" % k |
|
|
|
|
|
def get_model_max_length(model_state): |
|
if not isinstance(model_state['tokenizer'], (str, type(None))): |
|
return model_state['tokenizer'].model_max_length |
|
else: |
|
return 2048 |
|
|
|
|
|
def get_max_max_new_tokens(model_state, **kwargs): |
|
if not isinstance(model_state['tokenizer'], (str, type(None))): |
|
max_max_new_tokens = model_state['tokenizer'].model_max_length |
|
else: |
|
max_max_new_tokens = None |
|
|
|
if kwargs['max_max_new_tokens'] is not None and max_max_new_tokens is not None: |
|
return min(max_max_new_tokens, kwargs['max_max_new_tokens']) |
|
elif kwargs['max_max_new_tokens'] is not None: |
|
return kwargs['max_max_new_tokens'] |
|
elif kwargs['memory_restriction_level'] == 1: |
|
return 768 |
|
elif kwargs['memory_restriction_level'] == 2: |
|
return 512 |
|
elif kwargs['memory_restriction_level'] >= 3: |
|
return 256 |
|
else: |
|
|
|
return 2048 |
|
|
|
|
|
def get_minmax_top_k_docs(is_public): |
|
if is_public: |
|
min_top_k_docs = 1 |
|
max_top_k_docs = 8 |
|
label_top_k_docs = "Number of document chunks" |
|
else: |
|
min_top_k_docs = -1 |
|
max_top_k_docs = 100 |
|
label_top_k_docs = "Number of document chunks (-1 = auto fill model context)" |
|
return min_top_k_docs, max_top_k_docs, label_top_k_docs |
|
|
|
|
|
def merge_chat_conversation_history(chat_conversation1, history): |
|
|
|
if chat_conversation1: |
|
chat_conversation1 = str_to_list(chat_conversation1) |
|
for conv1 in chat_conversation1: |
|
assert isinstance(conv1, (list, tuple)) |
|
assert len(conv1) == 2 |
|
|
|
if isinstance(history, list): |
|
|
|
if chat_conversation1: |
|
|
|
history = chat_conversation1 + history.copy() |
|
elif chat_conversation1: |
|
history = chat_conversation1 |
|
else: |
|
history = [] |
|
return history |
|
|
|
|
|
def history_to_context(history, langchain_mode=None, |
|
add_chat_history_to_context=None, |
|
prompt_type=None, prompt_dict=None, chat=None, model_max_length=None, |
|
memory_restriction_level=None, keep_sources_in_context=None, |
|
system_prompt=None, chat_conversation=None): |
|
""" |
|
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair |
|
:param history: |
|
:param langchain_mode: |
|
:param add_chat_history_to_context: |
|
:param prompt_type: |
|
:param prompt_dict: |
|
:param chat: |
|
:param model_max_length: |
|
:param memory_restriction_level: |
|
:param keep_sources_in_context: |
|
:param system_prompt: |
|
:param chat_conversation: |
|
:return: |
|
""" |
|
history = merge_chat_conversation_history(chat_conversation, history) |
|
|
|
if len(history) >= 1 and len(history[-1]) >= 2 and not history[-1][1]: |
|
len_history = len(history) - 1 |
|
else: |
|
|
|
len_history = len(history) |
|
|
|
|
|
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, |
|
for_context=True, model_max_length=model_max_length) |
|
context1 = '' |
|
if max_prompt_length is not None and add_chat_history_to_context: |
|
context1 = '' |
|
|
|
for histi in range(0, len_history): |
|
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1]) |
|
prompt, pre_response, terminate_response, chat_sep, chat_turn_sep = \ |
|
generate_prompt(data_point, |
|
prompt_type, |
|
prompt_dict, |
|
chat, |
|
reduced=True, |
|
making_context=True, |
|
system_prompt=system_prompt, |
|
histi=histi) |
|
|
|
if not keep_sources_in_context and langchain_mode != 'Disabled' and prompt.find(super_source_prefix) >= 0: |
|
|
|
import re |
|
prompt = re.sub(f'{re.escape(super_source_prefix)}.*?{re.escape(super_source_postfix)}', '', prompt, |
|
flags=re.DOTALL) |
|
if prompt.endswith('\n<p>'): |
|
prompt = prompt[:-4] |
|
prompt = prompt.replace('<br>', chat_turn_sep) |
|
if not prompt.endswith(chat_turn_sep): |
|
prompt += chat_turn_sep |
|
|
|
|
|
if len(prompt + context1) > max_prompt_length: |
|
break |
|
context1 += prompt |
|
|
|
_, pre_response, terminate_response, chat_sep, chat_turn_sep = \ |
|
generate_prompt({}, prompt_type, prompt_dict, |
|
chat, reduced=True, |
|
making_context=True, |
|
system_prompt=system_prompt, |
|
histi=-1) |
|
if context1 and not context1.endswith(chat_turn_sep): |
|
context1 += chat_turn_sep |
|
return context1 |
|
|
|
|
|
def get_limited_prompt(instruction, |
|
iinput, |
|
tokenizer, |
|
prompter=None, |
|
inference_server=None, |
|
prompt_type=None, prompt_dict=None, chat=False, max_new_tokens=None, |
|
system_prompt='', |
|
context='', chat_conversation=None, text_context_list=None, |
|
keep_sources_in_context=False, |
|
model_max_length=None, memory_restriction_level=0, |
|
langchain_mode=None, add_chat_history_to_context=True, |
|
verbose=False, |
|
doc_importance=0.5, |
|
min_max_new_tokens=256, |
|
): |
|
if prompter: |
|
prompt_type = prompter.prompt_type |
|
prompt_dict = prompter.prompt_dict |
|
chat = prompter.chat |
|
stream_output = prompter.stream_output |
|
system_prompt = prompter.system_prompt |
|
|
|
|
|
history = [] |
|
history = merge_chat_conversation_history(chat_conversation, history) |
|
history_to_context_func = functools.partial(history_to_context, |
|
langchain_mode=langchain_mode, |
|
add_chat_history_to_context=add_chat_history_to_context, |
|
prompt_type=prompt_type, |
|
prompt_dict=prompt_dict, |
|
chat=chat, |
|
model_max_length=model_max_length, |
|
memory_restriction_level=memory_restriction_level, |
|
keep_sources_in_context=keep_sources_in_context, |
|
system_prompt=system_prompt) |
|
context2 = history_to_context_func(history) |
|
context1 = context |
|
if context1 is None: |
|
context1 = '' |
|
|
|
from h2oai_pipeline import H2OTextGenerationPipeline |
|
data_point_just_instruction = dict(context='', instruction=instruction, input='') |
|
prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction) |
|
instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer) |
|
num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer) |
|
num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens) |
|
|
|
context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer) |
|
context2, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer) |
|
iinput, num_iinput_tokens = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer) |
|
if text_context_list is None: |
|
text_context_list = [] |
|
num_doc_tokens = sum([get_token_count(x + '\n\n', tokenizer) for x in text_context_list]) |
|
|
|
num_prompt_tokens0 = (num_instruction_tokens or 0) + \ |
|
(num_context1_tokens or 0) + \ |
|
(num_context2_tokens or 0) + \ |
|
(num_iinput_tokens or 0) + \ |
|
(num_doc_tokens or 0) |
|
|
|
|
|
|
|
min_max_new_tokens = min(min_max_new_tokens, max_new_tokens) |
|
|
|
chat_index = 0 |
|
|
|
|
|
num_non_doc_tokens = num_prompt_tokens0 - num_doc_tokens |
|
|
|
doc_max_length = max(model_max_length - num_non_doc_tokens, doc_importance * model_max_length) |
|
top_k_docs, one_doc_size, num_doc_tokens = get_docs_tokens(tokenizer, text_context_list=text_context_list, |
|
max_input_tokens=doc_max_length) |
|
non_doc_max_length = max(model_max_length - num_doc_tokens, (1.0 - doc_importance) * model_max_length) |
|
|
|
if num_non_doc_tokens > non_doc_max_length: |
|
|
|
|
|
|
|
|
|
|
|
diff1 = non_doc_max_length - ( |
|
num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens) |
|
diff2 = non_doc_max_length - (num_instruction_tokens + num_context1_tokens + min_max_new_tokens) |
|
diff3 = non_doc_max_length - (num_instruction_tokens + min_max_new_tokens) |
|
diff4 = non_doc_max_length - min_max_new_tokens |
|
if diff1 > 0: |
|
|
|
iinput = '' |
|
num_iinput_tokens = 0 |
|
elif diff2 > 0 > diff1: |
|
|
|
iinput = '' |
|
num_iinput_tokens = 0 |
|
chat_index_final = len(history) |
|
for chat_index in range(len(history)): |
|
|
|
|
|
context2 = history_to_context_func(history[chat_index:]) |
|
num_context2_tokens = get_token_count(context2, tokenizer) |
|
diff1 = non_doc_max_length - ( |
|
num_instruction_tokens + num_context1_tokens + num_context2_tokens + min_max_new_tokens) |
|
if diff1 > 0: |
|
chat_index_final = chat_index |
|
if verbose: |
|
print("chat_conversation used %d out of %d" % (chat_index, len(history)), flush=True) |
|
break |
|
chat_index = chat_index_final |
|
elif diff3 > 0 > diff2: |
|
|
|
iinput = '' |
|
num_iinput_tokens = 0 |
|
context2 = '' |
|
num_context2_tokens = 0 |
|
context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer, |
|
max_prompt_length=diff3) |
|
if num_context1_tokens <= diff3: |
|
pass |
|
else: |
|
print("failed to reduce", flush=True) |
|
else: |
|
|
|
iinput = '' |
|
num_iinput_tokens = 0 |
|
context2 = '' |
|
num_context2_tokens = 0 |
|
context1 = '' |
|
num_context1_tokens = 0 |
|
|
|
|
|
instruction, num_instruction_tokens = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer, |
|
max_prompt_length=diff4) |
|
|
|
data_point_just_instruction = dict(context='', instruction=instruction, input='') |
|
prompt_just_instruction = prompter.generate_prompt(data_point_just_instruction) |
|
num_instruction_tokens_real = get_token_count(prompt_just_instruction, tokenizer) |
|
num_instruction_tokens += (num_instruction_tokens_real - num_instruction_tokens) |
|
|
|
|
|
context = context1 + context2 |
|
|
|
num_prompt_tokens = (num_instruction_tokens or 0) + \ |
|
(num_context1_tokens or 0) + \ |
|
(num_context2_tokens or 0) + \ |
|
(num_iinput_tokens or 0) + \ |
|
(num_doc_tokens or 0) |
|
|
|
|
|
if inference_server and inference_server.startswith('http'): |
|
|
|
pass |
|
else: |
|
|
|
|
|
max_new_tokens = min(max_new_tokens, model_max_length - num_prompt_tokens) |
|
|
|
if prompter is None: |
|
|
|
debug = False |
|
stream_output = False |
|
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output, |
|
system_prompt=system_prompt) |
|
|
|
data_point = dict(context=context, instruction=instruction, input=iinput) |
|
|
|
|
|
|
|
context_from_history = len(history) > 0 and len(context1) > 0 |
|
prompt = prompter.generate_prompt(data_point, context_from_history=context_from_history) |
|
num_prompt_tokens_actual = get_token_count(prompt, tokenizer) |
|
|
|
return prompt, \ |
|
instruction, iinput, context, \ |
|
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \ |
|
chat_index, top_k_docs, one_doc_size |
|
|
|
|
|
def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None): |
|
if text_context_list is None or len(text_context_list) == 0: |
|
return 0, None, 0 |
|
if max_input_tokens is None: |
|
max_input_tokens = tokenizer.model_max_length |
|
tokens = [get_token_count(x + '\n\n', tokenizer) for x in text_context_list] |
|
tokens_cumsum = np.cumsum(tokens) |
|
where_res = np.where(tokens_cumsum < max_input_tokens)[0] |
|
|
|
if where_res.shape[0] > 0: |
|
top_k_docs = 1 + where_res[-1] |
|
one_doc_size = None |
|
num_doc_tokens = tokens_cumsum[top_k_docs - 1] |
|
else: |
|
|
|
top_k_docs = 1 |
|
text_context_list = text_context_list[:top_k_docs] |
|
|
|
from src.h2oai_pipeline import H2OTextGenerationPipeline |
|
doc_content = text_context_list[0] |
|
doc_content, new_tokens0 = H2OTextGenerationPipeline.limit_prompt(doc_content, |
|
tokenizer, |
|
max_prompt_length=max_input_tokens) |
|
text_context_list[0] = doc_content |
|
one_doc_size = len(doc_content) |
|
num_doc_tokens = get_token_count(doc_content + '\n\n', tokenizer) |
|
print("Unexpected large chunks and can't add to context, will add 1 anyways. Tokens %s -> %s" % ( |
|
tokens[0], new_tokens0), flush=True) |
|
return top_k_docs, one_doc_size, num_doc_tokens |
|
|
|
|
|
def entrypoint_main(): |
|
""" |
|
Examples: |
|
|
|
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B |
|
python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B' |
|
python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B' |
|
|
|
# generate without lora weights, no prompt |
|
python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain' |
|
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' |
|
|
|
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq' |
|
# OpenChatKit settings: |
|
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 |
|
|
|
python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False |
|
python generate.py --base_model='t5-large' --prompt_type='simple_instruct' |
|
python generate.py --base_model='philschmid/bart-large-cnn-samsum' |
|
python generate.py --base_model='philschmid/flan-t5-base-samsum' |
|
python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt' |
|
|
|
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28' |
|
|
|
must have 4*48GB GPU and run without 8bit in order for sharding to work with use_gpu_id=False |
|
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned |
|
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --use_gpu_id=False --prompt_type='human_bot' |
|
|
|
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b |
|
""" |
|
H2O_Fire(main) |
|
|
|
|
|
if __name__ == "__main__": |
|
entrypoint_main() |
|
|