PyTorch
English
gpt_neox
Haseeb javed
finalized
e29d967
raw
history blame
1.76 kB
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
import os
MIN_TRANSFORMERS_VERSION = '4.25.1'
# Check transformers version
import transformers
assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
# Initialize tokenizer and model from local directory
model_dir = "./"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
app = Flask(__name__)
CORS(app) # Enable CORS
logging.basicConfig(level=logging.DEBUG)
def generate_response(prompt):
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
input_length = inputs.input_ids.shape[1]
outputs = model.generate(
**inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
)
token = outputs.sequences[0, input_length:]
output_str = tokenizer.decode(token, skip_special_tokens=True)
return output_str
@app.route('/chat', methods=['POST'])
def chat():
logging.debug("Received a POST request")
data = request.json
logging.debug(f"Request data: {data}")
user_input = data.get("message", "")
prompt = f"<human>: {user_input}\n<bot>:"
response = generate_response(prompt)
logging.debug(f"Generated response: {response}")
return jsonify({"response": response})
if __name__ == "__main__":
# Get the port from environment variable or default to 5000
port = int(os.getenv("PORT", 5000))
logging.info(f"Starting Flask app on port {port}")
app.run(debug=True, host="0.0.0.0", port=port)