Verias commited on
Commit
fd63d2f
·
verified ·
1 Parent(s): 9f507f1

Upload 94 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. modules/AutoGPTQ_loader.py +74 -0
  2. modules/GPTQ_loader.py +171 -0
  3. modules/LoRA.py +153 -0
  4. modules/RoPE.py +18 -0
  5. modules/__pycache__/LoRA.cpython-311.pyc +0 -0
  6. modules/__pycache__/RoPE.cpython-311.pyc +0 -0
  7. modules/__pycache__/block_requests.cpython-311.pyc +0 -0
  8. modules/__pycache__/callbacks.cpython-311.pyc +0 -0
  9. modules/__pycache__/chat.cpython-311.pyc +0 -0
  10. modules/__pycache__/evaluate.cpython-311.pyc +0 -0
  11. modules/__pycache__/extensions.cpython-311.pyc +0 -0
  12. modules/__pycache__/github.cpython-311.pyc +0 -0
  13. modules/__pycache__/html_generator.cpython-311.pyc +0 -0
  14. modules/__pycache__/llamacpp_hf.cpython-311.pyc +0 -0
  15. modules/__pycache__/llamacpp_model.cpython-311.pyc +0 -0
  16. modules/__pycache__/loaders.cpython-311.pyc +0 -0
  17. modules/__pycache__/logging_colors.cpython-311.pyc +0 -0
  18. modules/__pycache__/logits.cpython-311.pyc +0 -0
  19. modules/__pycache__/metadata_gguf.cpython-311.pyc +0 -0
  20. modules/__pycache__/models.cpython-311.pyc +0 -0
  21. modules/__pycache__/models_settings.cpython-311.pyc +0 -0
  22. modules/__pycache__/one_click_installer_check.cpython-311.pyc +0 -0
  23. modules/__pycache__/presets.cpython-311.pyc +0 -0
  24. modules/__pycache__/prompts.cpython-311.pyc +0 -0
  25. modules/__pycache__/relative_imports.cpython-311.pyc +0 -0
  26. modules/__pycache__/sampler_hijack.cpython-311.pyc +0 -0
  27. modules/__pycache__/shared.cpython-311.pyc +0 -0
  28. modules/__pycache__/text_generation.cpython-311.pyc +0 -0
  29. modules/__pycache__/training.cpython-311.pyc +0 -0
  30. modules/__pycache__/ui.cpython-311.pyc +0 -0
  31. modules/__pycache__/ui_chat.cpython-311.pyc +0 -0
  32. modules/__pycache__/ui_default.cpython-311.pyc +0 -0
  33. modules/__pycache__/ui_file_saving.cpython-311.pyc +0 -0
  34. modules/__pycache__/ui_model_menu.cpython-311.pyc +0 -0
  35. modules/__pycache__/ui_notebook.cpython-311.pyc +0 -0
  36. modules/__pycache__/ui_parameters.cpython-311.pyc +0 -0
  37. modules/__pycache__/ui_session.cpython-311.pyc +0 -0
  38. modules/__pycache__/utils.cpython-311.pyc +0 -0
  39. modules/block_requests.py +47 -0
  40. modules/callbacks.py +103 -0
  41. modules/chat.py +927 -0
  42. modules/ctransformers_model.py +79 -0
  43. modules/deepspeed_parameters.py +74 -0
  44. modules/evaluate.py +153 -0
  45. modules/exllamav2.py +149 -0
  46. modules/exllamav2_hf.py +170 -0
  47. modules/extensions.py +232 -0
  48. modules/github.py +38 -0
  49. modules/grammar/__pycache__/grammar_utils.cpython-311.pyc +0 -0
  50. modules/grammar/__pycache__/logits_process.cpython-311.pyc +0 -0
modules/AutoGPTQ_loader.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from accelerate.utils import is_xpu_available
4
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
5
+
6
+ import modules.shared as shared
7
+ from modules.logging_colors import logger
8
+ from modules.models import get_max_memory_dict
9
+
10
+
11
+ def load_quantized(model_name):
12
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
13
+ pt_path = None
14
+
15
+ # Find the model checkpoint
16
+ if shared.args.checkpoint:
17
+ pt_path = Path(shared.args.checkpoint)
18
+ else:
19
+ for ext in ['.safetensors', '.pt', '.bin']:
20
+ found = list(path_to_model.glob(f"*{ext}"))
21
+ if len(found) > 0:
22
+ if len(found) > 1:
23
+ logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
24
+
25
+ pt_path = found[-1]
26
+ break
27
+
28
+ if pt_path is None:
29
+ logger.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
30
+ return
31
+
32
+ use_safetensors = pt_path.suffix == '.safetensors'
33
+ if not (path_to_model / "quantize_config.json").exists():
34
+ quantize_config = BaseQuantizeConfig(
35
+ bits=bits if (bits := shared.args.wbits) > 0 else 4,
36
+ group_size=gs if (gs := shared.args.groupsize) > 0 else -1,
37
+ desc_act=shared.args.desc_act
38
+ )
39
+ else:
40
+ quantize_config = None
41
+
42
+ # Define the params for AutoGPTQForCausalLM.from_quantized
43
+ params = {
44
+ 'model_basename': pt_path.stem,
45
+ 'device': "xpu:0" if is_xpu_available() else "cuda:0" if not shared.args.cpu else "cpu",
46
+ 'use_triton': shared.args.triton,
47
+ 'inject_fused_attention': not shared.args.no_inject_fused_attention,
48
+ 'inject_fused_mlp': not shared.args.no_inject_fused_mlp,
49
+ 'use_safetensors': use_safetensors,
50
+ 'trust_remote_code': shared.args.trust_remote_code,
51
+ 'max_memory': get_max_memory_dict(),
52
+ 'quantize_config': quantize_config,
53
+ 'use_cuda_fp16': not shared.args.no_use_cuda_fp16,
54
+ 'disable_exllama': shared.args.disable_exllama,
55
+ 'disable_exllamav2': shared.args.disable_exllamav2,
56
+ }
57
+
58
+ logger.info(f"The AutoGPTQ params are: {params}")
59
+ model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
60
+
61
+ # These lines fix the multimodal extension when used with AutoGPTQ
62
+ if hasattr(model, 'model'):
63
+ if not hasattr(model, 'dtype'):
64
+ if hasattr(model.model, 'dtype'):
65
+ model.dtype = model.model.dtype
66
+
67
+ if hasattr(model.model, 'model') and hasattr(model.model.model, 'embed_tokens'):
68
+ if not hasattr(model, 'embed_tokens'):
69
+ model.embed_tokens = model.model.model.embed_tokens
70
+
71
+ if not hasattr(model.model, 'embed_tokens'):
72
+ model.model.embed_tokens = model.model.model.embed_tokens
73
+
74
+ return model
modules/GPTQ_loader.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ from pathlib import Path
4
+
5
+ import accelerate
6
+ import torch
7
+ import transformers
8
+ from accelerate.utils import is_xpu_available
9
+ from gptq_for_llama import llama_inference_offload
10
+ from gptq_for_llama.modelutils import find_layers
11
+ from gptq_for_llama.quant import make_quant
12
+ from transformers import AutoConfig, AutoModelForCausalLM
13
+
14
+ import modules.shared as shared
15
+ from modules.logging_colors import logger
16
+
17
+
18
+ # This function is a replacement for the load_quant function in the
19
+ # GPTQ-for_LLaMa repository. It supports more models and branches.
20
+ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=None, kernel_switch_threshold=128, eval=True):
21
+ exclude_layers = exclude_layers or ['lm_head']
22
+
23
+ def noop(*args, **kwargs):
24
+ pass
25
+
26
+ config = AutoConfig.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code)
27
+ torch.nn.init.kaiming_uniform_ = noop
28
+ torch.nn.init.uniform_ = noop
29
+ torch.nn.init.normal_ = noop
30
+
31
+ torch.set_default_dtype(torch.half)
32
+ transformers.modeling_utils._init_weights = False
33
+ torch.set_default_dtype(torch.half)
34
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=shared.args.trust_remote_code)
35
+ torch.set_default_dtype(torch.float)
36
+ if eval:
37
+ model = model.eval()
38
+
39
+ layers = find_layers(model)
40
+ for name in exclude_layers:
41
+ if name in layers:
42
+ del layers[name]
43
+
44
+ gptq_args = inspect.getfullargspec(make_quant).args
45
+
46
+ make_quant_kwargs = {
47
+ 'module': model,
48
+ 'names': layers,
49
+ 'bits': wbits,
50
+ }
51
+ if 'groupsize' in gptq_args:
52
+ make_quant_kwargs['groupsize'] = groupsize
53
+ if 'faster' in gptq_args:
54
+ make_quant_kwargs['faster'] = faster_kernel
55
+ if 'kernel_switch_threshold' in gptq_args:
56
+ make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
57
+
58
+ make_quant(**make_quant_kwargs)
59
+
60
+ del layers
61
+ if checkpoint.endswith('.safetensors'):
62
+ from safetensors.torch import load_file as safe_load
63
+ model.load_state_dict(safe_load(checkpoint), strict=False)
64
+ else:
65
+ model.load_state_dict(torch.load(checkpoint, weights_only=True), strict=False)
66
+
67
+ model.seqlen = 2048
68
+ return model
69
+
70
+
71
+ # Used to locate the .pt/.safetensors quantized file
72
+ def find_quantized_model_file(model_name):
73
+ if shared.args.checkpoint:
74
+ return Path(shared.args.checkpoint)
75
+
76
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
77
+ pt_path = None
78
+ priority_name_list = [
79
+ Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
80
+ for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
81
+ for ext in ['.safetensors', '.pt']
82
+ for hyphen in ['-', f'/{model_name}-', '/']
83
+ ]
84
+
85
+ for path in priority_name_list:
86
+ if path.exists():
87
+ pt_path = path
88
+ break
89
+
90
+ # If the model hasn't been found with a well-behaved name, pick the last .pt
91
+ # or the last .safetensors found in its folder as a last resort
92
+ if not pt_path:
93
+ for ext in ['.pt', '.safetensors']:
94
+ found = list(path_to_model.glob(f"*{ext}"))
95
+ if len(found) > 0:
96
+ if len(found) > 1:
97
+ logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
98
+
99
+ pt_path = found[-1]
100
+ break
101
+
102
+ return pt_path
103
+
104
+
105
+ # The function that loads the model in modules/models.py
106
+ def load_quantized(model_name):
107
+ if shared.args.model_type is None:
108
+ logger.error("The model could not be loaded because its type could not be inferred from its name.")
109
+ logger.error("Please specify the type manually using the --model_type argument.")
110
+ return None
111
+
112
+ # Select the appropriate load_quant function
113
+ model_type = shared.args.model_type.lower()
114
+ if shared.args.pre_layer and model_type == 'llama':
115
+ load_quant = llama_inference_offload.load_quant
116
+ elif model_type in ('llama', 'opt', 'gptj'):
117
+ if shared.args.pre_layer:
118
+ logger.warning("Ignoring --pre_layer because it only works for llama model type.")
119
+
120
+ load_quant = _load_quant
121
+ else:
122
+ logger.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
123
+ exit()
124
+
125
+ # Find the quantized model weights file (.pt/.safetensors)
126
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
127
+ pt_path = find_quantized_model_file(model_name)
128
+ if not pt_path:
129
+ logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
130
+ exit()
131
+ else:
132
+ logger.info(f"Found the following quantized model: {pt_path}")
133
+
134
+ # qwopqwop200's offload
135
+ if model_type == 'llama' and shared.args.pre_layer:
136
+ if len(shared.args.pre_layer) == 1:
137
+ pre_layer = shared.args.pre_layer[0]
138
+ else:
139
+ pre_layer = shared.args.pre_layer
140
+
141
+ model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, pre_layer)
142
+ else:
143
+ threshold = False if model_type == 'gptj' else 128
144
+ model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
145
+
146
+ # accelerate offload (doesn't work properly)
147
+ if shared.args.gpu_memory or torch.cuda.device_count() > 1 or (is_xpu_available() and torch.xpu.device_count() > 1):
148
+ if shared.args.gpu_memory:
149
+ memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
150
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
151
+ max_memory = {}
152
+ for i in range(len(memory_map)):
153
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
154
+
155
+ max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
156
+ else:
157
+ max_memory = accelerate.utils.get_balanced_memory(model)
158
+
159
+ device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
160
+ logger.info("Using the following device map for the quantized model:", device_map)
161
+ # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
162
+ model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
163
+
164
+ # No offload
165
+ elif not shared.args.cpu:
166
+ if is_xpu_available():
167
+ model = model.to(torch.device("xpu:0"))
168
+ else:
169
+ model = model.to(torch.device('cuda:0'))
170
+
171
+ return model
modules/LoRA.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from peft import PeftModel
5
+ from transformers import is_torch_xpu_available
6
+
7
+ import modules.shared as shared
8
+ from modules.logging_colors import logger
9
+ from modules.models import reload_model
10
+
11
+
12
+ def add_lora_to_model(lora_names):
13
+ if 'GPTQForCausalLM' in shared.model.__class__.__name__ or shared.args.loader == 'AutoGPTQ':
14
+ add_lora_autogptq(lora_names)
15
+ elif shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader in ['ExLlamav2', 'ExLlamav2_HF']:
16
+ add_lora_exllamav2(lora_names)
17
+ else:
18
+ add_lora_transformers(lora_names)
19
+
20
+
21
+ def get_lora_path(lora_name):
22
+ p = Path(lora_name)
23
+ if p.exists():
24
+ lora_name = p.parts[-1]
25
+
26
+ return Path(f"{shared.args.lora_dir}/{lora_name}")
27
+
28
+
29
+ def add_lora_exllamav2(lora_names):
30
+
31
+ from exllamav2 import ExLlamaV2Lora
32
+
33
+ if isinstance(shared.model.loras, list):
34
+ for lora in shared.model.loras:
35
+ lora.unload()
36
+
37
+ if len(lora_names) > 0:
38
+ logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
39
+ shared.model.loras = []
40
+ for lora_name in lora_names:
41
+ lora_path = get_lora_path(lora_name)
42
+ if shared.model.__class__.__name__ == 'Exllamav2Model':
43
+ lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
44
+ else:
45
+ lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
46
+
47
+ shared.model.loras.append(lora)
48
+
49
+ shared.lora_names = lora_names
50
+ else:
51
+ shared.lora_names = []
52
+ shared.model.loras = None
53
+
54
+
55
+ def add_lora_autogptq(lora_names):
56
+ '''
57
+ Adapted from https://github.com/Ph0rk0z/text-generation-webui-testing
58
+ '''
59
+
60
+ try:
61
+ from auto_gptq import get_gptq_peft_model
62
+ from auto_gptq.utils.peft_utils import GPTQLoraConfig
63
+ except:
64
+ logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
65
+ return
66
+
67
+ if len(lora_names) == 0:
68
+ reload_model()
69
+
70
+ shared.lora_names = []
71
+ return
72
+ else:
73
+ if len(lora_names) > 1:
74
+ logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded.')
75
+ if not shared.args.no_inject_fused_attention:
76
+ logger.warning('Fused Atttention + AutoGPTQ may break Lora loading. Disable it.')
77
+
78
+ peft_config = GPTQLoraConfig(
79
+ inference_mode=True,
80
+ )
81
+
82
+ lora_path = get_lora_path(lora_names[0])
83
+ logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
84
+ shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path)
85
+ shared.lora_names = [lora_names[0]]
86
+ return
87
+
88
+
89
+ def add_lora_transformers(lora_names):
90
+ prior_set = set(shared.lora_names)
91
+ added_set = set(lora_names) - prior_set
92
+ removed_set = prior_set - set(lora_names)
93
+
94
+ # If no LoRA needs to be added or removed, exit
95
+ if len(added_set) == 0 and len(removed_set) == 0:
96
+ return
97
+
98
+ # Add a LoRA when another LoRA is already present
99
+ if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
100
+ logger.info(f"Adding the LoRA(s) named {added_set} to the model")
101
+ for lora in added_set:
102
+ shared.model.load_adapter(get_lora_path(lora), lora)
103
+
104
+ if len(lora_names) > 1:
105
+ merge_loras()
106
+
107
+ shared.lora_names = lora_names
108
+ return
109
+
110
+ # If any LoRA needs to be removed, start over
111
+ if len(removed_set) > 0:
112
+ shared.model = shared.model.unload()
113
+
114
+ if len(lora_names) > 0:
115
+ params = {}
116
+ if not shared.args.cpu:
117
+ if shared.args.load_in_4bit or shared.args.load_in_8bit:
118
+ params['peft_type'] = shared.model.dtype
119
+ else:
120
+ params['dtype'] = shared.model.dtype
121
+ if hasattr(shared.model, "hf_device_map"):
122
+ params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
123
+
124
+ logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
125
+ shared.model = PeftModel.from_pretrained(shared.model, get_lora_path(lora_names[0]), adapter_name=lora_names[0], **params)
126
+ for lora in lora_names[1:]:
127
+ shared.model.load_adapter(get_lora_path(lora), lora)
128
+
129
+ if len(lora_names) > 1:
130
+ merge_loras()
131
+
132
+ if not shared.args.load_in_8bit and not shared.args.cpu:
133
+ shared.model.half()
134
+ if not hasattr(shared.model, "hf_device_map"):
135
+ if torch.backends.mps.is_available():
136
+ device = torch.device('mps')
137
+ shared.model = shared.model.to(device)
138
+ elif is_torch_xpu_available():
139
+ device = torch.device("xpu:0")
140
+ shared.model = shared.model.to(device)
141
+ else:
142
+ shared.model = shared.model.cuda()
143
+
144
+ shared.lora_names = lora_names
145
+
146
+
147
+ def merge_loras():
148
+ if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
149
+ logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.")
150
+ return
151
+
152
+ shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged")
153
+ shared.model.set_adapter("__merged")
modules/RoPE.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_alpha_value(alpha, base):
2
+ '''
3
+ Gets alpha_value from alpha_value and rope_freq_base
4
+ '''
5
+ if base > 0:
6
+ return (base / 10000.) ** (63 / 64.)
7
+ else:
8
+ return alpha
9
+
10
+
11
+ def get_rope_freq_base(alpha, base):
12
+ '''
13
+ Gets rope_freq_base from alpha_value and rope_freq_base
14
+ '''
15
+ if base > 0:
16
+ return base
17
+ else:
18
+ return 10000 * alpha ** (64 / 63.)
modules/__pycache__/LoRA.cpython-311.pyc ADDED
Binary file (9.93 kB). View file
 
modules/__pycache__/RoPE.cpython-311.pyc ADDED
Binary file (733 Bytes). View file
 
modules/__pycache__/block_requests.cpython-311.pyc ADDED
Binary file (3.11 kB). View file
 
modules/__pycache__/callbacks.cpython-311.pyc ADDED
Binary file (5.88 kB). View file
 
modules/__pycache__/chat.cpython-311.pyc ADDED
Binary file (46.9 kB). View file
 
modules/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (8.69 kB). View file
 
modules/__pycache__/extensions.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
modules/__pycache__/github.cpython-311.pyc ADDED
Binary file (2.25 kB). View file
 
modules/__pycache__/html_generator.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
modules/__pycache__/llamacpp_hf.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
modules/__pycache__/llamacpp_model.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
modules/__pycache__/loaders.cpython-311.pyc ADDED
Binary file (7.04 kB). View file
 
modules/__pycache__/logging_colors.cpython-311.pyc ADDED
Binary file (5.22 kB). View file
 
modules/__pycache__/logits.cpython-311.pyc ADDED
Binary file (5.23 kB). View file
 
modules/__pycache__/metadata_gguf.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
modules/__pycache__/models.cpython-311.pyc ADDED
Binary file (27 kB). View file
 
modules/__pycache__/models_settings.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
modules/__pycache__/one_click_installer_check.cpython-311.pyc ADDED
Binary file (781 Bytes). View file
 
modules/__pycache__/presets.cpython-311.pyc ADDED
Binary file (6.54 kB). View file
 
modules/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
modules/__pycache__/relative_imports.cpython-311.pyc ADDED
Binary file (1.37 kB). View file
 
modules/__pycache__/sampler_hijack.cpython-311.pyc ADDED
Binary file (24.7 kB). View file
 
modules/__pycache__/shared.cpython-311.pyc ADDED
Binary file (26.6 kB). View file
 
modules/__pycache__/text_generation.cpython-311.pyc ADDED
Binary file (25.2 kB). View file
 
modules/__pycache__/training.cpython-311.pyc ADDED
Binary file (67.5 kB). View file
 
modules/__pycache__/ui.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
modules/__pycache__/ui_chat.cpython-311.pyc ADDED
Binary file (59.1 kB). View file
 
modules/__pycache__/ui_default.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
modules/__pycache__/ui_file_saving.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
modules/__pycache__/ui_model_menu.cpython-311.pyc ADDED
Binary file (38.4 kB). View file
 
modules/__pycache__/ui_notebook.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
modules/__pycache__/ui_parameters.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
modules/__pycache__/ui_session.cpython-311.pyc ADDED
Binary file (9.11 kB). View file
 
modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
modules/block_requests.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+ import io
3
+
4
+ import requests
5
+
6
+ from modules.logging_colors import logger
7
+
8
+ original_open = open
9
+ original_get = requests.get
10
+
11
+
12
+ class RequestBlocker:
13
+
14
+ def __enter__(self):
15
+ requests.get = my_get
16
+
17
+ def __exit__(self, exc_type, exc_value, traceback):
18
+ requests.get = original_get
19
+
20
+
21
+ class OpenMonkeyPatch:
22
+
23
+ def __enter__(self):
24
+ builtins.open = my_open
25
+
26
+ def __exit__(self, exc_type, exc_value, traceback):
27
+ builtins.open = original_open
28
+
29
+
30
+ def my_get(url, **kwargs):
31
+ logger.info('Unwanted HTTP request redirected to localhost :)')
32
+ kwargs.setdefault('allow_redirects', True)
33
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
34
+
35
+
36
+ # Kindly provided by our friend WizardLM-30B
37
+ def my_open(*args, **kwargs):
38
+ filename = str(args[0])
39
+ if filename.endswith('index.html'):
40
+ with original_open(*args, **kwargs) as f:
41
+ file_contents = f.read()
42
+
43
+ file_contents = file_contents.replace(b'\t\t<script\n\t\t\tsrc="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.7/iframeResizer.contentWindow.min.js"\n\t\t\tasync\n\t\t></script>', b'')
44
+ file_contents = file_contents.replace(b'cdnjs.cloudflare.com', b'127.0.0.1')
45
+ return io.BytesIO(file_contents)
46
+ else:
47
+ return original_open(*args, **kwargs)
modules/callbacks.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import traceback
3
+ from queue import Queue
4
+ from threading import Thread
5
+
6
+ import torch
7
+ import transformers
8
+ from transformers import is_torch_xpu_available
9
+
10
+ import modules.shared as shared
11
+
12
+
13
+ class StopNowException(Exception):
14
+ pass
15
+
16
+
17
+ class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
18
+ def __init__(self):
19
+ transformers.StoppingCriteria.__init__(self)
20
+
21
+ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
22
+ return shared.stop_everything
23
+
24
+
25
+ class Stream(transformers.StoppingCriteria):
26
+ def __init__(self, callback_func=None):
27
+ self.callback_func = callback_func
28
+
29
+ def __call__(self, input_ids, scores) -> bool:
30
+ if self.callback_func is not None:
31
+ self.callback_func(input_ids[0])
32
+
33
+ return False
34
+
35
+
36
+ class Iteratorize:
37
+
38
+ """
39
+ Transforms a function that takes a callback
40
+ into a lazy iterator (generator).
41
+
42
+ Adapted from: https://stackoverflow.com/a/9969000
43
+ """
44
+
45
+ def __init__(self, func, args=None, kwargs=None, callback=None):
46
+ self.mfunc = func
47
+ self.c_callback = callback
48
+ self.q = Queue()
49
+ self.sentinel = object()
50
+ self.args = args or []
51
+ self.kwargs = kwargs or {}
52
+ self.stop_now = False
53
+
54
+ def _callback(val):
55
+ if self.stop_now or shared.stop_everything:
56
+ raise StopNowException
57
+ self.q.put(val)
58
+
59
+ def gentask():
60
+ try:
61
+ ret = self.mfunc(callback=_callback, *args, **self.kwargs)
62
+ except StopNowException:
63
+ pass
64
+ except:
65
+ traceback.print_exc()
66
+ pass
67
+
68
+ clear_torch_cache()
69
+ self.q.put(self.sentinel)
70
+ if self.c_callback:
71
+ self.c_callback(ret)
72
+
73
+ self.thread = Thread(target=gentask)
74
+ self.thread.start()
75
+
76
+ def __iter__(self):
77
+ return self
78
+
79
+ def __next__(self):
80
+ obj = self.q.get(True, None)
81
+ if obj is self.sentinel:
82
+ raise StopIteration
83
+ else:
84
+ return obj
85
+
86
+ def __del__(self):
87
+ clear_torch_cache()
88
+
89
+ def __enter__(self):
90
+ return self
91
+
92
+ def __exit__(self, exc_type, exc_val, exc_tb):
93
+ self.stop_now = True
94
+ clear_torch_cache()
95
+
96
+
97
+ def clear_torch_cache():
98
+ gc.collect()
99
+ if not shared.args.cpu:
100
+ if is_torch_xpu_available():
101
+ torch.xpu.empty_cache()
102
+ else:
103
+ torch.cuda.empty_cache()
modules/chat.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import functools
4
+ import html
5
+ import json
6
+ import re
7
+ from datetime import datetime
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import yaml
13
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
14
+ from PIL import Image
15
+
16
+ import modules.shared as shared
17
+ from modules import utils
18
+ from modules.extensions import apply_extensions
19
+ from modules.html_generator import chat_html_wrapper, make_thumbnail
20
+ from modules.logging_colors import logger
21
+ from modules.text_generation import (
22
+ generate_reply,
23
+ get_encoded_length,
24
+ get_max_prompt_length
25
+ )
26
+ from modules.utils import delete_file, get_available_characters, save_file
27
+
28
+ # Copied from the Transformers library
29
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
30
+
31
+
32
+ def str_presenter(dumper, data):
33
+ """
34
+ Copied from https://github.com/yaml/pyyaml/issues/240
35
+ Makes pyyaml output prettier multiline strings.
36
+ """
37
+
38
+ if data.count('\n') > 0:
39
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
40
+
41
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data)
42
+
43
+
44
+ yaml.add_representer(str, str_presenter)
45
+ yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
46
+
47
+
48
+ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
49
+ '''
50
+ Given a Jinja template, reverse-engineers the prefix and the suffix for
51
+ an assistant message (if impersonate=False) or an user message
52
+ (if impersonate=True)
53
+ '''
54
+
55
+ if impersonate:
56
+ messages = [
57
+ {"role": "user", "content": "<<|user-message-1|>>"},
58
+ {"role": "user", "content": "<<|user-message-2|>>"},
59
+ ]
60
+ else:
61
+ messages = [
62
+ {"role": "assistant", "content": "<<|user-message-1|>>"},
63
+ {"role": "assistant", "content": "<<|user-message-2|>>"},
64
+ ]
65
+
66
+ prompt = renderer(messages=messages)
67
+
68
+ suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
69
+ suffix = prompt.split("<<|user-message-2|>>")[1]
70
+ prefix = suffix_plus_prefix[len(suffix):]
71
+
72
+ if strip_trailing_spaces:
73
+ prefix = prefix.rstrip(' ')
74
+
75
+ return prefix, suffix
76
+
77
+
78
+ def generate_chat_prompt(user_input, state, **kwargs):
79
+ impersonate = kwargs.get('impersonate', False)
80
+ _continue = kwargs.get('_continue', False)
81
+ also_return_rows = kwargs.get('also_return_rows', False)
82
+ history = kwargs.get('history', state['history'])['internal']
83
+
84
+ # Templates
85
+ chat_template = jinja_env.from_string(state['chat_template_str'])
86
+ instruction_template = jinja_env.from_string(state['instruction_template_str'])
87
+ chat_renderer = partial(chat_template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
88
+ instruct_renderer = partial(instruction_template.render, add_generation_prompt=False)
89
+
90
+ messages = []
91
+
92
+ if state['mode'] == 'instruct':
93
+ renderer = instruct_renderer
94
+ if state['custom_system_message'].strip() != '':
95
+ messages.append({"role": "system", "content": state['custom_system_message']})
96
+ else:
97
+ renderer = chat_renderer
98
+ if state['context'].strip() != '':
99
+ context = replace_character_names(state['context'], state['name1'], state['name2'])
100
+ messages.append({"role": "system", "content": context})
101
+
102
+ insert_pos = len(messages)
103
+ for user_msg, assistant_msg in reversed(history):
104
+ user_msg = user_msg.strip()
105
+ assistant_msg = assistant_msg.strip()
106
+
107
+ if assistant_msg:
108
+ messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
109
+
110
+ if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
111
+ messages.insert(insert_pos, {"role": "user", "content": user_msg})
112
+
113
+ user_input = user_input.strip()
114
+ if user_input and not impersonate and not _continue:
115
+ messages.append({"role": "user", "content": user_input})
116
+
117
+ def remove_extra_bos(prompt):
118
+ for bos_token in ['<s>', '<|startoftext|>']:
119
+ while prompt.startswith(bos_token):
120
+ prompt = prompt[len(bos_token):]
121
+
122
+ return prompt
123
+
124
+ def make_prompt(messages):
125
+ if state['mode'] == 'chat-instruct' and _continue:
126
+ prompt = renderer(messages=messages[:-1])
127
+ else:
128
+ prompt = renderer(messages=messages)
129
+
130
+ if state['mode'] == 'chat-instruct':
131
+ outer_messages = []
132
+ if state['custom_system_message'].strip() != '':
133
+ outer_messages.append({"role": "system", "content": state['custom_system_message']})
134
+
135
+ prompt = remove_extra_bos(prompt)
136
+ command = state['chat-instruct_command']
137
+ command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
138
+ command = command.replace('<|prompt|>', prompt)
139
+
140
+ if _continue:
141
+ prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
142
+ prefix += messages[-1]["content"]
143
+ else:
144
+ prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
145
+ if not impersonate:
146
+ prefix = apply_extensions('bot_prefix', prefix, state)
147
+
148
+ outer_messages.append({"role": "user", "content": command})
149
+ outer_messages.append({"role": "assistant", "content": prefix})
150
+
151
+ prompt = instruction_template.render(messages=outer_messages)
152
+ suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
153
+ prompt = prompt[:-len(suffix)]
154
+
155
+ else:
156
+ if _continue:
157
+ suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
158
+ prompt = prompt[:-len(suffix)]
159
+ else:
160
+ prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
161
+ if state['mode'] == 'chat' and not impersonate:
162
+ prefix = apply_extensions('bot_prefix', prefix, state)
163
+
164
+ prompt += prefix
165
+
166
+ prompt = remove_extra_bos(prompt)
167
+ return prompt
168
+
169
+ # Handle truncation
170
+ max_length = get_max_prompt_length(state)
171
+ prompt = make_prompt(messages)
172
+ encoded_length = get_encoded_length(prompt)
173
+
174
+ while len(messages) > 0 and encoded_length > max_length:
175
+
176
+ # Remove old message, save system message
177
+ if len(messages) > 2 and messages[0]['role'] == 'system':
178
+ messages.pop(1)
179
+
180
+ # Remove old message when no system message is present
181
+ elif len(messages) > 1 and messages[0]['role'] != 'system':
182
+ messages.pop(0)
183
+
184
+ # Resort to truncating the user input
185
+ else:
186
+
187
+ user_message = messages[-1]['content']
188
+
189
+ # Bisect the truncation point
190
+ left, right = 0, len(user_message) - 1
191
+
192
+ while right - left > 1:
193
+ mid = (left + right) // 2
194
+
195
+ messages[-1]['content'] = user_message[mid:]
196
+ prompt = make_prompt(messages)
197
+ encoded_length = get_encoded_length(prompt)
198
+
199
+ if encoded_length <= max_length:
200
+ right = mid
201
+ else:
202
+ left = mid
203
+
204
+ messages[-1]['content'] = user_message[right:]
205
+ prompt = make_prompt(messages)
206
+ encoded_length = get_encoded_length(prompt)
207
+ if encoded_length > max_length:
208
+ logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
209
+ raise ValueError
210
+ else:
211
+ logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.")
212
+ break
213
+
214
+ prompt = make_prompt(messages)
215
+ encoded_length = get_encoded_length(prompt)
216
+
217
+ if also_return_rows:
218
+ return prompt, [message['content'] for message in messages]
219
+ else:
220
+ return prompt
221
+
222
+
223
+ def get_stopping_strings(state):
224
+ stopping_strings = []
225
+ renderers = []
226
+
227
+ if state['mode'] in ['instruct', 'chat-instruct']:
228
+ template = jinja_env.from_string(state['instruction_template_str'])
229
+ renderer = partial(template.render, add_generation_prompt=False)
230
+ renderers.append(renderer)
231
+
232
+ if state['mode'] in ['chat', 'chat-instruct']:
233
+ template = jinja_env.from_string(state['chat_template_str'])
234
+ renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
235
+ renderers.append(renderer)
236
+
237
+ for renderer in renderers:
238
+ prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
239
+ prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
240
+
241
+ stopping_strings += [
242
+ suffix_user + prefix_bot,
243
+ suffix_user + prefix_user,
244
+ suffix_bot + prefix_bot,
245
+ suffix_bot + prefix_user,
246
+ ]
247
+
248
+ if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
249
+ stopping_strings += state.pop('stopping_strings')
250
+
251
+ return list(set(stopping_strings))
252
+
253
+
254
+ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
255
+ history = state['history']
256
+ output = copy.deepcopy(history)
257
+ output = apply_extensions('history', output)
258
+ state = apply_extensions('state', state)
259
+
260
+ visible_text = None
261
+ stopping_strings = get_stopping_strings(state)
262
+ is_stream = state['stream']
263
+
264
+ # Prepare the input
265
+ if not (regenerate or _continue):
266
+ visible_text = html.escape(text)
267
+
268
+ # Apply extensions
269
+ text, visible_text = apply_extensions('chat_input', text, visible_text, state)
270
+ text = apply_extensions('input', text, state, is_chat=True)
271
+
272
+ output['internal'].append([text, ''])
273
+ output['visible'].append([visible_text, ''])
274
+
275
+ # *Is typing...*
276
+ if loading_message:
277
+ yield {
278
+ 'visible': output['visible'][:-1] + [[output['visible'][-1][0], shared.processing_message]],
279
+ 'internal': output['internal']
280
+ }
281
+ else:
282
+ text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
283
+ if regenerate:
284
+ if loading_message:
285
+ yield {
286
+ 'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]],
287
+ 'internal': output['internal'][:-1] + [[text, '']]
288
+ }
289
+ elif _continue:
290
+ last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
291
+ if loading_message:
292
+ yield {
293
+ 'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']],
294
+ 'internal': output['internal']
295
+ }
296
+
297
+ if shared.model_name == 'None' or shared.model is None:
298
+ raise ValueError("No model is loaded! Select one in the Model tab.")
299
+
300
+ # Generate the prompt
301
+ kwargs = {
302
+ '_continue': _continue,
303
+ 'history': output if _continue else {k: v[:-1] for k, v in output.items()}
304
+ }
305
+ prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
306
+ if prompt is None:
307
+ prompt = generate_chat_prompt(text, state, **kwargs)
308
+
309
+ # Generate
310
+ reply = None
311
+ for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)):
312
+
313
+ # Extract the reply
314
+ visible_reply = reply
315
+ if state['mode'] in ['chat', 'chat-instruct']:
316
+ visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
317
+
318
+ visible_reply = html.escape(visible_reply)
319
+
320
+ if shared.stop_everything:
321
+ output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
322
+ yield output
323
+ return
324
+
325
+ if _continue:
326
+ output['internal'][-1] = [text, last_reply[0] + reply]
327
+ output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
328
+ if is_stream:
329
+ yield output
330
+ elif not (j == 0 and visible_reply.strip() == ''):
331
+ output['internal'][-1] = [text, reply.lstrip(' ')]
332
+ output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
333
+ if is_stream:
334
+ yield output
335
+
336
+ output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
337
+ yield output
338
+
339
+
340
+ def impersonate_wrapper(text, state):
341
+
342
+ static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
343
+
344
+ if shared.model_name == 'None' or shared.model is None:
345
+ logger.error("No model is loaded! Select one in the Model tab.")
346
+ yield '', static_output
347
+ return
348
+
349
+ prompt = generate_chat_prompt('', state, impersonate=True)
350
+ stopping_strings = get_stopping_strings(state)
351
+
352
+ yield text + '...', static_output
353
+ reply = None
354
+ for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True):
355
+ yield (text + reply).lstrip(' '), static_output
356
+ if shared.stop_everything:
357
+ return
358
+
359
+
360
+ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
361
+ history = state['history']
362
+ if regenerate or _continue:
363
+ text = ''
364
+ if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
365
+ yield history
366
+ return
367
+
368
+ for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
369
+ yield history
370
+
371
+
372
+ def character_is_loaded(state, raise_exception=False):
373
+ if state['mode'] in ['chat', 'chat-instruct'] and state['name2'] == '':
374
+ logger.error('It looks like no character is loaded. Please load one under Parameters > Character.')
375
+ if raise_exception:
376
+ raise ValueError
377
+
378
+ return False
379
+ else:
380
+ return True
381
+
382
+
383
+ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
384
+ '''
385
+ Same as above but returns HTML for the UI
386
+ '''
387
+
388
+ if not character_is_loaded(state):
389
+ return
390
+
391
+ if state['start_with'] != '' and not _continue:
392
+ if regenerate:
393
+ text, state['history'] = remove_last_message(state['history'])
394
+ regenerate = False
395
+
396
+ _continue = True
397
+ send_dummy_message(text, state)
398
+ send_dummy_reply(state['start_with'], state)
399
+
400
+ for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
401
+ yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
402
+
403
+
404
+ def remove_last_message(history):
405
+ if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
406
+ last = history['visible'].pop()
407
+ history['internal'].pop()
408
+ else:
409
+ last = ['', '']
410
+
411
+ return html.unescape(last[0]), history
412
+
413
+
414
+ def send_last_reply_to_input(history):
415
+ if len(history['visible']) > 0:
416
+ return html.unescape(history['visible'][-1][1])
417
+ else:
418
+ return ''
419
+
420
+
421
+ def replace_last_reply(text, state):
422
+ history = state['history']
423
+
424
+ if len(text.strip()) == 0:
425
+ return history
426
+ elif len(history['visible']) > 0:
427
+ history['visible'][-1][1] = html.escape(text)
428
+ history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True)
429
+
430
+ return history
431
+
432
+
433
+ def send_dummy_message(text, state):
434
+ history = state['history']
435
+ history['visible'].append([html.escape(text), ''])
436
+ history['internal'].append([apply_extensions('input', text, state, is_chat=True), ''])
437
+ return history
438
+
439
+
440
+ def send_dummy_reply(text, state):
441
+ history = state['history']
442
+ if len(history['visible']) > 0 and not history['visible'][-1][1] == '':
443
+ history['visible'].append(['', ''])
444
+ history['internal'].append(['', ''])
445
+
446
+ history['visible'][-1][1] = html.escape(text)
447
+ history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True)
448
+ return history
449
+
450
+
451
+ def redraw_html(history, name1, name2, mode, style, character, reset_cache=False):
452
+ return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache)
453
+
454
+
455
+ def start_new_chat(state):
456
+ mode = state['mode']
457
+ history = {'internal': [], 'visible': []}
458
+
459
+ if mode != 'instruct':
460
+ greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
461
+ if greeting != '':
462
+ history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
463
+ history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
464
+
465
+ unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
466
+ save_history(history, unique_id, state['character_menu'], state['mode'])
467
+
468
+ return history
469
+
470
+
471
+ def get_history_file_path(unique_id, character, mode):
472
+ if mode == 'instruct':
473
+ p = Path(f'logs/instruct/{unique_id}.json')
474
+ else:
475
+ p = Path(f'logs/chat/{character}/{unique_id}.json')
476
+
477
+ return p
478
+
479
+
480
+ def save_history(history, unique_id, character, mode):
481
+ if shared.args.multi_user:
482
+ return
483
+
484
+ p = get_history_file_path(unique_id, character, mode)
485
+ if not p.parent.is_dir():
486
+ p.parent.mkdir(parents=True)
487
+
488
+ with open(p, 'w', encoding='utf-8') as f:
489
+ f.write(json.dumps(history, indent=4))
490
+
491
+
492
+ def rename_history(old_id, new_id, character, mode):
493
+ if shared.args.multi_user:
494
+ return
495
+
496
+ old_p = get_history_file_path(old_id, character, mode)
497
+ new_p = get_history_file_path(new_id, character, mode)
498
+ if new_p.parent != old_p.parent:
499
+ logger.error(f"The following path is not allowed: {new_p}.")
500
+ elif new_p == old_p:
501
+ logger.info("The provided path is identical to the old one.")
502
+ else:
503
+ logger.info(f"Renaming {old_p} to {new_p}")
504
+ old_p.rename(new_p)
505
+
506
+
507
+ def find_all_histories(state):
508
+ if shared.args.multi_user:
509
+ return ['']
510
+
511
+ if state['mode'] == 'instruct':
512
+ paths = Path('logs/instruct').glob('*.json')
513
+ else:
514
+ character = state['character_menu']
515
+
516
+ # Handle obsolete filenames and paths
517
+ old_p = Path(f'logs/{character}_persistent.json')
518
+ new_p = Path(f'logs/persistent_{character}.json')
519
+ if old_p.exists():
520
+ logger.warning(f"Renaming {old_p} to {new_p}")
521
+ old_p.rename(new_p)
522
+ if new_p.exists():
523
+ unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
524
+ p = get_history_file_path(unique_id, character, state['mode'])
525
+ logger.warning(f"Moving {new_p} to {p}")
526
+ p.parent.mkdir(exist_ok=True)
527
+ new_p.rename(p)
528
+
529
+ paths = Path(f'logs/chat/{character}').glob('*.json')
530
+
531
+ histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True)
532
+ histories = [path.stem for path in histories]
533
+
534
+ return histories
535
+
536
+
537
+ def load_latest_history(state):
538
+ '''
539
+ Loads the latest history for the given character in chat or chat-instruct
540
+ mode, or the latest instruct history for instruct mode.
541
+ '''
542
+
543
+ if shared.args.multi_user:
544
+ return start_new_chat(state)
545
+
546
+ histories = find_all_histories(state)
547
+
548
+ if len(histories) > 0:
549
+ history = load_history(histories[0], state['character_menu'], state['mode'])
550
+ else:
551
+ history = start_new_chat(state)
552
+
553
+ return history
554
+
555
+
556
+ def load_history_after_deletion(state, idx):
557
+ '''
558
+ Loads the latest history for the given character in chat or chat-instruct
559
+ mode, or the latest instruct history for instruct mode.
560
+ '''
561
+
562
+ if shared.args.multi_user:
563
+ return start_new_chat(state)
564
+
565
+ histories = find_all_histories(state)
566
+ idx = min(int(idx), len(histories) - 1)
567
+ idx = max(0, idx)
568
+
569
+ if len(histories) > 0:
570
+ history = load_history(histories[idx], state['character_menu'], state['mode'])
571
+ else:
572
+ history = start_new_chat(state)
573
+ histories = find_all_histories(state)
574
+
575
+ return history, gr.update(choices=histories, value=histories[idx])
576
+
577
+
578
+ def update_character_menu_after_deletion(idx):
579
+ characters = utils.get_available_characters()
580
+ idx = min(int(idx), len(characters) - 1)
581
+ idx = max(0, idx)
582
+ return gr.update(choices=characters, value=characters[idx])
583
+
584
+
585
+ def load_history(unique_id, character, mode):
586
+ p = get_history_file_path(unique_id, character, mode)
587
+
588
+ f = json.loads(open(p, 'rb').read())
589
+ if 'internal' in f and 'visible' in f:
590
+ history = f
591
+ else:
592
+ history = {
593
+ 'internal': f['data'],
594
+ 'visible': f['data_visible']
595
+ }
596
+
597
+ return history
598
+
599
+
600
+ def load_history_json(file, history):
601
+ try:
602
+ file = file.decode('utf-8')
603
+ f = json.loads(file)
604
+ if 'internal' in f and 'visible' in f:
605
+ history = f
606
+ else:
607
+ history = {
608
+ 'internal': f['data'],
609
+ 'visible': f['data_visible']
610
+ }
611
+
612
+ return history
613
+ except:
614
+ return history
615
+
616
+
617
+ def delete_history(unique_id, character, mode):
618
+ p = get_history_file_path(unique_id, character, mode)
619
+ delete_file(p)
620
+
621
+
622
+ def replace_character_names(text, name1, name2):
623
+ text = text.replace('{{user}}', name1).replace('{{char}}', name2)
624
+ return text.replace('<USER>', name1).replace('<BOT>', name2)
625
+
626
+
627
+ def generate_pfp_cache(character):
628
+ cache_folder = Path(shared.args.disk_cache_dir)
629
+ if not cache_folder.exists():
630
+ cache_folder.mkdir()
631
+
632
+ for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
633
+ if path.exists():
634
+ original_img = Image.open(path)
635
+ original_img.save(Path(f'{cache_folder}/pfp_character.png'), format='PNG')
636
+
637
+ thumb = make_thumbnail(original_img)
638
+ thumb.save(Path(f'{cache_folder}/pfp_character_thumb.png'), format='PNG')
639
+
640
+ return thumb
641
+
642
+ return None
643
+
644
+
645
+ def load_character(character, name1, name2):
646
+ context = greeting = ""
647
+ greeting_field = 'greeting'
648
+ picture = None
649
+
650
+ filepath = None
651
+ for extension in ["yml", "yaml", "json"]:
652
+ filepath = Path(f'characters/{character}.{extension}')
653
+ if filepath.exists():
654
+ break
655
+
656
+ if filepath is None or not filepath.exists():
657
+ logger.error(f"Could not find the character \"{character}\" inside characters/. No character has been loaded.")
658
+ raise ValueError
659
+
660
+ file_contents = open(filepath, 'r', encoding='utf-8').read()
661
+ data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
662
+ cache_folder = Path(shared.args.disk_cache_dir)
663
+
664
+ for path in [Path(f"{cache_folder}/pfp_character.png"), Path(f"{cache_folder}/pfp_character_thumb.png")]:
665
+ if path.exists():
666
+ path.unlink()
667
+
668
+ picture = generate_pfp_cache(character)
669
+
670
+ # Finding the bot's name
671
+ for k in ['name', 'bot', '<|bot|>', 'char_name']:
672
+ if k in data and data[k] != '':
673
+ name2 = data[k]
674
+ break
675
+
676
+ # Find the user name (if any)
677
+ for k in ['your_name', 'user', '<|user|>']:
678
+ if k in data and data[k] != '':
679
+ name1 = data[k]
680
+ break
681
+
682
+ if 'context' in data:
683
+ context = data['context'].strip()
684
+ elif "char_persona" in data:
685
+ context = build_pygmalion_style_context(data)
686
+ greeting_field = 'char_greeting'
687
+
688
+ greeting = data.get(greeting_field, greeting)
689
+ return name1, name2, picture, greeting, context
690
+
691
+
692
+ def load_instruction_template(template):
693
+ for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
694
+ if filepath.exists():
695
+ break
696
+ else:
697
+ return ''
698
+
699
+ file_contents = open(filepath, 'r', encoding='utf-8').read()
700
+ data = yaml.safe_load(file_contents)
701
+ if 'instruction_template' in data:
702
+ return data['instruction_template']
703
+ else:
704
+ return jinja_template_from_old_format(data)
705
+
706
+
707
+ @functools.cache
708
+ def load_character_memoized(character, name1, name2):
709
+ return load_character(character, name1, name2)
710
+
711
+
712
+ @functools.cache
713
+ def load_instruction_template_memoized(template):
714
+ return load_instruction_template(template)
715
+
716
+
717
+ def upload_character(file, img, tavern=False):
718
+ decoded_file = file if isinstance(file, str) else file.decode('utf-8')
719
+ try:
720
+ data = json.loads(decoded_file)
721
+ except:
722
+ data = yaml.safe_load(decoded_file)
723
+
724
+ if 'char_name' in data:
725
+ name = data['char_name']
726
+ greeting = data['char_greeting']
727
+ context = build_pygmalion_style_context(data)
728
+ yaml_data = generate_character_yaml(name, greeting, context)
729
+ else:
730
+ name = data['name']
731
+ yaml_data = generate_character_yaml(data['name'], data['greeting'], data['context'])
732
+
733
+ outfile_name = name
734
+ i = 1
735
+ while Path(f'characters/{outfile_name}.yaml').exists():
736
+ outfile_name = f'{name}_{i:03d}'
737
+ i += 1
738
+
739
+ with open(Path(f'characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f:
740
+ f.write(yaml_data)
741
+
742
+ if img is not None:
743
+ img.save(Path(f'characters/{outfile_name}.png'))
744
+
745
+ logger.info(f'New character saved to "characters/{outfile_name}.yaml".')
746
+ return gr.update(value=outfile_name, choices=get_available_characters())
747
+
748
+
749
+ def build_pygmalion_style_context(data):
750
+ context = ""
751
+ if 'char_persona' in data and data['char_persona'] != '':
752
+ context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
753
+
754
+ if 'world_scenario' in data and data['world_scenario'] != '':
755
+ context += f"Scenario: {data['world_scenario']}\n"
756
+
757
+ if 'example_dialogue' in data and data['example_dialogue'] != '':
758
+ context += f"{data['example_dialogue'].strip()}\n"
759
+
760
+ context = f"{context.strip()}\n"
761
+ return context
762
+
763
+
764
+ def upload_tavern_character(img, _json):
765
+ _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']}
766
+ return upload_character(json.dumps(_json), img, tavern=True)
767
+
768
+
769
+ def check_tavern_character(img):
770
+ if "chara" not in img.info:
771
+ return "Not a TavernAI card", None, None, gr.update(interactive=False)
772
+
773
+ decoded_string = base64.b64decode(img.info['chara']).replace(b'\\r\\n', b'\\n')
774
+ _json = json.loads(decoded_string)
775
+ if "data" in _json:
776
+ _json = _json["data"]
777
+
778
+ return _json['name'], _json['description'], _json, gr.update(interactive=True)
779
+
780
+
781
+ def upload_your_profile_picture(img):
782
+ cache_folder = Path(shared.args.disk_cache_dir)
783
+ if not cache_folder.exists():
784
+ cache_folder.mkdir()
785
+
786
+ if img is None:
787
+ if Path(f"{cache_folder}/pfp_me.png").exists():
788
+ Path(f"{cache_folder}/pfp_me.png").unlink()
789
+ else:
790
+ img = make_thumbnail(img)
791
+ img.save(Path(f'{cache_folder}/pfp_me.png'))
792
+ logger.info(f'Profile picture saved to "{cache_folder}/pfp_me.png"')
793
+
794
+
795
+ def generate_character_yaml(name, greeting, context):
796
+ data = {
797
+ 'name': name,
798
+ 'greeting': greeting,
799
+ 'context': context,
800
+ }
801
+
802
+ data = {k: v for k, v in data.items() if v} # Strip falsy
803
+ return yaml.dump(data, sort_keys=False, width=float("inf"))
804
+
805
+
806
+ def generate_instruction_template_yaml(instruction_template):
807
+ data = {
808
+ 'instruction_template': instruction_template
809
+ }
810
+
811
+ return my_yaml_output(data)
812
+
813
+
814
+ def save_character(name, greeting, context, picture, filename):
815
+ if filename == "":
816
+ logger.error("The filename is empty, so the character will not be saved.")
817
+ return
818
+
819
+ data = generate_character_yaml(name, greeting, context)
820
+ filepath = Path(f'characters/{filename}.yaml')
821
+ save_file(filepath, data)
822
+ path_to_img = Path(f'characters/{filename}.png')
823
+ if picture is not None:
824
+ picture.save(path_to_img)
825
+ logger.info(f'Saved {path_to_img}.')
826
+
827
+
828
+ def delete_character(name, instruct=False):
829
+ for extension in ["yml", "yaml", "json"]:
830
+ delete_file(Path(f'characters/{name}.{extension}'))
831
+
832
+ delete_file(Path(f'characters/{name}.png'))
833
+
834
+
835
+ def jinja_template_from_old_format(params, verbose=False):
836
+ MASTER_TEMPLATE = """
837
+ {%- set ns = namespace(found=false) -%}
838
+ {%- for message in messages -%}
839
+ {%- if message['role'] == 'system' -%}
840
+ {%- set ns.found = true -%}
841
+ {%- endif -%}
842
+ {%- endfor -%}
843
+ {%- if not ns.found -%}
844
+ {{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
845
+ {%- endif %}
846
+ {%- for message in messages %}
847
+ {%- if message['role'] == 'system' -%}
848
+ {{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
849
+ {%- else -%}
850
+ {%- if message['role'] == 'user' -%}
851
+ {{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
852
+ {%- else -%}
853
+ {{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
854
+ {%- endif -%}
855
+ {%- endif -%}
856
+ {%- endfor -%}
857
+ {%- if add_generation_prompt -%}
858
+ {{-'<|PRE-ASSISTANT-GENERATE|>'-}}
859
+ {%- endif -%}
860
+ """
861
+
862
+ if 'context' in params and '<|system-message|>' in params['context']:
863
+ pre_system = params['context'].split('<|system-message|>')[0]
864
+ post_system = params['context'].split('<|system-message|>')[1]
865
+ else:
866
+ pre_system = ''
867
+ post_system = ''
868
+
869
+ pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
870
+ post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
871
+
872
+ pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
873
+ pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
874
+ post_assistant = params['turn_template'].split('<|bot-message|>')[1]
875
+
876
+ def preprocess(string):
877
+ return string.replace('\n', '\\n').replace('\'', '\\\'')
878
+
879
+ pre_system = preprocess(pre_system)
880
+ post_system = preprocess(post_system)
881
+ pre_user = preprocess(pre_user)
882
+ post_user = preprocess(post_user)
883
+ pre_assistant = preprocess(pre_assistant)
884
+ post_assistant = preprocess(post_assistant)
885
+
886
+ if verbose:
887
+ print(
888
+ '\n',
889
+ repr(pre_system) + '\n',
890
+ repr(post_system) + '\n',
891
+ repr(pre_user) + '\n',
892
+ repr(post_user) + '\n',
893
+ repr(pre_assistant) + '\n',
894
+ repr(post_assistant) + '\n',
895
+ )
896
+
897
+ result = MASTER_TEMPLATE
898
+ if 'system_message' in params:
899
+ result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
900
+ else:
901
+ result = result.replace('<|SYSTEM-MESSAGE|>', '')
902
+
903
+ result = result.replace('<|PRE-SYSTEM|>', pre_system)
904
+ result = result.replace('<|POST-SYSTEM|>', post_system)
905
+ result = result.replace('<|PRE-USER|>', pre_user)
906
+ result = result.replace('<|POST-USER|>', post_user)
907
+ result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
908
+ result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
909
+ result = result.replace('<|POST-ASSISTANT|>', post_assistant)
910
+
911
+ result = result.strip()
912
+
913
+ return result
914
+
915
+
916
+ def my_yaml_output(data):
917
+ '''
918
+ pyyaml is very inconsistent with multiline strings.
919
+ for simple instruction template outputs, this is enough.
920
+ '''
921
+ result = ""
922
+ for k in data:
923
+ result += k + ": |-\n"
924
+ for line in data[k].splitlines():
925
+ result += " " + line.rstrip(' ') + "\n"
926
+
927
+ return result
modules/ctransformers_model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ctransformers import AutoConfig, AutoModelForCausalLM
2
+
3
+ from modules import shared
4
+ from modules.callbacks import Iteratorize
5
+ from modules.logging_colors import logger
6
+
7
+
8
+ class CtransformersModel:
9
+ def __init__(self):
10
+ pass
11
+
12
+ @classmethod
13
+ def from_pretrained(cls, path):
14
+ result = cls()
15
+
16
+ config = AutoConfig.from_pretrained(
17
+ str(path),
18
+ threads=shared.args.threads if shared.args.threads != 0 else -1,
19
+ gpu_layers=shared.args.n_gpu_layers,
20
+ batch_size=shared.args.n_batch,
21
+ context_length=shared.args.n_ctx,
22
+ stream=True,
23
+ mmap=not shared.args.no_mmap,
24
+ mlock=shared.args.mlock
25
+ )
26
+
27
+ result.model = AutoModelForCausalLM.from_pretrained(
28
+ str(result.model_dir(path) if result.model_type_is_auto() else path),
29
+ model_type=(None if result.model_type_is_auto() else shared.args.model_type),
30
+ config=config
31
+ )
32
+
33
+ logger.info(f'Using ctransformers model_type: {result.model.model_type} for {result.model.model_path}')
34
+ return result, result
35
+
36
+ def model_type_is_auto(self):
37
+ return shared.args.model_type is None or shared.args.model_type == "Auto" or shared.args.model_type == "None"
38
+
39
+ def model_dir(self, path):
40
+ if path.is_file():
41
+ return path.parent
42
+
43
+ return path
44
+
45
+ def encode(self, string, **kwargs):
46
+ return self.model.tokenize(string)
47
+
48
+ def decode(self, ids):
49
+ return self.model.detokenize(ids)
50
+
51
+ def generate(self, prompt, state, callback=None):
52
+ prompt = prompt if type(prompt) is str else prompt.decode()
53
+ # ctransformers uses -1 for random seed
54
+ generator = self.model(
55
+ prompt=prompt,
56
+ max_new_tokens=state['max_new_tokens'],
57
+ temperature=state['temperature'],
58
+ top_p=state['top_p'],
59
+ top_k=state['top_k'],
60
+ repetition_penalty=state['repetition_penalty'],
61
+ last_n_tokens=state['repetition_penalty_range'],
62
+ seed=int(state['seed'])
63
+ )
64
+
65
+ output = ""
66
+ for token in generator:
67
+ if callback:
68
+ callback(token)
69
+
70
+ output += token
71
+
72
+ return output
73
+
74
+ def generate_with_streaming(self, *args, **kwargs):
75
+ with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
76
+ reply = ''
77
+ for token in generator:
78
+ reply += token
79
+ yield reply
modules/deepspeed_parameters.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
2
+ '''
3
+ DeepSpeed configuration
4
+ https://huggingface.co/docs/transformers/main_classes/deepspeed
5
+ '''
6
+
7
+ if nvme_offload_dir:
8
+ ds_config = {
9
+ "fp16": {
10
+ "enabled": not ds_bf16,
11
+ },
12
+ "bf16": {
13
+ "enabled": ds_bf16,
14
+ },
15
+ "zero_optimization": {
16
+ "stage": 3,
17
+ "offload_param": {
18
+ "device": "nvme",
19
+ "nvme_path": nvme_offload_dir,
20
+ "pin_memory": True,
21
+ "buffer_count": 5,
22
+ "buffer_size": 1e9,
23
+ "max_in_cpu": 1e9
24
+ },
25
+ "overlap_comm": True,
26
+ "reduce_bucket_size": "auto",
27
+ "contiguous_gradients": True,
28
+ "sub_group_size": 1e8,
29
+ "stage3_prefetch_bucket_size": "auto",
30
+ "stage3_param_persistence_threshold": "auto",
31
+ "stage3_max_live_parameters": "auto",
32
+ "stage3_max_reuse_distance": "auto",
33
+ },
34
+ "aio": {
35
+ "block_size": 262144,
36
+ "queue_depth": 32,
37
+ "thread_count": 1,
38
+ "single_submit": False,
39
+ "overlap_events": True
40
+ },
41
+ "steps_per_print": 2000,
42
+ "train_batch_size": train_batch_size,
43
+ "train_micro_batch_size_per_gpu": 1,
44
+ "wall_clock_breakdown": False
45
+ }
46
+ else:
47
+ ds_config = {
48
+ "fp16": {
49
+ "enabled": not ds_bf16,
50
+ },
51
+ "bf16": {
52
+ "enabled": ds_bf16,
53
+ },
54
+ "zero_optimization": {
55
+ "stage": 3,
56
+ "offload_param": {
57
+ "device": "cpu",
58
+ "pin_memory": True
59
+ },
60
+ "overlap_comm": True,
61
+ "contiguous_gradients": True,
62
+ "reduce_bucket_size": "auto",
63
+ "stage3_prefetch_bucket_size": "auto",
64
+ "stage3_param_persistence_threshold": "auto",
65
+ "stage3_max_live_parameters": "auto",
66
+ "stage3_max_reuse_distance": "auto",
67
+ },
68
+ "steps_per_print": 2000,
69
+ "train_batch_size": train_batch_size,
70
+ "train_micro_batch_size_per_gpu": 1,
71
+ "wall_clock_breakdown": False
72
+ }
73
+
74
+ return ds_config
modules/evaluate.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from datasets import load_dataset
7
+ from tqdm import tqdm
8
+
9
+ from modules import shared
10
+ from modules.logging_colors import logger
11
+ from modules.models import clear_torch_cache, load_model, unload_model
12
+ from modules.models_settings import get_model_metadata, update_model_parameters
13
+ from modules.text_generation import encode
14
+
15
+
16
+ def load_past_evaluations():
17
+ if Path('logs/evaluations.csv').exists():
18
+ df = pd.read_csv(Path('logs/evaluations.csv'), dtype=str)
19
+ df['Perplexity'] = pd.to_numeric(df['Perplexity'])
20
+ return df
21
+ else:
22
+ return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
23
+
24
+
25
+ past_evaluations = load_past_evaluations()
26
+
27
+
28
+ def save_past_evaluations(df):
29
+ global past_evaluations
30
+ past_evaluations = df
31
+ filepath = Path('logs/evaluations.csv')
32
+ filepath.parent.mkdir(parents=True, exist_ok=True)
33
+ df.to_csv(filepath, index=False)
34
+
35
+
36
+ def calculate_perplexity(models, input_dataset, stride, _max_length):
37
+ '''
38
+ Based on:
39
+ https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models
40
+ '''
41
+
42
+ if not shared.args.no_use_fast:
43
+ logger.warning("--no_use_fast is not being used. If tokenizing the input dataset takes a long time, consider loading the model with that option checked.")
44
+
45
+ global past_evaluations
46
+ cumulative_log = ''
47
+ cumulative_log += "Loading the input dataset...\n\n"
48
+ yield cumulative_log
49
+
50
+ # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py
51
+ if input_dataset == 'wikitext':
52
+ data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
53
+ text = "\n\n".join(data['text'])
54
+ elif input_dataset == 'ptb':
55
+ data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
56
+ text = "\n\n".join(data['sentence'])
57
+ elif input_dataset == 'ptb_new':
58
+ data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
59
+ text = " ".join(data['sentence'])
60
+ else:
61
+ with open(Path(f'training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
62
+ text = f.read()
63
+
64
+ for model in models:
65
+ if is_in_past_evaluations(model, input_dataset, stride, _max_length):
66
+ cumulative_log += f"`{model}` has already been tested. Ignoring.\n\n"
67
+ yield cumulative_log
68
+ continue
69
+
70
+ if model != 'current model':
71
+ try:
72
+ yield cumulative_log + f"Loading `{model}`...\n\n"
73
+ model_settings = get_model_metadata(model)
74
+ shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
75
+ update_model_parameters(model_settings) # hijacking the command-line arguments
76
+ unload_model()
77
+ shared.model, shared.tokenizer = load_model(model)
78
+ except:
79
+ cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
80
+ yield cumulative_log
81
+ continue
82
+
83
+ cumulative_log += f"Processing `{shared.model_name}`...\n\n"
84
+ yield cumulative_log + "Tokenizing the input dataset...\n\n"
85
+ encodings = encode(text, add_special_tokens=False)
86
+ seq_len = encodings.shape[1]
87
+ if _max_length:
88
+ max_length = _max_length
89
+ elif hasattr(shared.model.config, 'max_position_embeddings'):
90
+ max_length = shared.model.config.max_position_embeddings
91
+ else:
92
+ max_length = 2048
93
+
94
+ nlls = []
95
+ prev_end_loc = 0
96
+ for begin_loc in tqdm(range(0, seq_len, stride)):
97
+ yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%"
98
+ end_loc = min(begin_loc + max_length, seq_len)
99
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
100
+ input_ids = encodings[:, begin_loc:end_loc]
101
+ target_ids = input_ids.clone()
102
+ target_ids[:, :-trg_len] = -100
103
+ clear_torch_cache()
104
+ with torch.no_grad():
105
+ outputs = shared.model(input_ids=input_ids, labels=target_ids)
106
+
107
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
108
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
109
+ # to the left by 1.
110
+ neg_log_likelihood = outputs.loss
111
+
112
+ nlls.append(neg_log_likelihood)
113
+ prev_end_loc = end_loc
114
+ if end_loc == seq_len:
115
+ break
116
+
117
+ ppl = torch.exp(torch.stack(nlls).mean())
118
+ add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length)
119
+ save_past_evaluations(past_evaluations)
120
+ cumulative_log += f"The perplexity for `{shared.model_name}` is: {float(ppl)}\n\n"
121
+ yield cumulative_log
122
+
123
+
124
+ def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length):
125
+ global past_evaluations
126
+ entry = {
127
+ 'Model': model,
128
+ 'LoRAs': ', '.join(shared.lora_names) or '-',
129
+ 'Dataset': dataset,
130
+ 'Perplexity': perplexity,
131
+ 'stride': str(stride),
132
+ 'max_length': str(max_length),
133
+ 'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
134
+ 'Comment': ''
135
+ }
136
+ past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True)
137
+
138
+
139
+ def is_in_past_evaluations(model, dataset, stride, max_length):
140
+ entries = past_evaluations[(past_evaluations['Model'] == model) &
141
+ (past_evaluations['Dataset'] == dataset) &
142
+ (past_evaluations['max_length'] == str(max_length)) &
143
+ (past_evaluations['stride'] == str(stride))]
144
+
145
+ if entries.shape[0] > 0:
146
+ return True
147
+ else:
148
+ return False
149
+
150
+
151
+ def generate_markdown_table():
152
+ sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date'])
153
+ return sorted_df
modules/exllamav2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from exllamav2 import (
6
+ ExLlamaV2,
7
+ ExLlamaV2Cache,
8
+ ExLlamaV2Cache_8bit,
9
+ ExLlamaV2Config,
10
+ ExLlamaV2Tokenizer
11
+ )
12
+ from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
13
+
14
+ from modules import shared
15
+ from modules.logging_colors import logger
16
+ from modules.text_generation import get_max_prompt_length
17
+
18
+ try:
19
+ import flash_attn
20
+ except ModuleNotFoundError:
21
+ logger.warning(
22
+ 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
23
+ 'to be a lot higher than it could be.\n'
24
+ 'Try installing flash-attention following the instructions here: '
25
+ 'https://github.com/Dao-AILab/flash-attention#installation-and-features'
26
+ )
27
+ pass
28
+ except Exception:
29
+ logger.warning('Failed to load flash-attention due to the following error:\n')
30
+ traceback.print_exc()
31
+
32
+
33
+ class Exllamav2Model:
34
+ def __init__(self):
35
+ pass
36
+
37
+ @classmethod
38
+ def from_pretrained(self, path_to_model):
39
+
40
+ path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
41
+
42
+ config = ExLlamaV2Config()
43
+ config.model_dir = str(path_to_model)
44
+ config.prepare()
45
+
46
+ config.max_seq_len = shared.args.max_seq_len
47
+ config.scale_pos_emb = shared.args.compress_pos_emb
48
+ config.scale_alpha_value = shared.args.alpha_value
49
+ config.no_flash_attn = shared.args.no_flash_attn
50
+ config.num_experts_per_token = int(shared.args.num_experts_per_token)
51
+
52
+ model = ExLlamaV2(config)
53
+
54
+ split = None
55
+ if shared.args.gpu_split:
56
+ split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
57
+
58
+ model.load(split)
59
+
60
+ tokenizer = ExLlamaV2Tokenizer(config)
61
+ if shared.args.cache_8bit:
62
+ cache = ExLlamaV2Cache_8bit(model)
63
+ else:
64
+ cache = ExLlamaV2Cache(model)
65
+
66
+ generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
67
+
68
+ result = self()
69
+ result.model = model
70
+ result.cache = cache
71
+ result.tokenizer = tokenizer
72
+ result.generator = generator
73
+ result.loras = None
74
+ return result, result
75
+
76
+ def encode(self, string, **kwargs):
77
+ return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True)
78
+
79
+ def decode(self, ids, **kwargs):
80
+ if isinstance(ids, list):
81
+ ids = torch.tensor([ids])
82
+ elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
83
+ ids = ids.view(1, -1)
84
+
85
+ return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
86
+
87
+ def get_logits(self, token_ids, **kwargs):
88
+ self.cache.current_seq_len = 0
89
+ if token_ids.shape[-1] > 1:
90
+ self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
91
+
92
+ return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
93
+
94
+ def generate_with_streaming(self, prompt, state):
95
+ settings = ExLlamaV2Sampler.Settings()
96
+
97
+ settings.token_repetition_penalty = state['repetition_penalty']
98
+ settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
99
+
100
+ settings.token_frequency_penalty = state['frequency_penalty']
101
+ settings.token_presence_penalty = state['presence_penalty']
102
+
103
+ settings.temperature = state['temperature']
104
+ settings.top_k = state['top_k']
105
+ settings.top_p = state['top_p']
106
+ settings.top_a = state['top_a']
107
+ settings.min_p = state['min_p']
108
+ settings.tfs = state['tfs']
109
+ settings.typical = state['typical_p']
110
+
111
+ settings.temperature_last = state['temperature_last']
112
+
113
+ settings.mirostat = state['mirostat_mode'] == 2
114
+ settings.mirostat_tau = state['mirostat_tau']
115
+ settings.mirostat_eta = state['mirostat_eta']
116
+
117
+ if state['ban_eos_token']:
118
+ settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
119
+
120
+ if state['custom_token_bans']:
121
+ to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
122
+ if len(to_ban) > 0:
123
+ settings.disallow_tokens(self.tokenizer, to_ban)
124
+
125
+ ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
126
+ ids = ids[:, -get_max_prompt_length(state):]
127
+
128
+ if state['auto_max_new_tokens']:
129
+ max_new_tokens = state['truncation_length'] - ids.shape[-1]
130
+ else:
131
+ max_new_tokens = state['max_new_tokens']
132
+
133
+ self.generator.begin_stream(ids, settings, loras=self.loras)
134
+
135
+ decoded_text = ''
136
+ for i in range(max_new_tokens):
137
+ chunk, eos, _ = self.generator.stream()
138
+ if eos or shared.stop_everything:
139
+ break
140
+
141
+ decoded_text += chunk
142
+ yield decoded_text
143
+
144
+ def generate(self, prompt, state):
145
+ output = ''
146
+ for output in self.generate_with_streaming(prompt, state):
147
+ pass
148
+
149
+ return output
modules/exllamav2_hf.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional, Union
5
+
6
+ import torch
7
+ from exllamav2 import (
8
+ ExLlamaV2,
9
+ ExLlamaV2Cache,
10
+ ExLlamaV2Cache_8bit,
11
+ ExLlamaV2Config
12
+ )
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ from modules import shared
18
+ from modules.logging_colors import logger
19
+
20
+ try:
21
+ import flash_attn
22
+ except ModuleNotFoundError:
23
+ logger.warning(
24
+ 'You are running ExLlamaV2 without flash-attention. This will cause the VRAM usage '
25
+ 'to be a lot higher than it could be.\n'
26
+ 'Try installing flash-attention following the instructions here: '
27
+ 'https://github.com/Dao-AILab/flash-attention#installation-and-features'
28
+ )
29
+ pass
30
+ except Exception:
31
+ logger.warning('Failed to load flash-attention due to the following error:\n')
32
+ traceback.print_exc()
33
+
34
+
35
+ class Exllamav2HF(PreTrainedModel):
36
+ def __init__(self, config: ExLlamaV2Config):
37
+ super().__init__(PretrainedConfig())
38
+ self.ex_config = config
39
+ self.ex_model = ExLlamaV2(config)
40
+ split = None
41
+ if shared.args.gpu_split:
42
+ split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
43
+
44
+ self.ex_model.load(split)
45
+ self.generation_config = GenerationConfig()
46
+ self.loras = None
47
+
48
+ if shared.args.cache_8bit:
49
+ self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
50
+ else:
51
+ self.ex_cache = ExLlamaV2Cache(self.ex_model)
52
+
53
+ self.past_seq = None
54
+ if shared.args.cfg_cache:
55
+ if shared.args.cache_8bit:
56
+ self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
57
+ else:
58
+ self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
59
+
60
+ self.past_seq_negative = None
61
+
62
+ def _validate_model_class(self):
63
+ pass
64
+
65
+ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
66
+ pass
67
+
68
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
69
+ return {'input_ids': input_ids, **kwargs}
70
+
71
+ @property
72
+ def device(self) -> torch.device:
73
+ return torch.device(0)
74
+
75
+ def __call__(self, *args, **kwargs):
76
+ use_cache = kwargs.get('use_cache', True)
77
+ labels = kwargs.get('labels', None)
78
+ past_key_values = kwargs.get('past_key_values', None)
79
+
80
+ if len(args) > 0:
81
+ if not shared.args.cfg_cache:
82
+ logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
83
+ return
84
+
85
+ input_ids = args[0]
86
+ is_negative = True
87
+ past_seq = self.past_seq_negative
88
+ ex_cache = self.ex_cache_negative
89
+ else:
90
+ input_ids = kwargs['input_ids']
91
+ is_negative = False
92
+ past_seq = self.past_seq
93
+ ex_cache = self.ex_cache
94
+
95
+ seq = input_ids[0].tolist()
96
+ if is_negative and past_key_values is not None:
97
+ seq = past_key_values + seq
98
+
99
+ seq_tensor = torch.tensor(seq)
100
+ reset = True
101
+
102
+ # Make the forward call
103
+ if labels is None:
104
+ if past_seq is not None:
105
+ min_length = min(past_seq.shape[0], seq_tensor.shape[0])
106
+ indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
107
+ if len(indices) > 0:
108
+ longest_prefix = indices[0].item()
109
+ else:
110
+ longest_prefix = min_length
111
+
112
+ if longest_prefix > 0:
113
+ reset = False
114
+ ex_cache.current_seq_len = longest_prefix
115
+ if len(seq_tensor) - longest_prefix > 1:
116
+ self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
117
+ elif len(seq_tensor) == longest_prefix:
118
+ # Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
119
+ # because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
120
+ ex_cache.current_seq_len -= 1
121
+
122
+ if reset:
123
+ ex_cache.current_seq_len = 0
124
+ if len(seq_tensor) > 1:
125
+ self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
126
+
127
+ logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
128
+ else:
129
+ ex_cache.current_seq_len = 0
130
+ logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
131
+
132
+ if is_negative:
133
+ self.past_seq_negative = seq_tensor
134
+ else:
135
+ self.past_seq = seq_tensor
136
+
137
+ loss = None
138
+ if labels is not None:
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = logits[..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+ # Flatten the tokens
143
+ loss_fct = CrossEntropyLoss()
144
+ shift_logits = shift_logits.view(-1, logits.shape[-1])
145
+ shift_labels = shift_labels.view(-1)
146
+ # Enable model parallelism
147
+ shift_labels = shift_labels.to(shift_logits.device)
148
+ loss = loss_fct(shift_logits, shift_labels)
149
+
150
+ return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
154
+ assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
155
+ if isinstance(pretrained_model_name_or_path, str):
156
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
157
+
158
+ pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
159
+
160
+ config = ExLlamaV2Config()
161
+ config.model_dir = str(pretrained_model_name_or_path)
162
+ config.prepare()
163
+
164
+ config.max_seq_len = shared.args.max_seq_len
165
+ config.scale_pos_emb = shared.args.compress_pos_emb
166
+ config.scale_alpha_value = shared.args.alpha_value
167
+ config.no_flash_attn = shared.args.no_flash_attn
168
+ config.num_experts_per_token = int(shared.args.num_experts_per_token)
169
+
170
+ return Exllamav2HF(config)
modules/extensions.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from functools import partial
3
+ from inspect import signature
4
+
5
+ import gradio as gr
6
+
7
+ import extensions
8
+ import modules.shared as shared
9
+ from modules.logging_colors import logger
10
+
11
+ state = {}
12
+ available_extensions = []
13
+ setup_called = set()
14
+
15
+
16
+ def apply_settings(extension, name):
17
+ if not hasattr(extension, 'params'):
18
+ return
19
+
20
+ for param in extension.params:
21
+ _id = f"{name}-{param}"
22
+ shared.default_settings[_id] = extension.params[param]
23
+ if _id in shared.settings:
24
+ extension.params[param] = shared.settings[_id]
25
+
26
+
27
+ def load_extensions():
28
+ global state, setup_called
29
+ state = {}
30
+ for i, name in enumerate(shared.args.extensions):
31
+ if name in available_extensions:
32
+ if name != 'api':
33
+ logger.info(f'Loading the extension "{name}"')
34
+ try:
35
+ try:
36
+ exec(f"import extensions.{name}.script")
37
+ except ModuleNotFoundError:
38
+ logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
39
+ raise
40
+
41
+ extension = getattr(extensions, name).script
42
+
43
+ # Only run setup() and apply settings from settings.yaml once
44
+ if extension not in setup_called:
45
+ apply_settings(extension, name)
46
+ if hasattr(extension, "setup"):
47
+ extension.setup()
48
+
49
+ setup_called.add(extension)
50
+
51
+ state[name] = [True, i]
52
+ except:
53
+ logger.error(f'Failed to load the extension "{name}".')
54
+ traceback.print_exc()
55
+
56
+
57
+ # This iterator returns the extensions in the order specified in the command-line
58
+ def iterator():
59
+ for name in sorted(state, key=lambda x: state[x][1]):
60
+ if state[name][0]:
61
+ yield getattr(extensions, name).script, name
62
+
63
+
64
+ # Extension functions that map string -> string
65
+ def _apply_string_extensions(function_name, text, state, is_chat=False):
66
+ for extension, _ in iterator():
67
+ if hasattr(extension, function_name):
68
+ func = getattr(extension, function_name)
69
+
70
+ # Handle old extensions without the 'state' arg or
71
+ # the 'is_chat' kwarg
72
+ count = 0
73
+ has_chat = False
74
+ for k in signature(func).parameters:
75
+ if k == 'is_chat':
76
+ has_chat = True
77
+ else:
78
+ count += 1
79
+
80
+ if count == 2:
81
+ args = [text, state]
82
+ else:
83
+ args = [text]
84
+
85
+ if has_chat:
86
+ kwargs = {'is_chat': is_chat}
87
+ else:
88
+ kwargs = {}
89
+
90
+ text = func(*args, **kwargs)
91
+
92
+ return text
93
+
94
+
95
+ # Extension functions that map string -> string
96
+ def _apply_chat_input_extensions(text, visible_text, state):
97
+ for extension, _ in iterator():
98
+ if hasattr(extension, 'chat_input_modifier'):
99
+ text, visible_text = extension.chat_input_modifier(text, visible_text, state)
100
+
101
+ return text, visible_text
102
+
103
+
104
+ # custom_generate_chat_prompt handling - currently only the first one will work
105
+ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
106
+ for extension, _ in iterator():
107
+ if hasattr(extension, 'custom_generate_chat_prompt'):
108
+ return extension.custom_generate_chat_prompt(text, state, **kwargs)
109
+
110
+ return None
111
+
112
+
113
+ # Extension that modifies the input parameters before they are used
114
+ def _apply_state_modifier_extensions(state):
115
+ for extension, _ in iterator():
116
+ if hasattr(extension, "state_modifier"):
117
+ state = getattr(extension, "state_modifier")(state)
118
+
119
+ return state
120
+
121
+
122
+ # Extension that modifies the chat history before it is used
123
+ def _apply_history_modifier_extensions(history):
124
+ for extension, _ in iterator():
125
+ if hasattr(extension, "history_modifier"):
126
+ history = getattr(extension, "history_modifier")(history)
127
+
128
+ return history
129
+
130
+
131
+ # Extension functions that override the default tokenizer output - The order of execution is not defined
132
+ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
133
+ for extension, _ in iterator():
134
+ if hasattr(extension, function_name):
135
+ prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
136
+
137
+ return prompt, input_ids, input_embeds
138
+
139
+
140
+ # Allow extensions to add their own logits processors to the stack being run.
141
+ # Each extension would call `processor_list.append({their LogitsProcessor}())`.
142
+ def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
143
+ for extension, _ in iterator():
144
+ if hasattr(extension, function_name):
145
+ result = getattr(extension, function_name)(processor_list, input_ids)
146
+ if type(result) is list:
147
+ processor_list = result
148
+
149
+ return processor_list
150
+
151
+
152
+ # Get prompt length in tokens after applying extension functions which override the default tokenizer output
153
+ # currently only the first one will work
154
+ def _apply_custom_tokenized_length(prompt):
155
+ for extension, _ in iterator():
156
+ if hasattr(extension, 'custom_tokenized_length'):
157
+ return getattr(extension, 'custom_tokenized_length')(prompt)
158
+
159
+ return None
160
+
161
+
162
+ # Custom generate reply handling - currently only the first one will work
163
+ def _apply_custom_generate_reply():
164
+ for extension, _ in iterator():
165
+ if hasattr(extension, 'custom_generate_reply'):
166
+ return getattr(extension, 'custom_generate_reply')
167
+
168
+ return None
169
+
170
+
171
+ def _apply_custom_css():
172
+ all_css = ''
173
+ for extension, _ in iterator():
174
+ if hasattr(extension, 'custom_css'):
175
+ all_css += getattr(extension, 'custom_css')()
176
+
177
+ return all_css
178
+
179
+
180
+ def _apply_custom_js():
181
+ all_js = ''
182
+ for extension, _ in iterator():
183
+ if hasattr(extension, 'custom_js'):
184
+ all_js += getattr(extension, 'custom_js')()
185
+
186
+ return all_js
187
+
188
+
189
+ def create_extensions_block():
190
+ to_display = []
191
+ for extension, name in iterator():
192
+ if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
193
+ to_display.append((extension, name))
194
+
195
+ # Creating the extension ui elements
196
+ if len(to_display) > 0:
197
+ with gr.Column(elem_id="extensions"):
198
+ for row in to_display:
199
+ extension, _ = row
200
+ extension.ui()
201
+
202
+
203
+ def create_extensions_tabs():
204
+ for extension, name in iterator():
205
+ if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
206
+ display_name = getattr(extension, 'params', {}).get('display_name', name)
207
+ with gr.Tab(display_name, elem_classes="extension-tab"):
208
+ extension.ui()
209
+
210
+
211
+ EXTENSION_MAP = {
212
+ "input": partial(_apply_string_extensions, "input_modifier"),
213
+ "output": partial(_apply_string_extensions, "output_modifier"),
214
+ "chat_input": _apply_chat_input_extensions,
215
+ "state": _apply_state_modifier_extensions,
216
+ "history": _apply_history_modifier_extensions,
217
+ "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
218
+ "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
219
+ 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
220
+ "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
221
+ "custom_generate_reply": _apply_custom_generate_reply,
222
+ "tokenized_length": _apply_custom_tokenized_length,
223
+ "css": _apply_custom_css,
224
+ "js": _apply_custom_js
225
+ }
226
+
227
+
228
+ def apply_extensions(typ, *args, **kwargs):
229
+ if typ not in EXTENSION_MAP:
230
+ raise ValueError(f"Invalid extension type {typ}")
231
+
232
+ return EXTENSION_MAP[typ](*args, **kwargs)
modules/github.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from pathlib import Path
3
+
4
+ new_extensions = set()
5
+
6
+
7
+ def clone_or_pull_repository(github_url):
8
+ global new_extensions
9
+
10
+ repository_folder = Path("extensions")
11
+ repo_name = github_url.rstrip("/").split("/")[-1].split(".")[0]
12
+
13
+ # Check if the repository folder exists
14
+ if not repository_folder.exists():
15
+ repository_folder.mkdir(parents=True)
16
+
17
+ repo_path = repository_folder / repo_name
18
+
19
+ # Check if the repository is already cloned
20
+ if repo_path.exists():
21
+ yield f"Updating {github_url}..."
22
+ # Perform a 'git pull' to update the repository
23
+ try:
24
+ pull_output = subprocess.check_output(["git", "-C", repo_path, "pull"], stderr=subprocess.STDOUT)
25
+ yield "Done."
26
+ return pull_output.decode()
27
+ except subprocess.CalledProcessError as e:
28
+ return str(e)
29
+
30
+ # Clone the repository
31
+ try:
32
+ yield f"Cloning {github_url}..."
33
+ clone_output = subprocess.check_output(["git", "clone", github_url, repo_path], stderr=subprocess.STDOUT)
34
+ new_extensions.add(repo_name)
35
+ yield f"The extension `{repo_name}` has been downloaded.\n\nPlease close the the web UI completely and launch it again to be able to load it."
36
+ return clone_output.decode()
37
+ except subprocess.CalledProcessError as e:
38
+ return str(e)
modules/grammar/__pycache__/grammar_utils.cpython-311.pyc ADDED
Binary file (33 kB). View file
 
modules/grammar/__pycache__/logits_process.cpython-311.pyc ADDED
Binary file (5.46 kB). View file