Skip to content

Commit

Permalink
support metadata files as CLI arg supplier
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonywu committed Oct 19, 2024
1 parent 231b988 commit 50cab7b
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 41 deletions.
98 changes: 88 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ like [Numpy](https://numpy.org) and [Pillow](https://pypi.org/project/pillow/) f


### 💿 Installation
For users, the easiest way to install MFLUX is to use `uv tool`: If you have [installed `uv`](https://github.com/astral-sh/uv?tab=readme-ov-file#installation), simply:
For users, the easiest way to install MFLUX is to use `uv tool`: If you have [installed `uv`](https://github.com/astral-sh/uv?tab=readme-ov-file#installation), simply:

```sh
uv tool install --upgrade mflux
```
```

to get the `mflux-generate` and related command line executables. You can skip to the usage guides below.

Expand Down Expand Up @@ -80,7 +80,7 @@ pip install -U mflux
```sh
make install
```
3. To run the test suite
3. To run the test suite
```sh
make test
```
Expand Down Expand Up @@ -152,6 +152,76 @@ mflux-generate --model dev --prompt "Luxury food photograph" --steps 25 --seed 2

- **`--controlnet-save-canny`** (optional, bool, default: False): If set, saves the Canny edge detection reference image used by ControlNet.

- **`--config-from-metadata`** or **`-C`** (optional, `str`): [EXPERIMENTAL] Path to a prior file saved via `--metadata`, or a compatible handcrafted config file adhering to the expected args schema.

<details>
<summary>parameters supported by config files</summary>

#### How configs are used

- all config properties are optional and applied to the image generation if applicable
- invalid or incompatible properties will be ignored

#### Config schema

```json
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"seed": {
"type": ["integer", "null"]
},
"steps": {
"type": ["integer", "null"]
},
"guidance": {
"type": ["number", "null"]
},
"quantize": {
"type": ["null", "string"]
},
"lora_paths": {
"type": ["array", "null"],
"items": {
"type": "string"
}
},
"lora_scales": {
"type": ["array", "null"],
"items": {
"type": "number"
}
},
"prompt": {
"type": ["string", "null"]
}
}
}
```

#### Example

```json
{
"model": "dev",
"seed": 42,
"steps": 8,
"guidance": 3.0,
"quantize": 4,
"lora_paths": [
"/some/path1/to/subject.safetensors",
"/some/path2/to/style.safetensors"
],
"lora_scales": [
0.8,
0.4
],
"prompt": "award winning modern art, MOMA"
}
```
</details>

Or, with the correct python environment active, create and run a separate script like the following:

```python
Expand Down Expand Up @@ -304,7 +374,7 @@ mflux-save \
*Note that when saving a quantized version, you will need the original huggingface weights.*
It is also possible to specify [LoRA](#-lora) adapters when saving the model, e.g
It is also possible to specify [LoRA](#-lora) adapters when saving the model, e.g
```sh
mflux-save \
Expand Down Expand Up @@ -453,7 +523,7 @@ To report additional formats, examples or other any suggestions related to LoRA
### 🕹️ Controlnet
MFLUX has [Controlnet](https://huggingface.co/docs/diffusers/en/using-diffusers/controlnet) support for an even more fine-grained control
of the image generation. By providing a reference image via `--controlnet-image-path` and a strength parameter via `--controlnet-strength`, you can guide the generation toward the reference image.
of the image generation. By providing a reference image via `--controlnet-image-path` and a strength parameter via `--controlnet-strength`, you can guide the generation toward the reference image.
```sh
mflux-generate-controlnet \
Expand All @@ -474,10 +544,10 @@ mflux-generate-controlnet \
*This example combines the controlnet reference image with the LoRA [Dark Comic Flux](https://civitai.com/models/742916/dark-comic-flux)*.
⚠️ *Note: Controlnet requires an additional one-time download of ~3.58GB of weights from Huggingface. This happens automatically the first time you run the `generate-controlnet` command.
At the moment, the Controlnet used is [InstantX/FLUX.1-dev-Controlnet-Canny](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny), which was trained for the `dev` model.
It can work well with `schnell`, but performance is not guaranteed.*
At the moment, the Controlnet used is [InstantX/FLUX.1-dev-Controlnet-Canny](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny), which was trained for the `dev` model.
It can work well with `schnell`, but performance is not guaranteed.*
⚠️ *Note: The output can be highly sensitive to the controlnet strength and is very much dependent on the reference image.
⚠️ *Note: The output can be highly sensitive to the controlnet strength and is very much dependent on the reference image.
Too high settings will corrupt the image. A recommended starting point a value like 0.4 and to play around with the strength.*
Expand All @@ -492,7 +562,15 @@ with different prompts and LoRA adapters active.
- Negative prompts not supported.
- LoRA weights are only supported for the transformer part of the network.
- Some LoRA adapters does not work.
- Currently, the supported controlnet is the [canny-only version](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny).
- Currently, the supported controlnet is the [canny-only version](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny).
### Workflow Tips
- To hide the model fetching status progress bars, `export HF_HUB_DISABLE_PROGRESS_BARS=1`
- Use config files to save complex job parameters in a file instead of passing many `--args`
- Set up shell aliases for required args examples:
- shortcut for dev model: `alias mflux-dev='mflux-generate --model dev'`
- shortcut for schnell model *and* always save metadata: `alias mflux-schnell='mflux-generate --model schnell --metadata'`
### ✅ TODO
Expand All @@ -505,4 +583,4 @@ with different prompts and LoRA adapters active.
### License
This project is licensed under the [MIT License](LICENSE).
This project is licensed under the [MIT License](LICENSE).
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mflux"
version = "0.3.0"
version = "0.4.0"
description = "A MLX port of FLUX based on the Huggingface Diffusers implementation."
readme = "README.md"
keywords = ["diffusers", "flux", "mlx"]
Expand Down Expand Up @@ -36,9 +36,7 @@ classifiers = [
]

[project.optional-dependencies]
dev = [
"pytest>=8.0.0,<9.0"
]
dev = ["pytest>=8.0.0,<9.0"]

[project.urls]
homepage = "https://github.com/filipstrand/mflux"
Expand Down
4 changes: 2 additions & 2 deletions src/mflux/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
width: int = 1024,
height: int = 1024,
guidance: float = 4.0,
init_image: Path | None = None,
init_image_path: Path | None = None,
init_image_strength: float | None = None,
seed: float | None = None,
):
Expand All @@ -26,7 +26,7 @@ def __init__(
self.height = 16 * (height // 16)
self.num_inference_steps = num_inference_steps
self.guidance = guidance
self.init_image = init_image
self.init_image_path = init_image_path
self.init_image_strength = init_image_strength
self.seed = seed or int(time.time())

Expand Down
14 changes: 11 additions & 3 deletions src/mflux/config/runtime_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,17 @@ def precision(self) -> mx.Dtype:
def num_train_steps(self) -> int:
return self.model_config.num_train_steps

@property
def init_image_path(self) -> int:
return self.config.init_image_path

@property
def init_image_strength(self) -> int:
return self.config.init_image_strength

@property
def init_time_step(self) -> int:
if self.config.init_image is None:
if self.config.init_image_path is None:
# text to image, always begin at time step 0
return 0
else:
Expand All @@ -69,8 +77,8 @@ def init_latents(self) -> mx.array:
shape=[1, (self.height // 16) * (self.width // 16), 64],
key=mx.random.key(self.seed)
) # fmt: off
if self.config.init_image is not None:
user_image = ImageUtil.load_image(self.config.init_image).convert("RGB")
if self.config.init_image_path is not None:
user_image = ImageUtil.load_image(self.config.init_image_path).convert("RGB")
latents = ArrayUtil.pack_latents(
self.vae.encode(ImageUtil.to_array(ImageUtil.scale_to_dimensions(user_image, self.width, self.height))),
self.width,
Expand Down
1 change: 1 addition & 0 deletions src/mflux/controlnet/flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def generate_image(
lora_scales=self.lora_scales,
config=config,
controlnet_image_path=controlnet_image_path,
controlnet_strength=config.controlnet_strength,
)

def _set_model_weights(self, weights):
Expand Down
2 changes: 2 additions & 0 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def generate_image(
generation_time=time_steps.format_dict["elapsed"],
lora_paths=self.lora_paths,
lora_scales=self.lora_scales,
init_image_path=config.init_image_path,
init_image_strength=config.init_image_strength,
config=config,
)

Expand Down
4 changes: 2 additions & 2 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
parser = CommandLineParser(description="Generate an image based on a prompt.")
parser.add_model_arguments()
parser.add_lora_arguments()
parser.add_image_generator_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_image_to_image_arguments(required=False)
parser.add_output_arguments()
args = parser.parse_args()
Expand All @@ -35,7 +35,7 @@ def main():
height=args.height,
width=args.width,
guidance=args.guidance,
init_image=args.init_image,
init_image_path=args.init_image_path,
init_image_strength=args.init_image_strength,
seed=args.seed
),
Expand Down
2 changes: 1 addition & 1 deletion src/mflux/generate_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def main():
parser = CommandLineParser(description="Generate an image based on a prompt and a controlnet reference image.") # fmt: off
parser.add_model_arguments()
parser.add_lora_arguments()
parser.add_image_generator_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_controlnet_arguments()
parser.add_output_arguments()
args = parser.parse_args()
Expand Down
35 changes: 23 additions & 12 deletions src/mflux/post_processing/generated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def __init__(
generation_time: float,
lora_paths: list[str],
lora_scales: list[float],
controlnet_image_path: str | None = None,
controlnet_image_path: str | pathlib.Path | None = None,
controlnet_strength: float | None = None,
init_image_path: str | pathlib.Path | None = None,
init_image_strength: float | None = None,
):
self.image = image
self.model_config = model_config
Expand All @@ -36,29 +38,38 @@ def __init__(
self.generation_time = generation_time
self.lora_paths = lora_paths
self.lora_scales = lora_scales
self.controlnet_image = controlnet_image_path
self.controlnet_image_path = controlnet_image_path
self.controlnet_strength = controlnet_strength
self.init_image_path = init_image_path
self.init_image_strength = init_image_strength

def save(self, path: t.Union[str, pathlib.Path], export_json_metadata: bool = False) -> None:
from mflux import ImageUtil

ImageUtil.save_image(self.image, path, self._get_metadata(), export_json_metadata)

def _get_metadata(self) -> dict:
"""Generate metadata for reference as well as input data for
command line --config-from-metadata arg in future generations.
"""
return {
# mflux_version is used by future metadata readers
# to determine supportability of metadata-derived workflows
"mflux_version": str(GeneratedImage.get_version()),
"model": str(self.model_config.alias),
"seed": str(self.seed),
"steps": str(self.steps),
"guidance": "None" if self.model_config == ModelConfig.FLUX1_SCHNELL else str(self.guidance),
"precision": f"{self.precision}",
"quantization": "None" if self.quantization is None else f"{self.quantization} bit",
"generation_time": f"{self.generation_time:.2f} seconds",
"lora_paths": ", ".join(self.lora_paths) if self.lora_paths else "None",
"lora_scales": ", ".join([f"{scale:.2f}" for scale in self.lora_scales]) if self.lora_scales else "None",
"seed": self.seed,
"steps": self.steps,
"guidance": self.guidance if ModelConfig.FLUX1_DEV else None, # only the dev model supports guidance
"precision": str(self.precision),
"quantize": self.quantization,
"generation_time_seconds": round(self.generation_time, 2),
"lora_paths": [str(p) for p in self.lora_paths] if self.lora_paths else None,
"lora_scales": [round(scale, 2) for scale in self.lora_scales] if self.lora_scales else None,
"init_image_path": str(self.init_image_path) if self.init_image_path else None,
"init_image_strength": self.init_image_strength if self.init_image_path else None,
"controlnet_image_path": str(self.controlnet_image_path) if self.controlnet_image_path else None,
"controlnet_strength": round(self.controlnet_strength, 2) if self.controlnet_strength else None,
"prompt": self.prompt,
"controlnet_image": "None" if self.controlnet_image is None else self.controlnet_image,
"controlnet_strength": "None" if self.controlnet_strength is None else f"{self.controlnet_strength:.2f}",
}

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions src/mflux/post_processing/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def to_image(
lora_scales: list[float],
config: RuntimeConfig,
controlnet_image_path: str | None = None,
init_image_path: str | None = None,
init_image_strength: float | None = None,
) -> GeneratedImage:
normalized = ImageUtil._denormalize(decoded_latents)
normalized_numpy = ImageUtil._to_numpy(normalized)
Expand All @@ -44,6 +46,8 @@ def to_image(
generation_time=generation_time,
lora_paths=lora_paths,
lora_scales=lora_scales,
init_image_path=init_image_path,
init_image_strength=init_image_strength,
controlnet_image_path=controlnet_image_path,
controlnet_strength=config.controlnet_strength if isinstance(config.config, ConfigControlnet) else None,
)
Expand Down
Loading

0 comments on commit 50cab7b

Please sign in to comment.