算力平台:
Flux
Flux 是一系列基于扩散变换器的文本到图像生成模型。要了解更多关于 Flux 的信息,请查看 Flux 的创建者 Black Forest Labs 的原始博客文章。
Flux 的原始模型检查点可以在这里找到。原始推理代码可以在这里找到。
Flux 有两种变体:
- 时间步长蒸馏(
black-forest-labs/FLUX.1-schnell
) - 引导蒸馏(
black-forest-labs/FLUX.1-dev
)
这两个检查点在使用上有一些细微差别,我们将在下面详细说明。
时间步长蒸馏
max_sequence_length
不能超过 256。guidance_scale
需要设置为 0。- 由于这是一个时间步长蒸馏模型,它从较少的采样步骤中受益。
python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
指导蒸馏
- 指导蒸馏变体大约需要50个采样步骤以生成高质量的图像。
- 它没有任何关于
max_sequence_length
的限制。
python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
prompt=prompt,
guidance_scale=3.5,
height=768,
width=1360,
num_inference_steps=50,
).images[0]
out.save("image.png")
运行FP16推理
Flux可以在FP16模式下生成高质量的图像(即加速Turing/Volta GPU上的推理),但与FP32/BF16相比会产生不同的输出。问题在于,在FP16模式下运行时,文本编码器中的一些激活值需要被裁剪,这会影响整体图像。强制文本编码器以FP32推理运行可以消除这种输出差异。详情请参见这里。
FP16推理代码:
python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
FluxTransformer2DModel
的单文件加载
FluxTransformer2DModel
支持加载 Black Forest Labs 提供的原始格式的检查点。这在尝试加载社区发布的微调或量化版本模型时也非常有用。
以下示例展示了如何在 VRAM 小于 16GB 的情况下运行 Flux。
首先安装 optimum-quanto
shell
pip install optimum-quanto
然后运行以下示例
python
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=3.5,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fp8-dev.png")
FluxPipeline
[[autodoc]] FluxPipeline - all - call
FluxImg2ImgPipeline
[[autodoc]] FluxImg2ImgPipeline - all - call
FluxInpaintPipeline
[[autodoc]] FluxInpaintPipeline - all - call
FluxControlNetInpaintPipeline
[[autodoc]] FluxControlNetInpaintPipeline - all - call
FluxControlNetImg2ImgPipeline
[[autodoc]] FluxControlNetImg2ImgPipeline - all - call