Skip to content

Latest commit

 

History

History

flux

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

FLUX

FLUX implementation in MLX. The implementation is ported directly from https://github.com/black-forest-labs/flux and the model weights are downloaded directly from the Hugging Face Hub.

The goal of this example is to be clean, educational and to allow for experimentation with finetuning FLUX models as well as adding extra functionality such as in-/outpainting, guidance with custom losses etc.

MLX image
Image generated using FLUX-dev in MLX and the prompt 'An image in the style of tron emanating futuristic technology with the word "MLX" in the center with capital red letters.'

Installation

The dependencies are minimal, namely:

  • huggingface-hub to download the checkpoints.
  • regex for the tokenization
  • tqdm, PIL, and numpy for the scripts
  • sentencepiece for the T5 tokenizer
  • datasets for using an HF dataset directly

You can install all of the above with the requirements.txt as follows:

pip install -r requirements.txt

Usage

You can use the following command to generate an image, using --output to specify the storage location of the image, defaulting to out.png.

python txt2image.py --model schnell \
    --n-images 1 \
    --image-size 256x512 \
    --verbose \
    'A photo of an astronaut riding a horse on Mars.'

For more parameters, please use the --help command to view.

python txt2image.py --help

Inference

Inference in this example is similar to the stable diffusion example. The classes to get you started are FluxPipeline from the flux module.

import mlx.core as mx
from flux import FluxPipeline

# This will download all the weights from HF hub
flux = FluxPipeline("flux-schnell")

# Make a generator that returns the latent variables from the reverse diffusion
# process
latent_generator = flux.generate_latents(
    "A photo of an astronaut riding a horse on Mars",
    num_steps=4,
    latent_size=(32, 64),  # 256x512 image
)

# The first return value of the generator contains the conditioning and the
# random noise at the beginning of the diffusion process.
conditioning = next(latent_generator)
(
    x_T,                # The initial noise
    x_positions,        # The integer positions used for image positional encoding
    t5_conditioning,    # The T5 features from the text prompt
    t5_positions,       # Integer positions for text (normally all 0s)
    clip_conditioning,  # The clip text features from the text prompt
) = conditioning

# Returning the conditioning as the first output from the generator allows us
# to unload T5 and clip before running the diffusion transformer.
mx.eval(conditioning)

# Evaluate each diffusion step
for x_t in latent_generator:
    mx.eval(x_t)

# Note that we need to pass the latent size because it is collapsed and
# patchified in x_t and we need to unwrap it.
img = flux.decode(x_t, latent_size=(32, 64))

The above are essentially the implementation of the txt2image.py script except for some additional logic to quantize and/or load trained adapters. One can use the script as follows:

python txt2image.py \
    --n-images 4 \
    --n-rows 2 \
    --image-size 256x512 \
    'A photo of an astronaut riding a horse on Mars.'

Experimental Options

FLUX pads the prompt to a specific size of 512 tokens for the dev model and 256 for the schnell model. Not applying padding results in faster generation but it is not clear how it may affect the generated images. To enable that option in this example pass --no-t5-padding to the txt2image.py script or instantiate the pipeline with FluxPipeline("flux-schnell", t5_padding=False).

Finetuning

The dreambooth.py script supports LoRA finetuning of FLUX-dev (and schnell but ymmv) on a provided image dataset. The dataset folder must have an train.jsonl file with the following format:

{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...

The training script by default trains for 600 iterations with a batch size of 1, gradient accumulation of 4 and LoRA rank of 8. Run python dreambooth.py --help for the list of hyperparameters you can tune.

Note

FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and should reduce this number significantly.

Training Example

This is a step-by-step finetuning example. We will be using the data from https://github.com/google/dreambooth. In particular, we will use dog6 which is a popular example for showcasing dreambooth 1.

The training images are the following 5 images 2:

dog6

We start by making the following train.jsonl file and placing it in the same folder as the images.

{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}

Subsequently we finetune FLUX using the following command:

python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 8 \
    path/to/dreambooth/dataset/dog6

Or you can directly use the pre-processed Hugging Face dataset mlx-community/dreambooth-dog6 for fine-tuning.

python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 8 \
    mlx-community/dreambooth-dog6

The training requires approximately 50GB of RAM and on an M2 Ultra it takes a bit more than 1 hour.

Using the Adapter

The adapters are saved in mlx_output and can be used directly by the txt2image.py script. For instance,

python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
    --adapter mlx_output/final_adapters.safetensors \
    --fuse-adapter \
    --no-t5-padding \
    'A photo of an sks dog lying on the sand at a beach in Greece'

generates an image that looks like the following,

dog image

and of course we can pass --image-size 512x1024 to get larger images with different aspect ratios,

wide dog image

The arguments that are relevant to the adapters are of course --adapter and --fuse-adapter. The first defines the path to an adapter to apply to the model and the second fuses the adapter back into the model to get a bit more speed during generation.

Footnotes

  1. Refer to the arXiv paper for more details.

  2. The images are from unsplash by https://unsplash.com/@alvannee .