Skip to content

Flux

Flux 是一系列基于扩散变换器的文本到图像生成模型。要了解更多关于 Flux 的信息,请查看 Flux 的创建者 Black Forest Labs 的原始博客文章

Flux 的原始模型检查点可以在这里找到。原始推理代码可以在这里找到。

Flux 有两种变体:

  • 时间步长蒸馏(black-forest-labs/FLUX.1-schnell
  • 引导蒸馏(black-forest-labs/FLUX.1-dev

这两个检查点在使用上有一些细微差别,我们将在下面详细说明。

时间步长蒸馏

  • max_sequence_length 不能超过 256。
  • guidance_scale 需要设置为 0。
  • 由于这是一个时间步长蒸馏模型,它从较少的采样步骤中受益。
python
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"
out = pipe(
    prompt=prompt,
    guidance_scale=0.,
    height=768,
    width=1360,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
out.save("image.png")

指导蒸馏

  • 指导蒸馏变体大约需要50个采样步骤以生成高质量的图像。
  • 它没有任何关于max_sequence_length的限制。
python
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=768,
    width=1360,
    num_inference_steps=50,
).images[0]
out.save("image.png")

运行FP16推理

Flux可以在FP16模式下生成高质量的图像(即加速Turing/Volta GPU上的推理),但与FP32/BF16相比会产生不同的输出。问题在于,在FP16模式下运行时,文本编码器中的一些激活值需要被裁剪,这会影响整体图像。强制文本编码器以FP32推理运行可以消除这种输出差异。详情请参见这里

FP16推理代码:

python
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once

prompt = "A cat holding a sign that says hello world"
out = pipe(
    prompt=prompt,
    guidance_scale=0.,
    height=768,
    width=1360,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
out.save("image.png")

FluxTransformer2DModel 的单文件加载

FluxTransformer2DModel 支持加载 Black Forest Labs 提供的原始格式的检查点。这在尝试加载社区发布的微调或量化版本模型时也非常有用。

以下示例展示了如何在 VRAM 小于 16GB 的情况下运行 Flux。

首先安装 optimum-quanto

shell
pip install optimum-quanto

然后运行以下示例

python
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    guidance_scale=3.5,
    output_type="pil",
    num_inference_steps=20,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

image.save("flux-fp8-dev.png")

FluxPipeline

[[autodoc]] FluxPipeline - all - call

FluxImg2ImgPipeline

[[autodoc]] FluxImg2ImgPipeline - all - call

FluxInpaintPipeline

[[autodoc]] FluxInpaintPipeline - all - call

FluxControlNetInpaintPipeline

[[autodoc]] FluxControlNetInpaintPipeline - all - call

FluxControlNetImg2ImgPipeline

[[autodoc]] FluxControlNetImg2ImgPipeline - all - call