Spaces:
Paused
Paused
import os | |
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
app = FastAPI() | |
# ------------------------------------------------------------------------- | |
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token. | |
# We omit any 'use_auth_token' parameter. | |
# ------------------------------------------------------------------------- | |
model_name = "Sao10K/L3-8B-Stheno-v3.2" | |
print(f"Loading tokenizer from: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
print(f"Loading model from: {model_name}") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Choose device based on availability (CPU or GPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
model.to(device) | |
async def predict(request: Request): | |
""" | |
Endpoint for streaming responses from Falcon-7B-Instruct. | |
Expects JSON: { "prompt": "<Your prompt>" } | |
Returns a text/event-stream of tokens (SSE). | |
""" | |
data = await request.json() | |
prompt = data.get("prompt", "") | |
if not prompt: | |
return {"error": "Prompt is required"} | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len] | |
attention_mask = inputs.attention_mask # same shape | |
def token_generator(): | |
nonlocal input_ids, attention_mask | |
# Basic generation hyperparameters | |
temperature = 0.7 | |
top_p = 0.9 | |
max_new_tokens = 30 # Increase if you want longer outputs | |
for _ in range(max_new_tokens): | |
with torch.no_grad(): | |
# 1) Forward pass: compute logits for the next token | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
next_token_logits = outputs.logits[:, -1, :] | |
# 2) Apply temperature scaling | |
next_token_logits = next_token_logits / temperature | |
# 3) Convert logits -> probabilities | |
next_token_probs = F.softmax(next_token_logits, dim=-1) | |
# 4) Nucleus (top-p) sampling | |
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
valid_indices = cumulative_probs <= top_p | |
filtered_probs = sorted_probs[valid_indices] | |
filtered_indices = sorted_indices[valid_indices] | |
# 5) If no tokens remain after filtering, fall back to greedy | |
if len(filtered_probs) == 0: | |
next_token_id = torch.argmax(next_token_probs) | |
else: | |
sampled_id = torch.multinomial(filtered_probs, 1) | |
next_token_id = filtered_indices[sampled_id] | |
# 6) Ensure next_token_id has shape [batch_size, 1] | |
if next_token_id.dim() == 0: | |
# shape [] => [1] | |
next_token_id = next_token_id.unsqueeze(0) | |
# shape [1] => [1,1] | |
next_token_id = next_token_id.unsqueeze(-1) | |
# 7) Append the new token to input_ids | |
input_ids = torch.cat([input_ids, next_token_id], dim=-1) | |
# 8) Update the attention mask | |
new_mask = attention_mask.new_ones((attention_mask.size(0), 1)) | |
attention_mask = torch.cat([attention_mask, new_mask], dim=-1) | |
# 9) Decode and yield the generated token | |
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True) | |
yield token + " " | |
# 10) Stop if EOS token is generated (if the model uses one) | |
if tokenizer.eos_token_id is not None: | |
if next_token_id.squeeze().item() == tokenizer.eos_token_id: | |
break | |
# Return a StreamingResponse for SSE | |
return StreamingResponse(token_generator(), media_type="text/plain") | |