custom-api / app.py
DataChem's picture
Update app.py
d847de1 verified
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
app = FastAPI()
# -------------------------------------------------------------------------
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
# We omit any 'use_auth_token' parameter.
# -------------------------------------------------------------------------
model_name = "Sao10K/L3-8B-Stheno-v3.2"
print(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Loading model from: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
)
# Choose device based on availability (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)
@app.post("/predict")
async def predict(request: Request):
"""
Endpoint for streaming responses from Falcon-7B-Instruct.
Expects JSON: { "prompt": "<Your prompt>" }
Returns a text/event-stream of tokens (SSE).
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
attention_mask = inputs.attention_mask # same shape
def token_generator():
nonlocal input_ids, attention_mask
# Basic generation hyperparameters
temperature = 0.7
top_p = 0.9
max_new_tokens = 30 # Increase if you want longer outputs
for _ in range(max_new_tokens):
with torch.no_grad():
# 1) Forward pass: compute logits for the next token
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# 2) Apply temperature scaling
next_token_logits = next_token_logits / temperature
# 3) Convert logits -> probabilities
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 4) Nucleus (top-p) sampling
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
valid_indices = cumulative_probs <= top_p
filtered_probs = sorted_probs[valid_indices]
filtered_indices = sorted_indices[valid_indices]
# 5) If no tokens remain after filtering, fall back to greedy
if len(filtered_probs) == 0:
next_token_id = torch.argmax(next_token_probs)
else:
sampled_id = torch.multinomial(filtered_probs, 1)
next_token_id = filtered_indices[sampled_id]
# 6) Ensure next_token_id has shape [batch_size, 1]
if next_token_id.dim() == 0:
# shape [] => [1]
next_token_id = next_token_id.unsqueeze(0)
# shape [1] => [1,1]
next_token_id = next_token_id.unsqueeze(-1)
# 7) Append the new token to input_ids
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# 8) Update the attention mask
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
# 9) Decode and yield the generated token
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
yield token + " "
# 10) Stop if EOS token is generated (if the model uses one)
if tokenizer.eos_token_id is not None:
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
break
# Return a StreamingResponse for SSE
return StreamingResponse(token_generator(), media_type="text/plain")