John6666 commited on
Commit
7b3925d
·
verified ·
1 Parent(s): b533696

Upload 12 files

Browse files
9em124t2-499968/clip_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d7b0548d12fa649370896982c2af9d03d43285b782bd47639c96e6e0b29473c
3
+ size 1713067838
9em124t2-499968/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_project: joy-caption-1
2
+ device_batch_size: 2
3
+ batch_size: 256
4
+ learning_rate: 0.0002
5
+ warmup_samples: 18000
6
+ max_samples: 500000
7
+ save_every: 50000
8
+ test_every: 50000
9
+ use_amp: true
10
+ grad_scaler: true
11
+ lr_scheduler_type: cosine
12
+ min_lr_ratio: 0.0
13
+ allow_tf32: true
14
+ seed: 69
15
+ num_workers: 8
16
+ optimizer_type: adamw
17
+ adam_beta1: 0.9
18
+ adam_beta2: 0.999
19
+ adam_eps: 1.0e-08
20
+ adam_weight_decay: 0.0
21
+ clip_grad_norm: 1.0
22
+ dataset: fancyfeast/joy-captioning-20240917a
23
+ clip_model: google/siglip-so400m-patch14-384
24
+ text_model: meta-llama/Meta-Llama-3.1-8B
25
+ resume: null
26
+ gradient_checkpointing: false
27
+ test_size: 2048
28
+ grad_scaler_init: 65536.0
29
+ max_caption_length: 257
30
+ num_image_tokens: 32
31
+ adapter_type: mlp
32
+ text_model_dtype: bfloat16
33
+ pre_test: false
34
+ train_image_model: true
35
+ image_model_lr: null
36
+ train_lora: true
37
+ lora_r: 64
38
+ lora_alpha: 16
39
+ lora_dropout: 0.1
9em124t2-499968/image_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e53c3bf8df745a3c19ae3c70dbf9bf23cfdc8f3fdb937000a4eafd2a36914661
3
+ size 86067714
9em124t2-499968/text_model/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Meta-Llama-3.1-8B
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.12.0
9em124t2-499968/text_model/adapter_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Meta-Llama-3.1-8B",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 16,
14
+ "lora_dropout": 0.1,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 64,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "q_proj",
24
+ "v_proj"
25
+ ],
26
+ "task_type": "CAUSAL_LM",
27
+ "use_dora": false,
28
+ "use_rslora": false
29
+ }
9em124t2-499968/text_model/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b48221de174ab0db7b46b4833118c5c0a4c2bf0b51b77b4cc4ab04651bd06cca
3
+ size 109069176
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Joy Caption Pre Alpha Mod
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.43.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: Joy Caption Alpha One Mod
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -2,20 +2,38 @@ import spaces
2
  import gradio as gr
3
  from joycaption import stream_chat_mod, get_text_model, change_text_model, get_repo_gguf
4
 
5
- JC_TITLE_MD = "<h1><center>JoyCaption Pre-Alpha Mod</center></h1>"
6
- JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-pre-alpha](https://huggingface.co/spaces/fancyfeast/joy-caption-pre-alpha),
7
  [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha)"""
8
 
9
  css = """
10
- .info {text-align:center; display:inline-flex; align-items:center !important}
11
  """
12
 
13
- with gr.Blocks(delete_cache=(60, 3600)) as demo:
14
  gr.HTML(JC_TITLE_MD)
15
  with gr.Row():
16
  with gr.Column():
17
  with gr.Group():
18
  jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with gr.Accordion("Advanced", open=False):
20
  with gr.Row():
21
  jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a huggingface model repo_id to want to use.",
@@ -28,8 +46,8 @@ with gr.Blocks(delete_cache=(60, 3600)) as demo:
28
  jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False)
29
  with gr.Row():
30
  jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens")
31
- jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.5, step=0.1, label="Temperature")
32
- jc_topk = gr.Slider(minimum=0, maximum=100, value=40, step=10, label="Top-k")
33
  jc_run_button = gr.Button("Caption", variant="primary")
34
 
35
  with gr.Column():
@@ -38,11 +56,11 @@ with gr.Blocks(delete_cache=(60, 3600)) as demo:
38
  gr.LoginButton()
39
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
40
 
41
- jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_tokens, jc_topk, jc_temperature], outputs=[jc_output_caption])
42
  jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
43
  #jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
44
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
45
 
46
  if __name__ == "__main__":
47
- demo.queue()
48
  demo.launch()
 
2
  import gradio as gr
3
  from joycaption import stream_chat_mod, get_text_model, change_text_model, get_repo_gguf
4
 
5
+ JC_TITLE_MD = "<h1><center>JoyCaption Alpha One Mod</center></h1>"
6
+ JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one),
7
  [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha)"""
8
 
9
  css = """
10
+ .info {text-align:center; !important}
11
  """
12
 
13
+ with gr.Blocks(fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
14
  gr.HTML(JC_TITLE_MD)
15
  with gr.Row():
16
  with gr.Column():
17
  with gr.Group():
18
  jc_input_image = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"], height=384)
19
+ with gr.Row():
20
+ jc_caption_type = gr.Dropdown(
21
+ choices=["descriptive", "training_prompt", "rng-tags"],
22
+ label="Caption Type",
23
+ value="descriptive",
24
+ )
25
+ jc_caption_tone = gr.Dropdown(
26
+ choices=["formal", "informal"],
27
+ label="Caption Tone",
28
+ value="formal",
29
+ )
30
+ jc_caption_length = gr.Dropdown(
31
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] +
32
+ [str(i) for i in range(20, 261, 10)],
33
+ label="Caption Length",
34
+ value="any",
35
+ )
36
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.", elem_classes="info")
37
  with gr.Accordion("Advanced", open=False):
38
  with gr.Row():
39
  jc_text_model = gr.Dropdown(label="LLM Model", info="You can enter a huggingface model repo_id to want to use.",
 
46
  jc_use_inference_client = gr.Checkbox(label="Use Inference Client", value=False, visible=False)
47
  with gr.Row():
48
  jc_tokens = gr.Slider(minimum=1, maximum=4096, value=300, step=1, label="Max tokens")
49
+ jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
50
+ jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
51
  jc_run_button = gr.Button("Caption", variant="primary")
52
 
53
  with gr.Column():
 
56
  gr.LoginButton()
57
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
58
 
59
+ jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length, jc_tokens, jc_topp, jc_temperature], outputs=[jc_output_caption])
60
  jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
61
  #jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
62
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
63
 
64
  if __name__ == "__main__":
65
+ #demo.queue()
66
  demo.launch()
joycaption.py CHANGED
@@ -8,7 +8,9 @@ import torch
8
  import torch.amp.autocast_mode
9
  from PIL import Image
10
  import os
 
11
  import gc
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -24,24 +26,89 @@ llm_models = {
24
  }
25
 
26
  CLIP_PATH = "google/siglip-so400m-patch14-384"
27
- VLM_PROMPT = "A descriptive caption for this image:\n"
28
  MODEL_PATH = list(llm_models.keys())[0]
29
- CHECKPOINT_PATH = Path("wpkklhc6")
30
- TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class ImageAdapter(nn.Module):
33
- def __init__(self, input_features: int, output_features: int):
34
  super().__init__()
 
 
 
 
 
35
  self.linear1 = nn.Linear(input_features, output_features)
36
  self.activation = nn.GELU()
37
  self.linear2 = nn.Linear(output_features, output_features)
38
-
 
 
 
 
 
 
 
 
 
 
39
  def forward(self, vision_outputs: torch.Tensor):
40
- x = self.linear1(vision_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  x = self.activation(x)
42
  x = self.linear2(x)
 
 
 
 
 
 
 
 
 
 
 
43
  return x
44
 
 
 
 
 
45
  # https://huggingface.co/docs/transformers/v4.44.2/gguf
46
  # https://github.com/city96/ComfyUI-GGUF/issues/7
47
  # https://github.com/THUDM/ChatGLM-6B/issues/18
@@ -50,14 +117,18 @@ class ImageAdapter(nn.Module):
50
  # https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
51
  # https://huggingface.co/google/flan-ul2/discussions/8
52
  # https://huggingface.co/blog/4bit-transformers-bitsandbytes
 
 
53
  tokenizer = None
54
  text_model_client = None
55
  text_model = None
56
  image_adapter = None
 
57
  def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True):
