import os from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F app = FastAPI() # ------------------------------------------------------------------------- # Since Falcon 7B Instruct is not gated, you do NOT need an HF token. # We omit any 'use_auth_token' parameter. # ------------------------------------------------------------------------- model_name = "Sao10K/L3-8B-Stheno-v3.2" print(f"Loading tokenizer from: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) print(f"Loading model from: {model_name}") model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True ) # Choose device based on availability (CPU or GPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model.to(device) @app.post("/predict") async def predict(request: Request): """ Endpoint for streaming responses from Falcon-7B-Instruct. Expects JSON: { "prompt": "" } Returns a text/event-stream of tokens (SSE). """ data = await request.json() prompt = data.get("prompt", "") if not prompt: return {"error": "Prompt is required"} # Tokenize the input prompt inputs = tokenizer(prompt, return_tensors="pt").to(device) input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len] attention_mask = inputs.attention_mask # same shape def token_generator(): nonlocal input_ids, attention_mask # Basic generation hyperparameters temperature = 0.7 top_p = 0.9 max_new_tokens = 30 # Increase if you want longer outputs for _ in range(max_new_tokens): with torch.no_grad(): # 1) Forward pass: compute logits for the next token outputs = model(input_ids=input_ids, attention_mask=attention_mask) next_token_logits = outputs.logits[:, -1, :] # 2) Apply temperature scaling next_token_logits = next_token_logits / temperature # 3) Convert logits -> probabilities next_token_probs = F.softmax(next_token_logits, dim=-1) # 4) Nucleus (top-p) sampling sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) valid_indices = cumulative_probs <= top_p filtered_probs = sorted_probs[valid_indices] filtered_indices = sorted_indices[valid_indices] # 5) If no tokens remain after filtering, fall back to greedy if len(filtered_probs) == 0: next_token_id = torch.argmax(next_token_probs) else: sampled_id = torch.multinomial(filtered_probs, 1) next_token_id = filtered_indices[sampled_id] # 6) Ensure next_token_id has shape [batch_size, 1] if next_token_id.dim() == 0: # shape [] => [1] next_token_id = next_token_id.unsqueeze(0) # shape [1] => [1,1] next_token_id = next_token_id.unsqueeze(-1) # 7) Append the new token to input_ids input_ids = torch.cat([input_ids, next_token_id], dim=-1) # 8) Update the attention mask new_mask = attention_mask.new_ones((attention_mask.size(0), 1)) attention_mask = torch.cat([attention_mask, new_mask], dim=-1) # 9) Decode and yield the generated token token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True) yield token + " " # 10) Stop if EOS token is generated (if the model uses one) if tokenizer.eos_token_id is not None: if next_token_id.squeeze().item() == tokenizer.eos_token_id: break # Return a StreamingResponse for SSE return StreamingResponse(token_generator(), media_type="text/plain")