Skip to content

文本到图像

像 Stable Diffusion 这样的文本到图像模型是根据给定的文本提示生成图像的。

训练模型可能会对你的硬件造成很大负担,但如果你启用 gradient_checkpointingmixed_precision,则可以在单个 24GB GPU 上训练模型。如果你使用更大的批量或希望训练得更快,最好使用内存超过 30GB 的 GPU。你可以通过启用 xFormers 的内存高效注意力机制来减少内存占用。JAX/Flax 训练也支持在 TPUs 和 GPUs 上高效训练,但不支持梯度检查点、梯度累积或 xFormers。建议使用至少 30GB 内存的 GPU 或 TPU v3 进行 Flax 训练。

本指南将探讨 train_text_to_image.py 训练脚本,帮助你熟悉它,并了解如何根据自己的需求进行调整。

在运行脚本之前,请确保从源代码安装库:

bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

然后导航到包含训练脚本的示例文件夹,并安装你所使用脚本所需的依赖项:

初始化一个 🤗 Accelerate 环境:

bash
accelerate config

要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:

bash
accelerate config default

或者,如果你的环境不支持交互式 shell,比如笔记本,你可以使用:

py
from accelerate.utils import write_basic_config

write_basic_config()

最后,如果你想在自己的数据集上训练模型,请参阅创建用于训练的数据集指南,了解如何创建与训练脚本兼容的数据集。

脚本参数

训练脚本提供了许多参数,帮助你自定义训练过程。所有参数及其描述都在 parse_args() 函数中。此函数为每个参数提供了默认值,例如训练批次大小和学习率,但你也可以在训练命令中设置自己的值。

例如,要使用 fp16 格式通过混合精度加速训练,可以在训练命令中添加 --mixed_precision 参数:

bash
accelerate launch train_text_to_image.py \
  --mixed_precision="fp16"

一些基本且重要的参数包括:

  • --pretrained_model_name_or_path:Hub 上的模型名称或预训练模型的本地路径
  • --dataset_name:Hub 上的数据集名称或要训练的数据集的本地路径
  • --image_column:数据集中用于训练的图像列的名称
  • --caption_column:数据集中用于训练的文本列的名称
  • --output_dir:保存训练模型的位置
  • --push_to_hub:是否将训练好的模型推送到 Hub
  • --checkpointing_steps:在模型训练过程中保存检查点的频率;如果训练因某种原因中断,可以通过在训练命令中添加 --resume_from_checkpoint 来从该检查点继续训练

Min-SNR 权重

Min-SNR 权重策略可以通过重新平衡损失来帮助训练,从而实现更快的收敛。训练脚本支持预测 epsilon(噪声)或 v_prediction,但 Min-SNR 与这两种预测类型都兼容。此权重策略仅由 PyTorch 支持,Flax 训练脚本中不可用。

添加 --snr_gamma 参数并将其设置为推荐值 5.0:

bash
accelerate launch train_text_to_image.py \
  --snr_gamma=5.0

你可以在这个 Weights and Biases 报告中比较不同 snr_gamma 值的损失曲面。对于较小的数据集,Min-SNR 的效果可能不如较大数据集那么明显。

训练脚本

数据集预处理代码和训练循环位于 main() 函数中。如果你需要调整训练脚本,这里是你需要进行修改的地方。

train_text_to_image 脚本首先 加载调度器 和分词器。如果你希望使用不同的调度器,可以在这里进行选择:

py
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)

然后脚本加载 UNet 模型:

py
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())

接下来,需要对数据集的文本和图像列进行预处理。tokenize_captions 函数负责对输入进行分词,而 train_transforms 函数则指定了要对图像应用的变换类型。这两个函数都被打包在 preprocess_train 中:

py
def preprocess_train(examples):
    images = [image.convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

最后,训练循环 处理所有其他事情。它将图像编码到潜在空间,向潜在变量添加噪声,计算文本嵌入以进行条件化,更新模型参数,并将模型保存并推送到 Hub。如果你想了解更多关于训练循环的工作原理,可以查看 理解管道、模型和调度器 教程,该教程详细介绍了去噪过程的基本模式。

启动脚本

一旦你完成了所有更改或对默认配置满意,就可以启动训练脚本了!🚀

一旦训练完成,你可以使用新训练的模型进行推理:

下一步

恭喜你训练了自己的文本到图像模型!要了解更多如何使用你的新模型,以下指南可能会对你有所帮助:

  • 如果你使用 LoRA 训练了模型,可以学习如何 加载 LoRA 权重 以进行推理。
  • 文本到图像 任务指南中,了解更多关于如何通过指导比例或提示权重等参数和技术来控制推理。