PixArt-Σ
PixArt-Σ: 从弱到强的扩散Transformer训练用于4K文本到图像生成 由陈俊松、于金成、葛崇健、姚乐伟、谢恩泽、吴越、王忠道、James Kwok、罗平、卢虎川和郑国立共同撰写。
论文的摘要如下:
在本文中,我们介绍了PixArt-Σ,一种能够直接生成4K分辨率图像的扩散Transformer模型(DiT)。PixArt-Σ在其前身PixArt-α的基础上取得了显著的进步,提供了明显更高保真度和与文本提示更好对齐的图像。PixArt-Σ的一个关键特点是其训练效率。利用PixArt-α的基础预训练,它通过融入更高质量的数据,从“较弱”的基线演变为“更强”的模型,我们称之为“从弱到强训练”。PixArt-Σ的进步主要体现在两个方面:(1)高质量训练数据:PixArt-Σ融入了更高质量的图像数据,并配以更精确和详细的图像描述。(2)高效的Token压缩:我们在DiT框架内提出了一种新的注意力模块,该模块压缩了键和值,显著提高了效率并促进了超高分辨率图像的生成。得益于这些改进,PixArt-Σ在模型尺寸(0.6B参数)显著小于现有文本到图像扩散模型(如SDXL的2.6B参数和SD Cascade的5.1B参数)的情况下,实现了卓越的图像质量和用户提示遵循能力。此外,PixArt-Σ生成4K图像的能力支持创建高分辨率海报和壁纸,有效增强了电影和游戏等行业高质量视觉内容的制作。
你可以在PixArt-alpha/PixArt-sigma找到原始代码库,并在PixArt-alpha找到所有可用的检查点。
关于此管道的几点说明:
- 它使用Transformer主干(而不是UNet)进行去噪。因此,它与DiT具有类似的架构。
- 它使用从T5计算的文本条件进行训练。这一方面使得管道在遵循复杂文本提示和精细细节方面表现更好。
- 它在不同宽高比下生成高分辨率图像方面表现出色。为了获得最佳结果,作者推荐了一些尺寸范围,可以在这里找到。
- 它在质量上与最先进的文本到图像生成系统(截至本文撰写时)如PixArt-α、Stable Diffusion XL、Playground V2.0和DALL-E 3相媲美,同时效率更高。
- 它展示了生成超高分辨率图像(如2048px甚至4K)的能力。
- 它表明,通过多种改进(如VAEs、数据集等),文本到图像模型可以从弱模型成长为强模型。
在8GB GPU VRAM下进行推理
通过以8位精度加载文本编码器,在8GB GPU VRAM下运行[PixArtSigmaPipeline
]。让我们通过一个完整的示例来了解。
首先,安装bitsandbytes库:
pip install -U bitsandbytes
然后以8位加载文本编码器:
from transformers import T5EncoderModel
from diffusers import PixArtSigmaPipeline
import torch
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="balanced"
)
现在,使用pipe
来编码一个提示:
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
由于已经计算了文本嵌入,从内存中移除text_encoder
和pipe
,并释放一些GPU VRAM:
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
del text_encoder
del pipe
flush()
然后使用提示嵌入作为输入计算潜在变量:
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
一旦计算出潜在变量,将其传递给VAE以解码为真实图像:
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("cat.png")
通过删除未使用的组件并清空 GPU VRAM,你应该能够在 8GB 以下 GPU VRAM 的情况下运行 [PixArtSigmaPipeline
]。
如果你想查看内存使用情况的报告,请运行此 脚本。
在加载 text_encoder
时,你将 load_in_8bit
设置为 True
。你还可以指定 load_in_4bit
,以进一步降低内存需求至 7GB 以下。
PixArtSigmaPipeline
[[autodoc]] PixArtSigmaPipeline - all - call