potatopizza's picture
Update inference.py
10ab55c verified
import torch
from diffusers import StableDiffusionPipeline
import os
def inference(
model_path: str,
prompt: str,
output_image_path: str = "generated_image.png",
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
use_fp16: bool = True,
):
"""
Generates an image using a fine-tuned Stable Diffusion model.
Args:
model_path (str): The path (or repo ID) to the fine-tuned Stable Diffusion model.
prompt (str): The text prompt used to generate an image.
output_image_path (str): The path to save the generated image.
guidance_scale (float): Classifier-Free Guidance scale.
num_inference_steps (int): Number of inference steps in the diffusion process.
use_fp16 (bool): Whether to use fp16 (half precision). Recommended if a GPU is available.
"""
# Load the model (use half precision if GPU is available)
# If you only have CPU, consider using torch.float32 or omitting torch_dtype
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16 if use_fp16 and torch.cuda.is_available() else torch.float32
)
# Set the device (use GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
# Generate the image
with torch.autocast(device) if (use_fp16 and device == "cuda") else torch.no_grad():
result = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
# Save the result
image = result.images[0]
image.save(output_image_path)
print(f"Saved Image: {os.path.abspath(output_image_path)}")
if __name__ == "__main__":
# Example prompt
sample_prompt = "A paint of Eiffel tower, in the style of eric fischl."
# Assuming the fine-tuned model was saved with pipe.save_pretrained("stable-diffusion-wikiart-final")
finetuned_model_path = "stable-diffusion-wikiart-final"
# Run inference
inference(
model_path=finetuned_model_path,
prompt=sample_prompt,
output_image_path="wikiart_inference_result.png",
guidance_scale=7.5,
num_inference_steps=50,
use_fp16=True
)