|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
from PIL import Image |
|
import requests |
|
import torch |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
processor = AutoProcessor.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
image_url: str |
|
text_input: str |
|
|
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "Molmo-7B-D API is up and running!"} |
|
|
|
|
|
@app.post("/generate/") |
|
def generate_text(request: GenerateRequest): |
|
try: |
|
|
|
response = requests.get(request.image_url, stream=True) |
|
image = Image.open(response.raw) |
|
|
|
|
|
inputs = processor(images=[image], text=request.text_input, return_tensors="pt").to(device) |
|
|
|
|
|
output_ids = model.generate(inputs["input_ids"], max_new_tokens=200) |
|
generated_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
return {"generated_text": generated_text} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|