DataChem commited on
Commit
1f3e16d
·
verified ·
1 Parent(s): 977cc0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -44
app.py CHANGED
@@ -7,49 +7,25 @@ import torch.nn.functional as F
7
 
8
  app = FastAPI()
9
 
10
- # Retrieve the token from environment variable
11
- hf_token = os.environ.get("HF_AUTH_TOKEN", None)
12
- if hf_token is None:
13
- print("WARNING: No HF_AUTH_TOKEN found in environment. "
14
- "Make sure to set a Hugging Face token if the model is gated.")
15
-
16
-
17
  # -------------------------------------------------------------------------
18
- # Update this to the Llama 2 Chat model you prefer. This example uses the
19
- # 7B chat version. For larger models (13B, 70B), ensure you have enough RAM.
20
  # -------------------------------------------------------------------------
21
- model_name = "meta-llama/Llama-2-7b-chat-hf"
22
 
23
- # -------------------------------------------------------------------------
24
- # If the repo is gated, you may need:
25
- # use_auth_token="YOUR_HF_TOKEN",
26
- # trust_remote_code=True,
27
- # or you can set environment variables in your HF Space to authenticate.
28
- # -------------------------------------------------------------------------
29
- print(f"Loading model/tokenizer from: {model_name}")
30
  tokenizer = AutoTokenizer.from_pretrained(
31
  model_name,
32
- trust_remote_code=True,
33
- use_auth_token=hf_token,
34
  )
35
 
36
- # -------------------------------------------------------------------------
37
- # If you had GPU available, you might do:
38
- # model = AutoModelForCausalLM.from_pretrained(
39
- # model_name,
40
- # torch_dtype=torch.float16,
41
- # device_map="auto",
42
- # trust_remote_code=True
43
- # )
44
- # But for CPU, we do a simpler load:
45
- # -------------------------------------------------------------------------
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_name,
48
- trust_remote_code=True,
49
- use_auth_token=hf_token,
50
  )
51
 
52
- # Choose device based on availability
53
  device = "cuda" if torch.cuda.is_available() else "cpu"
54
  print(f"Using device: {device}")
55
  model.to(device)
@@ -57,9 +33,9 @@ model.to(device)
57
  @app.post("/predict")
58
  async def predict(request: Request):
59
  """
60
- Endpoint for streaming responses from the Llama 2 chat model.
61
  Expects JSON: { "prompt": "<Your prompt>" }
62
- Returns a text/event-stream of tokens.
63
  """
64
  data = await request.json()
65
  prompt = data.get("prompt", "")
@@ -72,19 +48,16 @@ async def predict(request: Request):
72
  attention_mask = inputs.attention_mask # same shape
73
 
74
  def token_generator():
75
- """
76
- A generator that yields tokens one by one for SSE streaming.
77
- """
78
  nonlocal input_ids, attention_mask
79
 
80
  # Basic generation hyperparameters
81
  temperature = 0.7
82
  top_p = 0.9
83
- max_new_tokens = 30 # Increase for longer outputs
84
 
85
  for _ in range(max_new_tokens):
86
  with torch.no_grad():
87
- # 1) Forward pass: compute logits for next token
88
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
89
  next_token_logits = outputs.logits[:, -1, :]
90
 
@@ -101,7 +74,7 @@ async def predict(request: Request):
101
  filtered_probs = sorted_probs[valid_indices]
102
  filtered_indices = sorted_indices[valid_indices]
103
 
104
- # 5) If no tokens are valid under top_p, fallback to greedy
105
  if len(filtered_probs) == 0:
106
  next_token_id = torch.argmax(next_token_probs)
107
  else:
@@ -115,18 +88,18 @@ async def predict(request: Request):
115
  # shape [1] => [1,1]
116
  next_token_id = next_token_id.unsqueeze(-1)
117
 
118
- # 7) Append token to input_ids
119
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
120
 
121
- # 8) Update attention_mask for the new token
122
  new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
123
  attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
124
 
125
- # 9) Decode and yield
126
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
127
  yield token + " "
128
 
129
- # 10) Stop if we encounter EOS
130
  if tokenizer.eos_token_id is not None:
131
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
132
  break
 
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
10
  # -------------------------------------------------------------------------
11
+ # Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
12
+ # We omit any 'use_auth_token' parameter.
13
  # -------------------------------------------------------------------------
14
+ model_name = "tiiuae/falcon-7b-instruct"
15
 
16
+ print(f"Loading tokenizer from: {model_name}")
 
 
 
 
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(
18
  model_name,
19
+ trust_remote_code=True
 
20
  )
21
 
22
+ print(f"Loading model from: {model_name}")
 
 
 
 
 
 
 
 
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_name,
25
+ trust_remote_code=True
 
26
  )
27
 
28
+ # Choose device based on availability (CPU or GPU)
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print(f"Using device: {device}")
31
  model.to(device)
 
33
  @app.post("/predict")
34
  async def predict(request: Request):
35
  """
36
+ Endpoint for streaming responses from Falcon-7B-Instruct.
37
  Expects JSON: { "prompt": "<Your prompt>" }
38
+ Returns a text/event-stream of tokens (SSE).
39
  """
40
  data = await request.json()
41
  prompt = data.get("prompt", "")
 
48
  attention_mask = inputs.attention_mask # same shape
49
 
50
  def token_generator():
 
 
 
51
  nonlocal input_ids, attention_mask
52
 
53
  # Basic generation hyperparameters
54
  temperature = 0.7
55
  top_p = 0.9
56
+ max_new_tokens = 30 # Increase if you want longer outputs
57
 
58
  for _ in range(max_new_tokens):
59
  with torch.no_grad():
60
+ # 1) Forward pass: compute logits for the next token
61
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
62
  next_token_logits = outputs.logits[:, -1, :]
63
 
 
74
  filtered_probs = sorted_probs[valid_indices]
75
  filtered_indices = sorted_indices[valid_indices]
76
 
77
+ # 5) If no tokens remain after filtering, fall back to greedy
78
  if len(filtered_probs) == 0:
79
  next_token_id = torch.argmax(next_token_probs)
80
  else:
 
88
  # shape [1] => [1,1]
89
  next_token_id = next_token_id.unsqueeze(-1)
90
 
91
+ # 7) Append the new token to input_ids
92
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
93
 
94
+ # 8) Update the attention mask
95
  new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
96
  attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
97
 
98
+ # 9) Decode and yield the generated token
99
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
100
  yield token + " "
101
 
102
+ # 10) Stop if EOS token is generated (if the model uses one)
103
  if tokenizer.eos_token_id is not None:
104
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
105
  break