文本反转
文本反转 是一种训练技术,用于通过少量示例图像来个性化图像生成模型,使其学习你希望它学习的内容。该技术通过学习和更新文本嵌入(新的嵌入与你在提示中必须使用的特殊单词绑定)来匹配你提供的示例图像。
如果你在具有有限 vRAM 的 GPU 上进行训练,建议在训练命令中启用 gradient_checkpointing
和 mixed_precision
参数。你还可以通过使用 xFormers 的内存高效注意力机制来减少内存占用。JAX/Flax 训练也支持在 TPUs 和 GPUs 上高效训练,但不支持梯度检查点或 xFormers。在与 PyTorch 相同的配置和设置下,Flax 训练脚本的速度应至少快 ~70%!
本指南将探讨 textual_inversion.py 脚本,帮助你更熟悉它,并了解如何根据自己的需求进行调整。
在运行脚本之前,请确保从源代码安装库:
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
导航到包含训练脚本的示例文件夹,并安装你所使用脚本所需的依赖项:
初始化一个 🤗 Accelerate 环境:
accelerate config
要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:
accelerate config default
或者,如果你的环境不支持交互式 shell,比如笔记本,你可以使用:
from accelerate.utils import write_basic_config
write_basic_config()
最后,如果你想在自己的数据集上训练模型,请参阅创建用于训练的数据集指南,了解如何创建与训练脚本兼容的数据集。
脚本参数
训练脚本有许多参数,可帮助你根据需要定制训练过程。所有参数及其描述都列在 parse_args()
函数中。在适用的情况下,Diffusers 为每个参数提供了默认值,例如训练批次大小和学习率,但如果你愿意,可以在训练命令中更改这些值。
例如,要将梯度累积步数增加到默认值 1 以上:
accelerate launch textual_inversion.py \
--gradient_accumulation_steps=4
一些其他基本且重要的参数包括:
--pretrained_model_name_or_path
: 模型在 Hub 上的名称或预训练模型的本地路径--train_data_dir
: 包含训练数据集(示例图像)的文件夹路径--output_dir
: 保存训练后模型的位置--push_to_hub
: 是否将训练后的模型推送到 Hub--checkpointing_steps
: 在模型训练过程中保存检查点的频率;如果训练因某种原因中断,可以通过在训练命令中添加--resume_from_checkpoint
从该检查点继续训练--num_vectors
: 用于学习嵌入的向量数量;增加此参数有助于模型更好地学习,但会增加训练成本--placeholder_token
: 用于绑定学习到的嵌入的特殊词(在推理时必须在提示中使用该词)--initializer_token
: 一个大致描述你要训练的对象或风格的单个词--learnable_property
: 你是在训练模型学习新的“风格”(例如,梵高的绘画风格)还是“对象”(例如,你的狗)
训练脚本
与一些其他训练脚本不同,textual_inversion.py
有一个自定义的数据集类,TextualInversionDataset
,用于创建数据集。你可以自定义图像大小、占位符词、插值方法、是否裁剪图像等。如果需要更改数据集的创建方式,可以修改 TextualInversionDataset
。
接下来,你可以在 main()
函数中找到数据集预处理代码和训练循环。
# Load tokenizer
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
特殊的 占位符标记 被添加到分词器中,并且嵌入被重新调整以适应新的标记。
然后,脚本 创建一个数据集 从 TextualInversionDataset
:
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
tokenizer=tokenizer,
size=args.resolution,
placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
repeats=args.repeats,
learnable_property=args.learnable_property,
center_crop=args.center_crop,
set="train",
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
最后,训练循环 负责处理从预测噪声残差到更新特殊占位符标记的嵌入权重的所有其他内容。
如果你想了解更多关于训练循环的工作原理,可以查看 理解管道、模型和调度器 教程,该教程详细介绍了去噪过程的基本模式。
启动脚本
一旦你完成了所有更改或对默认配置满意,就可以启动训练脚本了!🚀
在本指南中,你将下载一些 猫玩具 的图片并将其存储在一个目录中。但请记住,你也可以创建并使用自己的数据集(参见 为训练创建数据集 指南)。
from huggingface_hub import snapshot_download
local_dir = "./cat"
snapshot_download(
"diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)
将环境变量 MODEL_NAME
设置为 Hub 上的模型 ID 或本地模型的路径,并将 DATA_DIR
设置为你刚刚下载猫图片的路径。脚本会将以下文件创建并保存到你的仓库中:
learned_embeds.bin
:与你的示例图片相对应的已学习嵌入向量token_identifier.txt
:特殊占位符令牌type_of_concept.txt
:你正在训练的概念类型(“object”或“style”)
在你启动脚本之前还有一件事。如果你对跟踪训练过程感兴趣,可以定期保存生成的图像。在训练命令中添加以下参数:
--validation_prompt="A <cat-toy> train"
--num_validation_images=4
--validation_steps=100
训练完成后,你可以像这样使用你新训练的模型进行推理:
下一步
恭喜你训练了自己的文本反转模型!🎉 要了解更多关于如何使用你的新模型,以下指南可能会有所帮助: