Spaces:
Paused
Paused
File size: 4,269 Bytes
977cc0a 4468cfe 45123df 5102dda 4468cfe 5102dda e05b36f 3895f1c 1f3e16d 3895f1c d847de1 3895f1c 1f3e16d 3895f1c 1f3e16d 3895f1c 1f3e16d 3895f1c 1f3e16d 3895f1c 1f3e16d 5102dda 3895f1c 5102dda 4468cfe 3895f1c 1f3e16d 3895f1c 1f3e16d 3895f1c 4468cfe 3895f1c 5102dda 3895f1c 4468cfe 45123df d638752 8194424 3895f1c 5102dda 1f3e16d 5102dda 8194424 1f3e16d 74b564f 5102dda 3895f1c 74b564f 8194424 3895f1c 74b564f 5102dda 3895f1c 74b564f 8194424 1f3e16d 8194424 d638752 74b564f d638752 3895f1c d638752 3895f1c d638752 1f3e16d 8194424 5102dda 1f3e16d d638752 1f3e16d 74b564f 5102dda 1f3e16d 8194424 5102dda 3895f1c 45123df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
app = FastAPI()
# -------------------------------------------------------------------------
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
# We omit any 'use_auth_token' parameter.
# -------------------------------------------------------------------------
model_name = "Sao10K/L3-8B-Stheno-v3.2"
print(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Loading model from: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
)
# Choose device based on availability (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)
@app.post("/predict")
async def predict(request: Request):
"""
Endpoint for streaming responses from Falcon-7B-Instruct.
Expects JSON: { "prompt": "<Your prompt>" }
Returns a text/event-stream of tokens (SSE).
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
attention_mask = inputs.attention_mask # same shape
def token_generator():
nonlocal input_ids, attention_mask
# Basic generation hyperparameters
temperature = 0.7
top_p = 0.9
max_new_tokens = 30 # Increase if you want longer outputs
for _ in range(max_new_tokens):
with torch.no_grad():
# 1) Forward pass: compute logits for the next token
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# 2) Apply temperature scaling
next_token_logits = next_token_logits / temperature
# 3) Convert logits -> probabilities
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 4) Nucleus (top-p) sampling
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
valid_indices = cumulative_probs <= top_p
filtered_probs = sorted_probs[valid_indices]
filtered_indices = sorted_indices[valid_indices]
# 5) If no tokens remain after filtering, fall back to greedy
if len(filtered_probs) == 0:
next_token_id = torch.argmax(next_token_probs)
else:
sampled_id = torch.multinomial(filtered_probs, 1)
next_token_id = filtered_indices[sampled_id]
# 6) Ensure next_token_id has shape [batch_size, 1]
if next_token_id.dim() == 0:
# shape [] => [1]
next_token_id = next_token_id.unsqueeze(0)
# shape [1] => [1,1]
next_token_id = next_token_id.unsqueeze(-1)
# 7) Append the new token to input_ids
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# 8) Update the attention mask
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
# 9) Decode and yield the generated token
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
yield token + " "
# 10) Stop if EOS token is generated (if the model uses one)
if tokenizer.eos_token_id is not None:
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
break
# Return a StreamingResponse for SSE
return StreamingResponse(token_generator(), media_type="text/plain")
|