dammy commited on
Commit
822a50d
·
1 Parent(s): a4185e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -10
app.py CHANGED
@@ -10,11 +10,26 @@ import uuid
10
  from sentence_transformers import SentenceTransformer
11
  import os
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- model_name = 'google/flan-t5-base'
15
- model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- print('flan read')
 
 
 
 
 
18
 
19
 
20
  ST_name = 'sentence-transformers/sentence-t5-base'
@@ -34,17 +49,37 @@ def get_context(query_text):
34
  return context
35
 
36
  def local_query(query, context):
37
- t5query = """Using the available context, please answer the question.
 
 
 
 
 
 
 
 
 
 
 
 
38
  If you aren't sure please say i don't know.
39
  Context: {}
40
  Question: {}
41
  """.format(context, query)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- inputs = tokenizer(t5query, return_tensors="pt")
44
-
45
- outputs = model.generate(**inputs, max_new_tokens=20)
46
-
47
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
 
49
 
50
 
 
10
  from sentence_transformers import SentenceTransformer
11
  import os
12
 
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ import transformers
15
+ import torch
16
+
17
+ model_name = "tiiuae/falcon-40b-instruct"
18
+
19
+ # model_name = 'google/flan-t5-base'
20
+ # model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
21
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ # print('flan read')
23
 
24
+ tokenizer = AutoTokenizer.from_pretrained(model)
25
+ pipeline = transformers.pipeline(
26
+ "text-generation",
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ torch_dtype=torch.bfloat16,
30
+ trust_remote_code=True,
31
+ device_map="auto",
32
+ )
33
 
34
 
35
  ST_name = 'sentence-transformers/sentence-t5-base'
 
49
  return context
50
 
51
  def local_query(query, context):
52
+ # t5query = """Using the available context, please answer the question.
53
+ # If you aren't sure please say i don't know.
54
+ # Context: {}
55
+ # Question: {}
56
+ # """.format(context, query)
57
+
58
+ # inputs = tokenizer(t5query, return_tensors="pt")
59
+
60
+ # outputs = model.generate(**inputs, max_new_tokens=20)
61
+
62
+ # return tokenizer.batch_decode(outputs, skip_special_tokens=True)
63
+
64
+ context_query = """Using the available context, please answer the question.
65
  If you aren't sure please say i don't know.
66
  Context: {}
67
  Question: {}
68
  """.format(context, query)
69
+
70
+ sequences = pipeline(
71
+ context_query,
72
+ max_length=200,
73
+ do_sample=True,
74
+ top_k=10,
75
+ num_return_sequences=1,
76
+ eos_token_id=tokenizer.eos_token_id,
77
+ )
78
+
79
+ # for seq in sequences:
80
+ # print(f"Result: {seq['generated_text']}")
81
 
82
+ return seq['generated_text']
 
 
 
 
83
 
84
 
85