58
  global tokenizer
59
  global text_model
60
  global image_adapter
 
61
  global text_model_client #
62
  global use_inference_client #
63
  try:
@@ -77,8 +148,14 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_n
77
  if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
78
  elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
79
  else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
 
 
 
 
 
 
80
  print("Loading image adapter")
81
- image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size).eval().to("cpu")
82
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
83
  image_adapter.eval().to(device)
84
  except Exception as e:
@@ -93,57 +170,95 @@ load_text_model.zerogpu = True
93
  # Load CLIP
94
  print("Loading CLIP")
95
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
96
- clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model.eval().requires_grad_(False).to(device)
 
 
 
 
 
 
 
 
 
97
 
98
  # Tokenizer
99
  # LLM
100
  # Image Adapter
101
  load_text_model()
102
 
 
103
  @spaces.GPU()
104
  @torch.no_grad()
105
- def stream_chat(input_image: Image.Image):
106
  torch.cuda.empty_cache()
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Preprocess image
109
- image = clip_processor(images=input_image, return_tensors='pt').pixel_values
110
- image = image.to(device)
 
 
 
111
 
112
  # Tokenize the prompt
113
- prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
114
 
115
  # Embed image
116
- with torch.amp.autocast_mode.autocast(device, enabled=True):
117
- vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
118
- image_features = vision_outputs.hidden_states[-2]
119
  embedded_images = image_adapter(image_features)
120
- embedded_images = embedded_images.to(device)
121
 
122
  # Embed prompt
123
- prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
124
  assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
125
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
 
126
 
127
  # Construct prompts
128
  inputs_embeds = torch.cat([
129
  embedded_bos.expand(embedded_images.shape[0], -1, -1),
130
  embedded_images.to(dtype=embedded_bos.dtype),
131
  prompt_embeds.expand(embedded_images.shape[0], -1, -1),
 
132
  ], dim=1)
133
 
134
  input_ids = torch.cat([
135
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
136
  torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
137
  prompt,
138
- ], dim=1).to(device)
 
139
  attention_mask = torch.ones_like(input_ids)
140
 
141
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
142
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
 
143
 
144
  # Trim off the prompt
145
  generate_ids = generate_ids[:, input_ids.shape[1]:]
146
- if generate_ids[0][-1] == tokenizer.eos_token_id:
147
  generate_ids = generate_ids[:, :-1]
148
 
149
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
@@ -153,23 +268,47 @@ def stream_chat(input_image: Image.Image):
153
 
154
  @spaces.GPU()
155
  @torch.no_grad()
156
- def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: int=10, temperature: float=0.5, progress=gr.Progress(track_tqdm=True)):
157
  global use_inference_client
158
  global text_model
159
  torch.cuda.empty_cache()
160
  gc.collect()
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # Preprocess image
163
- image = clip_processor(images=input_image, return_tensors='pt').pixel_values
164
- image = image.to(device)
 
 
 
165
 
166
  # Tokenize the prompt
167
- prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
168
 
169
  # Embed image
170
  with torch.amp.autocast_mode.autocast(device, enabled=True):
171
- vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
172
- image_features = vision_outputs.hidden_states[-2]
173
  embedded_images = image_adapter(image_features)
174
  embedded_images = embedded_images.to(device)
175
 
@@ -177,34 +316,34 @@ def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: in
177
  prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
178
  assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
179
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
 
180
 
181
  # Construct prompts
182
  inputs_embeds = torch.cat([
183
  embedded_bos.expand(embedded_images.shape[0], -1, -1),
184
  embedded_images.to(dtype=embedded_bos.dtype),
185
  prompt_embeds.expand(embedded_images.shape[0], -1, -1),
 
186
  ], dim=1)
187
 
188
  input_ids = torch.cat([
189
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
190
  torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
191
  prompt,
 
192
  ], dim=1).to(device)
193
  attention_mask = torch.ones_like(input_ids)
194
 
195
- # https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
196
- # https://github.com/huggingface/transformers/issues/6535
197
- # https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
198
- # https://github.com/ggerganov/llama.cpp/discussions/7712
199
- # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
200
- # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
201
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
202
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
203
- max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
 
 
204
 
205
  # Trim off the prompt
206
  generate_ids = generate_ids[:, input_ids.shape[1]:]
207
- if generate_ids[0][-1] == tokenizer.eos_token_id:
208
  generate_ids = generate_ids[:, :-1]
209
 
210
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
@@ -212,6 +351,14 @@ def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: in
212
  return caption.strip()
213
 
214
 
 
 
 
 
 
 
 
 
215
  def is_repo_name(s):
216
  import re
217
  return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
@@ -290,16 +437,39 @@ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_f
290
  # original UI
291
  with gr.Blocks() as demo:
292
  gr.HTML(TITLE)
 
293
  with gr.Row():
294
  with gr.Column():
295
  input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  run_button = gr.Button("Caption")
297
 
298
  with gr.Column():
299
  output_caption = gr.Textbox(label="Caption")
300
 
301
- run_button.click(fn=stream_chat, inputs=[input_image], outputs=[output_caption])
302
 
303
 
304
  if __name__ == "__main__":
305
- demo.launch()
 
8
  import torch.amp.autocast_mode
9
  from PIL import Image
10
  import os
11
+ import torchvision.transforms.functional as TVF
12
  import gc
13
+ from peft import PeftConfig
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
26
  }
27
 
28
  CLIP_PATH = "google/siglip-so400m-patch14-384"
 
29
  MODEL_PATH = list(llm_models.keys())[0]
30
+ CHECKPOINT_PATH = Path("9em124t2-499968")
31
+ LORA_PATH = CHECKPOINT_PATH / "text_model"
32
+ TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
33
+ CAPTION_TYPE_MAP = {
34
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
35
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
36
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
37
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
38
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
39
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
40
+
41
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
42
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
43
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
44
+
45
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
46
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
47
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
48
+ }
49
 
50
  class ImageAdapter(nn.Module):
51
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
52
  super().__init__()
53
+ self.deep_extract = deep_extract
54
+
55
+ if self.deep_extract:
56
+ input_features = input_features * 5
57
+
58
  self.linear1 = nn.Linear(input_features, output_features)
59
  self.activation = nn.GELU()
60
  self.linear2 = nn.Linear(output_features, output_features)
61
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
62
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
63
+
64
+ # Mode token
65
+ #self.mode_token = nn.Embedding(n_modes, output_features)
66
+ #self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
67
+
68
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
69
+ self.other_tokens = nn.Embedding(3, output_features)
70
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
71
+
72
  def forward(self, vision_outputs: torch.Tensor):
73
+ if self.deep_extract:
74
+ x = torch.concat((
75
+ vision_outputs[-2],
76
+ vision_outputs[3],
77
+ vision_outputs[7],
78
+ vision_outputs[13],
79
+ vision_outputs[20],
80
+ ), dim=-1)
81
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
82
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
83
+ else:
84
+ x = vision_outputs[-2]
85
+
86
+ x = self.ln1(x)
87
+
88
+ if self.pos_emb is not None:
89
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
90
+ x = x + self.pos_emb
91
+
92
+ x = self.linear1(x)
93
  x = self.activation(x)
94
  x = self.linear2(x)
95
+
96
+ # Mode token
97
+ #mode_token = self.mode_token(mode)
98
+ #assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
99
+ #x = torch.cat((x, mode_token), dim=1)
100
+
101
+ # <|image_start|>, IMAGE, <|image_end|>
102
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
103
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
104
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
105
+
106
  return x
107
 
108
+ def get_eot_embedding(self):
109
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
110
+
111
+
112
  # https://huggingface.co/docs/transformers/v4.44.2/gguf
113
  # https://github.com/city96/ComfyUI-GGUF/issues/7
114
  # https://github.com/THUDM/ChatGLM-6B/issues/18
 
117
  # https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
118
  # https://huggingface.co/google/flan-ul2/discussions/8
119
  # https://huggingface.co/blog/4bit-transformers-bitsandbytes
120
+ # https://huggingface.co/docs/transformers/main/en/peft
121
+ # https://huggingface.co/docs/transformers/main/en/peft#enable-and-disable-adapters
122
  tokenizer = None
123
  text_model_client = None
124
  text_model = None
125
  image_adapter = None
126
+ peft_config = None
127
  def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True):
128
  global tokenizer
129
  global text_model
130
  global image_adapter
131
+ global peft_config
132
  global text_model_client #
133
  global use_inference_client #
134
  try:
 
148
  if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
149
  elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
150
  else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
151
+ if LORA_PATH.exists():
152
+ print("Loading VLM's custom text model")
153
+ if is_nf4: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device, quantization_config=nf4_config)
154
+ else: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device)
155
+ text_model.add_adapter(peft_config)
156
+ text_model.enable_adapters()
157
  print("Loading image adapter")
158
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
159
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
160
  image_adapter.eval().to(device)
161
  except Exception as e:
 
170
  # Load CLIP
171
  print("Loading CLIP")
172
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
173
+ clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
174
+
175
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
176
+ print("Loading VLM's custom vision model")
177
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
178
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
179
+ clip_model.load_state_dict(checkpoint)
180
+ del checkpoint
181
+
182
+ clip_model.eval().requires_grad_(False).to(device)
183
 
184
  # Tokenizer
185
  # LLM
186
  # Image Adapter
187
  load_text_model()
188
 
189
+
190
  @spaces.GPU()
191
  @torch.no_grad()
192
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
193
  torch.cuda.empty_cache()
194
 
195
+ # 'any' means no length specified
196
+ length = None if caption_length == "any" else caption_length
197
+
198
+ if isinstance(length, str):
199
+ try:
200
+ length = int(length)
201
+ except ValueError:
202
+ pass
203
+
204
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
205
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
206
+ caption_tone = "formal"
207
+
208
+ # Build prompt
209
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
210
+ if prompt_key not in CAPTION_TYPE_MAP:
211
+ raise ValueError(f"Invalid caption type: {prompt_key}")
212
+
213
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
214
+ print(f"Prompt: {prompt_str}")
215
+
216
  # Preprocess image
217
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
218
+ image = input_image.resize((384, 384), Image.LANCZOS)
219
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
220
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
221
+ pixel_values = pixel_values.to('cuda')
222
 
223
  # Tokenize the prompt
224
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
225
 
226
  # Embed image
227
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
228
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
229
+ image_features = vision_outputs.hidden_states
230
  embedded_images = image_adapter(image_features)
231
+ embedded_images = embedded_images.to('cuda')
232
 
233
  # Embed prompt
234
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
235
  assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
236
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
237
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
238
 
239
  # Construct prompts
240
  inputs_embeds = torch.cat([
241
  embedded_bos.expand(embedded_images.shape[0], -1, -1),
242
  embedded_images.to(dtype=embedded_bos.dtype),
243
  prompt_embeds.expand(embedded_images.shape[0], -1, -1),
244
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
245
  ], dim=1)
246
 
247
  input_ids = torch.cat([
248
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
249
  torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
250
  prompt,
251
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
252
+ ], dim=1).to('cuda')
253
  attention_mask = torch.ones_like(input_ids)
254
 
255
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
256
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
257
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
258
 
259
  # Trim off the prompt
260
  generate_ids = generate_ids[:, input_ids.shape[1]:]
261
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
262
  generate_ids = generate_ids[:, :-1]
263
 
264
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
 
268
 
269
  @spaces.GPU()
270
  @torch.no_grad()
271
+ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str:
272
  global use_inference_client
