PixArt-α
PixArt-α: 快速训练扩散Transformer用于逼真文本到图像合成 由Junsong Chen、Jincheng Yu、Chongjian Ge、Lewei Yao、Enze Xie、Yue Wu、Zhongdao Wang、James Kwok、Ping Luo、Huchuan Lu和Zhenguo Li撰写。
论文的摘要如下:
最先进的文本到图像(T2I)模型需要显著的训练成本(例如,数百万GPU小时),严重阻碍了AIGC社区的基础创新,同时增加了二氧化碳排放。本文介绍了PIXART-α,一种基于Transformer的T2I扩散模型,其图像生成质量与最先进的图像生成器(例如,Imagen、SDXL,甚至Midjourney)相当,达到了接近商业应用的标准。此外,它支持以低训练成本合成高达1024px分辨率的高分辨率图像,如图1和图2所示。为了实现这一目标,提出了三个核心设计:(1)训练策略分解:我们设计了三个不同的训练步骤,分别优化像素依赖性、文本-图像对齐和图像美学质量;(2)高效的T2I Transformer:我们将交叉注意力模块引入扩散Transformer(DiT),以注入文本条件并简化计算密集型的类别条件分支;(3)高信息量数据:我们强调文本-图像对中概念密度的重要性,并利用大型视觉-语言模型自动标注密集伪标签,以辅助文本-图像对齐学习。因此,PIXART-α的训练速度显著超过现有的大规模T2I模型,例如,PIXART-α仅占用Stable Diffusion v1.5训练时间的10.8%(675 vs. 6,250 A100 GPU天),节省了近30万美元(26,000 vs. 320,000美元)并减少了90%的二氧化碳排放。此外,与更大的SOTA模型RAPHAEL相比,我们的训练成本仅为1%。广泛的实验表明,PIXART-α在图像质量、艺术性和语义控制方面表现出色。我们希望PIXART-α能为AIGC社区和初创公司提供新的见解,以加速构建他们自己的高质量且低成本的生成模型。
你可以在PixArt-alpha/PixArt-alpha找到原始代码库,并在PixArt-alpha找到所有可用的检查点。
关于此管道的几点说明:
- 它使用Transformer主干(而不是UNet)进行去噪。因此,它与DiT具有类似的架构。
- 它使用从T5计算的文本条件进行训练。这一方面使得管道在遵循复杂文本提示和精细细节方面表现更好。
- 它擅长生成不同宽高比的高分辨率图像。为了获得最佳结果,作者推荐了一些尺寸范围,可以在这里找到。
- 它在质量上与最先进的文本到图像生成系统(截至撰写本文时)如Stable Diffusion XL、Imagen和DALL-E 2相媲美,同时比它们更高效。
在8GB GPU VRAM下进行推理
通过以8位精度加载文本编码器,在8GB GPU VRAM下运行[PixArtAlphaPipeline
]。让我们通过一个完整的示例来了解。
首先,安装bitsandbytes库:
pip install -U bitsandbytes
然后以8位加载文本编码器:
from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline
import torch
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="auto"
)
现在,使用pipe
来编码一个提示:
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
由于已经计算了文本嵌入,从内存中移除text_encoder
和pipe
,并释放一些GPU VRAM:
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
del text_encoder
del pipe
flush()
然后使用提示嵌入作为输入计算潜在变量:
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
一旦计算出潜在变量,将其传递给VAE以解码为真实图像:
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("cat.png")
通过删除未使用的组件并清空 GPU VRAM,你应该能够在 8GB 以下 GPU VRAM 的情况下运行 [PixArtAlphaPipeline
]。
如果你想查看内存使用情况的报告,请运行此 脚本。
在加载 text_encoder
时,你将 load_in_8bit
设置为 True
。你还可以指定 load_in_4bit
,以进一步降低内存需求至 7GB 以下。
PixArtAlphaPipeline
[[autodoc]] PixArtAlphaPipeline - all - call