Haseeb javed
commited on
Commit
·
e29d967
1
Parent(s):
b0e34a1
finalized
Browse files- app.py +26 -48
- config.json +1 -1
- generation_config.json +1 -1
- special_tokens_map.json +1 -1
- tokenizer_config.json +0 -12
app.py
CHANGED
@@ -5,64 +5,42 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
5 |
import logging
|
6 |
import os
|
7 |
|
8 |
-
|
9 |
-
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
|
11 |
-
#
|
12 |
-
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
19 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=dtype).to(
|
20 |
-
"cuda" if torch.cuda.is_available() else "cpu"
|
21 |
-
)
|
22 |
-
logging.info("Model loaded successfully.")
|
23 |
-
except Exception as e:
|
24 |
-
logging.error("Failed to load the model or tokenizer.", exc_info=True)
|
25 |
-
raise e
|
26 |
|
27 |
-
# Flask app initialization
|
28 |
app = Flask(__name__)
|
29 |
CORS(app) # Enable CORS
|
30 |
|
|
|
|
|
31 |
def generate_response(prompt):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
inputs =
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
token = outputs.sequences[0, input_length:]
|
41 |
-
output_str = tokenizer.decode(token, skip_special_tokens=True)
|
42 |
-
logging.debug(f"Generated response: {output_str}")
|
43 |
-
return output_str
|
44 |
-
except Exception as e:
|
45 |
-
logging.error("Error during response generation", exc_info=True)
|
46 |
-
return "Sorry, I encountered an error while generating the response."
|
47 |
|
48 |
@app.route('/chat', methods=['POST'])
|
49 |
def chat():
|
50 |
-
"
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
user_input = data.get("message", "")
|
60 |
-
prompt = f"<human>: {user_input}\n<bot>:"
|
61 |
-
response = generate_response(prompt)
|
62 |
-
return jsonify({"response": response}), 200
|
63 |
-
except Exception as e:
|
64 |
-
logging.error("Error in /chat endpoint", exc_info=True)
|
65 |
-
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
66 |
|
67 |
if __name__ == "__main__":
|
68 |
# Get the port from environment variable or default to 5000
|
|
|
5 |
import logging
|
6 |
import os
|
7 |
|
8 |
+
MIN_TRANSFORMERS_VERSION = '4.25.1'
|
|
|
9 |
|
10 |
+
# Check transformers version
|
11 |
+
import transformers
|
12 |
+
assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
|
13 |
|
14 |
+
# Initialize tokenizer and model from local directory
|
15 |
+
model_dir = "./"
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
|
|
19 |
app = Flask(__name__)
|
20 |
CORS(app) # Enable CORS
|
21 |
|
22 |
+
logging.basicConfig(level=logging.DEBUG)
|
23 |
+
|
24 |
def generate_response(prompt):
|
25 |
+
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
|
26 |
+
input_length = inputs.input_ids.shape[1]
|
27 |
+
outputs = model.generate(
|
28 |
+
**inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
|
29 |
+
)
|
30 |
+
token = outputs.sequences[0, input_length:]
|
31 |
+
output_str = tokenizer.decode(token, skip_special_tokens=True)
|
32 |
+
return output_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
@app.route('/chat', methods=['POST'])
|
35 |
def chat():
|
36 |
+
logging.debug("Received a POST request")
|
37 |
+
data = request.json
|
38 |
+
logging.debug(f"Request data: {data}")
|
39 |
+
user_input = data.get("message", "")
|
40 |
+
prompt = f"<human>: {user_input}\n<bot>:"
|
41 |
+
response = generate_response(prompt)
|
42 |
+
logging.debug(f"Generated response: {response}")
|
43 |
+
return jsonify({"response": response})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
# Get the port from environment variable or default to 5000
|
config.json
CHANGED
@@ -22,4 +22,4 @@
|
|
22 |
"use_cache": true,
|
23 |
"use_parallel_residual": false,
|
24 |
"vocab_size": 50432
|
25 |
-
}
|
|
|
22 |
"use_cache": true,
|
23 |
"use_parallel_residual": false,
|
24 |
"vocab_size": 50432
|
25 |
+
}
|
generation_config.json
CHANGED
@@ -3,4 +3,4 @@
|
|
3 |
"bos_token_id": 0,
|
4 |
"eos_token_id": 0,
|
5 |
"transformers_version": "4.28.1"
|
6 |
-
}
|
|
|
3 |
"bos_token_id": 0,
|
4 |
"eos_token_id": 0,
|
5 |
"transformers_version": "4.28.1"
|
6 |
+
}
|
special_tokens_map.json
CHANGED
@@ -2,4 +2,4 @@
|
|
2 |
"bos_token": "<|endoftext|>",
|
3 |
"eos_token": "<|endoftext|>",
|
4 |
"unk_token": "<|endoftext|>"
|
5 |
-
}
|
|
|
2 |
"bos_token": "<|endoftext|>",
|
3 |
"eos_token": "<|endoftext|>",
|
4 |
"unk_token": "<|endoftext|>"
|
5 |
+
}
|
tokenizer_config.json
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
{
|
3 |
"add_prefix_space": false,
|
4 |
"bos_token": "<|endoftext|>",
|
@@ -8,14 +7,3 @@
|
|
8 |
"tokenizer_class": "GPTNeoXTokenizer",
|
9 |
"unk_token": "<|endoftext|>"
|
10 |
}
|
11 |
-
=======
|
12 |
-
{
|
13 |
-
"add_prefix_space": false,
|
14 |
-
"bos_token": "<|endoftext|>",
|
15 |
-
"clean_up_tokenization_spaces": true,
|
16 |
-
"eos_token": "<|endoftext|>",
|
17 |
-
"model_max_length": 2048,
|
18 |
-
"tokenizer_class": "GPTNeoXTokenizer",
|
19 |
-
"unk_token": "<|endoftext|>"
|
20 |
-
}
|
21 |
-
>>>>>>> 6ff0da104f5a2eb5ee298dc0164db0c0b16215e2
|
|
|
|
|
1 |
{
|
2 |
"add_prefix_space": false,
|
3 |
"bos_token": "<|endoftext|>",
|
|
|
7 |
"tokenizer_class": "GPTNeoXTokenizer",
|
8 |
"unk_token": "<|endoftext|>"
|
9 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|