Stable Diffusion image pa
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import matplotlib.pyplot as plt
from PIL import Image
# Load the pipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# Set a scheduler to include callback
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# Define the callback to capture intermediate images
intermediate_images = []
def save_intermediate_images(pipeline, step, timestep, extra_inputs):
"""
Callback function to save intermediate images.
"""
latents = extra_inputs["latents"] # Retrieve the latents from the extra inputs
if step == 0:
# Capture the first random latent (noise)
with torch.no_grad():
random_latent_image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor)
random_latent_image = random_latent_image.sample
processed_image = pipeline.image_processor.postprocess(random_latent_image, output_type="pil")
intermediate_images.append(processed_image[0]) # Append first random image
elif step == 1:
# Capture the latents before passing them to VAE (latent space representation)
with torch.no_grad():
latents_before_vae = latents
processed_image = pipeline.image_processor.postprocess(latents_before_vae, output_type="pil")
intermediate_images.append(processed_image[0])
elif step == 2:
# Capture the decoded image after passing through the VAE
with torch.no_grad():
decoded_output = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor)
decoded_image = decoded_output.sample
processed_image = pipeline.image_processor.postprocess(decoded_image, output_type="pil")
intermediate_images.append(processed_image[0])
return {"latents": latents} # Return the latents unmodified
# Generate the image with the callback
prompt = "floor map of 2bhk house"
image = pipe(prompt, callback_on_step_end=save_intermediate_images, num_inference_steps=50).images[0]
# Display the phases of generation
num_images = len(intermediate_images) + 1 # Intermediate images + final output
fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
# Ensure axes is iterable even for a single subplot
if num_images == 1:
axes = [axes]
# Display the intermediate and final images
axes[0].imshow(intermediate_images[0])
axes[0].axis("off")
axes[0].set_title("Random Latent Image")
axes[1].imshow(intermediate_images[1])
axes[1].axis("off")
axes[1].set_title("Latents Before VAE")
axes[2].imshow(intermediate_images[2])
axes[2].axis("off")
axes[2].set_title("After VAE Decode")
axes[-1].imshow(image)
axes[-1].axis("off")
axes[-1].set_title("Final Output")
plt.show()
Comments
Post a Comment