File size: 4,269 Bytes
977cc0a
4468cfe
45123df
5102dda
4468cfe
5102dda
e05b36f
 
 
3895f1c
1f3e16d
 
3895f1c
d847de1
3895f1c
1f3e16d
3895f1c
 
1f3e16d
3895f1c
 
1f3e16d
3895f1c
 
1f3e16d
3895f1c
 
1f3e16d
5102dda
3895f1c
5102dda
4468cfe
 
 
3895f1c
1f3e16d
3895f1c
1f3e16d
3895f1c
4468cfe
 
 
 
 
3895f1c
5102dda
3895f1c
 
4468cfe
45123df
d638752
8194424
3895f1c
5102dda
 
1f3e16d
5102dda
8194424
 
1f3e16d
74b564f
 
5102dda
3895f1c
74b564f
8194424
3895f1c
74b564f
5102dda
3895f1c
74b564f
 
8194424
 
 
 
1f3e16d
8194424
d638752
74b564f
d638752
 
 
3895f1c
d638752
3895f1c
 
 
 
d638752
1f3e16d
8194424
5102dda
1f3e16d
d638752
 
 
1f3e16d
74b564f
 
5102dda
1f3e16d
8194424
 
 
5102dda
3895f1c
45123df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

app = FastAPI()

# -------------------------------------------------------------------------
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
# We omit any 'use_auth_token' parameter.
# -------------------------------------------------------------------------
model_name = "Sao10K/L3-8B-Stheno-v3.2"

print(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

print(f"Loading model from: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True
)

# Choose device based on availability (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)

@app.post("/predict")
async def predict(request: Request):
    """
    Endpoint for streaming responses from Falcon-7B-Instruct.
    Expects JSON: { "prompt": "<Your prompt>" }
    Returns a text/event-stream of tokens (SSE).
    """
    data = await request.json()
    prompt = data.get("prompt", "")
    if not prompt:
        return {"error": "Prompt is required"}

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids             # shape: [batch_size, seq_len], typically [1, seq_len]
    attention_mask = inputs.attention_mask   # same shape

    def token_generator():
        nonlocal input_ids, attention_mask

        # Basic generation hyperparameters
        temperature = 0.7
        top_p = 0.9
        max_new_tokens = 30  # Increase if you want longer outputs

        for _ in range(max_new_tokens):
            with torch.no_grad():
                # 1) Forward pass: compute logits for the next token
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :]

                # 2) Apply temperature scaling
                next_token_logits = next_token_logits / temperature

                # 3) Convert logits -> probabilities
                next_token_probs = F.softmax(next_token_logits, dim=-1)

                # 4) Nucleus (top-p) sampling
                sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                valid_indices = cumulative_probs <= top_p
                filtered_probs = sorted_probs[valid_indices]
                filtered_indices = sorted_indices[valid_indices]

                # 5) If no tokens remain after filtering, fall back to greedy
                if len(filtered_probs) == 0:
                    next_token_id = torch.argmax(next_token_probs)
                else:
                    sampled_id = torch.multinomial(filtered_probs, 1)
                    next_token_id = filtered_indices[sampled_id]

                # 6) Ensure next_token_id has shape [batch_size, 1]
                if next_token_id.dim() == 0:
                    # shape [] => [1]
                    next_token_id = next_token_id.unsqueeze(0)
                # shape [1] => [1,1]
                next_token_id = next_token_id.unsqueeze(-1)

                # 7) Append the new token to input_ids
                input_ids = torch.cat([input_ids, next_token_id], dim=-1)

                # 8) Update the attention mask
                new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
                attention_mask = torch.cat([attention_mask, new_mask], dim=-1)

                # 9) Decode and yield the generated token
                token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
                yield token + " "

                # 10) Stop if EOS token is generated (if the model uses one)
                if tokenizer.eos_token_id is not None:
                    if next_token_id.squeeze().item() == tokenizer.eos_token_id:
                        break

    # Return a StreamingResponse for SSE
    return StreamingResponse(token_generator(), media_type="text/plain")