diff --git "a/src/gradio_runner.py" "b/src/gradio_runner.py" new file mode 100644--- /dev/null +++ "b/src/gradio_runner.py" @@ -0,0 +1,4594 @@ +import ast +import copy +import functools +import inspect +import itertools +import json +import os +import pprint +import random +import shutil +import sys +import time +import traceback +import uuid +import filelock +import numpy as np +import pandas as pd +import requests +from iterators import TimeoutIterator + +from gradio_utils.css import get_css +from gradio_utils.prompt_form import make_chatbots + +# This is a hack to prevent Gradio from phoning home when it gets imported +os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' + + +def my_get(url, **kwargs): + print('Gradio HTTP request redirected to localhost :)', flush=True) + kwargs.setdefault('allow_redirects', True) + return requests.api.request('get', 'http://127.0.0.1/', **kwargs) + + +original_get = requests.get +requests.get = my_get +import gradio as gr + +requests.get = original_get + + +def fix_pydantic_duplicate_validators_error(): + try: + from pydantic import class_validators + + class_validators.in_ipython = lambda: True # type: ignore[attr-defined] + except ImportError: + pass + + +fix_pydantic_duplicate_validators_error() + +from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \ + DocumentChoice, langchain_modes_intrinsic, LangChainTypes, langchain_modes_non_db, gr_to_lg, invalid_key_msg, \ + LangChainAgent, docs_ordering_types +from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, \ + get_dark_js, get_heap_js, wrap_js_to_lambda, \ + spacing_xsm, radius_xsm, text_xsm +from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \ + get_prompt +from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \ + ping, makedirs, get_kwargs, system_info, ping_gpu, get_url, get_local_ip, \ + save_generate_output, url_alive, remove, dict_to_html, text_to_html, lg_to_gr, str_to_dict, have_serpapi +from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, \ + get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \ + evaluate_fake, merge_chat_conversation_history +from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \ + input_args_list, key_overrides + +from apscheduler.schedulers.background import BackgroundScheduler + + +def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True): + if fix_latex_dollars: + ts = text.split('```') + for parti, part in enumerate(ts): + inside = parti % 2 == 1 + if not inside: + ts[parti] = ts[parti].replace('$', '﹩') + text = '```'.join(ts) + + if fix_new_lines: + # let Gradio handle code, since got improved recently + ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is. + # ensure good visually, else markdown ignores multiple \n + # handle code blocks + ts = text.split('```') + for parti, part in enumerate(ts): + inside = parti % 2 == 1 + if not inside: + ts[parti] = ts[parti].replace('\n', '
') + text = '```'.join(ts) + return text + + +def is_valid_key(enforce_h2ogpt_api_key, h2ogpt_api_keys, h2ogpt_key1, requests_state1=None): + valid_key = False + if not enforce_h2ogpt_api_key: + # no token barrier + valid_key = 'not enforced' + else: + if isinstance(h2ogpt_api_keys, list) and h2ogpt_key1 in h2ogpt_api_keys: + # passed token barrier + valid_key = True + elif isinstance(h2ogpt_api_keys, str) and os.path.isfile(h2ogpt_api_keys): + with filelock.FileLock(h2ogpt_api_keys + '.lock'): + with open(h2ogpt_api_keys, 'rt') as f: + h2ogpt_api_keys = json.load(f) + if h2ogpt_key1 in h2ogpt_api_keys: + valid_key = True + if isinstance(requests_state1, dict) and 'username' in requests_state1 and requests_state1['username']: + # no UI limit currently + valid_key = True + return valid_key + + +def go_gradio(**kwargs): + allow_api = kwargs['allow_api'] + is_public = kwargs['is_public'] + is_hf = kwargs['is_hf'] + memory_restriction_level = kwargs['memory_restriction_level'] + n_gpus = kwargs['n_gpus'] + admin_pass = kwargs['admin_pass'] + model_states = kwargs['model_states'] + dbs = kwargs['dbs'] + db_type = kwargs['db_type'] + visible_langchain_actions = kwargs['visible_langchain_actions'] + visible_langchain_agents = kwargs['visible_langchain_agents'] + allow_upload_to_user_data = kwargs['allow_upload_to_user_data'] + allow_upload_to_my_data = kwargs['allow_upload_to_my_data'] + enable_sources_list = kwargs['enable_sources_list'] + enable_url_upload = kwargs['enable_url_upload'] + enable_text_upload = kwargs['enable_text_upload'] + use_openai_embedding = kwargs['use_openai_embedding'] + hf_embedding_model = kwargs['hf_embedding_model'] + load_db_if_exists = kwargs['load_db_if_exists'] + migrate_embedding_model = kwargs['migrate_embedding_model'] + auto_migrate_db = kwargs['auto_migrate_db'] + captions_model = kwargs['captions_model'] + caption_loader = kwargs['caption_loader'] + doctr_loader = kwargs['doctr_loader'] + + n_jobs = kwargs['n_jobs'] + verbose = kwargs['verbose'] + + # for dynamic state per user session in gradio + model_state0 = kwargs['model_state0'] + score_model_state0 = kwargs['score_model_state0'] + my_db_state0 = kwargs['my_db_state0'] + selection_docs_state0 = kwargs['selection_docs_state0'] + visible_models_state0 = kwargs['visible_models_state0'] + # For Heap analytics + is_heap_analytics_enabled = kwargs['enable_heap_analytics'] + heap_app_id = kwargs['heap_app_id'] + + # easy update of kwargs needed for evaluate() etc. + queue = True + allow_upload = allow_upload_to_user_data or allow_upload_to_my_data + allow_upload_api = allow_api and allow_upload + + kwargs.update(locals()) + + # import control + if kwargs['langchain_mode'] != 'Disabled': + from gpt_langchain import file_types, have_arxiv + else: + have_arxiv = False + file_types = [] + + if 'mbart-' in kwargs['model_lower']: + instruction_label_nochat = "Text to translate" + else: + instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \ + " use Enter for multiple input lines)" + + title = 'h2oGPT' + if kwargs['visible_h2ogpt_header']: + description = """h2oGPT LLM Leaderboard LLM Studio
CodeLlama
🤗 Models""" + else: + description = None + description_bottom = "If this host is busy, try
[Multi-Model](https://gpt.h2o.ai)
[CodeLlama](https://codellama.h2o.ai)
[Llama2 70B](https://llama.h2o.ai)
[Falcon 40B](https://falcon.h2o.ai)
[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)
[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
" + if is_hf: + description_bottom += '''Duplicate Space''' + task_info_md = '' + css_code = get_css(kwargs) + + if kwargs['gradio_offline_level'] >= 0: + # avoid GoogleFont that pulls from internet + if kwargs['gradio_offline_level'] == 1: + # front end would still have to download fonts or have cached it at some point + base_font = 'Source Sans Pro' + else: + base_font = 'Helvetica' + theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'), + font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace')) + else: + theme_kwargs = dict() + if kwargs['gradio_size'] == 'xsmall': + theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm)) + elif kwargs['gradio_size'] in [None, 'small']: + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm, + radius_size=gr.themes.sizes.spacing_sm)) + elif kwargs['gradio_size'] == 'large': + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg), + radius_size=gr.themes.sizes.spacing_lg) + elif kwargs['gradio_size'] == 'medium': + theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md, + radius_size=gr.themes.sizes.spacing_md)) + + theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs) + demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False) + callback = gr.CSVLogger() + + model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options'] + if kwargs['base_model'].strip() not in model_options0: + model_options0 = [kwargs['base_model'].strip()] + model_options0 + lora_options = kwargs['extra_lora_options'] + if kwargs['lora_weights'].strip() not in lora_options: + lora_options = [kwargs['lora_weights'].strip()] + lora_options + server_options = kwargs['extra_server_options'] + if kwargs['inference_server'].strip() not in server_options: + server_options = [kwargs['inference_server'].strip()] + server_options + if os.getenv('OPENAI_API_KEY'): + if 'openai_chat' not in server_options: + server_options += ['openai_chat'] + if 'openai' not in server_options: + server_options += ['openai'] + + # always add in no lora case + # add fake space so doesn't go away in gradio dropdown + model_options0 = [no_model_str] + sorted(model_options0) + lora_options = [no_lora_str] + sorted(lora_options) + server_options = [no_server_str] + sorted(server_options) + # always add in no model case so can free memory + # add fake space so doesn't go away in gradio dropdown + + # transcribe, will be detranscribed before use by evaluate() + if not kwargs['base_model'].strip(): + kwargs['base_model'] = no_model_str + + if not kwargs['lora_weights'].strip(): + kwargs['lora_weights'] = no_lora_str + + if not kwargs['inference_server'].strip(): + kwargs['inference_server'] = no_server_str + + # transcribe for gradio + kwargs['gpu_id'] = str(kwargs['gpu_id']) + + no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]' + output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get( + 'base_model') else no_model_msg + output_label0_model2 = no_model_msg + + def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0): + if not prompt_type1 or which_model != 0: + # keep prompt_type and prompt_dict in sync if possible + prompt_type1 = kwargs.get('prompt_type', prompt_type1) + prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) + # prefer model specific prompt type instead of global one + if not prompt_type1 or which_model != 0: + prompt_type1 = model_state1.get('prompt_type', prompt_type1) + prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) + + if not prompt_dict1 or which_model != 0: + # if still not defined, try to get + prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) + if not prompt_dict1 or which_model != 0: + prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) + return prompt_type1, prompt_dict1 + + def visible_models_to_model_choice(visible_models1): + if isinstance(visible_models1, list): + assert len( + visible_models1) >= 1, "Invalid visible_models1=%s, can only be single entry" % visible_models1 + # just take first + model_active_choice1 = visible_models1[0] + elif isinstance(visible_models1, (str, int)): + model_active_choice1 = visible_models1 + else: + assert isinstance(visible_models1, type(None)), "Invalid visible_models1=%s" % visible_models1 + model_active_choice1 = visible_models1 + if model_active_choice1 is not None: + if isinstance(model_active_choice1, str): + base_model_list = [x['base_model'] for x in model_states] + if model_active_choice1 in base_model_list: + # if dups, will just be first one + model_active_choice1 = base_model_list.index(model_active_choice1) + else: + # NOTE: Could raise, but sometimes raising in certain places fails too hard and requires UI restart + model_active_choice1 = 0 + else: + model_active_choice1 = 0 + return model_active_choice1 + + default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults} + # ensure prompt_type consistent with prep_bot(), so nochat API works same way + default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \ + update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'], + model_state1=model_state0, + which_model=visible_models_to_model_choice(kwargs['visible_models'])) + for k in no_default_param_names: + default_kwargs[k] = '' + + def dummy_fun(x): + # need dummy function to block new input from being sent until output is done, + # else gets input_list at time of submit that is old, and shows up as truncated in chatbot + return x + + def update_auth_selection(auth_user, selection_docs_state1, save=False): + # in-place update of both + if 'selection_docs_state' not in auth_user: + auth_user['selection_docs_state'] = selection_docs_state0 + for k, v in auth_user['selection_docs_state'].items(): + if isinstance(selection_docs_state1[k], dict): + if save: + auth_user['selection_docs_state'][k].clear() + auth_user['selection_docs_state'][k].update(selection_docs_state1[k]) + else: + selection_docs_state1[k].clear() + selection_docs_state1[k].update(auth_user['selection_docs_state'][k]) + elif isinstance(selection_docs_state1[k], list): + if save: + auth_user['selection_docs_state'][k].clear() + auth_user['selection_docs_state'][k].extend(selection_docs_state1[k]) + else: + selection_docs_state1[k].clear() + selection_docs_state1[k].extend(auth_user['selection_docs_state'][k]) + else: + raise RuntimeError("Bad type: %s" % selection_docs_state1[k]) + + # BEGIN AUTH THINGS + def auth_func(username1, password1, auth_pairs=None, auth_filename=None, + auth_access=None, + auth_freeze=None, + guest_name=None, + selection_docs_state1=None, + selection_docs_state00=None, + **kwargs): + assert auth_freeze is not None + if selection_docs_state1 is None: + selection_docs_state1 = selection_docs_state00 + assert selection_docs_state1 is not None + assert auth_filename and isinstance(auth_filename, str), "Auth file must be a non-empty string, got: %s" % str( + auth_filename) + if auth_access == 'open' and username1 == guest_name: + return True + if username1 == '': + # some issue with login + return False + with filelock.FileLock(auth_filename + '.lock'): + auth_dict = {} + if os.path.isfile(auth_filename): + try: + with open(auth_filename, 'rt') as f: + auth_dict = json.load(f) + except json.decoder.JSONDecodeError as e: + print("Auth exception: %s" % str(e), flush=True) + shutil.move(auth_filename, auth_filename + '.bak' + str(uuid.uuid4())) + auth_dict = {} + if username1 in auth_dict and username1 in auth_pairs: + if password1 == auth_dict[username1]['password'] and password1 == auth_pairs[username1]: + auth_user = auth_dict[username1] + update_auth_selection(auth_user, selection_docs_state1) + save_auth_dict(auth_dict, auth_filename) + return True + else: + return False + elif username1 in auth_dict: + if password1 == auth_dict[username1]['password']: + auth_user = auth_dict[username1] + update_auth_selection(auth_user, selection_docs_state1) + save_auth_dict(auth_dict, auth_filename) + return True + else: + return False + elif username1 in auth_pairs: + # copy over CLI auth to file so only one state to manage + auth_dict[username1] = dict(password=auth_pairs[username1], userid=str(uuid.uuid4())) + auth_user = auth_dict[username1] + update_auth_selection(auth_user, selection_docs_state1) + save_auth_dict(auth_dict, auth_filename) + return True + else: + if auth_access == 'closed': + return False + # open access + auth_dict[username1] = dict(password=password1, userid=str(uuid.uuid4())) + auth_user = auth_dict[username1] + update_auth_selection(auth_user, selection_docs_state1) + save_auth_dict(auth_dict, auth_filename) + if auth_access == 'open': + return True + else: + raise RuntimeError("Invalid auth_access: %s" % auth_access) + + def auth_func_open(*args, **kwargs): + return True + + def get_username(requests_state1): + username1 = None + if 'username' in requests_state1: + username1 = requests_state1['username'] + return username1 + + def get_userid_auth_func(requests_state1, auth_filename=None, auth_access=None, guest_name=None, **kwargs): + if auth_filename and isinstance(auth_filename, str): + username1 = get_username(requests_state1) + if username1: + if username1 == guest_name: + return str(uuid.uuid4()) + with filelock.FileLock(auth_filename + '.lock'): + if os.path.isfile(auth_filename): + with open(auth_filename, 'rt') as f: + auth_dict = json.load(f) + if username1 in auth_dict: + return auth_dict[username1]['userid'] + # if here, then not persistently associated with username1, + # but should only be one-time asked if going to persist within a single session! + return str(uuid.uuid4()) + + get_userid_auth = functools.partial(get_userid_auth_func, + auth_filename=kwargs['auth_filename'], + auth_access=kwargs['auth_access'], + guest_name=kwargs['guest_name'], + ) + if kwargs['auth_access'] == 'closed': + auth_message1 = "Closed access" + else: + auth_message1 = "WELCOME! Open access" \ + " (%s/%s or any unique user/pass)" % (kwargs['guest_name'], kwargs['guest_name']) + + if kwargs['auth_message'] is not None: + auth_message = kwargs['auth_message'] + else: + auth_message = auth_message1 + + # always use same callable + auth_pairs0 = {} + if isinstance(kwargs['auth'], list): + for k, v in kwargs['auth']: + auth_pairs0[k] = v + authf = functools.partial(auth_func, + auth_pairs=auth_pairs0, + auth_filename=kwargs['auth_filename'], + auth_access=kwargs['auth_access'], + auth_freeze=kwargs['auth_freeze'], + guest_name=kwargs['guest_name'], + selection_docs_state00=copy.deepcopy(selection_docs_state0)) + + def get_request_state(requests_state1, request, db1s): + # if need to get state, do it now + if not requests_state1: + requests_state1 = requests_state0.copy() + if requests: + if not requests_state1.get('headers', '') and hasattr(request, 'headers'): + requests_state1.update(request.headers) + if not requests_state1.get('host', '') and hasattr(request, 'host'): + requests_state1.update(dict(host=request.host)) + if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'): + requests_state1.update(dict(host2=request.client.host)) + if not requests_state1.get('username', '') and hasattr(request, 'username'): + from src.gpt_langchain import get_username_direct + # use already-defined username instead of keep changing to new uuid + # should be same as in requests_state1 + db_username = get_username_direct(db1s) + requests_state1.update(dict(username=request.username or db_username or str(uuid.uuid4()))) + requests_state1 = {str(k): str(v) for k, v in requests_state1.items()} + return requests_state1 + + def user_state_setup(db1s, requests_state1, request: gr.Request, *args): + requests_state1 = get_request_state(requests_state1, request, db1s) + from src.gpt_langchain import set_userid + set_userid(db1s, requests_state1, get_userid_auth) + args_list = [db1s, requests_state1] + list(args) + return tuple(args_list) + + # END AUTH THINGS + + def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): + allow = False + allow |= langchain_action1 not in LangChainAction.QUERY.value + allow |= document_subset1 in DocumentSubset.TopKSources.name + if langchain_mode1 in [LangChainMode.LLM.value]: + allow = False + return allow + + image_loaders_options0, image_loaders_options, \ + pdf_loaders_options0, pdf_loaders_options, \ + url_loaders_options0, url_loaders_options = lg_to_gr(**kwargs) + jq_schema0 = '.[]' + + with demo: + # avoid actual model/tokenizer here or anything that would be bad to deepcopy + # https://github.com/gradio-app/gradio/issues/3558 + model_state = gr.State( + dict(model='model', tokenizer='tokenizer', device=kwargs['device'], + base_model=kwargs['base_model'], + tokenizer_base_model=kwargs['tokenizer_base_model'], + lora_weights=kwargs['lora_weights'], + inference_server=kwargs['inference_server'], + prompt_type=kwargs['prompt_type'], + prompt_dict=kwargs['prompt_dict'], + ) + ) + + def update_langchain_mode_paths(selection_docs_state1): + dup = selection_docs_state1['langchain_mode_paths'].copy() + for k, v in dup.items(): + if k not in selection_docs_state1['langchain_modes']: + selection_docs_state1['langchain_mode_paths'].pop(k) + for k in selection_docs_state1['langchain_modes']: + if k not in selection_docs_state1['langchain_mode_types']: + # if didn't specify shared, then assume scratch if didn't login or personal if logged in + selection_docs_state1['langchain_mode_types'][k] = LangChainTypes.PERSONAL.value + return selection_docs_state1 + + # Setup some gradio states for per-user dynamic state + model_state2 = gr.State(kwargs['model_state_none'].copy()) + model_options_state = gr.State([model_options0]) + lora_options_state = gr.State([lora_options]) + server_options_state = gr.State([server_options]) + my_db_state = gr.State(my_db_state0) + chat_state = gr.State({}) + docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value] + docs_state0 = [] + [docs_state0.append(x) for x in docs_state00 if x not in docs_state0] + docs_state = gr.State(docs_state0) + viewable_docs_state0 = [] + viewable_docs_state = gr.State(viewable_docs_state0) + selection_docs_state0 = update_langchain_mode_paths(selection_docs_state0) + selection_docs_state = gr.State(selection_docs_state0) + requests_state0 = dict(headers='', host='', username='') + requests_state = gr.State(requests_state0) + + if description is not None: + gr.Markdown(f""" + {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)} + """) + + # go button visible if + base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0'] + go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary") + + nas = ' '.join(['NA'] * len(kwargs['model_states'])) + res_value = "Response Score: NA" if not kwargs[ + 'model_lock'] else "Response Scores: %s" % nas + + user_can_do_sum = kwargs['langchain_mode'] != LangChainMode.DISABLED.value and \ + (kwargs['visible_side_bar'] or kwargs['visible_system_tab']) + if user_can_do_sum: + extra_prompt_form = ". For summarization, no query required, just click submit" + else: + extra_prompt_form = "" + if kwargs['input_lines'] > 1: + instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form + else: + instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form + + def get_langchain_choices(selection_docs_state1): + langchain_modes = selection_docs_state1['langchain_modes'] + + if is_hf: + # don't show 'wiki' since only usually useful for internal testing at moment + no_show_modes = ['Disabled', 'wiki'] + else: + no_show_modes = ['Disabled'] + allowed_modes = langchain_modes.copy() + # allowed_modes = [x for x in allowed_modes if x in dbs] + allowed_modes += ['LLM'] + if allow_upload_to_my_data and 'MyData' not in allowed_modes: + allowed_modes += ['MyData'] + if allow_upload_to_user_data and 'UserData' not in allowed_modes: + allowed_modes += ['UserData'] + choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes] + return choices + + def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None): + langchain_choices1 = get_langchain_choices(selection_docs_state1) + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if k in langchain_choices1} + if langchain_mode_paths: + langchain_mode_paths = langchain_mode_paths.copy() + for langchain_mode1 in langchain_modes_non_db: + langchain_mode_paths.pop(langchain_mode1, None) + df1 = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns') + df1.columns = ['Collection', 'Path'] + df1 = df1.set_index('Collection') + else: + df1 = pd.DataFrame(None) + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + langchain_mode_types = {k: v for k, v in langchain_mode_types.items() if k in langchain_choices1} + if langchain_mode_types: + langchain_mode_types = langchain_mode_types.copy() + for langchain_mode1 in langchain_modes_non_db: + langchain_mode_types.pop(langchain_mode1, None) + + df2 = pd.DataFrame.from_dict(langchain_mode_types.items(), orient='columns') + df2.columns = ['Collection', 'Type'] + df2 = df2.set_index('Collection') + + from src.gpt_langchain import get_persist_directory, load_embed + persist_directory_dict = {} + embed_dict = {} + chroma_version_dict = {} + for langchain_mode3 in langchain_mode_types: + langchain_type3 = langchain_mode_types.get(langchain_mode3, LangChainTypes.EITHER.value) + persist_directory3, langchain_type3 = get_persist_directory(langchain_mode3, + langchain_type=langchain_type3, + db1s=db1s, dbs=dbs1) + got_embedding3, use_openai_embedding3, hf_embedding_model3 = load_embed( + persist_directory=persist_directory3) + persist_directory_dict[langchain_mode3] = persist_directory3 + embed_dict[langchain_mode3] = 'OpenAI' if not hf_embedding_model3 else hf_embedding_model3 + + if os.path.isfile(os.path.join(persist_directory3, 'chroma.sqlite3')): + chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4' + elif os.path.isdir(os.path.join(persist_directory3, 'index')): + chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4' + elif not os.listdir(persist_directory3): + if db_type == 'chroma': + chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4' # will be + elif db_type == 'chroma_old': + chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4' # will be + else: + chroma_version_dict[langchain_mode3] = 'Weaviate' # will be + if isinstance(hf_embedding_model, dict): + hf_embedding_model3 = hf_embedding_model['name'] + else: + hf_embedding_model3 = hf_embedding_model + assert isinstance(hf_embedding_model3, str) + embed_dict[langchain_mode3] = hf_embedding_model3 # will be + else: + chroma_version_dict[langchain_mode3] = 'Weaviate' + + df3 = pd.DataFrame.from_dict(persist_directory_dict.items(), orient='columns') + df3.columns = ['Collection', 'Directory'] + df3 = df3.set_index('Collection') + + df4 = pd.DataFrame.from_dict(embed_dict.items(), orient='columns') + df4.columns = ['Collection', 'Embedding'] + df4 = df4.set_index('Collection') + + df5 = pd.DataFrame.from_dict(chroma_version_dict.items(), orient='columns') + df5.columns = ['Collection', 'DB'] + df5 = df5.set_index('Collection') + else: + df2 = pd.DataFrame(None) + df3 = pd.DataFrame(None) + df4 = pd.DataFrame(None) + df5 = pd.DataFrame(None) + df_list = [df2, df1, df3, df4, df5] + df_list = [x for x in df_list if x.shape[1] > 0] + if len(df_list) > 1: + df = df_list[0].join(df_list[1:]).replace(np.nan, '').reset_index() + elif len(df_list) == 0: + df = df_list[0].replace(np.nan, '').reset_index() + else: + df = pd.DataFrame(None) + return df + + normal_block = gr.Row(visible=not base_wanted, equal_height=False, elem_id="col_container") + with normal_block: + side_bar = gr.Column(elem_id="sidebar", scale=1, min_width=100, visible=kwargs['visible_side_bar']) + with side_bar: + with gr.Accordion("Chats", open=False, visible=True): + radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False, + visible=True, interactive=True, + type='value') + upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload + with gr.Accordion("Upload", open=False, visible=upload_visible): + with gr.Column(): + with gr.Row(equal_height=False): + fileup_output = gr.File(show_label=False, + file_types=['.' + x for x in file_types], + # file_types=['*', '*.*'], # for iPhone etc. needs to be unconstrained else doesn't work with extension-based restrictions + file_count="multiple", + scale=1, + min_width=0, + elem_id="warning", elem_classes="feedback", + ) + fileup_output_text = gr.Textbox(visible=False) + max_quality = gr.Checkbox(label="Maximum Ingest Quality", value=kwargs['max_quality'], + visible=not is_public) + url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload + url_label = 'URL/ArXiv' if have_arxiv else 'URL' + url_text = gr.Textbox(label=url_label, + # placeholder="Enter Submits", + max_lines=1, + interactive=True) + text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload + user_text_text = gr.Textbox(label='Paste Text', + # placeholder="Enter Submits", + interactive=True, + visible=text_visible) + github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP + database_visible = kwargs['langchain_mode'] != 'Disabled' + with gr.Accordion("Resources", open=False, visible=database_visible): + langchain_choices0 = get_langchain_choices(selection_docs_state0) + langchain_mode = gr.Radio( + langchain_choices0, + value=kwargs['langchain_mode'], + label="Collections", + show_label=True, + visible=kwargs['langchain_mode'] != 'Disabled', + min_width=100) + add_chat_history_to_context = gr.Checkbox(label="Chat History", + value=kwargs['add_chat_history_to_context']) + add_search_to_context = gr.Checkbox(label="Web Search", + value=kwargs['add_search_to_context'], + visible=os.environ.get('SERPAPI_API_KEY') is not None \ + and have_serpapi) + document_subset = gr.Radio([x.name for x in DocumentSubset], + label="Subset", + value=DocumentSubset.Relevant.name, + interactive=True, + ) + allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions] + langchain_action = gr.Radio( + allowed_actions, + value=allowed_actions[0] if len(allowed_actions) > 0 else None, + label="Action", + visible=True) + allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents] + if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.JSON.value in allowed_agents: + allowed_agents.remove(LangChainAgent.JSON.value) + if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.PYTHON.value in allowed_agents: + allowed_agents.remove(LangChainAgent.PYTHON.value) + if LangChainAgent.PANDAS.value in allowed_agents: + allowed_agents.remove(LangChainAgent.PANDAS.value) + langchain_agents = gr.Dropdown( + allowed_agents, + value=None, + label="Agents", + multiselect=True, + interactive=True, + visible=True, + elem_id="langchain_agents", + filterable=False) + visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs['large_file_count_mode'] + row_doc_track = gr.Row(visible=visible_doc_track) + with row_doc_track: + if kwargs['langchain_mode'] in langchain_modes_non_db: + doc_counts_str = "Pure LLM Mode" + else: + doc_counts_str = "Name: %s\nDocs: Unset\nChunks: Unset" % kwargs['langchain_mode'] + text_doc_count = gr.Textbox(lines=3, label="Doc Counts", value=doc_counts_str, + visible=visible_doc_track) + text_file_last = gr.Textbox(lines=1, label="Newest Doc", value=None, visible=visible_doc_track) + text_viewable_doc_count = gr.Textbox(lines=2, label=None, visible=False) + col_tabs = gr.Column(elem_id="col-tabs", scale=10) + with col_tabs, gr.Tabs(): + if kwargs['chat_tables']: + chat_tab = gr.Row(visible=True) + else: + chat_tab = gr.TabItem("Chat") \ + if kwargs['visible_chat_tab'] else gr.Row(visible=False) + with chat_tab: + if kwargs['langchain_mode'] == 'Disabled': + text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True, + visible=not kwargs['chat']) + else: + # text looks a bit worse, but HTML links work + text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat']) + with gr.Row(): + # NOCHAT + instruction_nochat = gr.Textbox( + lines=kwargs['input_lines'], + label=instruction_label_nochat, + placeholder=kwargs['placeholder_instruction'], + visible=not kwargs['chat'], + ) + iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction", + placeholder=kwargs['placeholder_input'], + value=kwargs['iinput'], + visible=not kwargs['chat']) + submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat']) + flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat']) + score_text_nochat = gr.Textbox("Response Score: NA", show_label=False, + visible=not kwargs['chat']) + submit_nochat_api = gr.Button("Submit nochat API", visible=False) + submit_nochat_api_plain = gr.Button("Submit nochat API Plain", visible=False) + inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False) + text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False, + show_copy_button=True) + + # CHAT + col_chat = gr.Column(visible=kwargs['chat']) + with col_chat: + with gr.Row(): + with gr.Column(scale=50): + with gr.Row(elem_id="prompt-form-row"): + label_instruction = 'Ask anything' + instruction = gr.Textbox( + lines=kwargs['input_lines'], + label=label_instruction, + placeholder=instruction_label, + info=None, + elem_id='prompt-form', + container=True, + ) + attach_button = gr.UploadButton( + elem_id="attach-button", + value="", + label="Upload File(s)", + size="sm", + min_width=24, + file_types=['.' + x for x in file_types], + file_count="multiple") + + submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons']) + with submit_buttons: + mw1 = 50 + mw2 = 50 + with gr.Column(min_width=mw1): + submit = gr.Button(value='Submit', variant='primary', size='sm', + min_width=mw1) + stop_btn = gr.Button(value="Stop", variant='secondary', size='sm', + min_width=mw1) + save_chat_btn = gr.Button("Save", size='sm', min_width=mw1) + with gr.Column(min_width=mw2): + retry_btn = gr.Button("Redo", size='sm', min_width=mw2) + undo = gr.Button("Undo", size='sm', min_width=mw2) + clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2) + + visible_model_choice = bool(kwargs['model_lock']) and \ + len(model_states) > 1 and \ + kwargs['visible_visible_models'] + with gr.Row(visible=visible_model_choice): + visible_models = gr.Dropdown(kwargs['all_models'], + label="Visible Models", + value=visible_models_state0, + interactive=True, + multiselect=True, + visible=visible_model_choice, + elem_id="visible-models", + filterable=False, + ) + + text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2, + **kwargs) + + with gr.Row(): + with gr.Column(visible=kwargs['score_model']): + score_text = gr.Textbox(res_value, + show_label=False, + visible=True) + score_text2 = gr.Textbox("Response Score2: NA", show_label=False, + visible=False and not kwargs['model_lock']) + + doc_selection_tab = gr.TabItem("Document Selection") \ + if kwargs['visible_doc_selection_tab'] else gr.Row(visible=False) + with doc_selection_tab: + if kwargs['langchain_mode'] in langchain_modes_non_db: + dlabel1 = 'Choose Resources->Collections and Pick Collection' + active_collection = gr.Markdown(value="#### Not Chatting with Any Collection\n%s" % dlabel1) + else: + dlabel1 = 'Select Subset of Document(s) for Chat with Collection: %s' % kwargs['langchain_mode'] + active_collection = gr.Markdown( + value="#### Chatting with Collection: %s" % kwargs['langchain_mode']) + document_choice = gr.Dropdown(docs_state0, + label=dlabel1, + value=[DocumentChoice.ALL.value], + interactive=True, + multiselect=True, + visible=kwargs['langchain_mode'] != 'Disabled', + ) + sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list + with gr.Row(): + with gr.Column(scale=1): + get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm', + visible=sources_visible and kwargs['large_file_count_mode']) + # handle API get sources + get_sources_api_btn = gr.Button(visible=False) + get_sources_api_text = gr.Textbox(visible=False) + + get_document_api_btn = gr.Button(visible=False) + get_document_api_text = gr.Textbox(visible=False) + + show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm', + visible=sources_visible and kwargs['large_file_count_mode']) + delete_sources_btn = gr.Button(value="Delete Selected Sources from DB", scale=0, size='sm', + visible=sources_visible) + refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0, + size='sm', + visible=sources_visible and allow_upload_to_user_data) + with gr.Column(scale=4): + pass + with gr.Row(): + with gr.Column(scale=1): + visible_add_remove_collection = (allow_upload_to_user_data or + allow_upload_to_my_data) and \ + kwargs['langchain_mode'] != 'Disabled' + add_placeholder = "e.g. UserData2, shared, user_path2" \ + if not is_public else "e.g. MyData2, personal (optional)" + remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2" + new_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection, + label='Add Collection', + placeholder=add_placeholder, + interactive=True) + remove_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection, + label='Remove Collection from UI', + placeholder=remove_placeholder, + interactive=True) + purge_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection, + label='Purge Collection (UI, DB, & source files)', + placeholder=remove_placeholder, + interactive=True) + sync_sources_btn = gr.Button( + value="Synchronize DB and UI [only required if did not login and have shared docs]", + scale=0, size='sm', + visible=sources_visible and allow_upload_to_user_data and not kwargs[ + 'large_file_count_mode']) + load_langchain = gr.Button( + value="Load Collections State [only required if logged in another user ", scale=0, + size='sm', + visible=False and allow_upload_to_user_data and + kwargs['langchain_mode'] != 'Disabled') + with gr.Column(scale=5): + if kwargs['langchain_mode'] != 'Disabled' and visible_add_remove_collection: + df0 = get_df_langchain_mode_paths(selection_docs_state0, None, dbs1=dbs) + else: + df0 = pd.DataFrame(None) + langchain_mode_path_text = gr.Dataframe(value=df0, + visible=visible_add_remove_collection, + label='LangChain Mode-Path', + show_label=False, + interactive=False) + + sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list, + equal_height=False) + with sources_row: + with gr.Column(scale=1): + file_source = gr.File(interactive=False, + label="Download File w/Sources") + with gr.Column(scale=2): + sources_text = gr.HTML(label='Sources Added', interactive=False) + + doc_exception_text = gr.Textbox(value="", label='Document Exceptions', + interactive=False, + visible=kwargs['langchain_mode'] != 'Disabled') + file_types_str = ' '.join(file_types) + ' URL ArXiv TEXT' + gr.Textbox(value=file_types_str, label='Document Types Supported', + lines=2, + interactive=False, + visible=kwargs['langchain_mode'] != 'Disabled') + + doc_view_tab = gr.TabItem("Document Viewer") \ + if kwargs['visible_doc_view_tab'] else gr.Row(visible=False) + with doc_view_tab: + with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled'): + with gr.Column(scale=2): + get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, + size='sm', + visible=sources_visible and kwargs[ + 'large_file_count_mode']) + view_document_choice = gr.Dropdown(viewable_docs_state0, + label="Select Single Document to View", + value=None, + interactive=True, + multiselect=False, + visible=True, + ) + info_view_raw = "Raw text shown if render of original doc fails" + if is_public: + info_view_raw += " (Up to %s chunks in public portal)" % kwargs['max_raw_chunks'] + view_raw_text_checkbox = gr.Checkbox(label="View Database Text", value=False, + info=info_view_raw, + visible=kwargs['db_type'] in ['chroma', 'chroma_old']) + with gr.Column(scale=4): + pass + doc_view = gr.HTML(visible=False) + doc_view2 = gr.Dataframe(visible=False) + doc_view3 = gr.JSON(visible=False) + doc_view4 = gr.Markdown(visible=False) + doc_view5 = gr.HTML(visible=False) + + chat_tab = gr.TabItem("Chat History") \ + if kwargs['visible_chat_history_tab'] else gr.Row(visible=False) + with chat_tab: + with gr.Row(): + with gr.Column(scale=1): + remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm') + flag_btn = gr.Button("Flag Current Chat", size='sm') + export_chats_btn = gr.Button(value="Export Chats to Download", size='sm') + with gr.Column(scale=4): + pass + with gr.Row(): + chats_file = gr.File(interactive=False, label="Download Exported Chats") + chatsup_output = gr.File(label="Upload Chat File(s)", + file_types=['.json'], + file_count='multiple', + elem_id="warning", elem_classes="feedback") + with gr.Row(): + if 'mbart-' in kwargs['model_lower']: + src_lang = gr.Dropdown(list(languages_covered().keys()), + value=kwargs['src_lang'], + label="Input Language") + tgt_lang = gr.Dropdown(list(languages_covered().keys()), + value=kwargs['tgt_lang'], + label="Output Language") + + chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions', + interactive=False) + expert_tab = gr.TabItem("Expert") \ + if kwargs['visible_expert_tab'] else gr.Row(visible=False) + with expert_tab: + with gr.Row(): + with gr.Column(): + prompt_type = gr.Dropdown(prompt_types_strings, + value=kwargs['prompt_type'], label="Prompt Type", + visible=not kwargs['model_lock'], + interactive=not is_public, + ) + prompt_type2 = gr.Dropdown(prompt_types_strings, + value=kwargs['prompt_type'], label="Prompt Type Model 2", + visible=False and not kwargs['model_lock'], + interactive=not is_public) + system_prompt = gr.Textbox(label="System Prompt", + info="If 'auto', then uses model's system prompt," + " else use this message." + " If empty, no system message is used", + value=kwargs['system_prompt']) + context = gr.Textbox(lines=2, label="System Pre-Context", + info="Directly pre-appended without prompt processing (before Pre-Conversation)", + value=kwargs['context']) + chat_conversation = gr.Textbox(lines=2, label="Pre-Conversation", + info="Pre-append conversation for instruct/chat models as List of tuple of (human, bot)", + value=kwargs['chat_conversation']) + text_context_list = gr.Textbox(lines=2, label="Text Doc Q/A", + info="List of strings, for document Q/A, for bypassing database (i.e. also works in LLM Mode)", + value=kwargs['chat_conversation'], + visible=not is_public, # primarily meant for API + ) + iinput = gr.Textbox(lines=2, label="Input for Instruct prompt types", + info="If given for document query, added after query", + value=kwargs['iinput'], + placeholder=kwargs['placeholder_input'], + interactive=not is_public) + with gr.Column(): + pre_prompt_query = gr.Textbox(label="Query Pre-Prompt", + info="Added before documents", + value=kwargs['pre_prompt_query'] or '') + prompt_query = gr.Textbox(label="Query Prompt", + info="Added after documents", + value=kwargs['prompt_query'] or '') + pre_prompt_summary = gr.Textbox(label="Summary Pre-Prompt", + info="Added before documents", + value=kwargs['pre_prompt_summary'] or '') + prompt_summary = gr.Textbox(label="Summary Prompt", + info="Added after documents (if query given, 'Focusing on {query}, ' is pre-appended)", + value=kwargs['prompt_summary'] or '') + with gr.Row(visible=not is_public): + image_loaders = gr.CheckboxGroup(image_loaders_options, + label="Force Image Reader", + value=image_loaders_options0) + pdf_loaders = gr.CheckboxGroup(pdf_loaders_options, + label="Force PDF Reader", + value=pdf_loaders_options0) + url_loaders = gr.CheckboxGroup(url_loaders_options, + label="Force URL Reader", value=url_loaders_options0) + jq_schema = gr.Textbox(label="JSON jq_schema", value=jq_schema0) + + min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) + top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1, + value=kwargs['top_k_docs'], + label=label_top_k_docs, + # info="For LangChain", + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public) + chunk_size = gr.Number(value=kwargs['chunk_size'], + label="Chunk size for document chunking", + info="For LangChain (ignored if chunk=False)", + minimum=128, + maximum=2048, + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public, + precision=0) + docs_ordering_type = gr.Radio( + docs_ordering_types, + value=kwargs['docs_ordering_type'], + label="Document Sorting in LLM Context", + visible=True) + chunk = gr.components.Checkbox(value=kwargs['chunk'], + label="Whether to chunk documents", + info="For LangChain", + visible=kwargs['langchain_mode'] != 'Disabled', + interactive=not is_public) + embed = gr.components.Checkbox(value=True, + label="Whether to embed text", + info="For LangChain", + visible=False) + with gr.Row(): + stream_output = gr.components.Checkbox(label="Stream output", + value=kwargs['stream_output']) + do_sample = gr.Checkbox(label="Sample", + info="Enable sampler (required for use of temperature, top_p, top_k)", + value=kwargs['do_sample']) + max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1, + value=min(kwargs['max_max_time'], + kwargs['max_time']), label="Max. time", + info="Max. time to search optimal output.") + temperature = gr.Slider(minimum=0.01, maximum=2, + value=kwargs['temperature'], + label="Temperature", + info="Lower is deterministic, higher more creative") + top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3, + value=kwargs['top_p'], label="Top p", + info="Cumulative probability of tokens to sample from") + top_k = gr.Slider( + minimum=1, maximum=100, step=1, + value=kwargs['top_k'], label="Top k", + info='Num. tokens to sample from' + ) + # FIXME: https://github.com/h2oai/h2ogpt/issues/106 + if os.getenv('TESTINGFAIL'): + max_beams = 8 if not (memory_restriction_level or is_public) else 1 + else: + max_beams = 1 + num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1, + value=min(max_beams, kwargs['num_beams']), label="Beams", + info="Number of searches for optimal overall probability. " + "Uses more GPU memory/compute", + interactive=False, visible=max_beams > 1) + max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs) + max_new_tokens = gr.Slider( + minimum=1, maximum=max_max_new_tokens, step=1, + value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length", + ) + min_new_tokens = gr.Slider( + minimum=0, maximum=max_max_new_tokens, step=1, + value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length", + ) + max_new_tokens2 = gr.Slider( + minimum=1, maximum=max_max_new_tokens, step=1, + value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2", + visible=False and not kwargs['model_lock'], + ) + min_new_tokens2 = gr.Slider( + minimum=0, maximum=max_max_new_tokens, step=1, + value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2", + visible=False and not kwargs['model_lock'], + ) + min_max_new_tokens = gr.Slider( + minimum=1, maximum=max_max_new_tokens, step=1, + value=min(max_max_new_tokens, kwargs['min_max_new_tokens']), label="Min. of Max output length", + ) + early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search", + value=kwargs['early_stopping'], visible=max_beams > 1) + repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0, + value=kwargs['repetition_penalty'], + label="Repetition Penalty") + num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1, + value=kwargs['num_return_sequences'], + label="Number Returns", info="Must be <= num_beams", + interactive=not is_public, visible=max_beams > 1) + chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'], + visible=False, # no longer support nochat in UI + interactive=not is_public, + ) + with gr.Row(): + count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", + visible=not is_public and not kwargs['model_lock'], + interactive=not is_public, size='sm') + chat_token_count = gr.Textbox(label="Chat Token Count Result", value=None, + visible=not is_public and not kwargs['model_lock'], + interactive=False) + + models_tab = gr.TabItem("Models") \ + if kwargs['visible_models_tab'] and not bool(kwargs['model_lock']) else gr.Row(visible=False) + with models_tab: + load_msg = "Download/Load Model" if not is_public \ + else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO" + if kwargs['base_model'] not in ['', None, no_model_str]: + load_msg += ' [WARNING: Avoid --base_model on CLI for memory efficient Load-Unload]' + load_msg2 = load_msg + "(Model 2)" + variant_load_msg = 'primary' if not is_public else 'secondary' + with gr.Row(): + n_gpus_list = [str(x) for x in list(range(-1, n_gpus))] + with gr.Column(): + with gr.Row(): + with gr.Column(scale=20, visible=not kwargs['model_lock']): + load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0, + size='sm', interactive=not is_public) + model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Base Model", + value=kwargs['base_model']) + lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", + value=kwargs['lora_weights'], visible=kwargs['show_lora']) + server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server", + value=kwargs['inference_server'], visible=not is_public) + max_seq_len = gr.Number(value=kwargs['max_seq_len'] or 2048, + minimum=128, + maximum=2 ** 18, + info="If standard LLaMa-2, choose up to 4096", + label="max_seq_len") + rope_scaling = gr.Textbox(value=str(kwargs['rope_scaling'] or {}), + label="rope_scaling") + row_llama = gr.Row(visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama') + with row_llama: + model_path_llama = gr.Textbox(value=kwargs['llamacpp_dict']['model_path_llama'], + lines=4, + label="Choose LLaMa.cpp Model Path/URL (for Base Model: llama)", + visible=kwargs['show_llama']) + n_gpu_layers = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'], + minimum=0, maximum=100, + label="LLaMa.cpp Num. GPU Layers Offloaded", + visible=kwargs['show_llama']) + n_batch = gr.Number(value=kwargs['llamacpp_dict']['n_batch'], + minimum=0, maximum=2048, + label="LLaMa.cpp Batch Size", + visible=kwargs['show_llama']) + n_gqa = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'], + minimum=0, maximum=32, + label="LLaMa.cpp Num. Group Query Attention (8 for 70B LLaMa2)", + visible=kwargs['show_llama']) + llamacpp_dict_more = gr.Textbox(value="{}", + lines=4, + label="Dict for other LLaMa.cpp/GPT4All options", + visible=kwargs['show_llama']) + row_gpt4all = gr.Row( + visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj', + 'gpt4all_llama']) + with row_gpt4all: + model_name_gptj = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'], + label="Choose GPT4All GPTJ Model Path/URL (for Base Model: gptj)", + visible=kwargs['show_gpt4all']) + model_name_gpt4all_llama = gr.Textbox( + value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'], + label="Choose GPT4All LLaMa Model Path/URL (for Base Model: gpt4all_llama)", + visible=kwargs['show_gpt4all']) + with gr.Column(scale=1, visible=not kwargs['model_lock']): + model_load8bit_checkbox = gr.components.Checkbox( + label="Load 8-bit [requires support]", + value=kwargs['load_8bit'], interactive=not is_public) + model_load4bit_checkbox = gr.components.Checkbox( + label="Load 4-bit [requires support]", + value=kwargs['load_4bit'], interactive=not is_public) + model_low_bit_mode = gr.Slider(value=kwargs['low_bit_mode'], + minimum=0, maximum=4, step=1, + label="low_bit_mode") + model_load_gptq = gr.Textbox(label="gptq", value=kwargs['load_gptq'], + interactive=not is_public) + model_load_exllama_checkbox = gr.components.Checkbox( + label="Load load_exllama [requires support]", + value=kwargs['load_exllama'], interactive=not is_public) + model_safetensors_checkbox = gr.components.Checkbox( + label="Safetensors [requires support]", + value=kwargs['use_safetensors'], interactive=not is_public) + model_revision = gr.Textbox(label="revision", value=kwargs['revision'], + interactive=not is_public) + model_use_gpu_id_checkbox = gr.components.Checkbox( + label="Choose Devices [If not Checked, use all GPUs]", + value=kwargs['use_gpu_id'], interactive=not is_public, + visible=n_gpus != 0) + model_gpu = gr.Dropdown(n_gpus_list, + label="GPU ID [-1 = all GPUs, if Choose is enabled]", + value=kwargs['gpu_id'], interactive=not is_public, + visible=n_gpus != 0) + model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'], + interactive=False) + lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], + visible=kwargs['show_lora'], interactive=False) + server_used = gr.Textbox(label="Current Server", + value=kwargs['inference_server'], + visible=bool(kwargs['inference_server']) and not is_public, + interactive=False) + prompt_dict = gr.Textbox(label="Prompt (or Custom)", + value=pprint.pformat(kwargs['prompt_dict'], indent=4), + interactive=not is_public, lines=4) + col_model2 = gr.Column(visible=False) + with col_model2: + with gr.Row(): + with gr.Column(scale=20, visible=not kwargs['model_lock']): + load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0, + size='sm', interactive=not is_public) + model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2", + value=no_model_str) + lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2", + value=no_lora_str, + visible=kwargs['show_lora']) + server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2", + value=no_server_str, + visible=not is_public) + max_seq_len2 = gr.Number(value=kwargs['max_seq_len'] or 2048, + minimum=128, + maximum=2 ** 18, + info="If standard LLaMa-2, choose up to 4096", + label="max_seq_len Model 2") + rope_scaling2 = gr.Textbox(value=str(kwargs['rope_scaling'] or {}), + label="rope_scaling Model 2") + + row_llama2 = gr.Row( + visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama') + with row_llama2: + model_path_llama2 = gr.Textbox( + value=kwargs['llamacpp_dict']['model_path_llama'], + label="Choose LLaMa.cpp Model 2 Path/URL (for Base Model: llama)", + lines=4, + visible=kwargs['show_llama']) + n_gpu_layers2 = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'], + minimum=0, maximum=100, + label="LLaMa.cpp Num. GPU 2 Layers Offloaded", + visible=kwargs['show_llama']) + n_batch2 = gr.Number(value=kwargs['llamacpp_dict']['n_batch'], + minimum=0, maximum=2048, + label="LLaMa.cpp Model 2 Batch Size", + visible=kwargs['show_llama']) + n_gqa2 = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'], + minimum=0, maximum=32, + label="LLaMa.cpp Model 2 Num. Group Query Attention (8 for 70B LLaMa2)", + visible=kwargs['show_llama']) + llamacpp_dict_more2 = gr.Textbox(value="{}", + lines=4, + label="Model 2 Dict for other LLaMa.cpp/GPT4All options", + visible=kwargs['show_llama']) + row_gpt4all2 = gr.Row( + visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj', + 'gpt4all_llama']) + with row_gpt4all2: + model_name_gptj2 = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'], + label="Choose GPT4All GPTJ Model 2 Path/URL (for Base Model: gptj)", + visible=kwargs['show_gpt4all']) + model_name_gpt4all_llama2 = gr.Textbox( + value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'], + label="Choose GPT4All LLaMa Model 2 Path/URL (for Base Model: gpt4all_llama)", + visible=kwargs['show_gpt4all']) + + with gr.Column(scale=1, visible=not kwargs['model_lock']): + model_load8bit_checkbox2 = gr.components.Checkbox( + label="Load 8-bit (Model 2) [requires support]", + value=kwargs['load_8bit'], interactive=not is_public) + model_load4bit_checkbox2 = gr.components.Checkbox( + label="Load 4-bit (Model 2) [requires support]", + value=kwargs['load_4bit'], interactive=not is_public) + model_low_bit_mode2 = gr.Slider(value=kwargs['low_bit_mode'], + # ok that same as Model 1 + minimum=0, maximum=4, step=1, + label="low_bit_mode (Model 2)") + model_load_gptq2 = gr.Textbox(label="gptq (Model 2)", value='', + interactive=not is_public) + model_load_exllama_checkbox2 = gr.components.Checkbox( + label="Load load_exllama (Model 2) [requires support]", + value=False, interactive=not is_public) + model_safetensors_checkbox2 = gr.components.Checkbox( + label="Safetensors (Model 2) [requires support]", + value=False, interactive=not is_public) + model_revision2 = gr.Textbox(label="revision (Model 2)", value='', + interactive=not is_public) + model_use_gpu_id_checkbox2 = gr.components.Checkbox( + label="Choose Devices (Model 2) [If not Checked, use all GPUs]", + value=kwargs[ + 'use_gpu_id'], interactive=not is_public) + model_gpu2 = gr.Dropdown(n_gpus_list, + label="GPU ID (Model 2) [-1 = all GPUs, if choose is enabled]", + value=kwargs['gpu_id'], interactive=not is_public) + # no model/lora loaded ever in model2 by default + model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str, + interactive=False) + lora_used2 = gr.Textbox(label="Current LORA (Model 2)", value=no_lora_str, + visible=kwargs['show_lora'], interactive=False) + server_used2 = gr.Textbox(label="Current Server (Model 2)", value=no_server_str, + interactive=False, + visible=not is_public) + prompt_dict2 = gr.Textbox(label="Prompt (or Custom) (Model 2)", + value=pprint.pformat(kwargs['prompt_dict'], indent=4), + interactive=not is_public, lines=4) + compare_checkbox = gr.components.Checkbox(label="Compare Two Models", + value=kwargs['model_lock'], + visible=not is_public and not kwargs['model_lock']) + with gr.Row(visible=not kwargs['model_lock']): + with gr.Column(scale=50): + new_model = gr.Textbox(label="New Model name/path/URL", interactive=not is_public) + with gr.Column(scale=50): + new_lora = gr.Textbox(label="New LORA name/path/URL", visible=kwargs['show_lora'], + interactive=not is_public) + with gr.Column(scale=50): + new_server = gr.Textbox(label="New Server url:port", interactive=not is_public) + with gr.Row(): + add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0, + variant=variant_load_msg, + size='sm', interactive=not is_public) + system_tab = gr.TabItem("System") \ + if kwargs['visible_system_tab'] else gr.Row(visible=False) + with system_tab: + with gr.Row(): + with gr.Column(scale=1): + side_bar_text = gr.Textbox('on' if kwargs['visible_side_bar'] else 'off', + visible=False, interactive=False) + doc_count_text = gr.Textbox('on' if kwargs['visible_doc_track'] else 'off', + visible=False, interactive=False) + submit_buttons_text = gr.Textbox('on' if kwargs['visible_submit_buttons'] else 'off', + visible=False, interactive=False) + visible_models_text = gr.Textbox('on' if kwargs['visible_visible_models'] else 'off', + visible=False, interactive=False) + + side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm") + doc_count_btn = gr.Button("Toggle SideBar Document Count/Show Newest", variant="secondary", + size="sm") + submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm") + visible_model_btn = gr.Button("Toggle Visible Models", variant="secondary", size="sm") + col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size') + text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400, + step=50, label='Chat Height') + dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm") + with gr.Column(scale=4): + pass + system_visible0 = not is_public and not admin_pass + admin_row = gr.Row() + with admin_row: + with gr.Column(scale=1): + admin_pass_textbox = gr.Textbox(label="Admin Password", + type='password', + visible=not system_visible0) + with gr.Column(scale=4): + pass + system_row = gr.Row(visible=system_visible0) + with system_row: + with gr.Column(): + with gr.Row(): + system_btn = gr.Button(value='Get System Info', size='sm') + system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True) + with gr.Row(): + system_input = gr.Textbox(label='System Info Dict Password', interactive=True, + visible=not is_public) + system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm') + system_text2 = gr.Textbox(label='System Info Dict', interactive=False, + visible=not is_public, show_copy_button=True) + with gr.Row(): + system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm') + system_text3 = gr.Textbox(label='Hash', interactive=False, + visible=not is_public, show_copy_button=True) + system_btn4 = gr.Button(value='Get Model Names', visible=not is_public, size='sm') + system_text4 = gr.Textbox(label='Model Names', interactive=False, + visible=not is_public, show_copy_button=True) + + with gr.Row(): + zip_btn = gr.Button("Zip", size='sm') + zip_text = gr.Textbox(label="Zip file name", interactive=False) + file_output = gr.File(interactive=False, label="Zip file to Download") + with gr.Row(): + s3up_btn = gr.Button("S3UP", size='sm') + s3up_text = gr.Textbox(label='S3UP result', interactive=False) + + tos_tab = gr.TabItem("Terms of Service") \ + if kwargs['visible_tos_tab'] else gr.Row(visible=False) + with tos_tab: + description = "" + description += """

DISCLAIMERS:

""" + gr.Markdown(value=description, show_label=False, interactive=False) + + login_tab = gr.TabItem("Login") \ + if kwargs['visible_login_tab'] else gr.Row(visible=False) + with login_tab: + gr.Markdown( + value="#### Login page to persist your state (database, documents, chat, chat history)\nDaily maintenance at midnight PST will not allow reconnection to state otherwise.") + username_text = gr.Textbox(label="Username") + password_text = gr.Textbox(label="Password", type='password', visible=True) + login_msg = "Login (pick unique user/pass to persist your state)" if kwargs[ + 'auth_access'] == 'open' else "Login (closed access)" + login_btn = gr.Button(value=login_msg) + login_result_text = gr.Text(label="Login Result", interactive=False) + h2ogpt_key = gr.Text(value=kwargs['h2ogpt_key'], label="h2oGPT Token for API access", + type='password', visible=False) + + hosts_tab = gr.TabItem("Hosts") \ + if kwargs['visible_hosts_tab'] else gr.Row(visible=False) + with hosts_tab: + gr.Markdown(f""" + {description_bottom} + {task_info_md} + """) + + # Get flagged data + zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']]) + zip_event = zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False, + api_name='zip_data' if allow_api else None) + s3up_event = s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False, + api_name='s3up_data' if allow_api else None) + + def clear_file_list(): + return None + + def set_loaders(max_quality1, + image_loaders_options1=None, + pdf_loaders_options1=None, + url_loaders_options1=None, + image_loaders_options01=None, + pdf_loaders_options01=None, + url_loaders_options01=None, + ): + if not max_quality1: + return image_loaders_options01, pdf_loaders_options01, url_loaders_options01 + else: + return image_loaders_options1, pdf_loaders_options1, url_loaders_options1 + + set_loaders_func = functools.partial(set_loaders, + image_loaders_options1=image_loaders_options, + pdf_loaders_options1=pdf_loaders_options, + url_loaders_options1=url_loaders_options, + image_loaders_options01=image_loaders_options0, + pdf_loaders_options01=pdf_loaders_options0, + url_loaders_options01=url_loaders_options0, + ) + + max_quality.change(fn=set_loaders_func, + inputs=max_quality, + outputs=[image_loaders, pdf_loaders, url_loaders]) + + def get_model_lock_visible_list(visible_models1, all_models): + visible_list = [] + for modeli, model in enumerate(all_models): + if visible_models1 is None or model in visible_models1 or modeli in visible_models1: + visible_list.append(True) + else: + visible_list.append(False) + return visible_list + + def set_visible_models(visible_models1, num_model_lock=0, all_models=None): + if num_model_lock == 0: + num_model_lock = 3 # 2 + 1 (which is dup of first) + ret_list = [gr.update(visible=True)] * num_model_lock + else: + assert isinstance(all_models, list) + assert num_model_lock == len(all_models) + visible_list = [False, False] + get_model_lock_visible_list(visible_models1, all_models) + ret_list = [gr.update(visible=x) for x in visible_list] + return tuple(ret_list) + + visible_models_func = functools.partial(set_visible_models, + num_model_lock=len(text_outputs), + all_models=kwargs['all_models']) + visible_models.change(fn=visible_models_func, + inputs=visible_models, + outputs=[text_output, text_output2] + text_outputs, + ) + + # Add to UserData or custom user db + update_db_func = functools.partial(update_user_db_gr, + dbs=dbs, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + verbose=kwargs['verbose'], + n_jobs=kwargs['n_jobs'], + get_userid_auth=get_userid_auth, + image_loaders_options0=image_loaders_options0, + pdf_loaders_options0=pdf_loaders_options0, + url_loaders_options0=url_loaders_options0, + jq_schema0=jq_schema0, + enforce_h2ogpt_api_key=kwargs['enforce_h2ogpt_api_key'], + h2ogpt_api_keys=kwargs['h2ogpt_api_keys'], + ) + add_file_outputs = [fileup_output, langchain_mode] + add_file_kwargs = dict(fn=update_db_func, + inputs=[fileup_output, my_db_state, selection_docs_state, requests_state, + langchain_mode, chunk, chunk_size, embed, + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + h2ogpt_key, + ], + outputs=add_file_outputs + [sources_text, doc_exception_text, text_file_last], + queue=queue, + api_name='add_file' if allow_upload_api else None) + + # then no need for add buttons, only single changeable db + user_state_kwargs = dict(fn=user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + eventdb1a = fileup_output.upload(**user_state_kwargs) + eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='full') + + event_attach1 = attach_button.upload(**user_state_kwargs) + attach_file_kwargs = add_file_kwargs.copy() + attach_file_kwargs['inputs'][0] = attach_button + attach_file_kwargs['outputs'][0] = attach_button + attach_file_kwargs['api_name'] = 'attach_file' + event_attach2 = event_attach1.then(**attach_file_kwargs, show_progress='full') + + sync1 = sync_sources_btn.click(**user_state_kwargs) + + # deal with challenge to have fileup_output itself as input + add_file_kwargs2 = dict(fn=update_db_func, + inputs=[fileup_output_text, my_db_state, selection_docs_state, requests_state, + langchain_mode, chunk, chunk_size, embed, + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + h2ogpt_key, + ], + outputs=add_file_outputs + [sources_text, doc_exception_text, text_file_last], + queue=queue, + api_name='add_file_api' if allow_upload_api else None) + eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full') + + # note for update_user_db_func output is ignored for db + + def clear_textbox(): + return gr.Textbox.update(value='') + + update_user_db_url_func = functools.partial(update_db_func, is_url=True) + + add_url_outputs = [url_text, langchain_mode] + add_url_kwargs = dict(fn=update_user_db_url_func, + inputs=[url_text, my_db_state, selection_docs_state, requests_state, + langchain_mode, chunk, chunk_size, embed, + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + h2ogpt_key, + ], + outputs=add_url_outputs + [sources_text, doc_exception_text, text_file_last], + queue=queue, + api_name='add_url' if allow_upload_api else None) + + eventdb2a = url_text.submit(fn=user_state_setup, + inputs=[my_db_state, requests_state, url_text, url_text], + outputs=[my_db_state, requests_state, url_text], + queue=queue, + show_progress='minimal') + # work around https://github.com/gradio-app/gradio/issues/4733 + eventdb2 = eventdb2a.then(**add_url_kwargs, show_progress='full') + + update_user_db_txt_func = functools.partial(update_db_func, is_txt=True) + add_text_outputs = [user_text_text, langchain_mode] + add_text_kwargs = dict(fn=update_user_db_txt_func, + inputs=[user_text_text, my_db_state, selection_docs_state, requests_state, + langchain_mode, chunk, chunk_size, embed, + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + h2ogpt_key, + ], + outputs=add_text_outputs + [sources_text, doc_exception_text, text_file_last], + queue=queue, + api_name='add_text' if allow_upload_api else None + ) + eventdb3a = user_text_text.submit(fn=user_state_setup, + inputs=[my_db_state, requests_state, user_text_text, user_text_text], + outputs=[my_db_state, requests_state, user_text_text], + queue=queue, + show_progress='minimal') + eventdb3 = eventdb3a.then(**add_text_kwargs, show_progress='full') + + db_events = [eventdb1a, eventdb1, eventdb1_api, + eventdb2a, eventdb2, + eventdb3a, eventdb3] + db_events.extend([event_attach1, event_attach2]) + + get_sources1 = functools.partial(get_sources_gr, dbs=dbs, docs_state0=docs_state0, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + n_jobs=n_jobs, + ) + + # if change collection source, must clear doc selections from it to avoid inconsistency + def clear_doc_choice(langchain_mode1): + if langchain_mode1 in langchain_modes_non_db: + label1 = 'Choose Resources->Collections and Pick Collection' + active_collection1 = "#### Not Chatting with Any Collection\n%s" % label1 + else: + label1 = 'Select Subset of Document(s) for Chat with Collection: %s' % langchain_mode1 + active_collection1 = "#### Chatting with Collection: %s" % langchain_mode1 + return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value, + label=label1), gr.Markdown.update(value=active_collection1) + + lg_change_event = langchain_mode.change(clear_doc_choice, inputs=langchain_mode, + outputs=[document_choice, active_collection], + queue=not kwargs['large_file_count_mode']) + + def change_visible_llama(x): + if x == 'llama': + return gr.update(visible=True), \ + gr.update(visible=True), \ + gr.update(visible=False), \ + gr.update(visible=False) + elif x in ['gptj', 'gpt4all_llama']: + return gr.update(visible=False), \ + gr.update(visible=False), \ + gr.update(visible=True), \ + gr.update(visible=True) + else: + return gr.update(visible=False), \ + gr.update(visible=False), \ + gr.update(visible=False), \ + gr.update(visible=False) + + model_choice.change(change_visible_llama, + inputs=model_choice, + outputs=[row_llama, row_llama2, row_gpt4all, row_gpt4all2]) + + def resize_col_tabs(x): + return gr.Dropdown.update(scale=x) + + col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False) + + def resize_chatbots(x, num_model_lock=0): + if num_model_lock == 0: + num_model_lock = 3 # 2 + 1 (which is dup of first) + else: + num_model_lock = 2 + num_model_lock + return tuple([gr.update(height=x)] * num_model_lock) + + resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs)) + text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height, + outputs=[text_output, text_output2] + text_outputs, queue=False) + + def update_dropdown(x): + if DocumentChoice.ALL.value in x: + x.remove(DocumentChoice.ALL.value) + source_list = [DocumentChoice.ALL.value] + x + return gr.Dropdown.update(choices=source_list, value=[DocumentChoice.ALL.value]) + + get_sources_kwargs = dict(fn=get_sources1, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode], + outputs=[file_source, docs_state, text_doc_count], + queue=queue) + + eventdb7a = get_sources_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, get_sources_btn, get_sources_btn], + outputs=[my_db_state, requests_state, get_sources_btn], + show_progress='minimal') + eventdb7 = eventdb7a.then(**get_sources_kwargs, + api_name='get_sources' if allow_api else None) \ + .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + + get_sources_api_args = dict(fn=functools.partial(get_sources1, api=True), + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode], + outputs=get_sources_api_text, + queue=queue) + get_sources_api_btn.click(**get_sources_api_args, + api_name='get_sources_api' if allow_api else None) + + # show button, else only show when add. + # Could add to above get_sources for download/dropdown, but bit much maybe + show_sources1 = functools.partial(get_source_files_given_langchain_mode_gr, + dbs=dbs, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + n_jobs=n_jobs) + eventdb8a = show_sources_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, show_sources_btn, show_sources_btn], + outputs=[my_db_state, requests_state, show_sources_btn], + show_progress='minimal') + show_sources_kwargs = dict(fn=show_sources1, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode], + outputs=sources_text) + eventdb8 = eventdb8a.then(**show_sources_kwargs, + api_name='show_sources' if allow_api else None) + + def update_viewable_dropdown(x): + return gr.Dropdown.update(choices=x, + value=viewable_docs_state0[0] if len(viewable_docs_state0) > 0 else None) + + get_viewable_sources1 = functools.partial(get_sources_gr, dbs=dbs, docs_state0=viewable_docs_state0, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=kwargs['verbose'], + get_userid_auth=get_userid_auth, + n_jobs=n_jobs) + get_viewable_sources_args = dict(fn=get_viewable_sources1, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode], + outputs=[file_source, viewable_docs_state, text_viewable_doc_count], + queue=queue) + eventdb12a = get_viewable_sources_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, + get_viewable_sources_btn, get_viewable_sources_btn], + outputs=[my_db_state, requests_state, get_viewable_sources_btn], + show_progress='minimal') + viewable_kwargs = dict(fn=update_viewable_dropdown, inputs=viewable_docs_state, outputs=view_document_choice) + eventdb12 = eventdb12a.then(**get_viewable_sources_args, + api_name='get_viewable_sources' if allow_api else None) \ + .then(**viewable_kwargs) + + eventdb_viewa = view_document_choice.select(user_state_setup, + inputs=[my_db_state, requests_state, + view_document_choice, view_document_choice], + outputs=[my_db_state, requests_state, view_document_choice], + show_progress='minimal') + show_doc_func = functools.partial(show_doc, + dbs1=dbs, + load_db_if_exists1=load_db_if_exists, + db_type1=db_type, + use_openai_embedding1=use_openai_embedding, + hf_embedding_model1=hf_embedding_model, + migrate_embedding_model_or_db1=migrate_embedding_model, + auto_migrate_db1=auto_migrate_db, + verbose1=verbose, + get_userid_auth1=get_userid_auth, + max_raw_chunks=kwargs['max_raw_chunks'], + api=False, + n_jobs=n_jobs, + ) + # Note: Not really useful for API, so no api_name + eventdb_viewa.then(fn=show_doc_func, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode, + view_document_choice, view_raw_text_checkbox, + text_context_list], + outputs=[doc_view, doc_view2, doc_view3, doc_view4, doc_view5]) + + show_doc_func_api = functools.partial(show_doc_func, api=True) + get_document_api_btn.click(fn=show_doc_func_api, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode, + view_document_choice, view_raw_text_checkbox, + text_context_list], + outputs=get_document_api_text, api_name='get_document_api') + + # Get inputs to evaluate() and make_db() + # don't deepcopy, can contain model itself + all_kwargs = kwargs.copy() + all_kwargs.update(locals()) + + refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode_gr, + captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + dbs=dbs, + first_para=kwargs['first_para'], + hf_embedding_model=hf_embedding_model, + use_openai_embedding=use_openai_embedding, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + text_limit=kwargs['text_limit'], + db_type=db_type, + load_db_if_exists=load_db_if_exists, + n_jobs=n_jobs, verbose=verbose, + get_userid_auth=get_userid_auth, + image_loaders_options0=image_loaders_options0, + pdf_loaders_options0=pdf_loaders_options0, + url_loaders_options0=url_loaders_options0, + jq_schema0=jq_schema0, + ) + eventdb9a = refresh_sources_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, + refresh_sources_btn, refresh_sources_btn], + outputs=[my_db_state, requests_state, refresh_sources_btn], + show_progress='minimal') + eventdb9 = eventdb9a.then(fn=refresh_sources1, + inputs=[my_db_state, selection_docs_state, requests_state, + langchain_mode, chunk, chunk_size, + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + ], + outputs=sources_text, + api_name='refresh_sources' if allow_api else None) + + delete_sources1 = functools.partial(del_source_files_given_langchain_mode_gr, + dbs=dbs, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + n_jobs=n_jobs) + eventdb90a = delete_sources_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, + delete_sources_btn, delete_sources_btn], + outputs=[my_db_state, requests_state, delete_sources_btn], + show_progress='minimal') + eventdb90 = eventdb90a.then(fn=delete_sources1, + inputs=[my_db_state, selection_docs_state, requests_state, document_choice, + langchain_mode], + outputs=sources_text, + api_name='delete_sources' if allow_api else None) + db_events.extend([eventdb90a, eventdb90]) + + def check_admin_pass(x): + return gr.update(visible=x == admin_pass) + + def close_admin(x): + return gr.update(visible=not (x == admin_pass)) + + eventdb_logina = login_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, login_btn, login_btn], + outputs=[my_db_state, requests_state, login_btn], + show_progress='minimal') + + def login(db1s, selection_docs_state1, requests_state1, chat_state1, langchain_mode1, + username1, password1, + text_output1, text_output21, *text_outputs1, + auth_filename=None, num_model_lock=0, pre_authorized=False): + # use full auth login to allow new users if open access etc. + if pre_authorized: + username1 = requests_state1['username'] + password1 = None + authorized1 = True + else: + authorized1 = authf(username1, password1, selection_docs_state1=selection_docs_state1) + if authorized1: + set_userid_gr(db1s, requests_state1, get_userid_auth) + username2 = get_username(requests_state1) + text_outputs1 = list(text_outputs1) + + success1, text_result, text_output1, text_output21, text_outputs1, langchain_mode1 = \ + load_auth(db1s, requests_state1, auth_filename, selection_docs_state1=selection_docs_state1, + chat_state1=chat_state1, langchain_mode1=langchain_mode1, + text_output1=text_output1, text_output21=text_output21, text_outputs1=text_outputs1, + username_override=username1, password_to_check=password1) + else: + success1 = False + text_result = "Wrong password for user %s" % username1 + df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=dbs) + if success1: + requests_state1['username'] = username1 + label_instruction1 = 'Ask anything, %s' % requests_state1['username'] + return db1s, selection_docs_state1, requests_state1, chat_state1, \ + text_result, \ + gr.update(label=label_instruction1), \ + df_langchain_mode_paths1, \ + gr.update(choices=list(chat_state1.keys()), value=None), \ + gr.update(choices=get_langchain_choices(selection_docs_state1), + value=langchain_mode1), \ + text_output1, text_output21, *tuple(text_outputs1) + + login_func = functools.partial(login, + auth_filename=kwargs['auth_filename'], + num_model_lock=len(text_outputs), + pre_authorized=False, + ) + load_login_func = functools.partial(login, + auth_filename=kwargs['auth_filename'], + num_model_lock=len(text_outputs), + pre_authorized=True, + ) + login_inputs = [my_db_state, selection_docs_state, requests_state, chat_state, + langchain_mode, + username_text, password_text, + text_output, text_output2] + text_outputs + login_outputs = [my_db_state, selection_docs_state, requests_state, chat_state, + login_result_text, + instruction, + langchain_mode_path_text, + radio_chats, + langchain_mode, + text_output, text_output2] + text_outputs + eventdb_logina.then(login_func, + inputs=login_inputs, + outputs=login_outputs, + queue=False) + + admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \ + .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False) + + def load_auth(db1s, requests_state1, auth_filename=None, selection_docs_state1=None, + chat_state1=None, langchain_mode1=None, + text_output1=None, text_output21=None, text_outputs1=None, + username_override=None, password_to_check=None): + # in-place assignment + if not auth_filename: + return False, "No auth file", text_output1, text_output21, text_outputs1 + # if first time here, need to set userID + set_userid_gr(db1s, requests_state1, get_userid_auth) + if username_override: + username1 = username_override + else: + username1 = get_username(requests_state1) + success1 = False + with filelock.FileLock(auth_filename + '.lock'): + if os.path.isfile(auth_filename): + with open(auth_filename, 'rt') as f: + auth_dict = json.load(f) + if username1 in auth_dict: + auth_user = auth_dict[username1] + if password_to_check: + if auth_user['password'] != password_to_check: + return False, [], [], [], "Invalid password for user %s" % username1 + if username_override: + # then use original user id + set_userid_direct_gr(db1s, auth_dict[username1]['userid'], username1) + if 'selection_docs_state' in auth_user: + update_auth_selection(auth_user, selection_docs_state1) + if 'chat_state' in auth_user: + chat_state1.update(auth_user['chat_state']) + if 'text_output' in auth_user: + text_output1 = auth_user['text_output'] + if 'text_output2' in auth_user: + text_output21 = auth_user['text_output2'] + if 'text_outputs' in auth_user: + text_outputs1 = auth_user['text_outputs'] + if 'langchain_mode' in auth_user: + langchain_mode1 = auth_user['langchain_mode'] + text_result = "Successful login for %s" % username1 + success1 = True + else: + text_result = "No user %s" % username1 + else: + text_result = "No auth file" + return success1, text_result, text_output1, text_output21, text_outputs1, langchain_mode1 + + def save_auth_dict(auth_dict, auth_filename): + backup_file = auth_filename + '.bak' + str(uuid.uuid4()) + if os.path.isfile(auth_filename): + shutil.copy(auth_filename, backup_file) + try: + with open(auth_filename, 'wt') as f: + f.write(json.dumps(auth_dict, indent=2)) + except BaseException as e: + print("Failure to save auth %s, restored backup: %s: %s" % (auth_filename, backup_file, str(e)), + flush=True) + shutil.copy(backup_file, auth_dict) + if os.getenv('HARD_ASSERTS'): + # unexpected in testing or normally + raise + + def save_auth(requests_state1, auth_filename, auth_freeze, + selection_docs_state1=None, chat_state1=None, langchain_mode1=None, + text_output1=None, text_output21=None, text_outputs1=None): + if auth_freeze: + return + if not auth_filename: + return + # save to auth file + username1 = get_username(requests_state1) + with filelock.FileLock(auth_filename + '.lock'): + if os.path.isfile(auth_filename): + with open(auth_filename, 'rt') as f: + auth_dict = json.load(f) + if username1 in auth_dict: + auth_user = auth_dict[username1] + if selection_docs_state1: + update_auth_selection(auth_user, selection_docs_state1, save=True) + if chat_state1: + # overwrite + auth_user['chat_state'] = chat_state1 + if text_output1: + auth_user['text_output'] = text_output1 + if text_output21: + auth_user['text_output2'] = text_output21 + if text_outputs1: + auth_user['text_outputs'] = text_outputs1 + if langchain_mode1: + auth_user['langchain_mode'] = langchain_mode1 + save_auth_dict(auth_dict, auth_filename) + + def add_langchain_mode(db1s, selection_docs_state1, requests_state1, langchain_mode1, y, + auth_filename=None, auth_freeze=None, guest_name=None): + assert auth_filename is not None + assert auth_freeze is not None + + set_userid_gr(db1s, requests_state1, get_userid_auth) + username1 = get_username(requests_state1) + for k in db1s: + set_dbid_gr(db1s[k]) + langchain_modes = selection_docs_state1['langchain_modes'] + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + + user_path = None + valid = True + y2 = y.strip().replace(' ', '').split(',') + if len(y2) >= 1: + langchain_mode2 = y2[0] + if len(langchain_mode2) >= 3 and langchain_mode2.isalnum(): + # real restriction is: + # ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me + # but just make simpler + # assume personal if don't have user_path + langchain_mode_type = y2[1] if len(y2) > 1 else LangChainTypes.PERSONAL.value + user_path = y2[2] if len(y2) > 2 else None # assume None if don't have user_path + if user_path in ['', "''"]: + # transcribe UI input + user_path = None + if langchain_mode_type not in [x.value for x in list(LangChainTypes)]: + textbox = "Invalid type %s" % langchain_mode_type + valid = False + langchain_mode2 = langchain_mode1 + elif langchain_mode_type == LangChainTypes.SHARED.value and username1 == guest_name: + textbox = "Guests cannot add shared collections" + valid = False + langchain_mode2 = langchain_mode1 + elif user_path is not None and langchain_mode_type == LangChainTypes.PERSONAL.value: + textbox = "Do not pass user_path for personal/scratch types" + valid = False + langchain_mode2 = langchain_mode1 + elif user_path is not None and username1 == guest_name: + textbox = "Guests cannot add collections with path" + valid = False + langchain_mode2 = langchain_mode1 + elif langchain_mode2 in langchain_modes_intrinsic: + user_path = None + textbox = "Invalid access to use internal name: %s" % langchain_mode2 + valid = False + langchain_mode2 = langchain_mode1 + elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data: + if user_path: + user_path = makedirs(user_path, exist_ok=True, use_base=True) + langchain_mode_paths.update({langchain_mode2: user_path}) + langchain_mode_types.update({langchain_mode2: langchain_mode_type}) + if langchain_mode2 not in langchain_modes: + langchain_modes.append(langchain_mode2) + textbox = '' + else: + valid = False + langchain_mode2 = langchain_mode1 + textbox = "Invalid access. user allowed: %s " \ + "personal/scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data) + else: + valid = False + langchain_mode2 = langchain_mode1 + textbox = "Invalid, collection must be >=3 characters and alphanumeric" + else: + valid = False + langchain_mode2 = langchain_mode1 + textbox = "Invalid, must be like UserData2, user_path2" + selection_docs_state1 = update_langchain_mode_paths(selection_docs_state1) + df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=dbs) + choices = get_langchain_choices(selection_docs_state1) + + if valid and not user_path: + # needs to have key for it to make it known different from userdata case in _update_user_db() + from src.gpt_langchain import length_db1 + db1s[langchain_mode2] = [None] * length_db1() + if valid: + save_auth(requests_state1, auth_filename, auth_freeze, selection_docs_state1=selection_docs_state1, + langchain_mode1=langchain_mode2) + + return db1s, selection_docs_state1, gr.update(choices=choices, + value=langchain_mode2), textbox, df_langchain_mode_paths1 + + def remove_langchain_mode(db1s, selection_docs_state1, requests_state1, + langchain_mode1, langchain_mode2, dbsu=None, auth_filename=None, auth_freeze=None, + guest_name=None, + purge=False): + assert auth_filename is not None + assert auth_freeze is not None + + set_userid_gr(db1s, requests_state1, get_userid_auth) + for k in db1s: + set_dbid_gr(db1s[k]) + assert dbsu is not None + langchain_modes = selection_docs_state1['langchain_modes'] + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + langchain_type2 = langchain_mode_types.get(langchain_mode2, LangChainTypes.EITHER.value) + + changed_state = False + textbox = "Invalid access, cannot remove %s" % langchain_mode2 + in_scratch_db = langchain_mode2 in db1s + in_user_db = dbsu is not None and langchain_mode2 in dbsu + if in_scratch_db and not allow_upload_to_my_data or \ + in_user_db and not allow_upload_to_user_data or \ + langchain_mode2 in langchain_modes_intrinsic: + can_remove = False + can_purge = False + if langchain_mode2 in langchain_modes_intrinsic: + can_purge = True + else: + can_remove = True + can_purge = True + + # change global variables + if langchain_mode2 in langchain_modes or langchain_mode2 in langchain_mode_paths or langchain_mode2 in db1s: + if can_purge and purge: + # remove source files + from src.gpt_langchain import get_sources, del_from_db + sources_file, source_list, num_chunks, db = \ + get_sources(db1s, selection_docs_state1, + requests_state1, langchain_mode2, dbs=dbsu, + docs_state0=docs_state0, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + n_jobs=n_jobs) + del_from_db(db, source_list, db_type=db_type) + for fil in source_list: + if os.path.isfile(fil): + print("Purged %s" % fil, flush=True) + remove(fil) + # remove db directory + from src.gpt_langchain import get_persist_directory + persist_directory, langchain_type2 = \ + get_persist_directory(langchain_mode2, langchain_type=langchain_type2, + db1s=db1s, dbs=dbsu) + print("removed persist_directory %s" % persist_directory, flush=True) + remove(persist_directory) + textbox = "Purged, but did not remove %s" % langchain_mode2 + if can_remove: + if langchain_mode2 in langchain_modes: + langchain_modes.remove(langchain_mode2) + if langchain_mode2 in langchain_mode_paths: + langchain_mode_paths.pop(langchain_mode2) + if langchain_mode2 in langchain_mode_types: + langchain_mode_types.pop(langchain_mode2) + if langchain_mode2 in db1s and langchain_mode2 != LangChainMode.MY_DATA.value: + # don't remove last MyData, used as user hash + db1s.pop(langchain_mode2) + textbox = "" + changed_state = True + else: + textbox = "%s is not visible" % langchain_mode2 + + # update + selection_docs_state1 = update_langchain_mode_paths(selection_docs_state1) + df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=dbs) + + if changed_state: + save_auth(requests_state1, auth_filename, auth_freeze, selection_docs_state1=selection_docs_state1, + langchain_mode1=langchain_mode2) + + return db1s, selection_docs_state1, \ + gr.update(choices=get_langchain_choices(selection_docs_state1), + value=langchain_mode2), textbox, df_langchain_mode_paths1 + + eventdb20a = new_langchain_mode_text.submit(user_state_setup, + inputs=[my_db_state, requests_state, + new_langchain_mode_text, new_langchain_mode_text], + outputs=[my_db_state, requests_state, new_langchain_mode_text], + show_progress='minimal') + add_langchain_mode_func = functools.partial(add_langchain_mode, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + guest_name=kwargs['guest_name'], + ) + eventdb20b = eventdb20a.then(fn=add_langchain_mode_func, + inputs=[my_db_state, selection_docs_state, requests_state, + langchain_mode, + new_langchain_mode_text], + outputs=[my_db_state, selection_docs_state, langchain_mode, + new_langchain_mode_text, + langchain_mode_path_text], + api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None) + db_events.extend([eventdb20a, eventdb20b]) + + remove_langchain_mode_func = functools.partial(remove_langchain_mode, + dbsu=dbs, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + guest_name=kwargs['guest_name'], + ) + eventdb21a = remove_langchain_mode_text.submit(user_state_setup, + inputs=[my_db_state, + requests_state, + remove_langchain_mode_text, remove_langchain_mode_text], + outputs=[my_db_state, + requests_state, remove_langchain_mode_text], + show_progress='minimal') + remove_langchain_mode_kwargs = dict(fn=remove_langchain_mode_func, + inputs=[my_db_state, selection_docs_state, requests_state, + langchain_mode, + remove_langchain_mode_text], + outputs=[my_db_state, selection_docs_state, langchain_mode, + remove_langchain_mode_text, + langchain_mode_path_text]) + eventdb21b = eventdb21a.then(**remove_langchain_mode_kwargs, + api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None) + db_events.extend([eventdb21a, eventdb21b]) + + eventdb22a = purge_langchain_mode_text.submit(user_state_setup, + inputs=[my_db_state, + requests_state, + purge_langchain_mode_text, purge_langchain_mode_text], + outputs=[my_db_state, + requests_state, purge_langchain_mode_text], + show_progress='minimal') + purge_langchain_mode_func = functools.partial(remove_langchain_mode_func, purge=True) + purge_langchain_mode_kwargs = dict(fn=purge_langchain_mode_func, + inputs=[my_db_state, selection_docs_state, requests_state, + langchain_mode, + purge_langchain_mode_text], + outputs=[my_db_state, selection_docs_state, langchain_mode, + purge_langchain_mode_text, + langchain_mode_path_text]) + # purge_langchain_mode_kwargs = remove_langchain_mode_kwargs.copy() + # purge_langchain_mode_kwargs['fn'] = functools.partial(remove_langchain_mode_kwargs['fn'], purge=True) + eventdb22b = eventdb22a.then(**purge_langchain_mode_kwargs, + api_name='purge_langchain_mode_text' if allow_api and allow_upload_to_user_data else None) + db_events.extend([eventdb22a, eventdb22b]) + + def load_langchain_gr(db1s, selection_docs_state1, requests_state1, langchain_mode1, auth_filename=None): + load_auth(db1s, requests_state1, auth_filename, selection_docs_state1=selection_docs_state1) + + selection_docs_state1 = update_langchain_mode_paths(selection_docs_state1) + df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=dbs) + return selection_docs_state1, \ + gr.update(choices=get_langchain_choices(selection_docs_state1), + value=langchain_mode1), df_langchain_mode_paths1 + + eventdbloadla = load_langchain.click(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + load_langchain_gr_func = functools.partial(load_langchain_gr, + auth_filename=kwargs['auth_filename']) + eventdbloadlb = eventdbloadla.then(fn=load_langchain_gr_func, + inputs=[my_db_state, selection_docs_state, requests_state, langchain_mode], + outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text], + api_name='load_langchain' if allow_api and allow_upload_to_user_data else None) + + if not kwargs['large_file_count_mode']: + # FIXME: Could add all these functions, inputs, outputs into single function for snappier GUI + # all update events when not doing large file count mode + # Note: Login touches langchain_mode, which triggers all these + lg_change_event2 = lg_change_event.then(**get_sources_kwargs) + lg_change_event3 = lg_change_event2.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + lg_change_event4 = lg_change_event3.then(**show_sources_kwargs) + lg_change_event5 = lg_change_event4.then(**get_viewable_sources_args) + lg_change_event6 = lg_change_event5.then(**viewable_kwargs) + + eventdb2c = eventdb2.then(**get_sources_kwargs) + eventdb2d = eventdb2c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb2e = eventdb2d.then(**show_sources_kwargs) + eventdb2f = eventdb2e.then(**get_viewable_sources_args) + eventdb2g = eventdb2f.then(**viewable_kwargs) + + eventdb1c = eventdb1.then(**get_sources_kwargs) + eventdb1d = eventdb1c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb1e = eventdb1d.then(**show_sources_kwargs) + eventdb1f = eventdb1e.then(**get_viewable_sources_args) + eventdb1g = eventdb1f.then(**viewable_kwargs) + + eventdb3c = eventdb3.then(**get_sources_kwargs) + eventdb3d = eventdb3c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb3e = eventdb3d.then(**show_sources_kwargs) + eventdb3f = eventdb3e.then(**get_viewable_sources_args) + eventdb3g = eventdb3f.then(**viewable_kwargs) + + eventdb90ua = eventdb90.then(**get_sources_kwargs) + eventdb90ub = eventdb90ua.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb90uc = eventdb90ub.then(**show_sources_kwargs) + eventdb90ud = eventdb90uc.then(**get_viewable_sources_args) + eventdb90ue = eventdb90ud.then(**viewable_kwargs) + + eventdb20c = eventdb20b.then(**get_sources_kwargs) + eventdb20d = eventdb20c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb20e = eventdb20d.then(**show_sources_kwargs) + eventdb20f = eventdb20e.then(**get_viewable_sources_args) + eventdb20g = eventdb20f.then(**viewable_kwargs) + + eventdb21c = eventdb21b.then(**get_sources_kwargs) + eventdb21d = eventdb21c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb21e = eventdb21d.then(**show_sources_kwargs) + eventdb21f = eventdb21e.then(**get_viewable_sources_args) + eventdb21g = eventdb21f.then(**viewable_kwargs) + + eventdb22c = eventdb22b.then(**get_sources_kwargs) + eventdb22d = eventdb22c.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb22e = eventdb22d.then(**show_sources_kwargs) + eventdb22f = eventdb22e.then(**get_viewable_sources_args) + eventdb22g = eventdb22f.then(**viewable_kwargs) + + event_attach3 = event_attach2.then(**get_sources_kwargs) + event_attach4 = event_attach3.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + event_attach5 = event_attach4.then(**show_sources_kwargs) + event_attach6 = event_attach5.then(**get_viewable_sources_args) + event_attach7 = event_attach6.then(**viewable_kwargs) + + sync2 = sync1.then(**get_sources_kwargs) + sync3 = sync2.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + sync4 = sync3.then(**show_sources_kwargs) + sync5 = sync4.then(**get_viewable_sources_args) + sync6 = sync5.then(**viewable_kwargs) + + eventdb_loginb = eventdb_logina.then(**get_sources_kwargs) + eventdb_loginc = eventdb_loginb.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + eventdb_logind = eventdb_loginc.then(**show_sources_kwargs) + eventdb_logine = eventdb_logind.then(**get_viewable_sources_args) + eventdb_loginf = eventdb_logine.then(**viewable_kwargs) + + db_events.extend([lg_change_event, lg_change_event2, lg_change_event3, lg_change_event4, lg_change_event5, + lg_change_event6] + + [eventdb2c, eventdb2d, eventdb2e, eventdb2f, eventdb2g] + + [eventdb1c, eventdb1d, eventdb1e, eventdb1f, eventdb1g] + + [eventdb3c, eventdb3d, eventdb3e, eventdb3f, eventdb3g] + + [eventdb90ua, eventdb90ub, eventdb90uc, eventdb90ud, eventdb90ue] + + [eventdb20c, eventdb20d, eventdb20e, eventdb20f, eventdb20g] + + [eventdb21c, eventdb21d, eventdb21e, eventdb21f, eventdb21g] + + [eventdb22c, eventdb22d, eventdb22e, eventdb22f, eventdb22g] + + [event_attach3, event_attach4, event_attach5, event_attach6, event_attach7] + + [sync1, sync2, sync3, sync4, sync5, sync6] + + [eventdb_logina, eventdb_loginb, eventdb_loginc, eventdb_logind, eventdb_logine, + eventdb_loginf] + , + ) + + inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1) + inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2) + from functools import partial + kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list} + # ensure present + for k in inputs_kwargs_list: + assert k in kwargs_evaluate, "Missing %s" % k + + def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, plain_api=False, **kwargs1): + args_list = list(args1) + if str_api: + if plain_api: + # i.e. not fresh model, tells evaluate to use model_state0 + args_list.insert(0, kwargs['model_state_none'].copy()) + args_list.insert(1, my_db_state0.copy()) + args_list.insert(2, selection_docs_state0.copy()) + args_list.insert(3, requests_state0.copy()) + user_kwargs = args_list[len(input_args_list)] + assert isinstance(user_kwargs, str) + user_kwargs = ast.literal_eval(user_kwargs) + else: + assert not plain_api + user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])} + # control kwargs1 for evaluate + kwargs1['answer_with_sources'] = -1 # just text chunk, not URL etc. + kwargs1['show_accordions'] = False + kwargs1['append_sources_to_answer'] = False + kwargs1['show_link_in_sources'] = False + kwargs1['top_k_docs_max_show'] = 30 + + # only used for submit_nochat_api + user_kwargs['chat'] = False + if 'stream_output' not in user_kwargs: + user_kwargs['stream_output'] = False + if plain_api: + user_kwargs['stream_output'] = False + if 'langchain_mode' not in user_kwargs: + # if user doesn't specify, then assume disabled, not use default + if LangChainMode.LLM.value in kwargs['langchain_modes']: + user_kwargs['langchain_mode'] = LangChainMode.LLM.value + elif len(kwargs['langchain_modes']) >= 1: + user_kwargs['langchain_mode'] = kwargs['langchain_modes'][0] + else: + # disabled should always be allowed + user_kwargs['langchain_mode'] = LangChainMode.DISABLED.value + if 'langchain_action' not in user_kwargs: + user_kwargs['langchain_action'] = LangChainAction.QUERY.value + if 'langchain_agents' not in user_kwargs: + user_kwargs['langchain_agents'] = [] + # be flexible + if 'instruction' in user_kwargs and 'instruction_nochat' not in user_kwargs: + user_kwargs['instruction_nochat'] = user_kwargs['instruction'] + if 'iinput' in user_kwargs and 'iinput_nochat' not in user_kwargs: + user_kwargs['iinput_nochat'] = user_kwargs['iinput'] + if 'visible_models' not in user_kwargs: + if kwargs['visible_models']: + if isinstance(kwargs['visible_models'], int): + user_kwargs['visible_models'] = [kwargs['visible_models']] + elif isinstance(kwargs['visible_models'], list): + # only take first one + user_kwargs['visible_models'] = [kwargs['visible_models'][0]] + else: + user_kwargs['visible_models'] = [0] + else: + # if no user version or default version, then just take first + user_kwargs['visible_models'] = [0] + + if 'h2ogpt_key' not in user_kwargs: + user_kwargs['h2ogpt_key'] = None + if 'system_prompt' in user_kwargs and user_kwargs['system_prompt'] is None: + # avoid worrying about below default_kwargs -> args_list that checks if None + user_kwargs['system_prompt'] = 'None' + + set1 = set(list(default_kwargs1.keys())) + set2 = set(eval_func_param_names) + assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2)) + # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get() + model_state1 = args_list[0] + my_db_state1 = args_list[1] + selection_docs_state1 = args_list[2] + requests_state1 = args_list[3] + args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k + in eval_func_param_names] + assert len(args_list) == len(eval_func_param_names) + stream_output1 = args_list[eval_func_param_names.index('stream_output')] + if len(model_states) > 1: + visible_models1 = args_list[eval_func_param_names.index('visible_models')] + model_active_choice1 = visible_models_to_model_choice(visible_models1) + model_state1 = model_states[model_active_choice1 % len(model_states)] + for key in key_overrides: + if user_kwargs.get(key) is None and model_state1.get(key) is not None: + args_list[eval_func_param_names.index(key)] = model_state1[key] + if hasattr(model_state1['tokenizer'], 'model_max_length'): + # ensure listen to limit, with some buffer + # buffer = 50 + buffer = 0 + args_list[eval_func_param_names.index('max_new_tokens')] = min( + args_list[eval_func_param_names.index('max_new_tokens')], + model_state1['tokenizer'].model_max_length - buffer) + + # override overall visible_models and h2ogpt_key if have model_specific one + # NOTE: only applicable if len(model_states) > 1 at moment + # else controlled by evaluate() + if 'visible_models' in model_state1 and model_state1['visible_models'] is not None: + assert isinstance(model_state1['visible_models'], int) + args_list[eval_func_param_names.index('visible_models')] = model_state1['visible_models'] + if 'h2ogpt_key' in model_state1 and model_state1['h2ogpt_key'] is not None: + # remote server key if present + # i.e. may be '' and used to override overall local key + assert isinstance(model_state1['h2ogpt_key'], str) + args_list[eval_func_param_names.index('h2ogpt_key')] = model_state1['h2ogpt_key'] + + # local key, not for remote server unless same, will be passed through + h2ogpt_key1 = args_list[eval_func_param_names.index('h2ogpt_key')] + + # final full evaluate args list + args_list = [model_state1, my_db_state1, selection_docs_state1, requests_state1] + args_list + + # NOTE: Don't allow UI-like access, in case modify state via API + valid_key = is_valid_key(kwargs['enforce_h2ogpt_api_key'], kwargs['h2ogpt_api_keys'], h2ogpt_key1, + requests_state1=None) + evaluate_local = evaluate if valid_key else evaluate_fake + + save_dict = dict() + ret = {} + try: + for res_dict in evaluate_local(*tuple(args_list), **kwargs1): + error = res_dict.get('error', '') + extra = res_dict.get('extra', '') + save_dict = res_dict.get('save_dict', {}) + + # update save_dict + save_dict['error'] = error + save_dict['extra'] = extra + save_dict['valid_key'] = valid_key + save_dict['h2ogpt_key'] = h2ogpt_key1 + if str_api and plain_api: + save_dict['which_api'] = 'str_plain_api' + elif str_api: + save_dict['which_api'] = 'str_api' + elif plain_api: + save_dict['which_api'] = 'plain_api' + else: + save_dict['which_api'] = 'nochat_api' + if 'extra_dict' not in save_dict: + save_dict['extra_dict'] = {} + if requests_state1: + save_dict['extra_dict'].update(requests_state1) + else: + save_dict['extra_dict'].update(dict(username='NO_REQUEST')) + + if is_public: + # don't want to share actual endpoints + if 'save_dict' in res_dict and isinstance(res_dict['save_dict'], dict): + res_dict['save_dict'].pop('inference_server', None) + if 'extra_dict' in res_dict['save_dict'] and isinstance(res_dict['save_dict']['extra_dict'], + dict): + res_dict['save_dict']['extra_dict'].pop('inference_server', None) + + # get response + if str_api: + # full return of dict + ret = res_dict + elif kwargs['langchain_mode'] == 'Disabled': + ret = fix_text_for_gradio(res_dict['response']) + else: + ret = '
' + fix_text_for_gradio(res_dict['response']) + if stream_output1: + # yield as it goes, else need to wait since predict only returns first yield + yield ret + finally: + clear_torch_cache() + clear_embeddings(user_kwargs['langchain_mode'], my_db_state1) + save_generate_output(**save_dict) + if not stream_output1: + # return back last ret + yield ret + + kwargs_evaluate_nochat = kwargs_evaluate.copy() + # nominally never want sources appended for API calls, which is what nochat used for primarily + kwargs_evaluate_nochat.update(dict(append_sources_to_answer=False)) + fun = partial(evaluate_nochat, + default_kwargs1=default_kwargs, + str_api=False, + **kwargs_evaluate_nochat) + fun_with_dict_str = partial(evaluate_nochat, + default_kwargs1=default_kwargs, + str_api=True, + **kwargs_evaluate_nochat + ) + + fun_with_dict_str_plain = partial(evaluate_nochat, + default_kwargs1=default_kwargs, + str_api=True, + plain_api=True, + **kwargs_evaluate_nochat + ) + + dark_mode_btn.click( + None, + None, + None, + _js=wrap_js_to_lambda(0, get_dark_js()), + api_name="dark" if allow_api else None, + queue=False, + ) + + # Handle uploads from API + upload_api_btn = gr.UploadButton("Upload File Results", visible=False) + file_upload_api = gr.File(visible=False) + file_upload_text = gr.Textbox(visible=False) + + def upload_file(files): + if isinstance(files, list): + file_paths = [file.name for file in files] + else: + file_paths = files.name + return file_paths, file_paths + + upload_api_btn.upload(fn=upload_file, + inputs=upload_api_btn, + outputs=[file_upload_api, file_upload_text], + api_name='upload_api' if allow_upload_api else None) + + def visible_toggle(x): + x = 'off' if x == 'on' else 'on' + return x, gr.Column.update(visible=True if x == 'on' else False) + + side_bar_btn.click(fn=visible_toggle, + inputs=side_bar_text, + outputs=[side_bar_text, side_bar], + queue=False) + + doc_count_btn.click(fn=visible_toggle, + inputs=doc_count_text, + outputs=[doc_count_text, row_doc_track], + queue=False) + + submit_buttons_btn.click(fn=visible_toggle, + inputs=submit_buttons_text, + outputs=[submit_buttons_text, submit_buttons], + queue=False) + + visible_model_btn.click(fn=visible_toggle, + inputs=visible_models_text, + outputs=[visible_models_text, visible_models], + queue=False) + + # examples after submit or any other buttons for chat or no chat + if kwargs['examples'] is not None and kwargs['show_examples']: + gr.Examples(examples=kwargs['examples'], inputs=inputs_list) + + # Score + def score_last_response(*args, nochat=False, num_model_lock=0): + try: + if num_model_lock > 0: + # then lock way + args_list = list(args).copy() + outputs = args_list[-num_model_lock:] + score_texts1 = [] + for output in outputs: + # same input, put into form good for _score_last_response() + args_list[-1] = output + score_texts1.append( + _score_last_response(*tuple(args_list), nochat=nochat, + num_model_lock=num_model_lock, prefix='')) + if len(score_texts1) > 1: + return "Response Scores: %s" % ' '.join(score_texts1) + else: + return "Response Scores: %s" % score_texts1[0] + else: + return _score_last_response(*args, nochat=nochat, num_model_lock=num_model_lock) + finally: + clear_torch_cache() + + def _score_last_response(*args, nochat=False, num_model_lock=0, prefix='Response Score: '): + """ Similar to user() """ + args_list = list(args) + smodel = score_model_state0['model'] + stokenizer = score_model_state0['tokenizer'] + sdevice = score_model_state0['device'] + + if memory_restriction_level > 0: + max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 + elif hasattr(stokenizer, 'model_max_length'): + max_length_tokenize = stokenizer.model_max_length + else: + # limit to 1024, not worth OOMing on reward score + max_length_tokenize = 2048 - 1024 + cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM + + if not nochat: + history = args_list[-1] + if history is None: + history = [] + if smodel is not None and \ + stokenizer is not None and \ + sdevice is not None and \ + history is not None and len(history) > 0 and \ + history[-1] is not None and \ + len(history[-1]) >= 2: + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + question = history[-1][0] + + answer = history[-1][1] + else: + return '%sNA' % prefix + else: + answer = args_list[-1] + instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat') + question = args_list[instruction_nochat_arg_id] + + if question is None: + return '%sBad Question' % prefix + if answer is None: + return '%sBad Answer' % prefix + try: + score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len) + finally: + clear_torch_cache() + if isinstance(score, str): + return '%sNA' % prefix + return '{}{:.1%}'.format(prefix, score) + + def noop_score_last_response(*args, **kwargs): + return "Response Score: Disabled" + + if kwargs['score_model']: + score_fun = score_last_response + else: + score_fun = noop_score_last_response + + score_args = dict(fn=score_fun, + inputs=inputs_list + [text_output], + outputs=[score_text], + ) + score_args2 = dict(fn=partial(score_fun), + inputs=inputs_list2 + [text_output2], + outputs=[score_text2], + ) + score_fun_func = functools.partial(score_fun, num_model_lock=len(text_outputs)) + all_score_args = dict(fn=score_fun_func, + inputs=inputs_list + text_outputs, + outputs=score_text, + ) + + score_args_nochat = dict(fn=partial(score_fun, nochat=True), + inputs=inputs_list + [text_output_nochat], + outputs=[score_text_nochat], + ) + + def update_history(*args, undo=False, retry=False, sanitize_user_prompt=False): + """ + User that fills history for bot + :param args: + :param undo: + :param retry: + :param sanitize_user_prompt: + :return: + """ + args_list = list(args) + user_message = args_list[eval_func_param_names.index('instruction')] # chat only + input1 = args_list[eval_func_param_names.index('iinput')] # chat only + prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] + langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] + langchain_action1 = args_list[eval_func_param_names.index('langchain_action')] + langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')] + document_subset1 = args_list[eval_func_param_names.index('document_subset')] + document_choice1 = args_list[eval_func_param_names.index('document_choice')] + if not prompt_type1: + # shouldn't have to specify if CLI launched model + prompt_type1 = kwargs['prompt_type'] + # apply back + args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 + if input1 and not user_message.endswith(':'): + user_message1 = user_message + ":" + input1 + elif input1: + user_message1 = user_message + input1 + else: + user_message1 = user_message + if sanitize_user_prompt: + pass + # requirements.txt has comment that need to re-enable the below 2 lines + # from better_profanity import profanity + # user_message1 = profanity.censor(user_message1) + + history = args_list[-1] + if history is None: + # bad history + history = [] + history = history.copy() + + if undo: + if len(history) > 0: + history.pop() + return history + if retry: + if history: + history[-1][1] = None + return history + if user_message1 in ['', None, '\n']: + if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): + # reject non-retry submit/enter + return history + user_message1 = fix_text_for_gradio(user_message1) + return history + [[user_message1, None]] + + def user(*args, undo=False, retry=False, sanitize_user_prompt=False): + return update_history(*args, undo=undo, retry=retry, sanitize_user_prompt=sanitize_user_prompt) + + def all_user(*args, undo=False, retry=False, sanitize_user_prompt=False, num_model_lock=0, all_models=None): + args_list = list(args) + + visible_models1 = args_list[eval_func_param_names.index('visible_models')] + assert isinstance(all_models, list) + visible_list = get_model_lock_visible_list(visible_models1, all_models) + + history_list = args_list[-num_model_lock:] + assert len(all_models) == len(history_list) + assert len(history_list) > 0, "Bad history list: %s" % history_list + for hi, history in enumerate(history_list): + if not visible_list[hi]: + continue + if num_model_lock > 0: + hargs = args_list[:-num_model_lock].copy() + else: + hargs = args_list.copy() + hargs += [history] + history_list[hi] = update_history(*hargs, undo=undo, retry=retry, + sanitize_user_prompt=sanitize_user_prompt) + if len(history_list) > 1: + return tuple(history_list) + else: + return history_list[0] + + def get_model_max_length(model_state1): + if model_state1 and not isinstance(model_state1["tokenizer"], str): + tokenizer = model_state1["tokenizer"] + elif model_state0 and not isinstance(model_state0["tokenizer"], str): + tokenizer = model_state0["tokenizer"] + else: + tokenizer = None + if tokenizer is not None: + return tokenizer.model_max_length + else: + return 2000 + + def prep_bot(*args, retry=False, which_model=0): + """ + + :param args: + :param retry: + :param which_model: identifies which model if doing model_lock + API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list + :return: last element is True if should run bot, False if should just yield history + """ + isize = len(input_args_list) + 1 # states + chat history + # don't deepcopy, can contain model itself + args_list = list(args).copy() + model_state1 = args_list[-isize] + my_db_state1 = args_list[-isize + 1] + selection_docs_state1 = args_list[-isize + 2] + requests_state1 = args_list[-isize + 3] + history = args_list[-1] + if not history: + history = [] + prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] + prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')] + langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] + langchain_action1 = args_list[eval_func_param_names.index('langchain_action')] + document_subset1 = args_list[eval_func_param_names.index('document_subset')] + h2ogpt_key1 = args_list[eval_func_param_names.index('h2ogpt_key')] + chat_conversation1 = args_list[eval_func_param_names.index('chat_conversation')] + valid_key = is_valid_key(kwargs['enforce_h2ogpt_api_key'], kwargs['h2ogpt_api_keys'], h2ogpt_key1, + requests_state1=requests_state1) + + dummy_return = history, None, langchain_mode1, my_db_state1, requests_state1, valid_key, h2ogpt_key1 + + if model_state1['model'] is None or model_state1['model'] == no_model_str: + return dummy_return + + args_list = args_list[:-isize] # only keep rest needed for evaluate() + if not history: + print("No history", flush=True) + return dummy_return + instruction1 = history[-1][0] + if retry and history: + # if retry, pop history and move onto bot stuff + instruction1 = history[-1][0] + history[-1][1] = None + elif not instruction1: + if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): + # if not retrying, then reject empty query + return dummy_return + elif len(history) > 0 and history[-1][1] not in [None, '']: + # reject submit button if already filled and not retrying + # None when not filling with '' to keep client happy + return dummy_return + + evaluate_local = evaluate if valid_key else evaluate_fake + + # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it + prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1, + which_model=which_model) + # apply back to args_list for evaluate() + args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 + args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 + context1 = args_list[eval_func_param_names.index('context')] + + chat_conversation1 = merge_chat_conversation_history(chat_conversation1, history) + args_list[eval_func_param_names.index('chat_conversation')] = chat_conversation1 + + if 'visible_models' in model_state1 and model_state1['visible_models'] is not None: + assert isinstance(model_state1['visible_models'], int) + args_list[eval_func_param_names.index('visible_models')] = model_state1['visible_models'] + if 'h2ogpt_key' in model_state1 and model_state1['h2ogpt_key'] is not None: + # i.e. may be '' and used to override overall local key + assert isinstance(model_state1['h2ogpt_key'], str) + args_list[eval_func_param_names.index('h2ogpt_key')] = model_state1['h2ogpt_key'] + + args_list[0] = instruction1 # override original instruction with history from user + args_list[2] = context1 + + fun1 = partial(evaluate_local, + model_state1, + my_db_state1, + selection_docs_state1, + requests_state1, + *tuple(args_list), + **kwargs_evaluate) + + return history, fun1, langchain_mode1, my_db_state1, requests_state1, valid_key, h2ogpt_key1 + + def gen1_fake(fun1, history): + error = '' + extra = '' + save_dict = dict() + yield history, error, extra, save_dict + return + + def get_response(fun1, history): + """ + bot that consumes history for user input + instruction (from input_list) itself is not consumed by bot + :return: + """ + error = '' + extra = '' + save_dict = dict() + if not fun1: + yield history, error, extra, save_dict + return + try: + for output_fun in fun1(): + output = output_fun['response'] + extra = output_fun['sources'] # FIXME: can show sources in separate text box etc. + save_dict = output_fun.get('save_dict', {}) + # ensure good visually, else markdown ignores multiple \n + bot_message = fix_text_for_gradio(output) + history[-1][1] = bot_message + yield history, error, extra, save_dict + except StopIteration: + yield history, error, extra, save_dict + except RuntimeError as e: + if "generator raised StopIteration" in str(e): + # assume last entry was bad, undo + history.pop() + yield history, error, extra, save_dict + else: + if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: + history[-1][1] = '' + yield history, str(e), extra, save_dict + raise + except Exception as e: + # put error into user input + ex = "Exception: %s" % str(e) + if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: + history[-1][1] = '' + yield history, ex, extra, save_dict + raise + finally: + # clear_torch_cache() + # don't clear torch cache here, too early and stalls generation if used for all_bot() + pass + return + + def clear_embeddings(langchain_mode1, db1s): + # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache + if db_type in ['chroma', 'chroma_old'] and langchain_mode1 not in ['LLM', 'Disabled', None, '']: + from gpt_langchain import clear_embedding, length_db1 + db = dbs.get('langchain_mode1') + if db is not None and not isinstance(db, str): + clear_embedding(db) + if db1s is not None and langchain_mode1 in db1s: + db1 = db1s[langchain_mode1] + if len(db1) == length_db1(): + clear_embedding(db1[0]) + + def bot(*args, retry=False): + history, fun1, langchain_mode1, db1, requests_state1, valid_key, h2ogpt_key1 = prep_bot(*args, retry=retry) + save_dict = dict() + error = '' + extra = '' + try: + for res in get_response(fun1, history): + history, error, extra, save_dict = res + # pass back to gradio only these, rest are consumed in this function + yield history, error + finally: + clear_torch_cache() + clear_embeddings(langchain_mode1, db1) + if 'extra_dict' not in save_dict: + save_dict['extra_dict'] = {} + save_dict['valid_key'] = valid_key + save_dict['h2ogpt_key'] = h2ogpt_key1 + if requests_state1: + save_dict['extra_dict'].update(requests_state1) + else: + save_dict['extra_dict'].update(dict(username='NO_REQUEST')) + save_dict['error'] = error + save_dict['extra'] = extra + save_dict['which_api'] = 'bot' + save_generate_output(**save_dict) + + def all_bot(*args, retry=False, model_states1=None, all_models=None): + args_list = list(args).copy() + chatbots = args_list[-len(model_states1):] + args_list0 = args_list[:-len(model_states1)] # same for all models + exceptions = [] + stream_output1 = args_list[eval_func_param_names.index('stream_output')] + max_time1 = args_list[eval_func_param_names.index('max_time')] + langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] + + visible_models1 = args_list[eval_func_param_names.index('visible_models')] + assert isinstance(all_models, list) + assert len(all_models) == len(model_states1) + visible_list = get_model_lock_visible_list(visible_models1, all_models) + + isize = len(input_args_list) + 1 # states + chat history + db1s = None + requests_state1 = None + valid_key = False + h2ogpt_key1 = '' + extras = [] + exceptions = [] + save_dicts = [] + try: + gen_list = [] + for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)): + args_list1 = args_list0.copy() + args_list1.insert(-isize + 2, + model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4 + # if at start, have None in response still, replace with '' so client etc. acts like normal + # assumes other parts of code treat '' and None as if no response yet from bot + # can't do this later in bot code as racy with threaded generators + if len(chatbot1) > 0 and len(chatbot1[-1]) == 2 and chatbot1[-1][1] is None: + chatbot1[-1][1] = '' + args_list1.append(chatbot1) + # so consistent with prep_bot() + # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1 + # langchain_mode1 and my_db_state1 and requests_state1 should be same for every bot + history, fun1, langchain_mode1, db1s, requests_state1, valid_key, h2ogpt_key1, = \ + prep_bot(*tuple(args_list1), retry=retry, + which_model=chatboti) + if visible_list[chatboti]: + gen1 = get_response(fun1, history) + if stream_output1: + gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False) + # else timeout will truncate output for non-streaming case + else: + gen1 = gen1_fake(fun1, history) + gen_list.append(gen1) + + bots_old = chatbots.copy() + exceptions_old = [''] * len(bots_old) + extras_old = [''] * len(bots_old) + save_dicts_old = [{}] * len(bots_old) + tgen0 = time.time() + for res1 in itertools.zip_longest(*gen_list): + if time.time() - tgen0 > max_time1: + print("Took too long: %s" % max_time1, flush=True) + break + + bots = [x[0] if x is not None and not isinstance(x, BaseException) else y + for x, y in zip(res1, bots_old)] + bots_old = bots.copy() + + def larger_str(x, y): + return x if len(x) > len(y) else y + + exceptions = [x[1] if x is not None and not isinstance(x, BaseException) else larger_str(str(x), y) + for x, y in zip(res1, exceptions_old)] + exceptions_old = exceptions.copy() + + extras = [x[2] if x is not None and not isinstance(x, BaseException) else y + for x, y in zip(res1, extras_old)] + extras_old = extras.copy() + + save_dicts = [x[3] if x is not None and not isinstance(x, BaseException) else y + for x, y in zip(res1, save_dicts_old)] + save_dicts_old = save_dicts.copy() + + def choose_exc(x): + # don't expose ports etc. to exceptions window + if is_public: + return "Endpoint unavailable or failed" + else: + return x + + exceptions_str = '\n'.join( + ['Model %s: %s' % (iix, choose_exc(x)) for iix, x in enumerate(exceptions) if + x not in [None, '', 'None']]) + # yield back to gradio only is bots + exceptions, rest are consumed locally + if len(bots) > 1: + yield tuple(bots + [exceptions_str]) + else: + yield bots[0], exceptions_str + if exceptions: + exceptions_reduced = [x for x in exceptions if x not in ['', None, 'None']] + if exceptions_reduced: + print("Generate exceptions: %s" % exceptions_reduced, flush=True) + finally: + clear_torch_cache() + clear_embeddings(langchain_mode1, db1s) + for extra, error, save_dict, model_name in zip(extras, exceptions, save_dicts, all_models): + if 'extra_dict' not in save_dict: + save_dict['extra_dict'] = {} + if requests_state1: + save_dict['extra_dict'].update(requests_state1) + else: + save_dict['extra_dict'].update(dict(username='NO_REQUEST')) + save_dict['error'] = error + save_dict['extra'] = extra + save_dict['which_api'] = 'all_bot_%s' % model_name + save_dict['valid_key'] = valid_key + save_dict['h2ogpt_key'] = h2ogpt_key1 + save_generate_output(**save_dict) + + # NORMAL MODEL + user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), + inputs=inputs_list + [text_output], + outputs=text_output, + ) + bot_args = dict(fn=bot, + inputs=inputs_list + [model_state, my_db_state, selection_docs_state, requests_state] + [ + text_output], + outputs=[text_output, chat_exception_text], + ) + retry_bot_args = dict(fn=functools.partial(bot, retry=True), + inputs=inputs_list + [model_state, my_db_state, selection_docs_state, requests_state] + [ + text_output], + outputs=[text_output, chat_exception_text], + ) + retry_user_args = dict(fn=functools.partial(user, retry=True), + inputs=inputs_list + [text_output], + outputs=text_output, + ) + undo_user_args = dict(fn=functools.partial(user, undo=True), + inputs=inputs_list + [text_output], + outputs=text_output, + ) + + # MODEL2 + user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), + inputs=inputs_list2 + [text_output2], + outputs=text_output2, + ) + bot_args2 = dict(fn=bot, + inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state, requests_state] + [ + text_output2], + outputs=[text_output2, chat_exception_text], + ) + retry_bot_args2 = dict(fn=functools.partial(bot, retry=True), + inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state, + requests_state] + [text_output2], + outputs=[text_output2, chat_exception_text], + ) + retry_user_args2 = dict(fn=functools.partial(user, retry=True), + inputs=inputs_list2 + [text_output2], + outputs=text_output2, + ) + undo_user_args2 = dict(fn=functools.partial(user, undo=True), + inputs=inputs_list2 + [text_output2], + outputs=text_output2, + ) + + # MODEL N + all_user_args = dict(fn=functools.partial(all_user, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + all_models=kwargs['all_models'] + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, + all_models=kwargs['all_models']), + inputs=inputs_list + [my_db_state, selection_docs_state, requests_state] + + text_outputs, + outputs=text_outputs + [chat_exception_text], + ) + all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, + all_models=kwargs['all_models'], + retry=True), + inputs=inputs_list + [my_db_state, selection_docs_state, requests_state] + + text_outputs, + outputs=text_outputs + [chat_exception_text], + ) + all_retry_user_args = dict(fn=functools.partial(all_user, retry=True, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + all_models=kwargs['all_models'] + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + all_undo_user_args = dict(fn=functools.partial(all_user, undo=True, + sanitize_user_prompt=kwargs['sanitize_user_prompt'], + num_model_lock=len(text_outputs), + all_models=kwargs['all_models'] + ), + inputs=inputs_list + text_outputs, + outputs=text_outputs, + ) + + def clear_instruct(): + return gr.Textbox.update(value='') + + def deselect_radio_chats(): + return gr.update(value=None) + + def clear_all(): + return gr.Textbox.update(value=''), gr.Textbox.update(value=''), gr.update(value=None), \ + gr.Textbox.update(value=''), gr.Textbox.update(value='') + + if kwargs['model_states']: + submits1 = submits2 = submits3 = [] + submits4 = [] + + triggers = [instruction, submit, retry_btn] + fun_source = [instruction.submit, submit.click, retry_btn.click] + fun_name = ['instruction', 'submit', 'retry'] + user_args = [all_user_args, all_user_args, all_retry_user_args] + bot_args = [all_bot_args, all_bot_args, all_retry_bot_args] + for userargs1, botarg1, funn1, funs1, trigger1, in zip(user_args, bot_args, fun_name, fun_source, triggers): + submit_event11 = funs1(fn=user_state_setup, + inputs=[my_db_state, requests_state, trigger1, trigger1], + outputs=[my_db_state, requests_state, trigger1], + queue=queue) + submit_event1a = submit_event11.then(**userargs1, queue=queue, + api_name='%s' % funn1 if allow_api else None) + # if hit enter on new instruction for submitting new query, no longer the saved chat + submit_event1b = submit_event1a.then(clear_all, inputs=None, + outputs=[instruction, iinput, radio_chats, score_text, + score_text2], + queue=queue) + submit_event1c = submit_event1b.then(**botarg1, + api_name='%s_bot' % funn1 if allow_api else None, + queue=queue) + submit_event1d = submit_event1c.then(**all_score_args, + api_name='%s_bot_score' % funn1 if allow_api else None, + queue=queue) + + submits1.extend([submit_event1a, submit_event1b, submit_event1c, submit_event1d]) + + # if undo, no longer the saved chat + submit_event4 = undo.click(fn=user_state_setup, + inputs=[my_db_state, requests_state, undo, undo], + outputs=[my_db_state, requests_state, undo], + queue=queue) \ + .then(**all_undo_user_args, api_name='undo' if allow_api else None) \ + .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, + score_text2], queue=queue) \ + .then(**all_score_args, api_name='undo_score' if allow_api else None) + submits4 = [submit_event4] + + else: + # in case 2nd model, consume instruction first, so can clear quickly + # bot doesn't consume instruction itself, just history from user, so why works + submit_event11 = instruction.submit(fn=user_state_setup, + inputs=[my_db_state, requests_state, instruction, instruction], + outputs=[my_db_state, requests_state, instruction], + queue=queue) + submit_event1a = submit_event11.then(**user_args, queue=queue, + api_name='instruction' if allow_api else None) + # if hit enter on new instruction for submitting new query, no longer the saved chat + submit_event1a2 = submit_event1a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event1b = submit_event1a2.then(**user_args2, api_name='instruction2' if allow_api else None) + submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \ + .then(clear_instruct, None, iinput) + submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None, + queue=queue) + submit_event1e = submit_event1d.then(**score_args, + api_name='instruction_bot_score' if allow_api else None, + queue=queue) + submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, + queue=queue) + submit_event1g = submit_event1f.then(**score_args2, + api_name='instruction_bot_score2' if allow_api else None, queue=queue) + + submits1 = [submit_event1a, submit_event1a2, submit_event1b, submit_event1c, submit_event1d, + submit_event1e, + submit_event1f, submit_event1g] + + submit_event21 = submit.click(fn=user_state_setup, + inputs=[my_db_state, requests_state, submit, submit], + outputs=[my_db_state, requests_state, submit], + queue=queue) + submit_event2a = submit_event21.then(**user_args, api_name='submit' if allow_api else None) + # if submit new query, no longer the saved chat + submit_event2a2 = submit_event2a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event2b = submit_event2a2.then(**user_args2, api_name='submit2' if allow_api else None) + submit_event2c = submit_event2b.then(clear_all, inputs=None, + outputs=[instruction, iinput, radio_chats, score_text, score_text2], + queue=queue) + submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) + submit_event2e = submit_event2d.then(**score_args, + api_name='submit_bot_score' if allow_api else None, + queue=queue) + submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, + queue=queue) + submit_event2g = submit_event2f.then(**score_args2, + api_name='submit_bot_score2' if allow_api else None, + queue=queue) + + submits2 = [submit_event2a, submit_event2a2, submit_event2b, submit_event2c, submit_event2d, + submit_event2e, + submit_event2f, submit_event2g] + + submit_event31 = retry_btn.click(fn=user_state_setup, + inputs=[my_db_state, requests_state, retry_btn, retry_btn], + outputs=[my_db_state, requests_state, retry_btn], + queue=queue) + submit_event3a = submit_event31.then(**user_args, api_name='retry' if allow_api else None) + # if retry, no longer the saved chat + submit_event3a2 = submit_event3a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) + submit_event3b = submit_event3a2.then(**user_args2, api_name='retry2' if allow_api else None) + submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \ + .then(clear_instruct, None, iinput) + submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, + queue=queue) + submit_event3e = submit_event3d.then(**score_args, + api_name='retry_bot_score' if allow_api else None, + queue=queue) + submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, + queue=queue) + submit_event3g = submit_event3f.then(**score_args2, + api_name='retry_bot_score2' if allow_api else None, + queue=queue) + + submits3 = [submit_event3a, submit_event3a2, submit_event3b, submit_event3c, submit_event3d, + submit_event3e, + submit_event3f, submit_event3g] + + # if undo, no longer the saved chat + submit_event4 = undo.click(fn=user_state_setup, + inputs=[my_db_state, requests_state, undo, undo], + outputs=[my_db_state, requests_state, undo], + queue=queue) \ + .then(**undo_user_args, api_name='undo' if allow_api else None) \ + .then(**undo_user_args2, api_name='undo2' if allow_api else None) \ + .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, + score_text2], queue=queue) \ + .then(**score_args, api_name='undo_score' if allow_api else None) \ + .then(**score_args2, api_name='undo_score2' if allow_api else None) + submits4 = [submit_event4] + + # MANAGE CHATS + def dedup(short_chat, short_chats): + if short_chat not in short_chats: + return short_chat + for i in range(1, 1000): + short_chat_try = short_chat + "_" + str(i) + if short_chat_try not in short_chats: + return short_chat_try + # fallback and hope for best + short_chat = short_chat + "_" + str(random.random()) + return short_chat + + def get_short_chat(x, short_chats, short_len=20, words=4): + if x and len(x[0]) == 2 and x[0][0] is not None: + short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip() + if not short_chat: + # e.g.summarization, try using answer + short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip() + if not short_chat: + short_chat = 'Unk' + short_chat = dedup(short_chat, short_chats) + else: + short_chat = None + return short_chat + + def is_chat_same(x, y): + #

etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation + is_same = True + # length of conversation has to be same + if len(x) != len(y): + return False + if len(x) != len(y): + return False + for stepx, stepy in zip(x, y): + if len(stepx) != len(stepy): + # something off with a conversation + return False + for stepxx, stepyy in zip(stepx, stepy): + if len(stepxx) != len(stepyy): + # something off with a conversation + return False + if len(stepxx) != 2: + # something off + return False + if len(stepyy) != 2: + # something off + return False + questionx = stepxx[0].replace('

', '').replace('

', '') if stepxx[0] is not None else None + answerx = stepxx[1].replace('

', '').replace('

', '') if stepxx[1] is not None else None + + questiony = stepyy[0].replace('

', '').replace('

', '') if stepyy[0] is not None else None + answery = stepyy[1].replace('

', '').replace('

', '') if stepyy[1] is not None else None + + if questionx != questiony or answerx != answery: + return False + return is_same + + def save_chat(*args, chat_is_list=False, auth_filename=None, auth_freeze=None): + args_list = list(args) + db1s = args_list[0] + requests_state1 = args_list[1] + args_list = args_list[2:] + if not chat_is_list: + # list of chatbot histories, + # can't pass in list with list of chatbot histories and state due to gradio limits + chat_list = args_list[:-1] + else: + assert len(args_list) == 2 + chat_list = args_list[0] + # if old chat file with single chatbot, get into shape + if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len( + chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str): + chat_list = [chat_list] + # remove None histories + chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] + chat_list_none = [x for x in chat_list if x not in chat_list_not_none] + if len(chat_list_none) > 0 and len(chat_list_not_none) == 0: + raise ValueError("Invalid chat file") + # dict with keys of short chat names, values of list of list of chatbot histories + chat_state1 = args_list[-1] + short_chats = list(chat_state1.keys()) + if len(chat_list_not_none) > 0: + # make short_chat key from only first history, based upon question that is same anyways + chat_first = chat_list_not_none[0] + short_chat = get_short_chat(chat_first, short_chats) + if short_chat: + old_chat_lists = list(chat_state1.values()) + already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) + if not already_exists: + chat_state1[short_chat] = chat_list.copy() + + # reverse so newest at top + choices = list(chat_state1.keys()).copy() + choices.reverse() + + # save saved chats and chatbots to auth file + text_output1 = chat_list[0] + text_output21 = chat_list[1] + text_outputs1 = chat_list[2:] + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1, + text_output1=text_output1, text_output21=text_output21, text_outputs1=text_outputs1) + + return chat_state1, gr.update(choices=choices, value=None) + + def switch_chat(chat_key, chat_state1, num_model_lock=0): + chosen_chat = chat_state1[chat_key] + # deal with possible different size of chat list vs. current list + ret_chat = [None] * (2 + num_model_lock) + for chati in range(0, 2 + num_model_lock): + ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] + return tuple(ret_chat) + + def clear_texts(*args): + return tuple([gr.Textbox.update(value='')] * len(args)) + + def clear_scores(): + return gr.Textbox.update(value=res_value), \ + gr.Textbox.update(value='Response Score: NA'), \ + gr.Textbox.update(value='Response Score: NA') + + switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) + radio_chats.input(switch_chat_fun, + inputs=[radio_chats, chat_state], + outputs=[text_output, text_output2] + text_outputs) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + def remove_chat(chat_key, chat_state1): + if isinstance(chat_key, str): + chat_state1.pop(chat_key, None) + return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1 + + remove_chat_event = remove_chat_btn.click(remove_chat, + inputs=[radio_chats, chat_state], + outputs=[radio_chats, chat_state], + queue=False, api_name='remove_chat') + + def get_chats1(chat_state1): + base = 'chats' + base = makedirs(base, exist_ok=True, tmp_ok=True, use_base=True) + filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) + with open(filename, "wt") as f: + f.write(json.dumps(chat_state1, indent=2)) + return filename + + export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, + api_name='export_chats' if allow_api else None) + + def add_chats_from_file(db1s, requests_state1, file, chat_state1, radio_chats1, chat_exception_text1, + auth_filename=None, auth_freeze=None): + if not file: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + if isinstance(file, str): + files = [file] + else: + files = file + if not files: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + chat_exception_list = [] + for file1 in files: + try: + if hasattr(file1, 'name'): + file1 = file1.name + with open(file1, "rt") as f: + new_chats = json.loads(f.read()) + for chat1_k, chat1_v in new_chats.items(): + # ignore chat1_k, regenerate and de-dup to avoid loss + chat_state1, _ = save_chat(db1s, requests_state1, chat1_v, chat_state1, chat_is_list=True) + except BaseException as e: + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + ex_str = "File %s exception: %s" % (file1, str(e)) + print(ex_str, flush=True) + chat_exception_list.append(ex_str) + chat_exception_text1 = '\n'.join(chat_exception_list) + # save chat to auth file + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1) + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + + # note for update_user_db_func output is ignored for db + chatup_change_eventa = chatsup_output.change(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + add_chats_from_file_func = functools.partial(add_chats_from_file, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + chatup_change_event = chatup_change_eventa.then(add_chats_from_file_func, + inputs=[my_db_state, requests_state] + + [chatsup_output, chat_state, radio_chats, + chat_exception_text], + outputs=[chatsup_output, chat_state, radio_chats, + chat_exception_text], + queue=False, + api_name='add_to_chats' if allow_api else None) + + clear_chat_event = clear_chat_btn.click(fn=clear_texts, + inputs=[text_output, text_output2] + text_outputs, + outputs=[text_output, text_output2] + text_outputs, + queue=False, api_name='clear' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + clear_eventa = save_chat_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + save_chat_func = functools.partial(save_chat, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + clear_event = clear_eventa.then(save_chat_func, + inputs=[my_db_state, requests_state] + + [text_output, text_output2] + text_outputs + + [chat_state], + outputs=[chat_state, radio_chats], + api_name='save_chat' if allow_api else None) + if kwargs['score_model']: + clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + # NOTE: clear of instruction/iinput for nochat has to come after score, + # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() + no_chat_args = dict(fn=fun, + inputs=[model_state, my_db_state, selection_docs_state, requests_state] + inputs_list, + outputs=text_output_nochat, + queue=queue, + ) + submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + # copy of above with text box submission + submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + + submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, + inputs=[model_state, my_db_state, selection_docs_state, + requests_state, + inputs_dict_str], + outputs=text_output_nochat_api, + queue=True, # required for generator + api_name='submit_nochat_api' if allow_api else None) + + submit_event_nochat_api_plain = submit_nochat_api_plain.click(fun_with_dict_str_plain, + inputs=inputs_dict_str, + outputs=text_output_nochat_api, + queue=False, + api_name='submit_nochat_plain_api' if allow_api else None) + + def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, + load_8bit, load_4bit, low_bit_mode, + load_gptq, load_exllama, use_safetensors, revision, + use_gpu_id, gpu_id, max_seq_len1, rope_scaling1, + model_path_llama1, model_name_gptj1, model_name_gpt4all_llama1, + n_gpu_layers1, n_batch1, n_gqa1, llamacpp_dict_more1, + system_prompt1): + try: + llamacpp_dict = ast.literal_eval(llamacpp_dict_more1) + except: + print("Failed to use user input for llamacpp_dict_more1 dict", flush=True) + llamacpp_dict = {} + llamacpp_dict.update(dict(model_path_llama=model_path_llama1, + model_name_gptj=model_name_gptj1, + model_name_gpt4all_llama=model_name_gpt4all_llama1, + n_gpu_layers=n_gpu_layers1, + n_batch=n_batch1, + n_gqa=n_gqa1, + )) + + # ensure no API calls reach here + if is_public: + raise RuntimeError("Illegal access for %s" % model_name) + # ensure old model removed from GPU memory + if kwargs['debug']: + print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) + + model0 = model_state0['model'] + if isinstance(model_state_old['model'], str) and \ + model0 is not None and \ + hasattr(model0, 'cpu'): + # best can do, move model loaded at first to CPU + model0.cpu() + + if model_state_old['model'] is not None and \ + not isinstance(model_state_old['model'], str): + if hasattr(model_state_old['model'], 'cpu'): + try: + model_state_old['model'].cpu() + except Exception as e: + # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! + print("Unable to put model on CPU: %s" % str(e), flush=True) + del model_state_old['model'] + model_state_old['model'] = None + + if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): + del model_state_old['tokenizer'] + model_state_old['tokenizer'] = None + + clear_torch_cache() + if kwargs['debug']: + print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True) + if not model_name: + model_name = no_model_str + if model_name == no_model_str: + # no-op if no model, just free memory + # no detranscribe needed for model, never go into evaluate + lora_weights = no_lora_str + server_name = no_server_str + return kwargs['model_state_none'].copy(), \ + model_name, lora_weights, server_name, prompt_type_old, \ + gr.Slider.update(maximum=256), \ + gr.Slider.update(maximum=256) + + # don't deepcopy, can contain model itself + all_kwargs1 = all_kwargs.copy() + all_kwargs1['base_model'] = model_name.strip() + all_kwargs1['load_8bit'] = load_8bit + all_kwargs1['load_4bit'] = load_4bit + all_kwargs1['low_bit_mode'] = low_bit_mode + all_kwargs1['load_gptq'] = load_gptq + all_kwargs1['load_exllama'] = load_exllama + all_kwargs1['use_safetensors'] = use_safetensors + all_kwargs1['revision'] = None if not revision else revision # transcribe, don't pass '' + all_kwargs1['use_gpu_id'] = use_gpu_id + all_kwargs1['gpu_id'] = int(gpu_id) if gpu_id not in [None, 'None'] else None # detranscribe + all_kwargs1['llamacpp_dict'] = llamacpp_dict + all_kwargs1['max_seq_len'] = max_seq_len1 + try: + all_kwargs1['rope_scaling'] = str_to_dict(rope_scaling1) # transcribe + except: + print("Failed to use user input for rope_scaling dict", flush=True) + all_kwargs1['rope_scaling'] = {} + model_lower = model_name.strip().lower() + if model_lower in inv_prompt_type_to_model_lower: + prompt_type1 = inv_prompt_type_to_model_lower[model_lower] + else: + prompt_type1 = prompt_type_old + + # detranscribe + if lora_weights == no_lora_str: + lora_weights = '' + all_kwargs1['lora_weights'] = lora_weights.strip() + if server_name == no_server_str: + server_name = '' + all_kwargs1['inference_server'] = server_name.strip() + + model1, tokenizer1, device1 = get_model(reward_type=False, + **get_kwargs(get_model, exclude_names=['reward_type'], + **all_kwargs1)) + clear_torch_cache() + + tokenizer_base_model = model_name + prompt_dict1, error0 = get_prompt(prompt_type1, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True, system_prompt=system_prompt1) + model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, + base_model=model_name, tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, inference_server=server_name, + prompt_type=prompt_type1, prompt_dict=prompt_dict1, + ) + + max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) + + if kwargs['debug']: + print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) + return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ + gr.Slider.update(maximum=max_max_new_tokens1), \ + gr.Slider.update(maximum=max_max_new_tokens1) + + def get_prompt_str(prompt_type1, prompt_dict1, system_prompt1, which=0): + if prompt_type1 in ['', None]: + print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) + return str({}) + prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', + reduced=False, making_context=False, return_dict=True, + system_prompt=system_prompt1) + if prompt_dict_error: + return str(prompt_dict_error) + else: + # return so user can manipulate if want and use as custom + return str(prompt_dict1) + + get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) + get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) + prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict, system_prompt], + outputs=prompt_dict, queue=False) + prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2, system_prompt], + outputs=prompt_dict2, + queue=False) + + def dropdown_prompt_type_list(x): + return gr.Dropdown.update(value=x) + + def chatbot_list(x, model_used_in): + return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') + + load_model_args = dict(fn=load_model, + inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, + model_load8bit_checkbox, model_load4bit_checkbox, model_low_bit_mode, + model_load_gptq, model_load_exllama_checkbox, + model_safetensors_checkbox, model_revision, + model_use_gpu_id_checkbox, model_gpu, + max_seq_len, rope_scaling, + model_path_llama, model_name_gptj, model_name_gpt4all_llama, + n_gpu_layers, n_batch, n_gqa, llamacpp_dict_more, + system_prompt], + outputs=[model_state, model_used, lora_used, server_used, + # if prompt_type changes, prompt_dict will change via change rule + prompt_type, max_new_tokens, min_new_tokens, + ]) + prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) + chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) + nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) + load_model_event = load_model_button.click(**load_model_args, + api_name='load_model' if allow_api and not is_public else None) \ + .then(**prompt_update_args) \ + .then(**chatbot_update_args) \ + .then(**nochat_update_args) \ + .then(clear_torch_cache) + + load_model_args2 = dict(fn=load_model, + inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, + model_load8bit_checkbox2, model_load4bit_checkbox2, model_low_bit_mode2, + model_load_gptq2, model_load_exllama_checkbox2, + model_safetensors_checkbox2, model_revision2, + model_use_gpu_id_checkbox2, model_gpu2, + max_seq_len2, rope_scaling2, + model_path_llama2, model_name_gptj2, model_name_gpt4all_llama2, + n_gpu_layers2, n_batch2, n_gqa2, llamacpp_dict_more2, + system_prompt], + outputs=[model_state2, model_used2, lora_used2, server_used2, + # if prompt_type2 changes, prompt_dict2 will change via change rule + prompt_type2, max_new_tokens2, min_new_tokens2 + ]) + prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) + chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) + load_model_event2 = load_model_button2.click(**load_model_args2, + api_name='load_model2' if allow_api and not is_public else None) \ + .then(**prompt_update_args2) \ + .then(**chatbot_update_args2) \ + .then(clear_torch_cache) + + def dropdown_model_lora_server_list(model_list0, model_x, + lora_list0, lora_x, + server_list0, server_x, + model_used1, lora_used1, server_used1, + model_used2, lora_used2, server_used2, + ): + model_new_state = [model_list0[0] + [model_x]] + model_new_options = [*model_new_state[0]] + if no_model_str in model_new_options: + model_new_options.remove(no_model_str) + model_new_options = [no_model_str] + sorted(model_new_options) + x1 = model_x if model_used1 == no_model_str else model_used1 + x2 = model_x if model_used2 == no_model_str else model_used2 + ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), + gr.Dropdown.update(value=x2, choices=model_new_options), + '', model_new_state] + + lora_new_state = [lora_list0[0] + [lora_x]] + lora_new_options = [*lora_new_state[0]] + if no_lora_str in lora_new_options: + lora_new_options.remove(no_lora_str) + lora_new_options = [no_lora_str] + sorted(lora_new_options) + # don't switch drop-down to added lora if already have model loaded + x1 = lora_x if model_used1 == no_model_str else lora_used1 + x2 = lora_x if model_used2 == no_model_str else lora_used2 + ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), + gr.Dropdown.update(value=x2, choices=lora_new_options), + '', lora_new_state] + + server_new_state = [server_list0[0] + [server_x]] + server_new_options = [*server_new_state[0]] + if no_server_str in server_new_options: + server_new_options.remove(no_server_str) + server_new_options = [no_server_str] + sorted(server_new_options) + # don't switch drop-down to added server if already have model loaded + x1 = server_x if model_used1 == no_model_str else server_used1 + x2 = server_x if model_used2 == no_model_str else server_used2 + ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), + gr.Dropdown.update(value=x2, choices=server_new_options), + '', server_new_state] + + return tuple(ret1 + ret2 + ret3) + + add_model_lora_server_event = \ + add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, + inputs=[model_options_state, new_model] + + [lora_options_state, new_lora] + + [server_options_state, new_server] + + [model_used, lora_used, server_used] + + [model_used2, lora_used2, server_used2], + outputs=[model_choice, model_choice2, new_model, model_options_state] + + [lora_choice, lora_choice2, new_lora, lora_options_state] + + [server_choice, server_choice2, new_server, + server_options_state], + queue=False) + + go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, + queue=False) \ + .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \ + .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False) + + def compare_textbox_fun(x): + return gr.Textbox.update(visible=x) + + def compare_column_fun(x): + return gr.Column.update(visible=x) + + def compare_prompt_fun(x): + return gr.Dropdown.update(visible=x) + + def slider_fun(x): + return gr.Slider.update(visible=x) + + compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, + api_name="compare_checkbox" if allow_api else None) \ + .then(compare_column_fun, compare_checkbox, col_model2) \ + .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ + .then(compare_textbox_fun, compare_checkbox, score_text2) \ + .then(slider_fun, compare_checkbox, max_new_tokens2) \ + .then(slider_fun, compare_checkbox, min_new_tokens2) + # FIXME: add score_res2 in condition, but do better + + # callback for logging flagged input/output + callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") + flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, + None, + preprocess=False, + api_name='flag' if allow_api else None, queue=False) + flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, + preprocess=False, + api_name='flag_nochat' if allow_api else None, queue=False) + + def get_system_info(): + if is_public: + time.sleep(10) # delay to avoid spam since queue=False + return gr.Textbox.update(value=system_info_print()) + + system_event = system_btn.click(get_system_info, outputs=system_text, + api_name='system_info' if allow_api else None, queue=False) + + def get_system_info_dict(system_input1, **kwargs1): + if system_input1 != os.getenv("ADMIN_PASS", ""): + return json.dumps({}) + exclude_list = ['admin_pass', 'examples'] + sys_dict = {k: v for k, v in kwargs1.items() if + isinstance(v, (str, int, bool, float)) and k not in exclude_list} + try: + sys_dict.update(system_info()) + except Exception as e: + # protection + print("Exception: %s" % str(e), flush=True) + return json.dumps(sys_dict) + + system_kwargs = all_kwargs.copy() + system_kwargs.update(dict(command=str(' '.join(sys.argv)))) + get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) + + system_dict_event = system_btn2.click(get_system_info_dict_func, + inputs=system_input, + outputs=system_text2, + api_name='system_info_dict' if allow_api else None, + queue=False, # queue to avoid spam + ) + + def get_hash(): + return kwargs['git_hash'] + + system_event = system_btn3.click(get_hash, + outputs=system_text3, + api_name='system_hash' if allow_api else None, + queue=False, + ) + + def get_model_names(): + key_list = ['base_model', 'prompt_type', 'prompt_dict'] + list(kwargs['other_model_state_defaults'].keys()) + # don't want to expose backend inference server IP etc. + # key_list += ['inference_server'] + return [{k: x[k] for k in key_list if k in x} for x in model_states] + + models_list_event = system_btn4.click(get_model_names, + outputs=system_text4, + api_name='model_names' if allow_api else None, + queue=False, + ) + + def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, + system_prompt1, chat_conversation1, + memory_restriction_level1=0, + keep_sources_in_context1=False, + ): + if model_state1 and not isinstance(model_state1['tokenizer'], str): + tokenizer = model_state1['tokenizer'] + elif model_state0 and not isinstance(model_state0['tokenizer'], str): + tokenizer = model_state0['tokenizer'] + else: + tokenizer = None + if tokenizer is not None: + langchain_mode1 = 'LLM' + add_chat_history_to_context1 = True + # fake user message to mimic bot() + chat1 = copy.deepcopy(chat1) + chat1 = chat1 + [['user_message1', None]] + model_max_length1 = tokenizer.model_max_length + context1 = history_to_context(chat1, + langchain_mode=langchain_mode1, + add_chat_history_to_context=add_chat_history_to_context1, + prompt_type=prompt_type1, + prompt_dict=prompt_dict1, + chat=True, + model_max_length=model_max_length1, + memory_restriction_level=memory_restriction_level1, + keep_sources_in_context=keep_sources_in_context1, + system_prompt=system_prompt1, + chat_conversation=chat_conversation1) + tokens = tokenizer(context1, return_tensors="pt")['input_ids'] + if len(tokens.shape) == 1: + return str(tokens.shape[0]) + elif len(tokens.shape) == 2: + return str(tokens.shape[1]) + else: + return "N/A" + else: + return "N/A" + + count_chat_tokens_func = functools.partial(count_chat_tokens, + memory_restriction_level1=memory_restriction_level, + keep_sources_in_context1=kwargs['keep_sources_in_context']) + count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens_func, + inputs=[model_state, text_output, prompt_type, prompt_dict, + system_prompt, chat_conversation], + outputs=chat_token_count, + api_name='count_tokens' if allow_api else None) + + # don't pass text_output, don't want to clear output, just stop it + # cancel only stops outer generation, not inner generation or non-generation + stop_btn.click(lambda: None, None, None, + cancels=submits1 + submits2 + submits3 + submits4 + + [submit_event_nochat, submit_event_nochat2] + + [eventdb1, eventdb2, eventdb3] + + [eventdb7a, eventdb7, eventdb8a, eventdb8, eventdb9a, eventdb9, eventdb12a, eventdb12] + + db_events + + [eventdbloadla, eventdbloadlb] + + [clear_event] + + [submit_event_nochat_api, submit_event_nochat] + + [load_model_event, load_model_event2] + + [count_tokens_event] + , + queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) + + if kwargs['auth'] is not None: + auth = authf + load_func = user_state_setup + load_inputs = [my_db_state, requests_state, login_btn, login_btn] + load_outputs = [my_db_state, requests_state, login_btn] + else: + auth = None + load_func, load_inputs, load_outputs = None, None, None + + app_js = wrap_js_to_lambda( + len(load_inputs) if load_inputs else 0, + get_dark_js() if kwargs['dark'] else None, + get_heap_js(heap_app_id) if is_heap_analytics_enabled else None) + + load_event = demo.load(fn=load_func, inputs=load_inputs, outputs=load_outputs, _js=app_js) + + if load_func: + load_event2 = load_event.then(load_login_func, + inputs=login_inputs, + outputs=login_outputs) + if not kwargs['large_file_count_mode']: + load_event3 = load_event2.then(**get_sources_kwargs) + load_event4 = load_event3.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + load_event5 = load_event4.then(**show_sources_kwargs) + load_event6 = load_event5.then(**get_viewable_sources_args) + load_event7 = load_event6.then(**viewable_kwargs) + + demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) + favicon_file = "h2o-logo.svg" + favicon_path = favicon_file + if not os.path.isfile(favicon_file): + print("favicon_path1=%s not found" % favicon_file, flush=True) + alt_path = os.path.dirname(os.path.abspath(__file__)) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path2: %s not found in %s" % (favicon_file, alt_path), flush=True) + alt_path = os.path.dirname(alt_path) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path3: %s not found in %s" % (favicon_file, alt_path), flush=True) + favicon_path = None + + if kwargs['prepare_offline_level'] > 0: + from src.prepare_offline import go_prepare_offline + go_prepare_offline(**locals()) + return + + scheduler = BackgroundScheduler() + scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) + if is_public and \ + kwargs['base_model'] not in non_hf_types: + # FIXME: disable for gptj, langchain or gpt4all modify print itself + # FIXME: and any multi-threaded/async print will enter model output! + scheduler.add_job(func=ping, trigger="interval", seconds=60) + if is_public or os.getenv('PING_GPU'): + scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) + scheduler.start() + + # import control + if kwargs['langchain_mode'] == 'Disabled' and \ + os.environ.get("TEST_LANGCHAIN_IMPORT") and \ + kwargs['base_model'] not in non_hf_types: + assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + + # set port in case GRADIO_SERVER_PORT was already set in prior main() call, + # gradio does not listen if change after import + # Keep None if not set so can find an open port above used ports + server_port = os.getenv('GRADIO_SERVER_PORT') + if server_port is not None: + server_port = int(server_port) + + demo.launch(share=kwargs['share'], + server_name=kwargs['server_name'], + show_error=True, + server_port=server_port, + favicon_path=favicon_path, + prevent_thread_lock=True, + auth=auth, + auth_message=auth_message, + root_path=kwargs['root_path']) + if kwargs['verbose'] or not (kwargs['base_model'] in ['gptj', 'gpt4all_llama']): + print("Started Gradio Server and/or GUI: server_name: %s port: %s" % (kwargs['server_name'], server_port), + flush=True) + if kwargs['block_gradio_exit']: + demo.block_thread() + + +def show_doc(db1s, selection_docs_state1, requests_state1, + langchain_mode1, + single_document_choice1, + view_raw_text_checkbox1, + text_context_list1, + dbs1=None, + load_db_if_exists1=None, + db_type1=None, + use_openai_embedding1=None, + hf_embedding_model1=None, + migrate_embedding_model_or_db1=None, + auto_migrate_db1=None, + verbose1=False, + get_userid_auth1=None, + max_raw_chunks=1000000, + api=False, + n_jobs=-1): + file = single_document_choice1 + document_choice1 = [single_document_choice1] + content = None + db_documents = [] + db_metadatas = [] + if db_type1 in ['chroma', 'chroma_old']: + assert langchain_mode1 is not None + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + from src.gpt_langchain import set_userid, get_any_db, get_docs_and_meta + set_userid(db1s, requests_state1, get_userid_auth1) + top_k_docs = -1 + db = get_any_db(db1s, langchain_mode1, langchain_mode_paths, langchain_mode_types, + dbs=dbs1, + load_db_if_exists=load_db_if_exists1, + db_type=db_type1, + use_openai_embedding=use_openai_embedding1, + hf_embedding_model=hf_embedding_model1, + migrate_embedding_model=migrate_embedding_model_or_db1, + auto_migrate_db=auto_migrate_db1, + for_sources_list=True, + verbose=verbose1, + n_jobs=n_jobs, + ) + query_action = False # long chunks like would be used for summarize + # the below is as or filter, so will show doc or by chunk, unrestricted + from langchain.vectorstores import Chroma + if isinstance(db, Chroma): + # chroma >= 0.4 + if view_raw_text_checkbox1: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$gte": -1}} + for x in document_choice1][0] + else: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + else: + # migration for chroma < 0.4 + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + if view_raw_text_checkbox1: + # like or, full raw all chunk types + filter_kwargs = dict(filter=one_filter) + else: + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, + text_context_list=text_context_list1) + # order documents + from langchain.docstore.document import Document + docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) + for result in zip(db_documents, db_metadatas)] + doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas] + doc_page_ids = [x.get('page', 0) for x in db_metadatas] + doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] + docs_with_score = [x for hx, px, cx, x in + sorted(zip(doc_hashes, doc_page_ids, doc_chunk_ids, docs_with_score), + key=lambda x: (x[0], x[1], x[2])) + # if cx == -1 + ] + db_metadatas = [x[0].metadata for x in docs_with_score][:max_raw_chunks] + db_documents = [x[0].page_content for x in docs_with_score][:max_raw_chunks] + # done reordering + if view_raw_text_checkbox1: + content = [dict_to_html(x) + '\n' + text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + else: + content = [text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + content = '\n'.join(content) + content = f""" + + +{file} + + +{content} + +""" + if api: + if view_raw_text_checkbox1: + return dict(contents=db_documents, metadatas=db_metadatas) + else: + contents = [text_to_html(y, api=api) for y in db_documents] + metadatas = [dict_to_html(x, api=api) for x in db_metadatas] + return dict(contents=contents, metadatas=metadatas) + else: + assert not api, "API mode for get_document only supported for chroma" + + dummy1 = gr.update(visible=False, value=None) + # backup is text dump of db version + if content: + dummy_ret = dummy1, dummy1, dummy1, dummy1, gr.update(visible=True, value=content) + if view_raw_text_checkbox1: + return dummy_ret + else: + dummy_ret = dummy1, dummy1, dummy1, dummy1, dummy1 + + if not isinstance(file, str): + return dummy_ret + + if file.lower().endswith('.html') or file.lower().endswith('.mhtml') or file.lower().endswith('.htm') or \ + file.lower().endswith('.xml'): + try: + with open(file, 'rt') as f: + content = f.read() + return gr.update(visible=True, value=content), dummy1, dummy1, dummy1, dummy1 + except: + return dummy_ret + + if file.lower().endswith('.md'): + try: + with open(file, 'rt') as f: + content = f.read() + return dummy1, dummy1, dummy1, gr.update(visible=True, value=content), dummy1 + except: + return dummy_ret + + if file.lower().endswith('.py'): + try: + with open(file, 'rt') as f: + content = f.read() + content = f"```python\n{content}\n```" + return dummy1, dummy1, dummy1, gr.update(visible=True, value=content), dummy1 + except: + return dummy_ret + + if file.lower().endswith('.txt') or file.lower().endswith('.rst') or file.lower().endswith( + '.rtf') or file.lower().endswith('.toml'): + try: + with open(file, 'rt') as f: + content = f.read() + content = f"```text\n{content}\n```" + return dummy1, dummy1, dummy1, gr.update(visible=True, value=content), dummy1 + except: + return dummy_ret + + func = None + if file.lower().endswith(".csv"): + func = pd.read_csv + elif file.lower().endswith(".pickle"): + func = pd.read_pickle + elif file.lower().endswith(".xls") or file.lower().endswith("xlsx"): + func = pd.read_excel + elif file.lower().endswith('.json'): + func = pd.read_json + # pandas doesn't show full thing, even if html view shows broken things still better + # elif file.lower().endswith('.xml'): + # func = pd.read_xml + if func is not None: + try: + df = func(file).head(100) + except: + return dummy_ret + return dummy1, gr.update(visible=True, value=df), dummy1, dummy1, dummy1 + port = int(os.getenv('GRADIO_SERVER_PORT', '7860')) + import pathlib + absolute_path_string = os.path.abspath(file) + url_path = pathlib.Path(absolute_path_string).as_uri() + url = get_url(absolute_path_string, from_str=True) + img_url = url.replace(""" + +"""), dummy1, dummy1, dummy1, dummy1 + else: + # FIXME: This doesn't work yet, just return dummy result for now + if False: + ip = get_local_ip() + document1 = url_path.replace('file://', f'http://{ip}:{port}/') + # document1 = url + return gr.update(visible=True, value=f""" + +"""), dummy1, dummy1, dummy1, dummy1 + else: + return dummy_ret + else: + return dummy_ret + + +def get_inputs_list(inputs_dict, model_lower, model_id=1): + """ + map gradio objects in locals() to inputs for evaluate(). + :param inputs_dict: + :param model_lower: + :param model_id: Which model (1 or 2) of 2 + :return: + """ + inputs_list_names = list(inspect.signature(evaluate).parameters) + inputs_list = [] + inputs_dict_out = {} + for k in inputs_list_names: + if k == 'kwargs': + continue + if k in input_args_list + inputs_kwargs_list: + # these are added at use time for args or partial for kwargs, not taken as input + continue + if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']: + continue + if model_id == 2: + if k == 'prompt_type': + k = 'prompt_type2' + if k == 'prompt_used': + k = 'prompt_used2' + if k == 'max_new_tokens': + k = 'max_new_tokens2' + if k == 'min_new_tokens': + k = 'min_new_tokens2' + inputs_list.append(inputs_dict[k]) + inputs_dict_out[k] = inputs_dict[k] + return inputs_list, inputs_dict_out + + +def update_user_db_gr(file, db1s, selection_docs_state1, requests_state1, + langchain_mode, chunk, chunk_size, embed, + + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + h2ogpt_key, + + captions_model=None, + caption_loader=None, + doctr_loader=None, + + dbs=None, + get_userid_auth=None, + **kwargs): + valid_key = is_valid_key(kwargs.pop('enforce_h2ogpt_api_key', None), + kwargs.pop('h2ogpt_api_keys', []), h2ogpt_key, + requests_state1=requests_state1) + if not valid_key: + raise ValueError(invalid_key_msg) + loaders_dict, captions_model = gr_to_lg(image_loaders, + pdf_loaders, + url_loaders, + captions_model=captions_model, + **kwargs, + ) + if jq_schema is None: + jq_schema = kwargs['jq_schema0'] + loaders_dict.update(dict(captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + jq_schema=jq_schema, + )) + kwargs.pop('image_loaders_options0', None) + kwargs.pop('pdf_loaders_options0', None) + kwargs.pop('url_loaders_options0', None) + kwargs.pop('jq_schema0', None) + if not embed: + kwargs['use_openai_embedding'] = False + kwargs['hf_embedding_model'] = 'fake' + kwargs['migrate_embedding_model'] = False + + from src.gpt_langchain import update_user_db + return update_user_db(file, db1s, selection_docs_state1, requests_state1, + langchain_mode=langchain_mode, chunk=chunk, chunk_size=chunk_size, + **loaders_dict, + dbs=dbs, + get_userid_auth=get_userid_auth, + **kwargs) + + +def get_sources_gr(db1s, selection_docs_state1, requests_state1, langchain_mode, dbs=None, docs_state0=None, + load_db_if_exists=None, + db_type=None, + use_openai_embedding=None, + hf_embedding_model=None, + migrate_embedding_model=None, + auto_migrate_db=None, + verbose=False, + get_userid_auth=None, + api=False, + n_jobs=-1): + from src.gpt_langchain import get_sources + sources_file, source_list, num_chunks, db = \ + get_sources(db1s, selection_docs_state1, requests_state1, langchain_mode, + dbs=dbs, docs_state0=docs_state0, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + n_jobs=n_jobs, + ) + if api: + return source_list + if langchain_mode in langchain_modes_non_db: + doc_counts_str = "LLM Mode\nNo Collection" + else: + doc_counts_str = "Collection: %s\nDocs: %d\nChunks: %d" % (langchain_mode, len(source_list), num_chunks) + return sources_file, source_list, doc_counts_str + + +def get_source_files_given_langchain_mode_gr(db1s, selection_docs_state1, requests_state1, + langchain_mode, + dbs=None, + load_db_if_exists=None, + db_type=None, + use_openai_embedding=None, + hf_embedding_model=None, + migrate_embedding_model=None, + auto_migrate_db=None, + verbose=False, + get_userid_auth=None, + n_jobs=-1): + from src.gpt_langchain import get_source_files_given_langchain_mode + return get_source_files_given_langchain_mode(db1s, selection_docs_state1, requests_state1, None, + langchain_mode, + dbs=dbs, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + delete_sources=False, + n_jobs=n_jobs) + + +def del_source_files_given_langchain_mode_gr(db1s, selection_docs_state1, requests_state1, document_choice1, + langchain_mode, + dbs=None, + load_db_if_exists=None, + db_type=None, + use_openai_embedding=None, + hf_embedding_model=None, + migrate_embedding_model=None, + auto_migrate_db=None, + verbose=False, + get_userid_auth=None, + n_jobs=-1): + from src.gpt_langchain import get_source_files_given_langchain_mode + return get_source_files_given_langchain_mode(db1s, selection_docs_state1, requests_state1, document_choice1, + langchain_mode, + dbs=dbs, + load_db_if_exists=load_db_if_exists, + db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + verbose=verbose, + get_userid_auth=get_userid_auth, + delete_sources=True, + n_jobs=n_jobs) + + +def update_and_get_source_files_given_langchain_mode_gr(db1s, + selection_docs_state, + requests_state, + langchain_mode, chunk, chunk_size, + + image_loaders, + pdf_loaders, + url_loaders, + jq_schema, + + captions_model=None, + caption_loader=None, + doctr_loader=None, + + dbs=None, first_para=None, + hf_embedding_model=None, + use_openai_embedding=None, + migrate_embedding_model=None, + auto_migrate_db=None, + text_limit=None, + db_type=None, load_db_if_exists=None, + n_jobs=None, verbose=None, get_userid_auth=None, + image_loaders_options0=None, + pdf_loaders_options0=None, + url_loaders_options0=None, + jq_schema0=None): + from src.gpt_langchain import update_and_get_source_files_given_langchain_mode + + loaders_dict, captions_model = gr_to_lg(image_loaders, + pdf_loaders, + url_loaders, + image_loaders_options0=image_loaders_options0, + pdf_loaders_options0=pdf_loaders_options0, + url_loaders_options0=url_loaders_options0, + captions_model=captions_model, + ) + if jq_schema is None: + jq_schema = jq_schema0 + loaders_dict.update(dict(captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + jq_schema=jq_schema, + )) + + return update_and_get_source_files_given_langchain_mode(db1s, + selection_docs_state, + requests_state, + langchain_mode, chunk, chunk_size, + **loaders_dict, + dbs=dbs, first_para=first_para, + hf_embedding_model=hf_embedding_model, + use_openai_embedding=use_openai_embedding, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + text_limit=text_limit, + db_type=db_type, load_db_if_exists=load_db_if_exists, + n_jobs=n_jobs, verbose=verbose, + get_userid_auth=get_userid_auth) + + +def set_userid_gr(db1s, requests_state1, get_userid_auth): + from src.gpt_langchain import set_userid + return set_userid(db1s, requests_state1, get_userid_auth) + + +def set_dbid_gr(db1): + from src.gpt_langchain import set_dbid + return set_dbid(db1) + + +def set_userid_direct_gr(db1s, userid, username): + from src.gpt_langchain import set_userid_direct + return set_userid_direct(db1s, userid, username)