PyTorch
English
gpt_neox
Haseeb javed commited on
Commit
9dc520a
·
1 Parent(s): daa6ec2
Files changed (3) hide show
  1. app.py +0 -49
  2. chat-bot +1 -0
  3. requirements.txt +0 -4
app.py DELETED
@@ -1,49 +0,0 @@
1
- from flask import Flask, request, jsonify
2
- from flask_cors import CORS
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import logging
6
- import os
7
-
8
- MIN_TRANSFORMERS_VERSION = '4.25.1'
9
-
10
- # Check transformers version
11
- import transformers
12
- assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
13
-
14
- # Initialize tokenizer and model from local directory
15
- model_dir = "./"
16
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
18
-
19
- app = Flask(__name__)
20
- CORS(app) # Enable CORS
21
-
22
- logging.basicConfig(level=logging.DEBUG)
23
-
24
- def generate_response(prompt):
25
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
26
- input_length = inputs.input_ids.shape[1]
27
- outputs = model.generate(
28
- **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
29
- )
30
- token = outputs.sequences[0, input_length:]
31
- output_str = tokenizer.decode(token, skip_special_tokens=True)
32
- return output_str
33
-
34
- @app.route('/chat', methods=['POST'])
35
- def chat():
36
- logging.debug("Received a POST request")
37
- data = request.json
38
- logging.debug(f"Request data: {data}")
39
- user_input = data.get("message", "")
40
- prompt = f"<human>: {user_input}\n<bot>:"
41
- response = generate_response(prompt)
42
- logging.debug(f"Generated response: {response}")
43
- return jsonify({"response": response})
44
-
45
- if __name__ == "__main__":
46
- # Get the port from environment variable or default to 5000
47
- port = int(os.getenv("PORT", 5000))
48
- logging.info(f"Starting Flask app on port {port}")
49
- app.run(debug=True, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chat-bot ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1f613fd4bb9a83fc7f0f03e34db07691642b77f0
requirements.txt DELETED
@@ -1,4 +0,0 @@
1
- flask
2
- flask-cors
3
- torch
4
- transformers