diff --git "a/src/gradio_runner.py" "b/src/gradio_runner.py"
new file mode 100644--- /dev/null
+++ "b/src/gradio_runner.py"
@@ -0,0 +1,4601 @@
+import ast
+import copy
+import functools
+import inspect
+import itertools
+import json
+import os
+import pprint
+import random
+import shutil
+import sys
+import time
+import traceback
+import uuid
+import filelock
+import numpy as np
+import pandas as pd
+import requests
+from iterators import TimeoutIterator
+
+from gradio_utils.css import get_css
+from gradio_utils.prompt_form import make_chatbots
+from src.db_utils import set_userid, get_username_direct
+
+# This is a hack to prevent Gradio from phoning home when it gets imported
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+
+
+def my_get(url, **kwargs):
+ print('Gradio HTTP request redirected to localhost :)', flush=True)
+ kwargs.setdefault('allow_redirects', True)
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
+
+
+original_get = requests.get
+requests.get = my_get
+import gradio as gr
+
+requests.get = original_get
+
+
+def fix_pydantic_duplicate_validators_error():
+ try:
+ from pydantic import class_validators
+
+ class_validators.in_ipython = lambda: True # type: ignore[attr-defined]
+ except ImportError:
+ pass
+
+
+fix_pydantic_duplicate_validators_error()
+
+from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
+ DocumentChoice, langchain_modes_intrinsic, LangChainTypes, langchain_modes_non_db, gr_to_lg, invalid_key_msg, \
+ LangChainAgent, docs_ordering_types
+from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, \
+ get_dark_js, get_heap_js, wrap_js_to_lambda, \
+ spacing_xsm, radius_xsm, text_xsm
+from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
+ get_prompt
+from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
+ ping, makedirs, get_kwargs, system_info, ping_gpu, get_url, get_local_ip, \
+ save_generate_output, url_alive, remove, dict_to_html, text_to_html, lg_to_gr, str_to_dict, have_serpapi
+from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, \
+ get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
+ evaluate_fake, merge_chat_conversation_history
+from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
+ input_args_list, key_overrides
+
+from apscheduler.schedulers.background import BackgroundScheduler
+
+
+def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True):
+ if fix_latex_dollars:
+ ts = text.split('```')
+ for parti, part in enumerate(ts):
+ inside = parti % 2 == 1
+ if not inside:
+ ts[parti] = ts[parti].replace('$', '﹩')
+ text = '```'.join(ts)
+
+ if fix_new_lines:
+ # let Gradio handle code, since got improved recently
+ ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is.
+ # ensure good visually, else markdown ignores multiple \n
+ # handle code blocks
+ ts = text.split('```')
+ for parti, part in enumerate(ts):
+ inside = parti % 2 == 1
+ if not inside:
+ ts[parti] = ts[parti].replace('\n', '
')
+ text = '```'.join(ts)
+ return text
+
+
+def is_valid_key(enforce_h2ogpt_api_key, h2ogpt_api_keys, h2ogpt_key1, requests_state1=None):
+ valid_key = False
+ if not enforce_h2ogpt_api_key:
+ # no token barrier
+ valid_key = 'not enforced'
+ else:
+ if isinstance(h2ogpt_api_keys, list) and h2ogpt_key1 in h2ogpt_api_keys:
+ # passed token barrier
+ valid_key = True
+ elif isinstance(h2ogpt_api_keys, str) and os.path.isfile(h2ogpt_api_keys):
+ with filelock.FileLock(h2ogpt_api_keys + '.lock'):
+ with open(h2ogpt_api_keys, 'rt') as f:
+ h2ogpt_api_keys = json.load(f)
+ if h2ogpt_key1 in h2ogpt_api_keys:
+ valid_key = True
+ if isinstance(requests_state1, dict) and 'username' in requests_state1 and requests_state1['username']:
+ # no UI limit currently
+ valid_key = True
+ return valid_key
+
+
+def go_gradio(**kwargs):
+ allow_api = kwargs['allow_api']
+ is_public = kwargs['is_public']
+ is_hf = kwargs['is_hf']
+ memory_restriction_level = kwargs['memory_restriction_level']
+ n_gpus = kwargs['n_gpus']
+ admin_pass = kwargs['admin_pass']
+ model_states = kwargs['model_states']
+ dbs = kwargs['dbs']
+ db_type = kwargs['db_type']
+ visible_langchain_actions = kwargs['visible_langchain_actions']
+ visible_langchain_agents = kwargs['visible_langchain_agents']
+ allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
+ allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
+ enable_sources_list = kwargs['enable_sources_list']
+ enable_url_upload = kwargs['enable_url_upload']
+ enable_text_upload = kwargs['enable_text_upload']
+ use_openai_embedding = kwargs['use_openai_embedding']
+ hf_embedding_model = kwargs['hf_embedding_model']
+ load_db_if_exists = kwargs['load_db_if_exists']
+ migrate_embedding_model = kwargs['migrate_embedding_model']
+ auto_migrate_db = kwargs['auto_migrate_db']
+ captions_model = kwargs['captions_model']
+ caption_loader = kwargs['caption_loader']
+ doctr_loader = kwargs['doctr_loader']
+
+ n_jobs = kwargs['n_jobs']
+ verbose = kwargs['verbose']
+
+ # for dynamic state per user session in gradio
+ model_state0 = kwargs['model_state0']
+ score_model_state0 = kwargs['score_model_state0']
+ my_db_state0 = kwargs['my_db_state0']
+ selection_docs_state0 = kwargs['selection_docs_state0']
+ visible_models_state0 = kwargs['visible_models_state0']
+ # For Heap analytics
+ is_heap_analytics_enabled = kwargs['enable_heap_analytics']
+ heap_app_id = kwargs['heap_app_id']
+
+ # easy update of kwargs needed for evaluate() etc.
+ queue = True
+ allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
+ allow_upload_api = allow_api and allow_upload
+
+ kwargs.update(locals())
+
+ # import control
+ if kwargs['langchain_mode'] != 'Disabled':
+ from gpt_langchain import file_types, have_arxiv
+ else:
+ have_arxiv = False
+ file_types = []
+
+ if 'mbart-' in kwargs['model_lower']:
+ instruction_label_nochat = "Text to translate"
+ else:
+ instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
+ " use Enter for multiple input lines)"
+
+ title = 'h2oGPT'
+ if kwargs['visible_h2ogpt_header']:
+ description = """h2oGPT LLM Leaderboard LLM Studio
CodeLlama
🤗 Models"""
+ else:
+ description = None
+ description_bottom = "If this host is busy, try
[Multi-Model](https://gpt.h2o.ai)
[CodeLlama](https://codellama.h2o.ai)
[Llama2 70B](https://llama.h2o.ai)
[Falcon 40B](https://falcon.h2o.ai)
[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)
[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
"
+ if is_hf:
+ description_bottom += ''''''
+ task_info_md = ''
+ css_code = get_css(kwargs)
+
+ if kwargs['gradio_offline_level'] >= 0:
+ # avoid GoogleFont that pulls from internet
+ if kwargs['gradio_offline_level'] == 1:
+ # front end would still have to download fonts or have cached it at some point
+ base_font = 'Source Sans Pro'
+ else:
+ base_font = 'Helvetica'
+ theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
+ font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
+ else:
+ theme_kwargs = dict()
+ if kwargs['gradio_size'] == 'xsmall':
+ theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
+ elif kwargs['gradio_size'] in [None, 'small']:
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
+ radius_size=gr.themes.sizes.spacing_sm))
+ elif kwargs['gradio_size'] == 'large':
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg),
+ radius_size=gr.themes.sizes.spacing_lg)
+ elif kwargs['gradio_size'] == 'medium':
+ theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md,
+ radius_size=gr.themes.sizes.spacing_md))
+
+ theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
+ demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
+ callback = gr.CSVLogger()
+
+ model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
+ if kwargs['base_model'].strip() not in model_options0:
+ model_options0 = [kwargs['base_model'].strip()] + model_options0
+ lora_options = kwargs['extra_lora_options']
+ if kwargs['lora_weights'].strip() not in lora_options:
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
+ server_options = kwargs['extra_server_options']
+ if kwargs['inference_server'].strip() not in server_options:
+ server_options = [kwargs['inference_server'].strip()] + server_options
+ if os.getenv('OPENAI_API_KEY'):
+ if 'openai_chat' not in server_options:
+ server_options += ['openai_chat']
+ if 'openai' not in server_options:
+ server_options += ['openai']
+
+ # always add in no lora case
+ # add fake space so doesn't go away in gradio dropdown
+ model_options0 = [no_model_str] + sorted(model_options0)
+ lora_options = [no_lora_str] + sorted(lora_options)
+ server_options = [no_server_str] + sorted(server_options)
+ # always add in no model case so can free memory
+ # add fake space so doesn't go away in gradio dropdown
+
+ # transcribe, will be detranscribed before use by evaluate()
+ if not kwargs['base_model'].strip():
+ kwargs['base_model'] = no_model_str
+
+ if not kwargs['lora_weights'].strip():
+ kwargs['lora_weights'] = no_lora_str
+
+ if not kwargs['inference_server'].strip():
+ kwargs['inference_server'] = no_server_str
+
+ # transcribe for gradio
+ kwargs['gpu_id'] = str(kwargs['gpu_id'])
+
+ no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
+ 'base_model') else no_model_msg
+ output_label0_model2 = no_model_msg
+
+ def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
+ if not prompt_type1 or which_model != 0:
+ # keep prompt_type and prompt_dict in sync if possible
+ prompt_type1 = kwargs.get('prompt_type', prompt_type1)
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
+ # prefer model specific prompt type instead of global one
+ if not prompt_type1 or which_model != 0:
+ prompt_type1 = model_state1.get('prompt_type', prompt_type1)
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
+
+ if not prompt_dict1 or which_model != 0:
+ # if still not defined, try to get
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
+ if not prompt_dict1 or which_model != 0:
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
+ return prompt_type1, prompt_dict1
+
+ def visible_models_to_model_choice(visible_models1):
+ if isinstance(visible_models1, list):
+ assert len(
+ visible_models1) >= 1, "Invalid visible_models1=%s, can only be single entry" % visible_models1
+ # just take first
+ model_active_choice1 = visible_models1[0]
+ elif isinstance(visible_models1, (str, int)):
+ model_active_choice1 = visible_models1
+ else:
+ assert isinstance(visible_models1, type(None)), "Invalid visible_models1=%s" % visible_models1
+ model_active_choice1 = visible_models1
+ if model_active_choice1 is not None:
+ if isinstance(model_active_choice1, str):
+ base_model_list = [x['base_model'] for x in model_states]
+ if model_active_choice1 in base_model_list:
+ # if dups, will just be first one
+ model_active_choice1 = base_model_list.index(model_active_choice1)
+ else:
+ # NOTE: Could raise, but sometimes raising in certain places fails too hard and requires UI restart
+ model_active_choice1 = 0
+ else:
+ model_active_choice1 = 0
+ return model_active_choice1
+
+ default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
+ # ensure prompt_type consistent with prep_bot(), so nochat API works same way
+ default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
+ update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
+ model_state1=model_state0,
+ which_model=visible_models_to_model_choice(kwargs['visible_models']))
+ for k in no_default_param_names:
+ default_kwargs[k] = ''
+
+ def dummy_fun(x):
+ # need dummy function to block new input from being sent until output is done,
+ # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
+ return x
+
+ def update_auth_selection(auth_user, selection_docs_state1, save=False):
+ # in-place update of both
+ if 'selection_docs_state' not in auth_user:
+ auth_user['selection_docs_state'] = selection_docs_state0
+ for k, v in auth_user['selection_docs_state'].items():
+ if isinstance(selection_docs_state1[k], dict):
+ if save:
+ auth_user['selection_docs_state'][k].clear()
+ auth_user['selection_docs_state'][k].update(selection_docs_state1[k])
+ else:
+ selection_docs_state1[k].clear()
+ selection_docs_state1[k].update(auth_user['selection_docs_state'][k])
+ elif isinstance(selection_docs_state1[k], list):
+ if save:
+ auth_user['selection_docs_state'][k].clear()
+ auth_user['selection_docs_state'][k].extend(selection_docs_state1[k])
+ else:
+ selection_docs_state1[k].clear()
+ selection_docs_state1[k].extend(auth_user['selection_docs_state'][k])
+ else:
+ raise RuntimeError("Bad type: %s" % selection_docs_state1[k])
+
+ # BEGIN AUTH THINGS
+ def auth_func(username1, password1, auth_pairs=None, auth_filename=None,
+ auth_access=None,
+ auth_freeze=None,
+ guest_name=None,
+ selection_docs_state1=None,
+ selection_docs_state00=None,
+ **kwargs):
+ assert auth_freeze is not None
+ if selection_docs_state1 is None:
+ selection_docs_state1 = selection_docs_state00
+ assert selection_docs_state1 is not None
+ assert auth_filename and isinstance(auth_filename, str), "Auth file must be a non-empty string, got: %s" % str(
+ auth_filename)
+ if auth_access == 'open' and username1 == guest_name:
+ return True
+ if username1 == '':
+ # some issue with login
+ return False
+ with filelock.FileLock(auth_filename + '.lock'):
+ auth_dict = {}
+ if os.path.isfile(auth_filename):
+ try:
+ with open(auth_filename, 'rt') as f:
+ auth_dict = json.load(f)
+ except json.decoder.JSONDecodeError as e:
+ print("Auth exception: %s" % str(e), flush=True)
+ shutil.move(auth_filename, auth_filename + '.bak' + str(uuid.uuid4()))
+ auth_dict = {}
+ if username1 in auth_dict and username1 in auth_pairs:
+ if password1 == auth_dict[username1]['password'] and password1 == auth_pairs[username1]:
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ return False
+ elif username1 in auth_dict:
+ if password1 == auth_dict[username1]['password']:
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ return False
+ elif username1 in auth_pairs:
+ # copy over CLI auth to file so only one state to manage
+ auth_dict[username1] = dict(password=auth_pairs[username1], userid=str(uuid.uuid4()))
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ return True
+ else:
+ if auth_access == 'closed':
+ return False
+ # open access
+ auth_dict[username1] = dict(password=password1, userid=str(uuid.uuid4()))
+ auth_user = auth_dict[username1]
+ update_auth_selection(auth_user, selection_docs_state1)
+ save_auth_dict(auth_dict, auth_filename)
+ if auth_access == 'open':
+ return True
+ else:
+ raise RuntimeError("Invalid auth_access: %s" % auth_access)
+
+ def auth_func_open(*args, **kwargs):
+ return True
+
+ def get_username(requests_state1):
+ username1 = None
+ if 'username' in requests_state1:
+ username1 = requests_state1['username']
+ return username1
+
+ def get_userid_auth_func(requests_state1, auth_filename=None, auth_access=None, guest_name=None, **kwargs):
+ if auth_filename and isinstance(auth_filename, str):
+ username1 = get_username(requests_state1)
+ if username1:
+ if username1 == guest_name:
+ return str(uuid.uuid4())
+ with filelock.FileLock(auth_filename + '.lock'):
+ if os.path.isfile(auth_filename):
+ with open(auth_filename, 'rt') as f:
+ auth_dict = json.load(f)
+ if username1 in auth_dict:
+ return auth_dict[username1]['userid']
+ # if here, then not persistently associated with username1,
+ # but should only be one-time asked if going to persist within a single session!
+ return str(uuid.uuid4())
+
+ get_userid_auth = functools.partial(get_userid_auth_func,
+ auth_filename=kwargs['auth_filename'],
+ auth_access=kwargs['auth_access'],
+ guest_name=kwargs['guest_name'],
+ )
+ if kwargs['auth_access'] == 'closed':
+ auth_message1 = "Closed access"
+ else:
+ auth_message1 = "WELCOME! Open access" \
+ " (%s/%s or any unique user/pass)" % (kwargs['guest_name'], kwargs['guest_name'])
+
+ if kwargs['auth_message'] is not None:
+ auth_message = kwargs['auth_message']
+ else:
+ auth_message = auth_message1
+
+ # always use same callable
+ auth_pairs0 = {}
+ if isinstance(kwargs['auth'], list):
+ for k, v in kwargs['auth']:
+ auth_pairs0[k] = v
+ authf = functools.partial(auth_func,
+ auth_pairs=auth_pairs0,
+ auth_filename=kwargs['auth_filename'],
+ auth_access=kwargs['auth_access'],
+ auth_freeze=kwargs['auth_freeze'],
+ guest_name=kwargs['guest_name'],
+ selection_docs_state00=copy.deepcopy(selection_docs_state0))
+
+ def get_request_state(requests_state1, request, db1s):
+ # if need to get state, do it now
+ if not requests_state1:
+ requests_state1 = requests_state0.copy()
+ if requests:
+ if not requests_state1.get('headers', '') and hasattr(request, 'headers'):
+ requests_state1.update(request.headers)
+ if not requests_state1.get('host', '') and hasattr(request, 'host'):
+ requests_state1.update(dict(host=request.host))
+ if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
+ requests_state1.update(dict(host2=request.client.host))
+ if not requests_state1.get('username', '') and hasattr(request, 'username'):
+ # use already-defined username instead of keep changing to new uuid
+ # should be same as in requests_state1
+ db_username = get_username_direct(db1s)
+ requests_state1.update(dict(username=request.username or db_username or str(uuid.uuid4())))
+ requests_state1 = {str(k): str(v) for k, v in requests_state1.items()}
+ return requests_state1
+
+ def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
+ requests_state1 = get_request_state(requests_state1, request, db1s)
+ set_userid(db1s, requests_state1, get_userid_auth)
+ args_list = [db1s, requests_state1] + list(args)
+ return tuple(args_list)
+
+ # END AUTH THINGS
+
+ def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
+ allow = False
+ allow |= langchain_action1 not in LangChainAction.QUERY.value
+ allow |= document_subset1 in DocumentSubset.TopKSources.name
+ if langchain_mode1 in [LangChainMode.LLM.value]:
+ allow = False
+ return allow
+
+ image_loaders_options0, image_loaders_options, \
+ pdf_loaders_options0, pdf_loaders_options, \
+ url_loaders_options0, url_loaders_options = lg_to_gr(**kwargs)
+ jq_schema0 = '.[]'
+
+ with demo:
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
+ # https://github.com/gradio-app/gradio/issues/3558
+ model_state = gr.State(
+ dict(model='model', tokenizer='tokenizer', device=kwargs['device'],
+ base_model=kwargs['base_model'],
+ tokenizer_base_model=kwargs['tokenizer_base_model'],
+ lora_weights=kwargs['lora_weights'],
+ inference_server=kwargs['inference_server'],
+ prompt_type=kwargs['prompt_type'],
+ prompt_dict=kwargs['prompt_dict'],
+ visible_models=kwargs['visible_models'],
+ h2ogpt_key=kwargs['h2ogpt_key'],
+ )
+ )
+
+ def update_langchain_mode_paths(selection_docs_state1):
+ dup = selection_docs_state1['langchain_mode_paths'].copy()
+ for k, v in dup.items():
+ if k not in selection_docs_state1['langchain_modes']:
+ selection_docs_state1['langchain_mode_paths'].pop(k)
+ for k in selection_docs_state1['langchain_modes']:
+ if k not in selection_docs_state1['langchain_mode_types']:
+ # if didn't specify shared, then assume scratch if didn't login or personal if logged in
+ selection_docs_state1['langchain_mode_types'][k] = LangChainTypes.PERSONAL.value
+ return selection_docs_state1
+
+ # Setup some gradio states for per-user dynamic state
+ model_state2 = gr.State(kwargs['model_state_none'].copy())
+ model_options_state = gr.State([model_options0])
+ lora_options_state = gr.State([lora_options])
+ server_options_state = gr.State([server_options])
+ my_db_state = gr.State(my_db_state0)
+ chat_state = gr.State({})
+ docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
+ docs_state0 = []
+ [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
+ docs_state = gr.State(docs_state0)
+ viewable_docs_state0 = []
+ viewable_docs_state = gr.State(viewable_docs_state0)
+ selection_docs_state0 = update_langchain_mode_paths(selection_docs_state0)
+ selection_docs_state = gr.State(selection_docs_state0)
+ requests_state0 = dict(headers='', host='', username='')
+ requests_state = gr.State(requests_state0)
+
+ if description is not None:
+ gr.Markdown(f"""
+ {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
+ """)
+
+ # go button visible if
+ base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
+ go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
+
+ nas = ' '.join(['NA'] * len(kwargs['model_states']))
+ res_value = "Response Score: NA" if not kwargs[
+ 'model_lock'] else "Response Scores: %s" % nas
+
+ user_can_do_sum = kwargs['langchain_mode'] != LangChainMode.DISABLED.value and \
+ (kwargs['visible_side_bar'] or kwargs['visible_system_tab'])
+ if user_can_do_sum:
+ extra_prompt_form = ". For summarization, no query required, just click submit"
+ else:
+ extra_prompt_form = ""
+ if kwargs['input_lines'] > 1:
+ instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
+ else:
+ instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
+
+ def get_langchain_choices(selection_docs_state1):
+ langchain_modes = selection_docs_state1['langchain_modes']
+
+ if is_hf:
+ # don't show 'wiki' since only usually useful for internal testing at moment
+ no_show_modes = ['Disabled', 'wiki']
+ else:
+ no_show_modes = ['Disabled']
+ allowed_modes = langchain_modes.copy()
+ # allowed_modes = [x for x in allowed_modes if x in dbs]
+ allowed_modes += ['LLM']
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
+ allowed_modes += ['MyData']
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
+ allowed_modes += ['UserData']
+ choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
+ return choices
+
+ def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None):
+ langchain_choices1 = get_langchain_choices(selection_docs_state1)
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
+ langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if k in langchain_choices1}
+ if langchain_mode_paths:
+ langchain_mode_paths = langchain_mode_paths.copy()
+ for langchain_mode1 in langchain_modes_non_db:
+ langchain_mode_paths.pop(langchain_mode1, None)
+ df1 = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
+ df1.columns = ['Collection', 'Path']
+ df1 = df1.set_index('Collection')
+ else:
+ df1 = pd.DataFrame(None)
+ langchain_mode_types = selection_docs_state1['langchain_mode_types']
+ langchain_mode_types = {k: v for k, v in langchain_mode_types.items() if k in langchain_choices1}
+ if langchain_mode_types:
+ langchain_mode_types = langchain_mode_types.copy()
+ for langchain_mode1 in langchain_modes_non_db:
+ langchain_mode_types.pop(langchain_mode1, None)
+
+ df2 = pd.DataFrame.from_dict(langchain_mode_types.items(), orient='columns')
+ df2.columns = ['Collection', 'Type']
+ df2 = df2.set_index('Collection')
+
+ from src.gpt_langchain import get_persist_directory, load_embed
+ persist_directory_dict = {}
+ embed_dict = {}
+ chroma_version_dict = {}
+ for langchain_mode3 in langchain_mode_types:
+ langchain_type3 = langchain_mode_types.get(langchain_mode3, LangChainTypes.EITHER.value)
+ persist_directory3, langchain_type3 = get_persist_directory(langchain_mode3,
+ langchain_type=langchain_type3,
+ db1s=db1s, dbs=dbs1)
+ got_embedding3, use_openai_embedding3, hf_embedding_model3 = load_embed(
+ persist_directory=persist_directory3)
+ persist_directory_dict[langchain_mode3] = persist_directory3
+ embed_dict[langchain_mode3] = 'OpenAI' if not hf_embedding_model3 else hf_embedding_model3
+
+ if os.path.isfile(os.path.join(persist_directory3, 'chroma.sqlite3')):
+ chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4'
+ elif os.path.isdir(os.path.join(persist_directory3, 'index')):
+ chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4'
+ elif not os.listdir(persist_directory3):
+ if db_type == 'chroma':
+ chroma_version_dict[langchain_mode3] = 'ChromaDB>=0.4' # will be
+ elif db_type == 'chroma_old':
+ chroma_version_dict[langchain_mode3] = 'ChromaDB<0.4' # will be
+ else:
+ chroma_version_dict[langchain_mode3] = 'Weaviate' # will be
+ if isinstance(hf_embedding_model, dict):
+ hf_embedding_model3 = hf_embedding_model['name']
+ else:
+ hf_embedding_model3 = hf_embedding_model
+ assert isinstance(hf_embedding_model3, str)
+ embed_dict[langchain_mode3] = hf_embedding_model3 # will be
+ else:
+ chroma_version_dict[langchain_mode3] = 'Weaviate'
+
+ df3 = pd.DataFrame.from_dict(persist_directory_dict.items(), orient='columns')
+ df3.columns = ['Collection', 'Directory']
+ df3 = df3.set_index('Collection')
+
+ df4 = pd.DataFrame.from_dict(embed_dict.items(), orient='columns')
+ df4.columns = ['Collection', 'Embedding']
+ df4 = df4.set_index('Collection')
+
+ df5 = pd.DataFrame.from_dict(chroma_version_dict.items(), orient='columns')
+ df5.columns = ['Collection', 'DB']
+ df5 = df5.set_index('Collection')
+ else:
+ df2 = pd.DataFrame(None)
+ df3 = pd.DataFrame(None)
+ df4 = pd.DataFrame(None)
+ df5 = pd.DataFrame(None)
+ df_list = [df2, df1, df3, df4, df5]
+ df_list = [x for x in df_list if x.shape[1] > 0]
+ if len(df_list) > 1:
+ df = df_list[0].join(df_list[1:]).replace(np.nan, '').reset_index()
+ elif len(df_list) == 0:
+ df = df_list[0].replace(np.nan, '').reset_index()
+ else:
+ df = pd.DataFrame(None)
+ return df
+
+ normal_block = gr.Row(visible=not base_wanted, equal_height=False, elem_id="col_container")
+ with normal_block:
+ side_bar = gr.Column(elem_id="sidebar", scale=1, min_width=100, visible=kwargs['visible_side_bar'])
+ with side_bar:
+ with gr.Accordion("Chats", open=False, visible=True):
+ radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
+ visible=True, interactive=True,
+ type='value')
+ upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
+ with gr.Accordion("Upload", open=False, visible=upload_visible):
+ with gr.Column():
+ with gr.Row(equal_height=False):
+ fileup_output = gr.File(show_label=False,
+ file_types=['.' + x for x in file_types],
+ # file_types=['*', '*.*'], # for iPhone etc. needs to be unconstrained else doesn't work with extension-based restrictions
+ file_count="multiple",
+ scale=1,
+ min_width=0,
+ elem_id="warning", elem_classes="feedback",
+ )
+ fileup_output_text = gr.Textbox(visible=False)
+ max_quality = gr.Checkbox(label="Maximum Ingest Quality", value=kwargs['max_quality'],
+ visible=not is_public)
+ url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
+ url_label = 'URL/ArXiv' if have_arxiv else 'URL'
+ url_text = gr.Textbox(label=url_label,
+ # placeholder="Enter Submits",
+ max_lines=1,
+ interactive=True)
+ text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
+ user_text_text = gr.Textbox(label='Paste Text',
+ # placeholder="Enter Submits",
+ interactive=True,
+ visible=text_visible)
+ github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
+ database_visible = kwargs['langchain_mode'] != 'Disabled'
+ with gr.Accordion("Resources", open=False, visible=database_visible):
+ langchain_choices0 = get_langchain_choices(selection_docs_state0)
+ langchain_mode = gr.Radio(
+ langchain_choices0,
+ value=kwargs['langchain_mode'],
+ label="Collections",
+ show_label=True,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ min_width=100)
+ add_chat_history_to_context = gr.Checkbox(label="Chat History",
+ value=kwargs['add_chat_history_to_context'])
+ add_search_to_context = gr.Checkbox(label="Web Search",
+ value=kwargs['add_search_to_context'],
+ visible=os.environ.get('SERPAPI_API_KEY') is not None \
+ and have_serpapi)
+ document_subset = gr.Radio([x.name for x in DocumentSubset],
+ label="Subset",
+ value=DocumentSubset.Relevant.name,
+ interactive=True,
+ )
+ allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
+ langchain_action = gr.Radio(
+ allowed_actions,
+ value=allowed_actions[0] if len(allowed_actions) > 0 else None,
+ label="Action",
+ visible=True)
+ allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
+ if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.JSON.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.JSON.value)
+ if os.getenv('OPENAI_API_KEY') is None and LangChainAgent.PYTHON.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.PYTHON.value)
+ if LangChainAgent.PANDAS.value in allowed_agents:
+ allowed_agents.remove(LangChainAgent.PANDAS.value)
+ langchain_agents = gr.Dropdown(
+ allowed_agents,
+ value=None,
+ label="Agents",
+ multiselect=True,
+ interactive=True,
+ visible=True,
+ elem_id="langchain_agents",
+ filterable=False)
+ visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
+ 'large_file_count_mode']
+ row_doc_track = gr.Row(visible=visible_doc_track)
+ with row_doc_track:
+ if kwargs['langchain_mode'] in langchain_modes_non_db:
+ doc_counts_str = "Pure LLM Mode"
+ else:
+ doc_counts_str = "Name: %s\nDocs: Unset\nChunks: Unset" % kwargs['langchain_mode']
+ text_doc_count = gr.Textbox(lines=3, label="Doc Counts", value=doc_counts_str,
+ visible=visible_doc_track)
+ text_file_last = gr.Textbox(lines=1, label="Newest Doc", value=None, visible=visible_doc_track)
+ text_viewable_doc_count = gr.Textbox(lines=2, label=None, visible=False)
+ col_tabs = gr.Column(elem_id="col-tabs", scale=10)
+ with col_tabs, gr.Tabs():
+ if kwargs['chat_tables']:
+ chat_tab = gr.Row(visible=True)
+ else:
+ chat_tab = gr.TabItem("Chat") \
+ if kwargs['visible_chat_tab'] else gr.Row(visible=False)
+ with chat_tab:
+ if kwargs['langchain_mode'] == 'Disabled':
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True,
+ visible=not kwargs['chat'])
+ else:
+ # text looks a bit worse, but HTML links work
+ text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat'])
+ with gr.Row():
+ # NOCHAT
+ instruction_nochat = gr.Textbox(
+ lines=kwargs['input_lines'],
+ label=instruction_label_nochat,
+ placeholder=kwargs['placeholder_instruction'],
+ visible=not kwargs['chat'],
+ )
+ iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
+ placeholder=kwargs['placeholder_input'],
+ value=kwargs['iinput'],
+ visible=not kwargs['chat'])
+ submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat'])
+ flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat'])
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False,
+ visible=not kwargs['chat'])
+ submit_nochat_api = gr.Button("Submit nochat API", visible=False)
+ submit_nochat_api_plain = gr.Button("Submit nochat API Plain", visible=False)
+ inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
+ text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
+ show_copy_button=True)
+
+ visible_upload = (allow_upload_to_user_data or
+ allow_upload_to_my_data) and \
+ kwargs['langchain_mode'] != 'Disabled'
+ # CHAT
+ col_chat = gr.Column(visible=kwargs['chat'])
+ with col_chat:
+ with gr.Row():
+ with gr.Column(scale=50):
+ with gr.Row(elem_id="prompt-form-row"):
+ label_instruction = 'Ask anything'
+ instruction = gr.Textbox(
+ lines=kwargs['input_lines'],
+ label=label_instruction,
+ placeholder=instruction_label,
+ info=None,
+ elem_id='prompt-form',
+ container=True,
+ )
+ attach_button = gr.UploadButton(
+ elem_id="attach-button" if visible_upload else None,
+ value="",
+ label="Upload File(s)",
+ size="sm",
+ min_width=24,
+ file_types=['.' + x for x in file_types],
+ file_count="multiple",
+ visible=visible_upload)
+
+ submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
+ with submit_buttons:
+ mw1 = 50
+ mw2 = 50
+ with gr.Column(min_width=mw1):
+ submit = gr.Button(value='Submit', variant='primary', size='sm',
+ min_width=mw1)
+ stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
+ min_width=mw1)
+ save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
+ with gr.Column(min_width=mw2):
+ retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
+ undo = gr.Button("Undo", size='sm', min_width=mw2)
+ clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
+
+ visible_model_choice = bool(kwargs['model_lock']) and \
+ len(model_states) > 1 and \
+ kwargs['visible_visible_models']
+ with gr.Row(visible=visible_model_choice):
+ visible_models = gr.Dropdown(kwargs['all_models'],
+ label="Visible Models",
+ value=visible_models_state0,
+ interactive=True,
+ multiselect=True,
+ visible=visible_model_choice,
+ elem_id="visible-models",
+ filterable=False,
+ )
+
+ text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
+ **kwargs)
+
+ with gr.Row():
+ with gr.Column(visible=kwargs['score_model']):
+ score_text = gr.Textbox(res_value,
+ show_label=False,
+ visible=True)
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
+ visible=False and not kwargs['model_lock'])
+
+ doc_selection_tab = gr.TabItem("Document Selection") \
+ if kwargs['visible_doc_selection_tab'] else gr.Row(visible=False)
+ with doc_selection_tab:
+ if kwargs['langchain_mode'] in langchain_modes_non_db:
+ dlabel1 = 'Choose Resources->Collections and Pick Collection'
+ active_collection = gr.Markdown(value="#### Not Chatting with Any Collection\n%s" % dlabel1)
+ else:
+ dlabel1 = 'Select Subset of Document(s) for Chat with Collection: %s' % kwargs['langchain_mode']
+ active_collection = gr.Markdown(
+ value="#### Chatting with Collection: %s" % kwargs['langchain_mode'])
+ document_choice = gr.Dropdown(docs_state0,
+ label=dlabel1,
+ value=[DocumentChoice.ALL.value],
+ interactive=True,
+ multiselect=True,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ )
+ sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
+ with gr.Row():
+ with gr.Column(scale=1):
+ get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
+ visible=sources_visible and kwargs['large_file_count_mode'])
+ # handle API get sources
+ get_sources_api_btn = gr.Button(visible=False)
+ get_sources_api_text = gr.Textbox(visible=False)
+
+ get_document_api_btn = gr.Button(visible=False)
+ get_document_api_text = gr.Textbox(visible=False)
+
+ show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
+ visible=sources_visible and kwargs['large_file_count_mode'])
+ delete_sources_btn = gr.Button(value="Delete Selected Sources from DB", scale=0, size='sm',
+ visible=sources_visible)
+ refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
+ size='sm',
+ visible=sources_visible and allow_upload_to_user_data)
+ with gr.Column(scale=4):
+ pass
+ visible_add_remove_collection = visible_upload
+ with gr.Row():
+ with gr.Column(scale=1):
+ add_placeholder = "e.g. UserData2, shared, user_path2" \
+ if not is_public else "e.g. MyData2, personal (optional)"
+ remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
+ new_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Add Collection',
+ placeholder=add_placeholder,
+ interactive=True)
+ remove_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Remove Collection from UI',
+ placeholder=remove_placeholder,
+ interactive=True)
+ purge_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
+ label='Purge Collection (UI, DB, & source files)',
+ placeholder=remove_placeholder,
+ interactive=True)
+ sync_sources_btn = gr.Button(
+ value="Synchronize DB and UI [only required if did not login and have shared docs]",
+ scale=0, size='sm',
+ visible=sources_visible and allow_upload_to_user_data and not kwargs[
+ 'large_file_count_mode'])
+ load_langchain = gr.Button(
+ value="Load Collections State [only required if logged in another user ", scale=0,
+ size='sm',
+ visible=False and allow_upload_to_user_data and
+ kwargs['langchain_mode'] != 'Disabled')
+ with gr.Column(scale=5):
+ if kwargs['langchain_mode'] != 'Disabled' and visible_add_remove_collection:
+ df0 = get_df_langchain_mode_paths(selection_docs_state0, None, dbs1=dbs)
+ else:
+ df0 = pd.DataFrame(None)
+ langchain_mode_path_text = gr.Dataframe(value=df0,
+ visible=visible_add_remove_collection,
+ label='LangChain Mode-Path',
+ show_label=False,
+ interactive=False)
+
+ sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
+ equal_height=False)
+ with sources_row:
+ with gr.Column(scale=1):
+ file_source = gr.File(interactive=False,
+ label="Download File w/Sources")
+ with gr.Column(scale=2):
+ sources_text = gr.HTML(label='Sources Added', interactive=False)
+
+ doc_exception_text = gr.Textbox(value="", label='Document Exceptions',
+ interactive=False,
+ visible=kwargs['langchain_mode'] != 'Disabled')
+ file_types_str = ' '.join(file_types) + ' URL ArXiv TEXT'
+ gr.Textbox(value=file_types_str, label='Document Types Supported',
+ lines=2,
+ interactive=False,
+ visible=kwargs['langchain_mode'] != 'Disabled')
+
+ doc_view_tab = gr.TabItem("Document Viewer") \
+ if kwargs['visible_doc_view_tab'] else gr.Row(visible=False)
+ with doc_view_tab:
+ with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled'):
+ with gr.Column(scale=2):
+ get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0,
+ size='sm',
+ visible=sources_visible and kwargs[
+ 'large_file_count_mode'])
+ view_document_choice = gr.Dropdown(viewable_docs_state0,
+ label="Select Single Document to View",
+ value=None,
+ interactive=True,
+ multiselect=False,
+ visible=True,
+ )
+ info_view_raw = "Raw text shown if render of original doc fails"
+ if is_public:
+ info_view_raw += " (Up to %s chunks in public portal)" % kwargs['max_raw_chunks']
+ view_raw_text_checkbox = gr.Checkbox(label="View Database Text", value=False,
+ info=info_view_raw,
+ visible=kwargs['db_type'] in ['chroma', 'chroma_old'])
+ with gr.Column(scale=4):
+ pass
+ doc_view = gr.HTML(visible=False)
+ doc_view2 = gr.Dataframe(visible=False)
+ doc_view3 = gr.JSON(visible=False)
+ doc_view4 = gr.Markdown(visible=False)
+ doc_view5 = gr.HTML(visible=False)
+
+ chat_tab = gr.TabItem("Chat History") \
+ if kwargs['visible_chat_history_tab'] else gr.Row(visible=False)
+ with chat_tab:
+ with gr.Row():
+ with gr.Column(scale=1):
+ remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm')
+ flag_btn = gr.Button("Flag Current Chat", size='sm')
+ export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
+ with gr.Column(scale=4):
+ pass
+ with gr.Row():
+ chats_file = gr.File(interactive=False, label="Download Exported Chats")
+ chatsup_output = gr.File(label="Upload Chat File(s)",
+ file_types=['.json'],
+ file_count='multiple',
+ elem_id="warning", elem_classes="feedback")
+ with gr.Row():
+ if 'mbart-' in kwargs['model_lower']:
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['src_lang'],
+ label="Input Language")
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['tgt_lang'],
+ label="Output Language")
+
+ chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions',
+ interactive=False)
+ expert_tab = gr.TabItem("Expert") \
+ if kwargs['visible_expert_tab'] else gr.Row(visible=False)
+ with expert_tab:
+ with gr.Row():
+ with gr.Column():
+ prompt_type = gr.Dropdown(prompt_types_strings,
+ value=kwargs['prompt_type'], label="Prompt Type",
+ visible=not kwargs['model_lock'],
+ interactive=not is_public,
+ )
+ prompt_type2 = gr.Dropdown(prompt_types_strings,
+ value=kwargs['prompt_type'], label="Prompt Type Model 2",
+ visible=False and not kwargs['model_lock'],
+ interactive=not is_public)
+ system_prompt = gr.Textbox(label="System Prompt",
+ info="If 'auto', then uses model's system prompt,"
+ " else use this message."
+ " If empty, no system message is used",
+ value=kwargs['system_prompt'])
+ context = gr.Textbox(lines=2, label="System Pre-Context",
+ info="Directly pre-appended without prompt processing (before Pre-Conversation)",
+ value=kwargs['context'])
+ chat_conversation = gr.Textbox(lines=2, label="Pre-Conversation",
+ info="Pre-append conversation for instruct/chat models as List of tuple of (human, bot)",
+ value=kwargs['chat_conversation'])
+ text_context_list = gr.Textbox(lines=2, label="Text Doc Q/A",
+ info="List of strings, for document Q/A, for bypassing database (i.e. also works in LLM Mode)",
+ value=kwargs['chat_conversation'],
+ visible=not is_public, # primarily meant for API
+ )
+ iinput = gr.Textbox(lines=2, label="Input for Instruct prompt types",
+ info="If given for document query, added after query",
+ value=kwargs['iinput'],
+ placeholder=kwargs['placeholder_input'],
+ interactive=not is_public)
+ with gr.Column():
+ pre_prompt_query = gr.Textbox(label="Query Pre-Prompt",
+ info="Added before documents",
+ value=kwargs['pre_prompt_query'] or '')
+ prompt_query = gr.Textbox(label="Query Prompt",
+ info="Added after documents",
+ value=kwargs['prompt_query'] or '')
+ pre_prompt_summary = gr.Textbox(label="Summary Pre-Prompt",
+ info="Added before documents",
+ value=kwargs['pre_prompt_summary'] or '')
+ prompt_summary = gr.Textbox(label="Summary Prompt",
+ info="Added after documents (if query given, 'Focusing on {query}, ' is pre-appended)",
+ value=kwargs['prompt_summary'] or '')
+ with gr.Row(visible=not is_public):
+ image_loaders = gr.CheckboxGroup(image_loaders_options,
+ label="Force Image Reader",
+ value=image_loaders_options0)
+ pdf_loaders = gr.CheckboxGroup(pdf_loaders_options,
+ label="Force PDF Reader",
+ value=pdf_loaders_options0)
+ url_loaders = gr.CheckboxGroup(url_loaders_options,
+ label="Force URL Reader", value=url_loaders_options0)
+ jq_schema = gr.Textbox(label="JSON jq_schema", value=jq_schema0)
+
+ min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
+ top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1,
+ value=kwargs['top_k_docs'],
+ label=label_top_k_docs,
+ # info="For LangChain",
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ chunk_size = gr.Number(value=kwargs['chunk_size'],
+ label="Chunk size for document chunking",
+ info="For LangChain (ignored if chunk=False)",
+ minimum=128,
+ maximum=2048,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public,
+ precision=0)
+ docs_ordering_type = gr.Radio(
+ docs_ordering_types,
+ value=kwargs['docs_ordering_type'],
+ label="Document Sorting in LLM Context",
+ visible=True)
+ chunk = gr.components.Checkbox(value=kwargs['chunk'],
+ label="Whether to chunk documents",
+ info="For LangChain",
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ embed = gr.components.Checkbox(value=True,
+ label="Whether to embed text",
+ info="For LangChain",
+ visible=False)
+ with gr.Row():
+ stream_output = gr.components.Checkbox(label="Stream output",
+ value=kwargs['stream_output'])
+ do_sample = gr.Checkbox(label="Sample",
+ info="Enable sampler (required for use of temperature, top_p, top_k)",
+ value=kwargs['do_sample'])
+ max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1,
+ value=min(kwargs['max_max_time'],
+ kwargs['max_time']), label="Max. time",
+ info="Max. time to search optimal output.")
+ temperature = gr.Slider(minimum=0.01, maximum=2,
+ value=kwargs['temperature'],
+ label="Temperature",
+ info="Lower is deterministic, higher more creative")
+ top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3,
+ value=kwargs['top_p'], label="Top p",
+ info="Cumulative probability of tokens to sample from")
+ top_k = gr.Slider(
+ minimum=1, maximum=100, step=1,
+ value=kwargs['top_k'], label="Top k",
+ info='Num. tokens to sample from'
+ )
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/106
+ if os.getenv('TESTINGFAIL'):
+ max_beams = 8 if not (memory_restriction_level or is_public) else 1
+ else:
+ max_beams = 1
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
+ info="Number of searches for optimal overall probability. "
+ "Uses more GPU memory/compute",
+ interactive=False, visible=max_beams > 1)
+ max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs)
+ max_new_tokens = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
+ )
+ min_new_tokens = gr.Slider(
+ minimum=0, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
+ )
+ max_new_tokens2 = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
+ min_new_tokens2 = gr.Slider(
+ minimum=0, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
+ min_max_new_tokens = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
+ label="Min. of Max output length",
+ )
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
+ value=kwargs['early_stopping'], visible=max_beams > 1)
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
+ value=kwargs['repetition_penalty'],
+ label="Repetition Penalty")
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
+ value=kwargs['num_return_sequences'],
+ label="Number Returns", info="Must be <= num_beams",
+ interactive=not is_public, visible=max_beams > 1)
+ chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
+ visible=False, # no longer support nochat in UI
+ interactive=not is_public,
+ )
+ with gr.Row():
+ count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=not is_public, size='sm')
+ chat_token_count = gr.Textbox(label="Chat Token Count Result", value=None,
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=False)
+
+ models_tab = gr.TabItem("Models") \
+ if kwargs['visible_models_tab'] and not bool(kwargs['model_lock']) else gr.Row(visible=False)
+ with models_tab:
+ load_msg = "Download/Load Model" if not is_public \
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
+ if kwargs['base_model'] not in ['', None, no_model_str]:
+ load_msg += ' [WARNING: Avoid --base_model on CLI for memory efficient Load-Unload]'
+ load_msg2 = load_msg + "(Model 2)"
+ variant_load_msg = 'primary' if not is_public else 'secondary'
+ with gr.Row():
+ n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
+ with gr.Column():
+ with gr.Row():
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
+ load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Base Model",
+ value=kwargs['base_model'])
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
+ value=kwargs['lora_weights'], visible=kwargs['show_lora'])
+ server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server",
+ value=kwargs['inference_server'], visible=not is_public)
+ max_seq_len = gr.Number(value=kwargs['max_seq_len'] or 2048,
+ minimum=128,
+ maximum=2 ** 18,
+ info="If standard LLaMa-2, choose up to 4096",
+ label="max_seq_len")
+ rope_scaling = gr.Textbox(value=str(kwargs['rope_scaling'] or {}),
+ label="rope_scaling")
+ row_llama = gr.Row(visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama')
+ with row_llama:
+ model_path_llama = gr.Textbox(value=kwargs['llamacpp_dict']['model_path_llama'],
+ lines=4,
+ label="Choose LLaMa.cpp Model Path/URL (for Base Model: llama)",
+ visible=kwargs['show_llama'])
+ n_gpu_layers = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'],
+ minimum=0, maximum=100,
+ label="LLaMa.cpp Num. GPU Layers Offloaded",
+ visible=kwargs['show_llama'])
+ n_batch = gr.Number(value=kwargs['llamacpp_dict']['n_batch'],
+ minimum=0, maximum=2048,
+ label="LLaMa.cpp Batch Size",
+ visible=kwargs['show_llama'])
+ n_gqa = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'],
+ minimum=0, maximum=32,
+ label="LLaMa.cpp Num. Group Query Attention (8 for 70B LLaMa2)",
+ visible=kwargs['show_llama'])
+ llamacpp_dict_more = gr.Textbox(value="{}",
+ lines=4,
+ label="Dict for other LLaMa.cpp/GPT4All options",
+ visible=kwargs['show_llama'])
+ row_gpt4all = gr.Row(
+ visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj',
+ 'gpt4all_llama'])
+ with row_gpt4all:
+ model_name_gptj = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'],
+ label="Choose GPT4All GPTJ Model Path/URL (for Base Model: gptj)",
+ visible=kwargs['show_gpt4all'])
+ model_name_gpt4all_llama = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'],
+ label="Choose GPT4All LLaMa Model Path/URL (for Base Model: gpt4all_llama)",
+ visible=kwargs['show_gpt4all'])
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ model_load8bit_checkbox = gr.components.Checkbox(
+ label="Load 8-bit [requires support]",
+ value=kwargs['load_8bit'], interactive=not is_public)
+ model_load4bit_checkbox = gr.components.Checkbox(
+ label="Load 4-bit [requires support]",
+ value=kwargs['load_4bit'], interactive=not is_public)
+ model_low_bit_mode = gr.Slider(value=kwargs['low_bit_mode'],
+ minimum=0, maximum=4, step=1,
+ label="low_bit_mode")
+ model_load_gptq = gr.Textbox(label="gptq", value=kwargs['load_gptq'],
+ interactive=not is_public)
+ model_load_exllama_checkbox = gr.components.Checkbox(
+ label="Load load_exllama [requires support]",
+ value=kwargs['load_exllama'], interactive=not is_public)
+ model_safetensors_checkbox = gr.components.Checkbox(
+ label="Safetensors [requires support]",
+ value=kwargs['use_safetensors'], interactive=not is_public)
+ model_revision = gr.Textbox(label="revision", value=kwargs['revision'],
+ interactive=not is_public)
+ model_use_gpu_id_checkbox = gr.components.Checkbox(
+ label="Choose Devices [If not Checked, use all GPUs]",
+ value=kwargs['use_gpu_id'], interactive=not is_public,
+ visible=n_gpus != 0)
+ model_gpu = gr.Dropdown(n_gpus_list,
+ label="GPU ID [-1 = all GPUs, if Choose is enabled]",
+ value=kwargs['gpu_id'], interactive=not is_public,
+ visible=n_gpus != 0)
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
+ interactive=False)
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
+ visible=kwargs['show_lora'], interactive=False)
+ server_used = gr.Textbox(label="Current Server",
+ value=kwargs['inference_server'],
+ visible=bool(kwargs['inference_server']) and not is_public,
+ interactive=False)
+ prompt_dict = gr.Textbox(label="Prompt (or Custom)",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
+ col_model2 = gr.Column(visible=False)
+ with col_model2:
+ with gr.Row():
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
+ load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
+ model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
+ value=no_model_str)
+ lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
+ value=no_lora_str,
+ visible=kwargs['show_lora'])
+ server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2",
+ value=no_server_str,
+ visible=not is_public)
+ max_seq_len2 = gr.Number(value=kwargs['max_seq_len'] or 2048,
+ minimum=128,
+ maximum=2 ** 18,
+ info="If standard LLaMa-2, choose up to 4096",
+ label="max_seq_len Model 2")
+ rope_scaling2 = gr.Textbox(value=str(kwargs['rope_scaling'] or {}),
+ label="rope_scaling Model 2")
+
+ row_llama2 = gr.Row(
+ visible=kwargs['show_llama'] and kwargs['base_model'] == 'llama')
+ with row_llama2:
+ model_path_llama2 = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_path_llama'],
+ label="Choose LLaMa.cpp Model 2 Path/URL (for Base Model: llama)",
+ lines=4,
+ visible=kwargs['show_llama'])
+ n_gpu_layers2 = gr.Number(value=kwargs['llamacpp_dict']['n_gpu_layers'],
+ minimum=0, maximum=100,
+ label="LLaMa.cpp Num. GPU 2 Layers Offloaded",
+ visible=kwargs['show_llama'])
+ n_batch2 = gr.Number(value=kwargs['llamacpp_dict']['n_batch'],
+ minimum=0, maximum=2048,
+ label="LLaMa.cpp Model 2 Batch Size",
+ visible=kwargs['show_llama'])
+ n_gqa2 = gr.Number(value=kwargs['llamacpp_dict']['n_gqa'],
+ minimum=0, maximum=32,
+ label="LLaMa.cpp Model 2 Num. Group Query Attention (8 for 70B LLaMa2)",
+ visible=kwargs['show_llama'])
+ llamacpp_dict_more2 = gr.Textbox(value="{}",
+ lines=4,
+ label="Model 2 Dict for other LLaMa.cpp/GPT4All options",
+ visible=kwargs['show_llama'])
+ row_gpt4all2 = gr.Row(
+ visible=kwargs['show_gpt4all'] and kwargs['base_model'] in ['gptj',
+ 'gpt4all_llama'])
+ with row_gpt4all2:
+ model_name_gptj2 = gr.Textbox(value=kwargs['llamacpp_dict']['model_name_gptj'],
+ label="Choose GPT4All GPTJ Model 2 Path/URL (for Base Model: gptj)",
+ visible=kwargs['show_gpt4all'])
+ model_name_gpt4all_llama2 = gr.Textbox(
+ value=kwargs['llamacpp_dict']['model_name_gpt4all_llama'],
+ label="Choose GPT4All LLaMa Model 2 Path/URL (for Base Model: gpt4all_llama)",
+ visible=kwargs['show_gpt4all'])
+
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ model_load8bit_checkbox2 = gr.components.Checkbox(
+ label="Load 8-bit (Model 2) [requires support]",
+ value=kwargs['load_8bit'], interactive=not is_public)
+ model_load4bit_checkbox2 = gr.components.Checkbox(
+ label="Load 4-bit (Model 2) [requires support]",
+ value=kwargs['load_4bit'], interactive=not is_public)
+ model_low_bit_mode2 = gr.Slider(value=kwargs['low_bit_mode'],
+ # ok that same as Model 1
+ minimum=0, maximum=4, step=1,
+ label="low_bit_mode (Model 2)")
+ model_load_gptq2 = gr.Textbox(label="gptq (Model 2)", value='',
+ interactive=not is_public)
+ model_load_exllama_checkbox2 = gr.components.Checkbox(
+ label="Load load_exllama (Model 2) [requires support]",
+ value=False, interactive=not is_public)
+ model_safetensors_checkbox2 = gr.components.Checkbox(
+ label="Safetensors (Model 2) [requires support]",
+ value=False, interactive=not is_public)
+ model_revision2 = gr.Textbox(label="revision (Model 2)", value='',
+ interactive=not is_public)
+ model_use_gpu_id_checkbox2 = gr.components.Checkbox(
+ label="Choose Devices (Model 2) [If not Checked, use all GPUs]",
+ value=kwargs[
+ 'use_gpu_id'], interactive=not is_public)
+ model_gpu2 = gr.Dropdown(n_gpus_list,
+ label="GPU ID (Model 2) [-1 = all GPUs, if choose is enabled]",
+ value=kwargs['gpu_id'], interactive=not is_public)
+ # no model/lora loaded ever in model2 by default
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str,
+ interactive=False)
+ lora_used2 = gr.Textbox(label="Current LORA (Model 2)", value=no_lora_str,
+ visible=kwargs['show_lora'], interactive=False)
+ server_used2 = gr.Textbox(label="Current Server (Model 2)", value=no_server_str,
+ interactive=False,
+ visible=not is_public)
+ prompt_dict2 = gr.Textbox(label="Prompt (or Custom) (Model 2)",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
+ compare_checkbox = gr.components.Checkbox(label="Compare Two Models",
+ value=kwargs['model_lock'],
+ visible=not is_public and not kwargs['model_lock'])
+ with gr.Row(visible=not kwargs['model_lock']):
+ with gr.Column(scale=50):
+ new_model = gr.Textbox(label="New Model name/path/URL", interactive=not is_public)
+ with gr.Column(scale=50):
+ new_lora = gr.Textbox(label="New LORA name/path/URL", visible=kwargs['show_lora'],
+ interactive=not is_public)
+ with gr.Column(scale=50):
+ new_server = gr.Textbox(label="New Server url:port", interactive=not is_public)
+ with gr.Row():
+ add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
+ variant=variant_load_msg,
+ size='sm', interactive=not is_public)
+ system_tab = gr.TabItem("System") \
+ if kwargs['visible_system_tab'] else gr.Row(visible=False)
+ with system_tab:
+ with gr.Row():
+ with gr.Column(scale=1):
+ side_bar_text = gr.Textbox('on' if kwargs['visible_side_bar'] else 'off',
+ visible=False, interactive=False)
+ doc_count_text = gr.Textbox('on' if kwargs['visible_doc_track'] else 'off',
+ visible=False, interactive=False)
+ submit_buttons_text = gr.Textbox('on' if kwargs['visible_submit_buttons'] else 'off',
+ visible=False, interactive=False)
+ visible_models_text = gr.Textbox('on' if kwargs['visible_visible_models'] else 'off',
+ visible=False, interactive=False)
+
+ side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
+ doc_count_btn = gr.Button("Toggle SideBar Document Count/Show Newest", variant="secondary",
+ size="sm")
+ submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
+ visible_model_btn = gr.Button("Toggle Visible Models", variant="secondary", size="sm")
+ col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
+ text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
+ step=50, label='Chat Height')
+ dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
+ with gr.Column(scale=4):
+ pass
+ system_visible0 = not is_public and not admin_pass
+ admin_row = gr.Row()
+ with admin_row:
+ with gr.Column(scale=1):
+ admin_pass_textbox = gr.Textbox(label="Admin Password",
+ type='password',
+ visible=not system_visible0)
+ with gr.Column(scale=4):
+ pass
+ system_row = gr.Row(visible=system_visible0)
+ with system_row:
+ with gr.Column():
+ with gr.Row():
+ system_btn = gr.Button(value='Get System Info', size='sm')
+ system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
+ with gr.Row():
+ system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
+ visible=not is_public)
+ system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm')
+ system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
+ visible=not is_public, show_copy_button=True)
+ with gr.Row():
+ system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
+ system_text3 = gr.Textbox(label='Hash', interactive=False,
+ visible=not is_public, show_copy_button=True)
+ system_btn4 = gr.Button(value='Get Model Names', visible=not is_public, size='sm')
+ system_text4 = gr.Textbox(label='Model Names', interactive=False,
+ visible=not is_public, show_copy_button=True)
+
+ with gr.Row():
+ zip_btn = gr.Button("Zip", size='sm')
+ zip_text = gr.Textbox(label="Zip file name", interactive=False)
+ file_output = gr.File(interactive=False, label="Zip file to Download")
+ with gr.Row():
+ s3up_btn = gr.Button("S3UP", size='sm')
+ s3up_text = gr.Textbox(label='S3UP result', interactive=False)
+
+ tos_tab = gr.TabItem("Terms of Service") \
+ if kwargs['visible_tos_tab'] else gr.Row(visible=False)
+ with tos_tab:
+ description = ""
+ description += """
DISCLAIMERS:
etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation + is_same = True + # length of conversation has to be same + if len(x) != len(y): + return False + if len(x) != len(y): + return False + for stepx, stepy in zip(x, y): + if len(stepx) != len(stepy): + # something off with a conversation + return False + for stepxx, stepyy in zip(stepx, stepy): + if len(stepxx) != len(stepyy): + # something off with a conversation + return False + if len(stepxx) != 2: + # something off + return False + if len(stepyy) != 2: + # something off + return False + questionx = stepxx[0].replace('
', '').replace('
', '') if stepxx[0] is not None else None + answerx = stepxx[1].replace('', '').replace('
', '') if stepxx[1] is not None else None + + questiony = stepyy[0].replace('', '').replace('
', '') if stepyy[0] is not None else None + answery = stepyy[1].replace('', '').replace('
', '') if stepyy[1] is not None else None + + if questionx != questiony or answerx != answery: + return False + return is_same + + def save_chat(*args, chat_is_list=False, auth_filename=None, auth_freeze=None): + args_list = list(args) + db1s = args_list[0] + requests_state1 = args_list[1] + args_list = args_list[2:] + if not chat_is_list: + # list of chatbot histories, + # can't pass in list with list of chatbot histories and state due to gradio limits + chat_list = args_list[:-1] + else: + assert len(args_list) == 2 + chat_list = args_list[0] + # if old chat file with single chatbot, get into shape + if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len( + chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str): + chat_list = [chat_list] + # remove None histories + chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] + chat_list_none = [x for x in chat_list if x not in chat_list_not_none] + if len(chat_list_none) > 0 and len(chat_list_not_none) == 0: + raise ValueError("Invalid chat file") + # dict with keys of short chat names, values of list of list of chatbot histories + chat_state1 = args_list[-1] + short_chats = list(chat_state1.keys()) + if len(chat_list_not_none) > 0: + # make short_chat key from only first history, based upon question that is same anyways + chat_first = chat_list_not_none[0] + short_chat = get_short_chat(chat_first, short_chats) + if short_chat: + old_chat_lists = list(chat_state1.values()) + already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) + if not already_exists: + chat_state1[short_chat] = chat_list.copy() + + # reverse so newest at top + choices = list(chat_state1.keys()).copy() + choices.reverse() + + # save saved chats and chatbots to auth file + text_output1 = chat_list[0] + text_output21 = chat_list[1] + text_outputs1 = chat_list[2:] + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1, + text_output1=text_output1, text_output21=text_output21, text_outputs1=text_outputs1) + + return chat_state1, gr.update(choices=choices, value=None) + + def switch_chat(chat_key, chat_state1, num_model_lock=0): + chosen_chat = chat_state1[chat_key] + # deal with possible different size of chat list vs. current list + ret_chat = [None] * (2 + num_model_lock) + for chati in range(0, 2 + num_model_lock): + ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] + return tuple(ret_chat) + + def clear_texts(*args): + return tuple([gr.Textbox.update(value='')] * len(args)) + + def clear_scores(): + return gr.Textbox.update(value=res_value), \ + gr.Textbox.update(value='Response Score: NA'), \ + gr.Textbox.update(value='Response Score: NA') + + switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) + radio_chats.input(switch_chat_fun, + inputs=[radio_chats, chat_state], + outputs=[text_output, text_output2] + text_outputs) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + def remove_chat(chat_key, chat_state1): + if isinstance(chat_key, str): + chat_state1.pop(chat_key, None) + return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1 + + remove_chat_event = remove_chat_btn.click(remove_chat, + inputs=[radio_chats, chat_state], + outputs=[radio_chats, chat_state], + queue=False, api_name='remove_chat') + + def get_chats1(chat_state1): + base = 'chats' + base = makedirs(base, exist_ok=True, tmp_ok=True, use_base=True) + filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) + with open(filename, "wt") as f: + f.write(json.dumps(chat_state1, indent=2)) + return filename + + export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, + api_name='export_chats' if allow_api else None) + + def add_chats_from_file(db1s, requests_state1, file, chat_state1, radio_chats1, chat_exception_text1, + auth_filename=None, auth_freeze=None): + if not file: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + if isinstance(file, str): + files = [file] + else: + files = file + if not files: + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + chat_exception_list = [] + for file1 in files: + try: + if hasattr(file1, 'name'): + file1 = file1.name + with open(file1, "rt") as f: + new_chats = json.loads(f.read()) + for chat1_k, chat1_v in new_chats.items(): + # ignore chat1_k, regenerate and de-dup to avoid loss + chat_state1, _ = save_chat(db1s, requests_state1, chat1_v, chat_state1, chat_is_list=True) + except BaseException as e: + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + ex_str = "File %s exception: %s" % (file1, str(e)) + print(ex_str, flush=True) + chat_exception_list.append(ex_str) + chat_exception_text1 = '\n'.join(chat_exception_list) + # save chat to auth file + save_auth(requests_state1, auth_filename, auth_freeze, chat_state1=chat_state1) + return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 + + # note for update_user_db_func output is ignored for db + chatup_change_eventa = chatsup_output.change(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + add_chats_from_file_func = functools.partial(add_chats_from_file, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + chatup_change_event = chatup_change_eventa.then(add_chats_from_file_func, + inputs=[my_db_state, requests_state] + + [chatsup_output, chat_state, radio_chats, + chat_exception_text], + outputs=[chatsup_output, chat_state, radio_chats, + chat_exception_text], + queue=False, + api_name='add_to_chats' if allow_api else None) + + clear_chat_event = clear_chat_btn.click(fn=clear_texts, + inputs=[text_output, text_output2] + text_outputs, + outputs=[text_output, text_output2] + text_outputs, + queue=False, api_name='clear' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + clear_eventa = save_chat_btn.click(user_state_setup, + inputs=[my_db_state, requests_state, langchain_mode], + outputs=[my_db_state, requests_state, langchain_mode], + show_progress='minimal') + save_chat_func = functools.partial(save_chat, + auth_filename=kwargs['auth_filename'], + auth_freeze=kwargs['auth_freeze'], + ) + clear_event = clear_eventa.then(save_chat_func, + inputs=[my_db_state, requests_state] + + [text_output, text_output2] + text_outputs + + [chat_state], + outputs=[chat_state, radio_chats], + api_name='save_chat' if allow_api else None) + if kwargs['score_model']: + clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + + # NOTE: clear of instruction/iinput for nochat has to come after score, + # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() + no_chat_args = dict(fn=fun, + inputs=[model_state, my_db_state, selection_docs_state, requests_state] + inputs_list, + outputs=text_output_nochat, + queue=queue, + ) + submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + # copy of above with text box submission + submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) + + submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, + inputs=[model_state, my_db_state, selection_docs_state, + requests_state, + inputs_dict_str], + outputs=text_output_nochat_api, + queue=True, # required for generator + api_name='submit_nochat_api' if allow_api else None) + + submit_event_nochat_api_plain = submit_nochat_api_plain.click(fun_with_dict_str_plain, + inputs=inputs_dict_str, + outputs=text_output_nochat_api, + queue=False, + api_name='submit_nochat_plain_api' if allow_api else None) + + def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, + load_8bit, load_4bit, low_bit_mode, + load_gptq, load_exllama, use_safetensors, revision, + use_gpu_id, gpu_id, max_seq_len1, rope_scaling1, + model_path_llama1, model_name_gptj1, model_name_gpt4all_llama1, + n_gpu_layers1, n_batch1, n_gqa1, llamacpp_dict_more1, + system_prompt1): + try: + llamacpp_dict = ast.literal_eval(llamacpp_dict_more1) + except: + print("Failed to use user input for llamacpp_dict_more1 dict", flush=True) + llamacpp_dict = {} + llamacpp_dict.update(dict(model_path_llama=model_path_llama1, + model_name_gptj=model_name_gptj1, + model_name_gpt4all_llama=model_name_gpt4all_llama1, + n_gpu_layers=n_gpu_layers1, + n_batch=n_batch1, + n_gqa=n_gqa1, + )) + + # ensure no API calls reach here + if is_public: + raise RuntimeError("Illegal access for %s" % model_name) + # ensure old model removed from GPU memory + if kwargs['debug']: + print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) + + model0 = model_state0['model'] + if isinstance(model_state_old['model'], str) and \ + model0 is not None and \ + hasattr(model0, 'cpu'): + # best can do, move model loaded at first to CPU + model0.cpu() + + if model_state_old['model'] is not None and \ + not isinstance(model_state_old['model'], str): + if hasattr(model_state_old['model'], 'cpu'): + try: + model_state_old['model'].cpu() + except Exception as e: + # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! + print("Unable to put model on CPU: %s" % str(e), flush=True) + del model_state_old['model'] + model_state_old['model'] = None + + if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): + del model_state_old['tokenizer'] + model_state_old['tokenizer'] = None + + clear_torch_cache() + if kwargs['debug']: + print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True) + if not model_name: + model_name = no_model_str + if model_name == no_model_str: + # no-op if no model, just free memory + # no detranscribe needed for model, never go into evaluate + lora_weights = no_lora_str + server_name = no_server_str + return kwargs['model_state_none'].copy(), \ + model_name, lora_weights, server_name, prompt_type_old, \ + gr.Slider.update(maximum=256), \ + gr.Slider.update(maximum=256) + + # don't deepcopy, can contain model itself + all_kwargs1 = all_kwargs.copy() + all_kwargs1['base_model'] = model_name.strip() + all_kwargs1['load_8bit'] = load_8bit + all_kwargs1['load_4bit'] = load_4bit + all_kwargs1['low_bit_mode'] = low_bit_mode + all_kwargs1['load_gptq'] = load_gptq + all_kwargs1['load_exllama'] = load_exllama + all_kwargs1['use_safetensors'] = use_safetensors + all_kwargs1['revision'] = None if not revision else revision # transcribe, don't pass '' + all_kwargs1['use_gpu_id'] = use_gpu_id + all_kwargs1['gpu_id'] = int(gpu_id) if gpu_id not in [None, 'None'] else None # detranscribe + all_kwargs1['llamacpp_dict'] = llamacpp_dict + all_kwargs1['max_seq_len'] = max_seq_len1 + try: + all_kwargs1['rope_scaling'] = str_to_dict(rope_scaling1) # transcribe + except: + print("Failed to use user input for rope_scaling dict", flush=True) + all_kwargs1['rope_scaling'] = {} + model_lower = model_name.strip().lower() + if model_lower in inv_prompt_type_to_model_lower: + prompt_type1 = inv_prompt_type_to_model_lower[model_lower] + else: + prompt_type1 = prompt_type_old + + # detranscribe + if lora_weights == no_lora_str: + lora_weights = '' + all_kwargs1['lora_weights'] = lora_weights.strip() + if server_name == no_server_str: + server_name = '' + all_kwargs1['inference_server'] = server_name.strip() + + model1, tokenizer1, device1 = get_model(reward_type=False, + **get_kwargs(get_model, exclude_names=['reward_type'], + **all_kwargs1)) + clear_torch_cache() + + tokenizer_base_model = model_name + prompt_dict1, error0 = get_prompt(prompt_type1, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True, system_prompt=system_prompt1) + model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, + base_model=model_name, tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, inference_server=server_name, + prompt_type=prompt_type1, prompt_dict=prompt_dict1, + # FIXME: not typically required, unless want to expose adding h2ogpt endpoint in UI + visible_models=None, h2ogpt_key=None, + ) + + max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) + + if kwargs['debug']: + print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) + return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ + gr.Slider.update(maximum=max_max_new_tokens1), \ + gr.Slider.update(maximum=max_max_new_tokens1) + + def get_prompt_str(prompt_type1, prompt_dict1, system_prompt1, which=0): + if prompt_type1 in ['', None]: + print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) + return str({}) + prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', + reduced=False, making_context=False, return_dict=True, + system_prompt=system_prompt1) + if prompt_dict_error: + return str(prompt_dict_error) + else: + # return so user can manipulate if want and use as custom + return str(prompt_dict1) + + get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) + get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) + prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict, system_prompt], + outputs=prompt_dict, queue=False) + prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2, system_prompt], + outputs=prompt_dict2, + queue=False) + + def dropdown_prompt_type_list(x): + return gr.Dropdown.update(value=x) + + def chatbot_list(x, model_used_in): + return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') + + load_model_args = dict(fn=load_model, + inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, + model_load8bit_checkbox, model_load4bit_checkbox, model_low_bit_mode, + model_load_gptq, model_load_exllama_checkbox, + model_safetensors_checkbox, model_revision, + model_use_gpu_id_checkbox, model_gpu, + max_seq_len, rope_scaling, + model_path_llama, model_name_gptj, model_name_gpt4all_llama, + n_gpu_layers, n_batch, n_gqa, llamacpp_dict_more, + system_prompt], + outputs=[model_state, model_used, lora_used, server_used, + # if prompt_type changes, prompt_dict will change via change rule + prompt_type, max_new_tokens, min_new_tokens, + ]) + prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) + chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) + nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) + load_model_event = load_model_button.click(**load_model_args, + api_name='load_model' if allow_api and not is_public else None) \ + .then(**prompt_update_args) \ + .then(**chatbot_update_args) \ + .then(**nochat_update_args) \ + .then(clear_torch_cache) + + load_model_args2 = dict(fn=load_model, + inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, + model_load8bit_checkbox2, model_load4bit_checkbox2, model_low_bit_mode2, + model_load_gptq2, model_load_exllama_checkbox2, + model_safetensors_checkbox2, model_revision2, + model_use_gpu_id_checkbox2, model_gpu2, + max_seq_len2, rope_scaling2, + model_path_llama2, model_name_gptj2, model_name_gpt4all_llama2, + n_gpu_layers2, n_batch2, n_gqa2, llamacpp_dict_more2, + system_prompt], + outputs=[model_state2, model_used2, lora_used2, server_used2, + # if prompt_type2 changes, prompt_dict2 will change via change rule + prompt_type2, max_new_tokens2, min_new_tokens2 + ]) + prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) + chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) + load_model_event2 = load_model_button2.click(**load_model_args2, + api_name='load_model2' if allow_api and not is_public else None) \ + .then(**prompt_update_args2) \ + .then(**chatbot_update_args2) \ + .then(clear_torch_cache) + + def dropdown_model_lora_server_list(model_list0, model_x, + lora_list0, lora_x, + server_list0, server_x, + model_used1, lora_used1, server_used1, + model_used2, lora_used2, server_used2, + ): + model_new_state = [model_list0[0] + [model_x]] + model_new_options = [*model_new_state[0]] + if no_model_str in model_new_options: + model_new_options.remove(no_model_str) + model_new_options = [no_model_str] + sorted(model_new_options) + x1 = model_x if model_used1 == no_model_str else model_used1 + x2 = model_x if model_used2 == no_model_str else model_used2 + ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), + gr.Dropdown.update(value=x2, choices=model_new_options), + '', model_new_state] + + lora_new_state = [lora_list0[0] + [lora_x]] + lora_new_options = [*lora_new_state[0]] + if no_lora_str in lora_new_options: + lora_new_options.remove(no_lora_str) + lora_new_options = [no_lora_str] + sorted(lora_new_options) + # don't switch drop-down to added lora if already have model loaded + x1 = lora_x if model_used1 == no_model_str else lora_used1 + x2 = lora_x if model_used2 == no_model_str else lora_used2 + ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), + gr.Dropdown.update(value=x2, choices=lora_new_options), + '', lora_new_state] + + server_new_state = [server_list0[0] + [server_x]] + server_new_options = [*server_new_state[0]] + if no_server_str in server_new_options: + server_new_options.remove(no_server_str) + server_new_options = [no_server_str] + sorted(server_new_options) + # don't switch drop-down to added server if already have model loaded + x1 = server_x if model_used1 == no_model_str else server_used1 + x2 = server_x if model_used2 == no_model_str else server_used2 + ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), + gr.Dropdown.update(value=x2, choices=server_new_options), + '', server_new_state] + + return tuple(ret1 + ret2 + ret3) + + add_model_lora_server_event = \ + add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, + inputs=[model_options_state, new_model] + + [lora_options_state, new_lora] + + [server_options_state, new_server] + + [model_used, lora_used, server_used] + + [model_used2, lora_used2, server_used2], + outputs=[model_choice, model_choice2, new_model, model_options_state] + + [lora_choice, lora_choice2, new_lora, lora_options_state] + + [server_choice, server_choice2, new_server, + server_options_state], + queue=False) + + go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, + queue=False) \ + .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \ + .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False) + + def compare_textbox_fun(x): + return gr.Textbox.update(visible=x) + + def compare_column_fun(x): + return gr.Column.update(visible=x) + + def compare_prompt_fun(x): + return gr.Dropdown.update(visible=x) + + def slider_fun(x): + return gr.Slider.update(visible=x) + + compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, + api_name="compare_checkbox" if allow_api else None) \ + .then(compare_column_fun, compare_checkbox, col_model2) \ + .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ + .then(compare_textbox_fun, compare_checkbox, score_text2) \ + .then(slider_fun, compare_checkbox, max_new_tokens2) \ + .then(slider_fun, compare_checkbox, min_new_tokens2) + # FIXME: add score_res2 in condition, but do better + + # callback for logging flagged input/output + callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") + flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, + None, + preprocess=False, + api_name='flag' if allow_api else None, queue=False) + flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, + preprocess=False, + api_name='flag_nochat' if allow_api else None, queue=False) + + def get_system_info(): + if is_public: + time.sleep(10) # delay to avoid spam since queue=False + return gr.Textbox.update(value=system_info_print()) + + system_event = system_btn.click(get_system_info, outputs=system_text, + api_name='system_info' if allow_api else None, queue=False) + + def get_system_info_dict(system_input1, **kwargs1): + if system_input1 != os.getenv("ADMIN_PASS", ""): + return json.dumps({}) + exclude_list = ['admin_pass', 'examples'] + sys_dict = {k: v for k, v in kwargs1.items() if + isinstance(v, (str, int, bool, float)) and k not in exclude_list} + try: + sys_dict.update(system_info()) + except Exception as e: + # protection + print("Exception: %s" % str(e), flush=True) + return json.dumps(sys_dict) + + system_kwargs = all_kwargs.copy() + system_kwargs.update(dict(command=str(' '.join(sys.argv)))) + get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) + + system_dict_event = system_btn2.click(get_system_info_dict_func, + inputs=system_input, + outputs=system_text2, + api_name='system_info_dict' if allow_api else None, + queue=False, # queue to avoid spam + ) + + def get_hash(): + return kwargs['git_hash'] + + system_event = system_btn3.click(get_hash, + outputs=system_text3, + api_name='system_hash' if allow_api else None, + queue=False, + ) + + def get_model_names(): + key_list = ['base_model', 'prompt_type', 'prompt_dict'] + list(kwargs['other_model_state_defaults'].keys()) + # don't want to expose backend inference server IP etc. + # key_list += ['inference_server'] + return [{k: x[k] for k in key_list if k in x} for x in model_states] + + models_list_event = system_btn4.click(get_model_names, + outputs=system_text4, + api_name='model_names' if allow_api else None, + queue=False, + ) + + def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, + system_prompt1, chat_conversation1, + memory_restriction_level1=0, + keep_sources_in_context1=False, + ): + if model_state1 and not isinstance(model_state1['tokenizer'], str): + tokenizer = model_state1['tokenizer'] + elif model_state0 and not isinstance(model_state0['tokenizer'], str): + tokenizer = model_state0['tokenizer'] + else: + tokenizer = None + if tokenizer is not None: + langchain_mode1 = 'LLM' + add_chat_history_to_context1 = True + # fake user message to mimic bot() + chat1 = copy.deepcopy(chat1) + chat1 = chat1 + [['user_message1', None]] + model_max_length1 = tokenizer.model_max_length + context1 = history_to_context(chat1, + langchain_mode=langchain_mode1, + add_chat_history_to_context=add_chat_history_to_context1, + prompt_type=prompt_type1, + prompt_dict=prompt_dict1, + chat=True, + model_max_length=model_max_length1, + memory_restriction_level=memory_restriction_level1, + keep_sources_in_context=keep_sources_in_context1, + system_prompt=system_prompt1, + chat_conversation=chat_conversation1) + tokens = tokenizer(context1, return_tensors="pt")['input_ids'] + if len(tokens.shape) == 1: + return str(tokens.shape[0]) + elif len(tokens.shape) == 2: + return str(tokens.shape[1]) + else: + return "N/A" + else: + return "N/A" + + count_chat_tokens_func = functools.partial(count_chat_tokens, + memory_restriction_level1=memory_restriction_level, + keep_sources_in_context1=kwargs['keep_sources_in_context']) + count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens_func, + inputs=[model_state, text_output, prompt_type, prompt_dict, + system_prompt, chat_conversation], + outputs=chat_token_count, + api_name='count_tokens' if allow_api else None) + + # don't pass text_output, don't want to clear output, just stop it + # cancel only stops outer generation, not inner generation or non-generation + stop_btn.click(lambda: None, None, None, + cancels=submits1 + submits2 + submits3 + submits4 + + [submit_event_nochat, submit_event_nochat2] + + [eventdb1, eventdb2, eventdb3] + + [eventdb7a, eventdb7, eventdb8a, eventdb8, eventdb9a, eventdb9, eventdb12a, eventdb12] + + db_events + + [eventdbloadla, eventdbloadlb] + + [clear_event] + + [submit_event_nochat_api, submit_event_nochat] + + [load_model_event, load_model_event2] + + [count_tokens_event] + , + queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) + + if kwargs['auth'] is not None: + auth = authf + load_func = user_state_setup + load_inputs = [my_db_state, requests_state, login_btn, login_btn] + load_outputs = [my_db_state, requests_state, login_btn] + else: + auth = None + load_func, load_inputs, load_outputs = None, None, None + + app_js = wrap_js_to_lambda( + len(load_inputs) if load_inputs else 0, + get_dark_js() if kwargs['dark'] else None, + get_heap_js(heap_app_id) if is_heap_analytics_enabled else None) + + load_event = demo.load(fn=load_func, inputs=load_inputs, outputs=load_outputs, _js=app_js) + + if load_func: + load_event2 = load_event.then(load_login_func, + inputs=login_inputs, + outputs=login_outputs) + if not kwargs['large_file_count_mode']: + load_event3 = load_event2.then(**get_sources_kwargs) + load_event4 = load_event3.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) + load_event5 = load_event4.then(**show_sources_kwargs) + load_event6 = load_event5.then(**get_viewable_sources_args) + load_event7 = load_event6.then(**viewable_kwargs) + + demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) + favicon_file = "h2o-logo.svg" + favicon_path = favicon_file + if not os.path.isfile(favicon_file): + print("favicon_path1=%s not found" % favicon_file, flush=True) + alt_path = os.path.dirname(os.path.abspath(__file__)) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path2: %s not found in %s" % (favicon_file, alt_path), flush=True) + alt_path = os.path.dirname(alt_path) + favicon_path = os.path.join(alt_path, favicon_file) + if not os.path.isfile(favicon_path): + print("favicon_path3: %s not found in %s" % (favicon_file, alt_path), flush=True) + favicon_path = None + + if kwargs['prepare_offline_level'] > 0: + from src.prepare_offline import go_prepare_offline + go_prepare_offline(**locals()) + return + + scheduler = BackgroundScheduler() + scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) + if is_public and \ + kwargs['base_model'] not in non_hf_types: + # FIXME: disable for gptj, langchain or gpt4all modify print itself + # FIXME: and any multi-threaded/async print will enter model output! + scheduler.add_job(func=ping, trigger="interval", seconds=60) + if is_public or os.getenv('PING_GPU'): + scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) + scheduler.start() + + # import control + if kwargs['langchain_mode'] == 'Disabled' and \ + os.environ.get("TEST_LANGCHAIN_IMPORT") and \ + kwargs['base_model'] not in non_hf_types: + assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + + # set port in case GRADIO_SERVER_PORT was already set in prior main() call, + # gradio does not listen if change after import + # Keep None if not set so can find an open port above used ports + server_port = os.getenv('GRADIO_SERVER_PORT') + if server_port is not None: + server_port = int(server_port) + + demo.launch(share=kwargs['share'], + server_name=kwargs['server_name'], + show_error=True, + server_port=server_port, + favicon_path=favicon_path, + prevent_thread_lock=True, + auth=auth, + auth_message=auth_message, + root_path=kwargs['root_path']) + if kwargs['verbose'] or not (kwargs['base_model'] in ['gptj', 'gpt4all_llama']): + print("Started Gradio Server and/or GUI: server_name: %s port: %s" % (kwargs['server_name'], server_port), + flush=True) + if kwargs['block_gradio_exit']: + demo.block_thread() + + +def show_doc(db1s, selection_docs_state1, requests_state1, + langchain_mode1, + single_document_choice1, + view_raw_text_checkbox1, + text_context_list1, + dbs1=None, + load_db_if_exists1=None, + db_type1=None, + use_openai_embedding1=None, + hf_embedding_model1=None, + migrate_embedding_model_or_db1=None, + auto_migrate_db1=None, + verbose1=False, + get_userid_auth1=None, + max_raw_chunks=1000000, + api=False, + n_jobs=-1): + file = single_document_choice1 + document_choice1 = [single_document_choice1] + content = None + db_documents = [] + db_metadatas = [] + if db_type1 in ['chroma', 'chroma_old']: + assert langchain_mode1 is not None + langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] + langchain_mode_types = selection_docs_state1['langchain_mode_types'] + from src.gpt_langchain import set_userid, get_any_db, get_docs_and_meta + set_userid(db1s, requests_state1, get_userid_auth1) + top_k_docs = -1 + db = get_any_db(db1s, langchain_mode1, langchain_mode_paths, langchain_mode_types, + dbs=dbs1, + load_db_if_exists=load_db_if_exists1, + db_type=db_type1, + use_openai_embedding=use_openai_embedding1, + hf_embedding_model=hf_embedding_model1, + migrate_embedding_model=migrate_embedding_model_or_db1, + auto_migrate_db=auto_migrate_db1, + for_sources_list=True, + verbose=verbose1, + n_jobs=n_jobs, + ) + query_action = False # long chunks like would be used for summarize + # the below is as or filter, so will show doc or by chunk, unrestricted + from langchain.vectorstores import Chroma + if isinstance(db, Chroma): + # chroma >= 0.4 + if view_raw_text_checkbox1: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$gte": -1}} + for x in document_choice1][0] + else: + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + else: + # migration for chroma < 0.4 + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice1][0] + if view_raw_text_checkbox1: + # like or, full raw all chunk types + filter_kwargs = dict(filter=one_filter) + else: + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, + text_context_list=text_context_list1) + # order documents + from langchain.docstore.document import Document + docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) + for result in zip(db_documents, db_metadatas)] + doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas] + doc_page_ids = [x.get('page', 0) for x in db_metadatas] + doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] + docs_with_score = [x for hx, px, cx, x in + sorted(zip(doc_hashes, doc_page_ids, doc_chunk_ids, docs_with_score), + key=lambda x: (x[0], x[1], x[2])) + # if cx == -1 + ] + db_metadatas = [x[0].metadata for x in docs_with_score][:max_raw_chunks] + db_documents = [x[0].page_content for x in docs_with_score][:max_raw_chunks] + # done reordering + if view_raw_text_checkbox1: + content = [dict_to_html(x) + '\n' + text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + else: + content = [text_to_html(y) for x, y in zip(db_metadatas, db_documents)] + content = '\n'.join(content) + content = f""" + + +