|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
import os |
|
|
|
MIN_TRANSFORMERS_VERSION = '4.25.1' |
|
|
|
|
|
import transformers |
|
assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.' |
|
|
|
|
|
model_dir = "./" |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16) |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
def generate_response(prompt): |
|
inputs = tokenizer(prompt, return_tensors='pt').to(model.device) |
|
input_length = inputs.input_ids.shape[1] |
|
outputs = model.generate( |
|
**inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True |
|
) |
|
token = outputs.sequences[0, input_length:] |
|
output_str = tokenizer.decode(token, skip_special_tokens=True) |
|
return output_str |
|
|
|
@app.route('/chat', methods=['POST']) |
|
def chat(): |
|
logging.debug("Received a POST request") |
|
data = request.json |
|
logging.debug(f"Request data: {data}") |
|
user_input = data.get("message", "") |
|
prompt = f"<human>: {user_input}\n<bot>:" |
|
response = generate_response(prompt) |
|
logging.debug(f"Generated response: {response}") |
|
return jsonify({"response": response}) |
|
|
|
if __name__ == "__main__": |
|
|
|
port = int(os.getenv("PORT", 5000)) |
|
logging.info(f"Starting Flask app on port {port}") |
|
app.run(debug=True, host="0.0.0.0", port=port) |
|
|