Stable Diffusion image pa

 import torch

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import matplotlib.pyplot as plt
from PIL import Image

# Load the pipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# Set a scheduler to include callback
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

# Define the callback to capture intermediate images
intermediate_images = []

def save_intermediate_images(pipeline, step, timestep, extra_inputs):
    """
    Callback function to save intermediate images.
    """
    latents = extra_inputs["latents"]  # Retrieve the latents from the extra inputs

    if step == 0:
        # Capture the first random latent (noise)
        with torch.no_grad():
            random_latent_image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor)
            random_latent_image = random_latent_image.sample
            processed_image = pipeline.image_processor.postprocess(random_latent_image, output_type="pil")
            intermediate_images.append(processed_image[0])  # Append first random image

    elif step == 1:
        # Capture the latents before passing them to VAE (latent space representation)
        with torch.no_grad():
            latents_before_vae = latents
            processed_image = pipeline.image_processor.postprocess(latents_before_vae, output_type="pil")
            intermediate_images.append(processed_image[0])

    elif step == 2:
        # Capture the decoded image after passing through the VAE
        with torch.no_grad():
            decoded_output = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor)
            decoded_image = decoded_output.sample
            processed_image = pipeline.image_processor.postprocess(decoded_image, output_type="pil")
            intermediate_images.append(processed_image[0])

    return {"latents": latents}  # Return the latents unmodified

# Generate the image with the callback
prompt = "floor map of 2bhk house"
image = pipe(prompt, callback_on_step_end=save_intermediate_images, num_inference_steps=50).images[0]

# Display the phases of generation
num_images = len(intermediate_images) + 1  # Intermediate images + final output
fig, axes = plt.subplots(1, num_images, figsize=(15, 5))

# Ensure axes is iterable even for a single subplot
if num_images == 1:
    axes = [axes]

# Display the intermediate and final images
axes[0].imshow(intermediate_images[0])
axes[0].axis("off")
axes[0].set_title("Random Latent Image")

axes[1].imshow(intermediate_images[1])
axes[1].axis("off")
axes[1].set_title("Latents Before VAE")

axes[2].imshow(intermediate_images[2])
axes[2].axis("off")
axes[2].set_title("After VAE Decode")

axes[-1].imshow(image)
axes[-1].axis("off")
axes[-1].set_title("Final Output")

plt.show()

Comments

Popular posts from this blog

CSS-position property

maintext/ react

randomly changing color of tiles