PyTorch
English
gpt_neox
Haseeb javed commited on
Commit
3b1d815
·
1 Parent(s): f13d1c4
Files changed (1) hide show
  1. app.py +0 -74
app.py CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
  import torch
@@ -70,76 +69,3 @@ if __name__ == "__main__":
70
  port = int(os.getenv("PORT", 5000))
71
  logging.info(f"Starting Flask app on port {port}")
72
  app.run(debug=True, host="0.0.0.0", port=port)
73
- =======
74
- from flask import Flask, request, jsonify
75
- from flask_cors import CORS
76
- import torch
77
- from transformers import AutoTokenizer, AutoModelForCausalLM
78
- import logging
79
- import os
80
-
81
- # Logging setup
82
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
83
-
84
- # Hugging Face Model Hub Repository
85
- MODEL_REPO = "./" # Replace with your Hugging Face model repo name
86
-
87
- # Load tokenizer and model from Hugging Face Model Hub
88
- try:
89
- logging.info("Loading model and tokenizer from Hugging Face Model Hub...")
90
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
91
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
92
- model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=dtype).to(
93
- "cuda" if torch.cuda.is_available() else "cpu"
94
- )
95
- logging.info("Model loaded successfully.")
96
- except Exception as e:
97
- logging.error("Failed to load the model or tokenizer.", exc_info=True)
98
- raise e
99
-
100
- # Flask app initialization
101
- app = Flask(__name__)
102
- CORS(app) # Enable CORS
103
-
104
- def generate_response(prompt):
105
- """Generate a response from the model given a prompt."""
106
- try:
107
- logging.debug(f"Generating response for prompt: {prompt}")
108
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
109
- input_length = inputs.input_ids.shape[1]
110
- outputs = model.generate(
111
- **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
112
- )
113
- token = outputs.sequences[0, input_length:]
114
- output_str = tokenizer.decode(token, skip_special_tokens=True)
115
- logging.debug(f"Generated response: {output_str}")
116
- return output_str
117
- except Exception as e:
118
- logging.error("Error during response generation", exc_info=True)
119
- return "Sorry, I encountered an error while generating the response."
120
-
121
- @app.route('/chat', methods=['POST'])
122
- def chat():
123
- """Endpoint to handle chat requests."""
124
- try:
125
- logging.debug("Received a POST request to /chat")
126
- data = request.json
127
- logging.debug(f"Request data: {data}")
128
-
129
- if not data or "message" not in data:
130
- return jsonify({"error": "Invalid request. 'message' field is required."}), 400
131
-
132
- user_input = data.get("message", "")
133
- prompt = f"<human>: {user_input}\n<bot>:"
134
- response = generate_response(prompt)
135
- return jsonify({"response": response}), 200
136
- except Exception as e:
137
- logging.error("Error in /chat endpoint", exc_info=True)
138
- return jsonify({"error": "Internal server error", "message": str(e)}), 500
139
-
140
- if __name__ == "__main__":
141
- # Get the port from environment variable or default to 5000
142
- port = int(os.getenv("PORT", 5000))
143
- logging.info(f"Starting Flask app on port {port}")
144
- app.run(debug=True, host="0.0.0.0", port=port)
145
- >>>>>>> 6ff0da104f5a2eb5ee298dc0164db0c0b16215e2
 
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
  import torch
 
69
  port = int(os.getenv("PORT", 5000))
70
  logging.info(f"Starting Flask app on port {port}")
71
  app.run(debug=True, host="0.0.0.0", port=port)