from flask import Flask, request, jsonify from flask_cors import CORS import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging import os # Logging setup logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') # Hugging Face Model Hub Repository MODEL_REPO = "./" # Replace with your Hugging Face model repo name # Load tokenizer and model from Hugging Face Model Hub try: logging.info("Loading model and tokenizer from Hugging Face Model Hub...") tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=dtype).to( "cuda" if torch.cuda.is_available() else "cpu" ) logging.info("Model loaded successfully.") except Exception as e: logging.error("Failed to load the model or tokenizer.", exc_info=True) raise e # Flask app initialization app = Flask(__name__) CORS(app) # Enable CORS def generate_response(prompt): """Generate a response from the model given a prompt.""" try: logging.debug(f"Generating response for prompt: {prompt}") inputs = tokenizer(prompt, return_tensors='pt').to(model.device) input_length = inputs.input_ids.shape[1] outputs = model.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True ) token = outputs.sequences[0, input_length:] output_str = tokenizer.decode(token, skip_special_tokens=True) logging.debug(f"Generated response: {output_str}") return output_str except Exception as e: logging.error("Error during response generation", exc_info=True) return "Sorry, I encountered an error while generating the response." @app.route('/chat', methods=['POST']) def chat(): """Endpoint to handle chat requests.""" try: logging.debug("Received a POST request to /chat") data = request.json logging.debug(f"Request data: {data}") if not data or "message" not in data: return jsonify({"error": "Invalid request. 'message' field is required."}), 400 user_input = data.get("message", "") prompt = f": {user_input}\n:" response = generate_response(prompt) return jsonify({"response": response}), 200 except Exception as e: logging.error("Error in /chat endpoint", exc_info=True) return jsonify({"error": "Internal server error", "message": str(e)}), 500 if __name__ == "__main__": # Get the port from environment variable or default to 5000 port = int(os.getenv("PORT", 5000)) logging.info(f"Starting Flask app on port {port}") app.run(debug=True, host="0.0.0.0", port=port)