Skip to content

PixArt-Σ

PixArt-Σ: 从弱到强的扩散Transformer训练用于4K文本到图像生成 由陈俊松、于金成、葛崇健、姚乐伟、谢恩泽、吴越、王忠道、James Kwok、罗平、卢虎川和郑国立共同撰写。

论文的摘要如下:

在本文中,我们介绍了PixArt-Σ,一种能够直接生成4K分辨率图像的扩散Transformer模型(DiT)。PixArt-Σ在其前身PixArt-α的基础上取得了显著的进步,提供了明显更高保真度和与文本提示更好对齐的图像。PixArt-Σ的一个关键特点是其训练效率。利用PixArt-α的基础预训练,它通过融入更高质量的数据,从“较弱”的基线演变为“更强”的模型,我们称之为“从弱到强训练”。PixArt-Σ的进步主要体现在两个方面:(1)高质量训练数据:PixArt-Σ融入了更高质量的图像数据,并配以更精确和详细的图像描述。(2)高效的Token压缩:我们在DiT框架内提出了一种新的注意力模块,该模块压缩了键和值,显著提高了效率并促进了超高分辨率图像的生成。得益于这些改进,PixArt-Σ在模型尺寸(0.6B参数)显著小于现有文本到图像扩散模型(如SDXL的2.6B参数和SD Cascade的5.1B参数)的情况下,实现了卓越的图像质量和用户提示遵循能力。此外,PixArt-Σ生成4K图像的能力支持创建高分辨率海报和壁纸,有效增强了电影和游戏等行业高质量视觉内容的制作。

你可以在PixArt-alpha/PixArt-sigma找到原始代码库,并在PixArt-alpha找到所有可用的检查点。

关于此管道的几点说明:

  • 它使用Transformer主干(而不是UNet)进行去噪。因此,它与DiT具有类似的架构。
  • 它使用从T5计算的文本条件进行训练。这一方面使得管道在遵循复杂文本提示和精细细节方面表现更好。
  • 它在不同宽高比下生成高分辨率图像方面表现出色。为了获得最佳结果,作者推荐了一些尺寸范围,可以在这里找到。
  • 它在质量上与最先进的文本到图像生成系统(截至本文撰写时)如PixArt-α、Stable Diffusion XL、Playground V2.0和DALL-E 3相媲美,同时效率更高。
  • 它展示了生成超高分辨率图像(如2048px甚至4K)的能力。
  • 它表明,通过多种改进(如VAEs、数据集等),文本到图像模型可以从弱模型成长为强模型。

在8GB GPU VRAM下进行推理

通过以8位精度加载文本编码器,在8GB GPU VRAM下运行[PixArtSigmaPipeline]。让我们通过一个完整的示例来了解。

首先,安装bitsandbytes库:

bash
pip install -U bitsandbytes

然后以8位加载文本编码器:

python
from transformers import T5EncoderModel
from diffusers import PixArtSigmaPipeline
import torch

text_encoder = T5EncoderModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="text_encoder",
    load_in_8bit=True,
    device_map="auto",
)
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    text_encoder=text_encoder,
    transformer=None,
    device_map="balanced"
)

现在,使用pipe来编码一个提示:

python
with torch.no_grad():
    prompt = "cute cat"
    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)

由于已经计算了文本嵌入,从内存中移除text_encoderpipe,并释放一些GPU VRAM:

python
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()

del text_encoder
del pipe
flush()

然后使用提示嵌入作为输入计算潜在变量:

python
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    text_encoder=None,
    torch_dtype=torch.float16,
).to("cuda")

latents = pipe(
    negative_prompt=None,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    prompt_attention_mask=prompt_attention_mask,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
    num_images_per_prompt=1,
    output_type="latent",
).images

del pipe.transformer
flush()

一旦计算出潜在变量,将其传递给VAE以解码为真实图像:

python
with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("cat.png")

通过删除未使用的组件并清空 GPU VRAM,你应该能够在 8GB 以下 GPU VRAM 的情况下运行 [PixArtSigmaPipeline]。

如果你想查看内存使用情况的报告,请运行此 脚本

在加载 text_encoder 时,你将 load_in_8bit 设置为 True。你还可以指定 load_in_4bit,以进一步降低内存需求至 7GB 以下。

PixArtSigmaPipeline

[[autodoc]] PixArtSigmaPipeline - all - call