loubnabnl HF staff commited on
Commit
963bc16
·
verified ·
1 Parent(s): ba519e2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -1
README.md CHANGED
@@ -192,8 +192,73 @@ def parse_response(text: str) -> str | dict[str, any]:
192
  if matches:
193
  return json.loads(matches[0])
194
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ```
196
- More details can be found [here](https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B-Instruct/blob/main/instructions_function_calling.md)
197
 
198
  ## Limitations
199
 
 
192
  if matches:
193
  return json.loads(matches[0])
194
  return text
195
+
196
+
197
+ model_name_smollm = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
198
+ model = AutoModelForCausalLM.from_pretrained(model_name_smollm, device_map="auto", torch_dtype="auto", trust_remote_code=True)
199
+ tokenizer = AutoTokenizer.from_pretrained(model_name_smollm)
200
+
201
+ from datetime import datetime
202
+ import random
203
+
204
+ def get_current_time() -> str:
205
+ """Returns the current time in 24-hour format.
206
+
207
+ Returns:
208
+ str: Current time in HH:MM:SS format.
209
+ """
210
+ return datetime.now().strftime("%H:%M:%S")
211
+
212
+
213
+ def get_random_number_between(min: int, max: int) -> int:
214
+ """
215
+ Gets a random number between min and max.
216
+
217
+ Args:
218
+ min: The minimum number.
219
+ max: The maximum number.
220
+
221
+ Returns:
222
+ A random number between min and max.
223
+ """
224
+ return random.randint(min, max)
225
+
226
+
227
+ tools = [get_json_schema(get_random_number_between), get_json_schema(get_current_time)]
228
+
229
+ toolbox = {"get_random_number_between": get_random_number_between, "get_current_time": get_current_time}
230
+
231
+ query = "Give me a number between 1 and 300"
232
+
233
+ messages = prepare_messages(query, tools=tools)
234
+
235
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
236
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
237
+ result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
238
+
239
+ tool_calls = parse_response(result)
240
+ # [{'name': 'get_random_number_between', 'arguments': {'min': 1, 'max': 300}}
241
+
242
+ # Get tool responses
243
+ tool_responses = [toolbox.get(tc["name"])(*tc["arguments"].values()) for tc in tool_calls]
244
+ # [63]
245
+
246
+ # For the second turn, rebuild the history of messages:
247
+ history = messages.copy()
248
+ # Add the "parsed response"
249
+ history.append({"role": "assistant", "content": result})
250
+ query = "Can you give me the hour?"
251
+ history.append({"role": "user", "content": query})
252
+
253
+ inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, return_tensors="pt").to(model.device)
254
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
255
+ result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
256
+
257
+ tool_calls = parse_response(result)
258
+ tool_responses = [toolbox.get(tc["name"])(*tc["arguments"].values()) for tc in tool_calls]
259
+ # ['07:57:25']
260
  ```
261
+ More details such as parallel function calls and tools not available can be found [here](https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B-Instruct/blob/main/instructions_function_calling.md)
262
 
263
  ## Limitations
264