hysts HF staff commited on
Commit
8c029ff
·
1 Parent(s): abc3afc
Files changed (6) hide show
  1. .pre-commit-config.yaml +60 -0
  2. .vscode/settings.json +30 -0
  3. README.md +5 -4
  4. app.py +145 -0
  5. requirements.txt +6 -0
  6. style.css +11 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Gemma 2 2b It
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gemma 2 2B IT
3
+ emoji: 😻
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Chatbot
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ BitsAndBytesConfig,
11
+ GemmaTokenizerFast,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+ DESCRIPTION = """\
16
+ # Gemma 2 2B IT
17
+
18
+ Gemma 2 is Google's latest iteration of open LLMs.
19
+ This is a demo of [`google/gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it), fine-tuned for instruction following.
20
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
21
+
22
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it) and the 9B version in [this Space](https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it).
23
+ """
24
+
25
+ MAX_MAX_NEW_TOKENS = 2048
26
+ DEFAULT_MAX_NEW_TOKENS = 1024
27
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
28
+
29
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ model_id = "gg-hf/gemma-2-2b-it"
32
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ device_map="auto",
36
+ torch_dtype=torch.bfloat16,
37
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True),
38
+ )
39
+ model.config.sliding_window = 4096
40
+ model.eval()
41
+
42
+
43
+ @spaces.GPU(duration=90)
44
+ def generate(
45
+ message: str,
46
+ chat_history: list[tuple[str, str]],
47
+ max_new_tokens: int = 1024,
48
+ temperature: float = 0.6,
49
+ top_p: float = 0.9,
50
+ top_k: int = 50,
51
+ repetition_penalty: float = 1.2,
52
+ ) -> Iterator[str]:
53
+ conversation = []
54
+ for user, assistant in chat_history:
55
+ conversation.extend(
56
+ [
57
+ {"role": "user", "content": user},
58
+ {"role": "assistant", "content": assistant},
59
+ ]
60
+ )
61
+ conversation.append({"role": "user", "content": message})
62
+
63
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
64
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
65
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
66
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
67
+ input_ids = input_ids.to(model.device)
68
+
69
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
70
+ generate_kwargs = dict(
71
+ {"input_ids": input_ids},
72
+ streamer=streamer,
73
+ max_new_tokens=max_new_tokens,
74
+ do_sample=True,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ temperature=temperature,
78
+ num_beams=1,
79
+ repetition_penalty=repetition_penalty,
80
+ )
81
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
82
+ t.start()
83
+
84
+ outputs = []
85
+ for text in streamer:
86
+ outputs.append(text)
87
+ yield "".join(outputs)
88
+
89
+
90
+ chat_interface = gr.ChatInterface(
91
+ fn=generate,
92
+ additional_inputs=[
93
+ gr.Slider(
94
+ label="Max new tokens",
95
+ minimum=1,
96
+ maximum=MAX_MAX_NEW_TOKENS,
97
+ step=1,
98
+ value=DEFAULT_MAX_NEW_TOKENS,
99
+ ),
100
+ gr.Slider(
101
+ label="Temperature",
102
+ minimum=0.1,
103
+ maximum=4.0,
104
+ step=0.1,
105
+ value=0.6,
106
+ ),
107
+ gr.Slider(
108
+ label="Top-p (nucleus sampling)",
109
+ minimum=0.05,
110
+ maximum=1.0,
111
+ step=0.05,
112
+ value=0.9,
113
+ ),
114
+ gr.Slider(
115
+ label="Top-k",
116
+ minimum=1,
117
+ maximum=1000,
118
+ step=1,
119
+ value=50,
120
+ ),
121
+ gr.Slider(
122
+ label="Repetition penalty",
123
+ minimum=1.0,
124
+ maximum=2.0,
125
+ step=0.05,
126
+ value=1.2,
127
+ ),
128
+ ],
129
+ stop_btn=None,
130
+ examples=[
131
+ ["Hello there! How are you doing?"],
132
+ ["Can you explain briefly to me what is the Python programming language?"],
133
+ ["Explain the plot of Cinderella in a sentence."],
134
+ ["How many hours does it take a man to eat a Helicopter?"],
135
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
136
+ ],
137
+ )
138
+
139
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
140
+ gr.Markdown(DESCRIPTION)
141
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
142
+ chat_interface.render()
143
+
144
+ if __name__ == "__main__":
145
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ bitsandbytes==0.43.2
3
+ gradio==4.39.0
4
+ spaces==0.29.2
5
+ torch==2.2.0
6
+ transformers==4.43.3
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }