Fred808 commited on
Commit
36267e8
·
verified ·
1 Parent(s): cc53fa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -54
app.py CHANGED
@@ -1,68 +1,49 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  import os
 
4
  import logging
5
- import openai
6
 
7
  # Read the NVIDIA API key from environment variables
8
  api_key = os.getenv("NVIDIA_API_KEY")
9
  if api_key is None:
10
  raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
11
 
12
- # Initialize FastAPI app
13
- app = FastAPI()
14
-
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  # NVIDIA API configuration
20
- openai.api_key = api_key
21
- openai.base_url = "https://integrate.api.nvidia.com/v1"
22
-
23
- # Define request body schema
24
- class TextGenerationRequest(BaseModel):
25
- prompt: str
26
- max_new_tokens: int = 1024
27
- temperature: float = 0.4
28
- top_p: float = 0.7
29
- stream: bool = True
30
-
31
- # Define API endpoint
32
- @app.post("/generate-text")
33
- async def generate_text(request: TextGenerationRequest):
34
- try:
35
- logger.info("Generating text...")
36
-
37
- # Generate response from NVIDIA API
38
- response = openai.ChatCompletion.create(
39
- model="meta/llama-3.1-405b-instruct",
40
- messages=[{"role": "user", "content": request.prompt}],
41
- temperature=request.temperature,
42
- top_p=request.top_p,
43
- max_tokens=request.max_new_tokens,
44
- stream=request.stream,
45
- )
46
-
47
  response_text = ""
48
- if request.stream:
49
- for chunk in response:
50
- if chunk.choices[0].delta.get("content"):
51
- response_text += chunk.choices[0].delta.content
52
- else:
53
- response_text = response["choices"][0]["message"]["content"]
54
-
55
- return {"generated_text": response_text}
56
- except Exception as e:
57
- logger.error(f"Error generating text: {e}")
58
- raise HTTPException(status_code=500, detail=str(e))
59
-
60
- # Add a root endpoint for health checks
61
- @app.get("/")
62
- async def root():
63
- return {"message": "Welcome Fred808 GPT"}
64
-
65
- # Add a test endpoint
66
- @app.get("/test")
67
- async def test():
68
- return {"message": "API is running!"}
 
 
 
1
  import os
2
+ import requests
3
  import logging
 
4
 
5
  # Read the NVIDIA API key from environment variables
6
  api_key = os.getenv("NVIDIA_API_KEY")
7
  if api_key is None:
8
  raise ValueError("NVIDIA API key not found in environment variables. Please set the NVIDIA_API_KEY.")
9
 
 
 
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
  # NVIDIA API configuration
15
+ base_url = "https://integrate.api.nvidia.com/v1"
16
+ headers = {
17
+ "Authorization": f"Bearer {api_key}",
18
+ "Content-Type": "application/json"
19
+ }
20
+
21
+ # Define request payload
22
+ payload = {
23
+ "model": "meta/llama-3.1-405b-instruct", # Model for NVIDIA's text generation
24
+ "messages": [{"role": "user", "content": "Write a limerick about the wonders of GPU computing."}],
25
+ "temperature": 0.2,
26
+ "top_p": 0.7,
27
+ "max_tokens": 1024,
28
+ "stream": True
29
+ }
30
+
31
+ # Call NVIDIA's API for text generation
32
+ try:
33
+ logger.info("Generating text with NVIDIA API...")
34
+ response = requests.post(f"{base_url}/chat/completions", headers=headers, json=payload, stream=True)
35
+
36
+ if response.status_code == 200:
37
+ # Stream the response
 
 
 
 
38
  response_text = ""
39
+ for chunk in response.iter_lines():
40
+ if chunk:
41
+ data = chunk.decode("utf-8")
42
+ # Extract the content from the response (adjust based on actual API response structure)
43
+ if "content" in data:
44
+ response_text += data["choices"][0]["delta"].get("content", "")
45
+ print(response_text, end="") # Print content as it's received
46
+ else:
47
+ logger.error(f"Error: {response.status_code} - {response.text}")
48
+ except Exception as e:
49
+ logger.error(f"Error generating text: {e}")