Testing / model.py
Masrkai's picture
Update model.py
768bc7d verified
raw
history blame
1.32 kB
import torch
from diffusers import DiffusionPipeline
import trimesh
import numpy as np
from PIL import Image
from io import BytesIO
def load_pipeline():
"""
Load the stable-zero123 model pipeline from Hugging Face.
"""
ckpt_id = "stabilityai/stable-zero123"
pipe = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.float32).to("cpu")
return pipe
def generate_3d_model(pipe, prompt, output_path="output.obj", guidance_scale=7.5, num_inference_steps=32):
"""
Generate a 3D model from the prompt and save it in a Blender-compatible format (.obj).
"""
# Generate the model output
outputs = pipe(prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps)
# Extract mesh data if the output structure allows
vertices = outputs["vertices"][0].detach().cpu().numpy()
faces = outputs["faces"][0].detach().cpu().numpy()
# Create and save the mesh using trimesh
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=True)
mesh.export(output_path)
return output_path
def convert_to_gif(images, gif_path="output.gif"):
"""
Convert a list of images into a GIF.
"""
images[0].save(
gif_path, save_all=True, append_images=images[1:], loop=0, duration=100
)
return gif_path