273
  global text_model
274
  torch.cuda.empty_cache()
275
  gc.collect()
276
 
277
+ # 'any' means no length specified
278
+ length = None if caption_length == "any" else caption_length
279
+
280
+ if isinstance(length, str):
281
+ try:
282
+ length = int(length)
283
+ except ValueError:
284
+ pass
285
+
286
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
287
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
288
+ caption_tone = "formal"
289
+
290
+ # Build prompt
291
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
292
+ if prompt_key not in CAPTION_TYPE_MAP:
293
+ raise ValueError(f"Invalid caption type: {prompt_key}")
294
+
295
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
296
+ print(f"Prompt: {prompt_str}")
297
+
298
  # Preprocess image
299
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
300
+ image = input_image.resize((384, 384), Image.LANCZOS)
301
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
302
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
303
+ pixel_values = pixel_values.to(device)
304
 
305
  # Tokenize the prompt
306
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
307
 
308
  # Embed image
309
  with torch.amp.autocast_mode.autocast(device, enabled=True):
310
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
311
+ image_features = vision_outputs.hidden_states
312
  embedded_images = image_adapter(image_features)
313
  embedded_images = embedded_images.to(device)
314
 
 
316
  prompt_embeds = text_model.model.embed_tokens(prompt.to(device))
317
  assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
318
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
319
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
320
 
321
  # Construct prompts
322
  inputs_embeds = torch.cat([
323
  embedded_bos.expand(embedded_images.shape[0], -1, -1),
324
  embedded_images.to(dtype=embedded_bos.dtype),
325
  prompt_embeds.expand(embedded_images.shape[0], -1, -1),
326
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
327
  ], dim=1)
328
 
329
  input_ids = torch.cat([
330
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
331
  torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
332
  prompt,
333
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
334
  ], dim=1).to(device)
335
  attention_mask = torch.ones_like(input_ids)
336
 
337
+ text_model.to(device)
 
 
 
 
 
338
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
339
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
340
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
341
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
342
+ do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
343
 
344
  # Trim off the prompt
345
  generate_ids = generate_ids[:, input_ids.shape[1]:]
346
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
347
  generate_ids = generate_ids[:, :-1]
348
 
349
  caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
 
351
  return caption.strip()
352
 
353
 
354
+ # https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
355
+ # https://github.com/huggingface/transformers/issues/6535
356
+ # https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
357
+ # https://github.com/ggerganov/llama.cpp/discussions/7712
358
+ # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
359
+ # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
360
+
361
+
362
  def is_repo_name(s):
363
  import re
364
  return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
 
437
  # original UI
438
  with gr.Blocks() as demo:
439
  gr.HTML(TITLE)
440
+
441
  with gr.Row():
442
  with gr.Column():
443
  input_image = gr.Image(type="pil", label="Input Image")
444
+
445
+ caption_type = gr.Dropdown(
446
+ choices=["descriptive", "training_prompt", "rng-tags"],
447
+ label="Caption Type",
448
+ value="descriptive",
449
+ )
450
+
451
+ caption_tone = gr.Dropdown(
452
+ choices=["formal", "informal"],
453
+ label="Caption Tone",
454
+ value="formal",
455
+ )
456
+
457
+ caption_length = gr.Dropdown(
458
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] +
459
+ [str(i) for i in range(20, 261, 10)],
460
+ label="Caption Length",
461
+ value="any",
462
+ )
463
+
464
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
465
+
466
  run_button = gr.Button("Caption")
467
 
468
  with gr.Column():
469
  output_caption = gr.Textbox(label="Caption")
470
 
471
+ run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length], outputs=[output_caption])
472
 
473
 
474
  if __name__ == "__main__":
475
+ demo.launch()
requirements.txt CHANGED
@@ -7,4 +7,6 @@ bitsandbytes
7
  Pillow
8
  protobuf
9
  gguf
10
- numpy<2.0.0
 
 
 
7
  Pillow
8
  protobuf
9
  gguf
10
+ numpy<2.0.0
11
+ peft
12
+ torchvision