DataChem commited on
Commit
3895f1c
·
verified ·
1 Parent(s): d638752

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -27
app.py CHANGED
@@ -6,84 +6,122 @@ import torch.nn.functional as F
6
 
7
  app = FastAPI()
8
 
9
- # Load the model and tokenizer
10
- model_name = "EleutherAI/gpt-neo-1.3B" # Replace with your desired model
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name)
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  model.to(device)
16
 
17
  @app.post("/predict")
18
  async def predict(request: Request):
 
 
 
 
 
19
  data = await request.json()
20
  prompt = data.get("prompt", "")
21
  if not prompt:
22
  return {"error": "Prompt is required"}
23
 
24
- # Initial tokenization on the prompt
25
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
- input_ids = inputs.input_ids # Shape: [batch_size, seq_len], often [1, seq_len]
27
- attention_mask = inputs.attention_mask # Same shape as input_ids
28
 
29
  def token_generator():
 
 
 
30
  nonlocal input_ids, attention_mask
31
 
32
- # Generation hyperparameters
33
  temperature = 0.7
34
  top_p = 0.9
35
- max_new_tokens = 30
36
 
37
  for _ in range(max_new_tokens):
38
  with torch.no_grad():
39
- # Forward pass: compute logits for the last token
40
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
41
  next_token_logits = outputs.logits[:, -1, :]
42
 
43
- # Apply temperature
44
  next_token_logits = next_token_logits / temperature
45
 
46
- # Convert logits -> probabilities
47
  next_token_probs = F.softmax(next_token_logits, dim=-1)
48
 
49
- # Apply top-p (nucleus) sampling
50
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
51
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
52
  valid_indices = cumulative_probs <= top_p
53
  filtered_probs = sorted_probs[valid_indices]
54
  filtered_indices = sorted_indices[valid_indices]
55
 
 
56
  if len(filtered_probs) == 0:
57
- # Fallback to greedy if nothing meets top_p
58
  next_token_id = torch.argmax(next_token_probs)
59
  else:
60
- # Sample a token from the filtered distribution
61
  sampled_id = torch.multinomial(filtered_probs, 1)
62
  next_token_id = filtered_indices[sampled_id]
63
 
64
- # At this point, next_token_id might be shape [] (scalar) or [1].
65
- # We need [batch_size, 1], so if it's just a scalar, unsqueeze(0).
66
  if next_token_id.dim() == 0:
67
- next_token_id = next_token_id.unsqueeze(0) # shape [1]
68
- next_token_id = next_token_id.unsqueeze(-1) # shape [1,1]
 
 
69
 
70
- # Append the new token to input_ids
71
- # input_ids: [1, seq_len], next_token_id: [1,1] => final shape [1, seq_len+1]
72
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
73
 
74
- # Also update the attention mask so the model attends to the new token
75
- # shape: [1, seq_len+1]
76
  new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
77
  attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
78
 
79
- # Decode and yield the token for streaming
80
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
81
  yield token + " "
82
 
83
- # Stop if we hit the EOS token
84
  if tokenizer.eos_token_id is not None:
85
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
86
  break
87
 
88
- # Return the streaming response
89
  return StreamingResponse(token_generator(), media_type="text/plain")
 
6
 
7
  app = FastAPI()
8
 
9
+ # -------------------------------------------------------------------------
10
+ # Update this to the Llama 2 Chat model you prefer. This example uses the
11
+ # 7B chat version. For larger models (13B, 70B), ensure you have enough RAM.
12
+ # -------------------------------------------------------------------------
13
+ model_name = "meta-llama/Llama-2-7b-chat-hf"
14
+
15
+ # -------------------------------------------------------------------------
16
+ # If the repo is gated, you may need:
17
+ # use_auth_token="YOUR_HF_TOKEN",
18
+ # trust_remote_code=True,
19
+ # or you can set environment variables in your HF Space to authenticate.
20
+ # -------------------------------------------------------------------------
21
+ print(f"Loading model/tokenizer from: {model_name}")
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_name,
24
+ trust_remote_code=True
25
+ # use_auth_token="YOUR_HF_TOKEN", # If needed for private/gated model
26
+ )
27
+
28
+ # -------------------------------------------------------------------------
29
+ # If you had GPU available, you might do:
30
+ # model = AutoModelForCausalLM.from_pretrained(
31
+ # model_name,
32
+ # torch_dtype=torch.float16,
33
+ # device_map="auto",
34
+ # trust_remote_code=True
35
+ # )
36
+ # But for CPU, we do a simpler load:
37
+ # -------------------------------------------------------------------------
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ trust_remote_code=True
41
+ # use_auth_token="YOUR_HF_TOKEN", # If needed
42
+ )
43
+
44
+ # Choose device based on availability
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(f"Using device: {device}")
47
  model.to(device)
48
 
49
  @app.post("/predict")
50
  async def predict(request: Request):
51
+ """
52
+ Endpoint for streaming responses from the Llama 2 chat model.
53
+ Expects JSON: { "prompt": "<Your prompt>" }
54
+ Returns a text/event-stream of tokens.
55
+ """
56
  data = await request.json()
57
  prompt = data.get("prompt", "")
58
  if not prompt:
59
  return {"error": "Prompt is required"}
60
 
61
+ # Tokenize the input prompt
62
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
63
+ input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
64
+ attention_mask = inputs.attention_mask # same shape
65
 
66
  def token_generator():
67
+ """
68
+ A generator that yields tokens one by one for SSE streaming.
69
+ """
70
  nonlocal input_ids, attention_mask
71
 
72
+ # Basic generation hyperparameters
73
  temperature = 0.7
74
  top_p = 0.9
75
+ max_new_tokens = 30 # Increase for longer outputs
76
 
77
  for _ in range(max_new_tokens):
78
  with torch.no_grad():
79
+ # 1) Forward pass: compute logits for next token
80
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
81
  next_token_logits = outputs.logits[:, -1, :]
82
 
83
+ # 2) Apply temperature scaling
84
  next_token_logits = next_token_logits / temperature
85
 
86
+ # 3) Convert logits -> probabilities
87
  next_token_probs = F.softmax(next_token_logits, dim=-1)
88
 
89
+ # 4) Nucleus (top-p) sampling
90
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
91
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
92
  valid_indices = cumulative_probs <= top_p
93
  filtered_probs = sorted_probs[valid_indices]
94
  filtered_indices = sorted_indices[valid_indices]
95
 
96
+ # 5) If no tokens are valid under top_p, fallback to greedy
97
  if len(filtered_probs) == 0:
 
98
  next_token_id = torch.argmax(next_token_probs)
99
  else:
 
100
  sampled_id = torch.multinomial(filtered_probs, 1)
101
  next_token_id = filtered_indices[sampled_id]
102
 
103
+ # 6) Ensure next_token_id has shape [batch_size, 1]
 
104
  if next_token_id.dim() == 0:
105
+ # shape [] => [1]
106
+ next_token_id = next_token_id.unsqueeze(0)
107
+ # shape [1] => [1,1]
108
+ next_token_id = next_token_id.unsqueeze(-1)
109
 
110
+ # 7) Append token to input_ids
 
111
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
112
 
113
+ # 8) Update attention_mask for the new token
 
114
  new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
115
  attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
116
 
117
+ # 9) Decode and yield
118
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
119
  yield token + " "
120
 
121
+ # 10) Stop if we encounter EOS
122
  if tokenizer.eos_token_id is not None:
123
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
124
  break
125
 
126
+ # Return a StreamingResponse for SSE
127
  return StreamingResponse(token_generator(), media_type="text/plain")