加载调度器和模型
[[open-in-colab]]
扩散管道是一系列可互换的调度器和模型的集合,可以混合和匹配以针对特定用例定制管道。调度器封装了整个去噪过程,例如去噪步骤的数量和用于找到去噪样本的算法。调度器不进行参数化或训练,因此占用的内存非常少。模型通常只关注从噪声输入到较少噪声样本的前向传递过程。
本指南将向你展示如何加载调度器和模型以自定义管道。你将在整个指南中使用 stable-diffusion-v1-5/stable-diffusion-v1-5 检查点,因此我们先加载它。
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
你可以通过 pipeline.scheduler
属性查看此管道使用的是哪个调度器。
pipeline.scheduler
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.21.4",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": null
}
加载调度器
调度器由一个配置文件定义,该文件可以被多种调度器使用。使用 [SchedulerMixin.from_pretrained
] 方法加载调度器,并指定 subfolder
参数将配置文件加载到管道仓库的正确子文件夹中。
例如,要加载 [DDIMScheduler
]:
from diffusers import DDIMScheduler, DiffusionPipeline
ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
然后你可以将新加载的调度器传递给管道。
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
比较调度器
调度器各有其独特的优缺点,这使得很难定量比较哪种调度器最适合某个管道。你通常需要在去噪速度和去噪质量之间做出权衡。我们建议尝试不同的调度器,以找到最适合你使用场景的调度器。调用 pipeline.scheduler.compatibles
属性,查看与管道兼容的调度器。
让我们比较以下提示和种子下的 [LMSDiscreteScheduler
]、[EulerDiscreteScheduler
]、[EulerAncestralDiscreteScheduler
] 和 [DPMSolverMultistepScheduler
]。
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
generator = torch.Generator(device="cuda").manual_seed(8)
要更改管道调度器,请使用 [~ConfigMixin.from_config
] 方法将不同调度器的 pipeline.scheduler.config
加载到管道中。




大多数图像看起来非常相似,质量也相当。再次强调,这通常取决于你的具体使用场景,因此一个好的方法是运行多个不同的调度器并比较结果。
Flax 调度器
要比较 Flax 调度器,你需要将调度器状态加载到模型参数中。例如,让我们将 [FlaxStableDiffusionPipeline
] 中的默认调度器更改为使用超级快速的 [FlaxDPMSolverMultistepScheduler
]。
WARNING
[FlaxLMSDiscreteScheduler
] 和 [FlaxDDPMScheduler
] 尚未与 [FlaxStableDiffusionPipeline
] 兼容。
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
scheduler=scheduler,
variant="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state
然后你可以利用 Flax 与 TPU 的兼容性来并行生成多个图像。你需要为每个可用设备复制一份模型参数,然后将输入拆分到这些设备上,以生成你所需数量的图像。
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
模型
模型通过 [ModelMixin.from_pretrained
] 方法加载,该方法会下载并缓存模型的最新权重和配置。如果最新文件已存在于本地缓存中,[~ModelMixin.from_pretrained
] 会重用缓存中的文件,而不会重新下载。
模型可以从子文件夹中加载,使用 subfolder
参数。例如,stable-diffusion-v1-5/stable-diffusion-v1-5 的模型权重存储在 unet 子文件夹中。
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
它们也可以直接从 仓库 加载。
from diffusers import UNet2DModel
unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
要加载和保存模型变体,请在 [ModelMixin.from_pretrained
] 和 [ModelMixin.save_pretrained
] 中指定 variant
参数。
from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
)
unet.save_pretrained("./local-unet", variant="non_ema")