PyTorch
English
gpt_neox
Haseeb javed commited on
Commit
e29d967
·
1 Parent(s): b0e34a1
app.py CHANGED
@@ -5,64 +5,42 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import logging
6
  import os
7
 
8
- # Logging setup
9
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
- # Hugging Face Model Hub Repository
12
- MODEL_REPO = "./" # Replace with your Hugging Face model repo name
 
13
 
14
- # Load tokenizer and model from Hugging Face Model Hub
15
- try:
16
- logging.info("Loading model and tokenizer from Hugging Face Model Hub...")
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
18
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
19
- model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=dtype).to(
20
- "cuda" if torch.cuda.is_available() else "cpu"
21
- )
22
- logging.info("Model loaded successfully.")
23
- except Exception as e:
24
- logging.error("Failed to load the model or tokenizer.", exc_info=True)
25
- raise e
26
 
27
- # Flask app initialization
28
  app = Flask(__name__)
29
  CORS(app) # Enable CORS
30
 
 
 
31
  def generate_response(prompt):
32
- """Generate a response from the model given a prompt."""
33
- try:
34
- logging.debug(f"Generating response for prompt: {prompt}")
35
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
36
- input_length = inputs.input_ids.shape[1]
37
- outputs = model.generate(
38
- **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
39
- )
40
- token = outputs.sequences[0, input_length:]
41
- output_str = tokenizer.decode(token, skip_special_tokens=True)
42
- logging.debug(f"Generated response: {output_str}")
43
- return output_str
44
- except Exception as e:
45
- logging.error("Error during response generation", exc_info=True)
46
- return "Sorry, I encountered an error while generating the response."
47
 
48
  @app.route('/chat', methods=['POST'])
49
  def chat():
50
- """Endpoint to handle chat requests."""
51
- try:
52
- logging.debug("Received a POST request to /chat")
53
- data = request.json
54
- logging.debug(f"Request data: {data}")
55
-
56
- if not data or "message" not in data:
57
- return jsonify({"error": "Invalid request. 'message' field is required."}), 400
58
-
59
- user_input = data.get("message", "")
60
- prompt = f"<human>: {user_input}\n<bot>:"
61
- response = generate_response(prompt)
62
- return jsonify({"response": response}), 200
63
- except Exception as e:
64
- logging.error("Error in /chat endpoint", exc_info=True)
65
- return jsonify({"error": "Internal server error", "message": str(e)}), 500
66
 
67
  if __name__ == "__main__":
68
  # Get the port from environment variable or default to 5000
 
5
  import logging
6
  import os
7
 
8
+ MIN_TRANSFORMERS_VERSION = '4.25.1'
 
9
 
10
+ # Check transformers version
11
+ import transformers
12
+ assert transformers.__version__ >= MIN_TRANSFORMERS_VERSION, f'Please upgrade transformers to version {MIN_TRANSFORMERS_VERSION} or higher.'
13
 
14
+ # Initialize tokenizer and model from local directory
15
+ model_dir = "./"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
18
 
 
19
  app = Flask(__name__)
20
  CORS(app) # Enable CORS
21
 
22
+ logging.basicConfig(level=logging.DEBUG)
23
+
24
  def generate_response(prompt):
25
+ inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
26
+ input_length = inputs.input_ids.shape[1]
27
+ outputs = model.generate(
28
+ **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.7, top_k=50, return_dict_in_generate=True
29
+ )
30
+ token = outputs.sequences[0, input_length:]
31
+ output_str = tokenizer.decode(token, skip_special_tokens=True)
32
+ return output_str
 
 
 
 
 
 
 
33
 
34
  @app.route('/chat', methods=['POST'])
35
  def chat():
36
+ logging.debug("Received a POST request")
37
+ data = request.json
38
+ logging.debug(f"Request data: {data}")
39
+ user_input = data.get("message", "")
40
+ prompt = f"<human>: {user_input}\n<bot>:"
41
+ response = generate_response(prompt)
42
+ logging.debug(f"Generated response: {response}")
43
+ return jsonify({"response": response})
 
 
 
 
 
 
 
 
44
 
45
  if __name__ == "__main__":
46
  # Get the port from environment variable or default to 5000
config.json CHANGED
@@ -22,4 +22,4 @@
22
  "use_cache": true,
23
  "use_parallel_residual": false,
24
  "vocab_size": 50432
25
- }
 
22
  "use_cache": true,
23
  "use_parallel_residual": false,
24
  "vocab_size": 50432
25
+ }
generation_config.json CHANGED
@@ -3,4 +3,4 @@
3
  "bos_token_id": 0,
4
  "eos_token_id": 0,
5
  "transformers_version": "4.28.1"
6
- }
 
3
  "bos_token_id": 0,
4
  "eos_token_id": 0,
5
  "transformers_version": "4.28.1"
6
+ }
special_tokens_map.json CHANGED
@@ -2,4 +2,4 @@
2
  "bos_token": "<|endoftext|>",
3
  "eos_token": "<|endoftext|>",
4
  "unk_token": "<|endoftext|>"
5
- }
 
2
  "bos_token": "<|endoftext|>",
3
  "eos_token": "<|endoftext|>",
4
  "unk_token": "<|endoftext|>"
5
+ }
tokenizer_config.json CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  {
3
  "add_prefix_space": false,
4
  "bos_token": "<|endoftext|>",
@@ -8,14 +7,3 @@
8
  "tokenizer_class": "GPTNeoXTokenizer",
9
  "unk_token": "<|endoftext|>"
10
  }
11
- =======
12
- {
13
- "add_prefix_space": false,
14
- "bos_token": "<|endoftext|>",
15
- "clean_up_tokenization_spaces": true,
16
- "eos_token": "<|endoftext|>",
17
- "model_max_length": 2048,
18
- "tokenizer_class": "GPTNeoXTokenizer",
19
- "unk_token": "<|endoftext|>"
20
- }
21
- >>>>>>> 6ff0da104f5a2eb5ee298dc0164db0c0b16215e2
 
 
1
  {
2
  "add_prefix_space": false,
3
  "bos_token": "<|endoftext|>",
 
7
  "tokenizer_class": "GPTNeoXTokenizer",
8
  "unk_token": "<|endoftext|>"
9
  }