减少内存使用
使用扩散模型的一个障碍是需要大量的内存。为了克服这一挑战,你可以使用一些减少内存的技术,即使是在免费层或消费级 GPU 上运行一些最大的模型。其中一些技术甚至可以结合使用,以进一步减少内存使用。
以下结果来自使用 50 个 DDIM 步从提示“火星上骑马的宇航员的照片”生成单个 512x512 图像,在 Nvidia Titan RTX 上进行,展示了由于减少内存消耗而可以预期的加速。
latency | speed-up | |
---|---|---|
original | 9.50s | x1 |
fp16 | 3.61s | x2.63 |
channels last | 3.30s | x2.88 |
traced UNet | 3.21s | x2.96 |
memory-efficient attention | 2.63s | x3.61 |
切片 VAE
切片 VAE 允许通过一次解码一个图像的潜在变量批次来解码大量图像,而无需大量 VRAM 或 32 张或更多图像的批次。如果你安装了 xFormers,你可能希望将它与 [~ModelMixin.enable_xformers_memory_efficient_attention
] 结合使用,以进一步减少内存使用。
要使用切片 VAE,请在推理之前在你的管道上调用 [~StableDiffusionPipeline.enable_vae_slicing
]:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
#pipe.enable_xformers_memory_efficient_attention()
images = pipe([prompt] * 32).images
你可能会在多图像批次中看到 VAE 解码的性能略有提升,并且对单图像批次没有性能影响。
平铺 VAE
平铺 VAE 处理还允许在有限的 VRAM 上处理大型图像(例如,在 8GB VRAM 上生成 4k 图像),方法是将图像分割成重叠的平铺,解码平铺,然后将输出混合在一起以合成最终图像。如果你安装了 xFormers,你还可以将平铺 VAE 与 [~ModelMixin.enable_xformers_memory_efficient_attention
] 一起使用,以进一步减少内存使用。
要使用平铺 VAE 处理,请在推理之前在你的管道上调用 [~StableDiffusionPipeline.enable_vae_tiling
]:
import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
#pipe.enable_xformers_memory_efficient_attention()
image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
输出图像可能存在一些瓷砖之间的色调变化,因为瓷砖是单独解码的,但你应该看不到瓷砖之间任何明显和尖锐的接缝。对于 512x512 或更小的图像,瓷砖功能将被关闭。
CPU卸载
将权重卸载到 CPU,并在执行前向传递时仅将它们加载到 GPU 上,也可以节省内存。通常,这种技术可以将内存消耗降低到 3GB 以下。
要执行 CPU 卸载,请调用 [~StableDiffusionPipeline.enable_sequential_cpu_offload
]:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
image = pipe(prompt).images[0]
CPU 卸载
CPU 卸载作用于子模块而不是整个模型。这是最大程度减少内存消耗的最佳方式,但由于扩散过程的迭代性质,推理速度要慢得多。管道中的 UNet 组件运行多次(最多 num_inference_steps
次);每次,不同的 UNet 子模块都会根据需要依次加载和卸载,导致大量内存传输。
模型卸载
顺序 CPU 卸载 保留了大量内存,但它使推理速度变慢,因为子模块根据需要被移至 GPU,并且在新的模块运行时立即返回到 CPU。
全模型卸载是一种替代方案,它将整个模型移至 GPU,而不是处理每个模型的组成 子模块。与将管道移至 cuda
相比,对推理时间的影响可以忽略不计,并且它仍然可以节省一些内存。
在模型卸载期间,管道的主要组件(通常是文本编码器、UNet 和 VAE)中只有一个被放置在 GPU 上,而其他组件则在 CPU 上等待。像 UNet 这样的运行多次迭代的组件会一直保留在 GPU 上,直到不再需要它们。
通过在管道上调用 [~StableDiffusionPipeline.enable_model_cpu_offload
] 来启用模型卸载:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
image = pipe(prompt).images[0]
通道最后内存格式
通道最后内存格式是另一种在内存中对 NCHW 张量进行排序的方式,以保留维度排序。通道最后张量以通道成为最密集维度(逐像素存储图像)的方式进行排序。由于并非所有运算符目前都支持通道最后格式,因此可能会导致性能下降,但你仍然应该尝试看看它是否适用于你的模型。
例如,要将管道的 UNet 设置为使用通道最后格式:
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
pipe.unet.to(memory_format=torch.channels_last) # in-place operation
print(
pipe.unet.conv_out.state_dict()["weight"].stride()
) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
追踪
追踪会将一个示例输入张量运行通过模型,并捕获该输入在模型各层中的传递过程中执行的操作。返回的可执行文件或 ScriptFunction
会通过即时编译进行优化。
要追踪一个 UNet:
import time
import torch
from diffusers import StableDiffusionPipeline
import functools
# torch disable grad
torch.set_grad_enabled(False)
# set variables
n_experiments = 2
unet_runs_per_experiment = 50
# load inputs
def generate_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
for _ in range(3):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")
# warmup and optimize graph
for _ in range(5):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
# benchmarking
with torch.inference_mode():
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
print(f"unet inference took {time.time() - start_time:.2f} seconds")
# save the model
unet_traced.save("unet_traced.pt")
将管道中的 unet
属性替换为追踪后的模型:
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass
@dataclass
class UNet2DConditionOutput:
sample: torch.Tensor
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
# del pipe.unet
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
内存高效注意力机制
最近关于优化注意力块带宽的工作带来了巨大的加速和 GPU 内存使用量的减少。 最新的内存高效注意力机制是 Flash Attention(你可以在 HazyResearch/flash-attention 查看原始代码)。
要使用 Flash Attention,请安装以下内容:
- PyTorch > 1.12
- 可用的 CUDA
- xFormers
然后在管道上调用 [~ModelMixin.enable_xformers_memory_efficient_attention
]:
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
with torch.inference_mode():
sample = pipe("a small cat")
# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
使用 xformers
时的迭代速度应与 PyTorch 2.0 的迭代速度一致,如 这里 所述。