potatopizza commited on
Commit
f69a255
·
verified ·
1 Parent(s): 1d1a873

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +65 -0
inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+ import os
4
+
5
+ def inference(
6
+ model_path: str,
7
+ prompt: str,
8
+ output_image_path: str = "generated_image.png",
9
+ guidance_scale: float = 7.5,
10
+ num_inference_steps: int = 50,
11
+ use_fp16: bool = True,
12
+ ):
13
+ """
14
+ Generates an image using a fine-tuned Stable Diffusion model.
15
+
16
+ Args:
17
+ model_path (str): The path (or repo ID) to the fine-tuned Stable Diffusion model.
18
+ prompt (str): The text prompt used to generate an image.
19
+ output_image_path (str): The path to save the generated image.
20
+ guidance_scale (float): Classifier-Free Guidance scale.
21
+ num_inference_steps (int): Number of inference steps in the diffusion process.
22
+ use_fp16 (bool): Whether to use fp16 (half precision). Recommended if a GPU is available.
23
+ """
24
+
25
+ # Load the model (use half precision if GPU is available)
26
+ # If you only have CPU, consider using torch.float32 or omitting torch_dtype
27
+ pipe = StableDiffusionPipeline.from_pretrained(
28
+ model_path,
29
+ torch_dtype=torch.float16 if use_fp16 and torch.cuda.is_available() else torch.float32
30
+ )
31
+
32
+ # Set the device (use GPU if available)
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ pipe.to(device)
35
+
36
+ # Generate the image
37
+ with torch.autocast(device) if (use_fp16 and device == "cuda") else torch.no_grad():
38
+ result = pipe(
39
+ prompt=prompt,
40
+ guidance_scale=guidance_scale,
41
+ num_inference_steps=num_inference_steps
42
+ )
43
+
44
+ # Save the result
45
+ image = result.images[0]
46
+ image.save(output_image_path)
47
+ print(f"Saved Image: {os.path.abspath(output_image_path)}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ # Example prompt
52
+ sample_prompt = "A painting of the Eiffel Tower in the style of Eric Fischl."
53
+
54
+ # Assuming the fine-tuned model was saved with pipe.save_pretrained("stable-diffusion-wikiart-final")
55
+ finetuned_model_path = "/home/work/daehyun/Style-Portrait-Generator/scripts/stable-diffusion-wikiart-final"
56
+
57
+ # Run inference
58
+ inference(
59
+ model_path=finetuned_model_path,
60
+ prompt=sample_prompt,
61
+ output_image_path="wikiart_inference_result.png",
62
+ guidance_scale=7.5,
63
+ num_inference_steps=50,
64
+ use_fp16=True
65
+ )