Fred808 commited on
Commit
64c0b0e
·
verified ·
1 Parent(s): 49b4625

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import torch
5
+
6
+ # Initialize FastAPI app
7
+ app = FastAPI()
8
+
9
+ # Load the Falcon-7B model with 8-bit quantization (if CUDA is available)
10
+ model_id = "tiiuae/falcon-7b-instruct"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+
13
+ # Check if CUDA is available
14
+ if torch.cuda.is_available():
15
+ # Load the model with 8-bit quantization for GPU
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ load_in_8bit=True,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+ else:
23
+ # Fallback to CPU or full precision
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ device_map="auto",
27
+ trust_remote_code=True
28
+ )
29
+
30
+ # Create a text generation pipeline
31
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
32
+
33
+ # Define request body schema
34
+ class TextGenerationRequest(BaseModel):
35
+ prompt: str
36
+ max_new_tokens: int = 50
37
+ temperature: float = 0.7
38
+ top_k: int = 50
39
+ top_p: float = 0.9
40
+ do_sample: bool = True
41
+
42
+ # Define API endpoint
43
+ @app.post("/generate-text")
44
+ async def generate_text(request: TextGenerationRequest):
45
+ try:
46
+ # Generate text using the pipeline
47
+ outputs = pipe(
48
+ request.prompt,
49
+ max_new_tokens=request.max_new_tokens,
50
+ temperature=request.temperature,
51
+ top_k=request.top_k,
52
+ top_p=request.top_p,
53
+ do_sample=request.do_sample
54
+ )
55
+ return {"generated_text": outputs[0]["generated_text"]}
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=str(e))
58
+
59
+ # Add a root endpoint for health checks
60
+ @app.get("/test")
61
+ async def root():
62
+ return {"message": "API is running!"}