From 351cc8bbf127c6224cb6f168c5dc418375b3af19 Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Sun, 20 Oct 2024 09:18:49 -0700 Subject: [PATCH 1/4] support metadata files as CLI arg supplier --- README.md | 98 ++++++- pyproject.toml | 7 +- src/mflux/config/runtime_config.py | 8 + src/mflux/flux/flux.py | 2 + src/mflux/generate.py | 4 +- src/mflux/generate_controlnet.py | 2 +- src/mflux/post_processing/generated_image.py | 35 ++- src/mflux/post_processing/image_util.py | 4 + src/mflux/save.py | 2 +- src/mflux/ui/cli/parsers.py | 94 ++++++- tests/test_cli_argparser.py | 258 +++++++++++++++++++ 11 files changed, 476 insertions(+), 38 deletions(-) create mode 100644 tests/test_cli_argparser.py diff --git a/README.md b/README.md index e835a10..13f5b72 100644 --- a/README.md +++ b/README.md @@ -45,11 +45,11 @@ like [Numpy](https://numpy.org) and [Pillow](https://pypi.org/project/pillow/) f ### đŸ’ŋ Installation -For users, the easiest way to install MFLUX is to use `uv tool`: If you have [installed `uv`](https://github.com/astral-sh/uv?tab=readme-ov-file#installation), simply: +For users, the easiest way to install MFLUX is to use `uv tool`: If you have [installed `uv`](https://github.com/astral-sh/uv?tab=readme-ov-file#installation), simply: ```sh uv tool install --upgrade mflux -``` +``` to get the `mflux-generate` and related command line executables. You can skip to the usage guides below. @@ -80,7 +80,7 @@ pip install -U mflux ```sh make install ``` -3. To run the test suite +3. To run the test suite ```sh make test ``` @@ -152,6 +152,76 @@ mflux-generate --model dev --prompt "Luxury food photograph" --steps 25 --seed 2 - **`--controlnet-save-canny`** (optional, bool, default: False): If set, saves the Canny edge detection reference image used by ControlNet. +- **`--config-from-metadata`** or **`-C`** (optional, `str`): [EXPERIMENTAL] Path to a prior file saved via `--metadata`, or a compatible handcrafted config file adhering to the expected args schema. + +
+parameters supported by config files + +#### How configs are used + +- all config properties are optional and applied to the image generation if applicable +- invalid or incompatible properties will be ignored + +#### Config schema + +```json +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "seed": { + "type": ["integer", "null"] + }, + "steps": { + "type": ["integer", "null"] + }, + "guidance": { + "type": ["number", "null"] + }, + "quantize": { + "type": ["null", "string"] + }, + "lora_paths": { + "type": ["array", "null"], + "items": { + "type": "string" + } + }, + "lora_scales": { + "type": ["array", "null"], + "items": { + "type": "number" + } + }, + "prompt": { + "type": ["string", "null"] + } + } +} +``` + +#### Example + +```json +{ + "model": "dev", + "seed": 42, + "steps": 8, + "guidance": 3.0, + "quantize": 4, + "lora_paths": [ + "/some/path1/to/subject.safetensors", + "/some/path2/to/style.safetensors" + ], + "lora_scales": [ + 0.8, + 0.4 + ], + "prompt": "award winning modern art, MOMA" +} +``` +
+ Or, with the correct python environment active, create and run a separate script like the following: ```python @@ -304,7 +374,7 @@ mflux-save \ *Note that when saving a quantized version, you will need the original huggingface weights.* -It is also possible to specify [LoRA](#-lora) adapters when saving the model, e.g +It is also possible to specify [LoRA](#-lora) adapters when saving the model, e.g ```sh mflux-save \ @@ -453,7 +523,7 @@ To report additional formats, examples or other any suggestions related to LoRA ### 🕹ī¸ Controlnet MFLUX has [Controlnet](https://huggingface.co/docs/diffusers/en/using-diffusers/controlnet) support for an even more fine-grained control -of the image generation. By providing a reference image via `--controlnet-image-path` and a strength parameter via `--controlnet-strength`, you can guide the generation toward the reference image. +of the image generation. By providing a reference image via `--controlnet-image-path` and a strength parameter via `--controlnet-strength`, you can guide the generation toward the reference image. ```sh mflux-generate-controlnet \ @@ -474,10 +544,10 @@ mflux-generate-controlnet \ *This example combines the controlnet reference image with the LoRA [Dark Comic Flux](https://civitai.com/models/742916/dark-comic-flux)*. ⚠ī¸ *Note: Controlnet requires an additional one-time download of ~3.58GB of weights from Huggingface. This happens automatically the first time you run the `generate-controlnet` command. -At the moment, the Controlnet used is [InstantX/FLUX.1-dev-Controlnet-Canny](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny), which was trained for the `dev` model. -It can work well with `schnell`, but performance is not guaranteed.* +At the moment, the Controlnet used is [InstantX/FLUX.1-dev-Controlnet-Canny](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny), which was trained for the `dev` model. +It can work well with `schnell`, but performance is not guaranteed.* -⚠ī¸ *Note: The output can be highly sensitive to the controlnet strength and is very much dependent on the reference image. +⚠ī¸ *Note: The output can be highly sensitive to the controlnet strength and is very much dependent on the reference image. Too high settings will corrupt the image. A recommended starting point a value like 0.4 and to play around with the strength.* @@ -492,7 +562,15 @@ with different prompts and LoRA adapters active. - Negative prompts not supported. - LoRA weights are only supported for the transformer part of the network. - Some LoRA adapters does not work. -- Currently, the supported controlnet is the [canny-only version](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny). +- Currently, the supported controlnet is the [canny-only version](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny). + +### Workflow Tips + +- To hide the model fetching status progress bars, `export HF_HUB_DISABLE_PROGRESS_BARS=1` +- Use config files to save complex job parameters in a file instead of passing many `--args` +- Set up shell aliases for required args examples: + - shortcut for dev model: `alias mflux-dev='mflux-generate --model dev'` + - shortcut for schnell model *and* always save metadata: `alias mflux-schnell='mflux-generate --model schnell --metadata'` ### ✅ TODO @@ -505,4 +583,4 @@ with different prompts and LoRA adapters active. ### License -This project is licensed under the [MIT License](LICENSE). \ No newline at end of file +This project is licensed under the [MIT License](LICENSE). diff --git a/pyproject.toml b/pyproject.toml index be18f64..818c128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mflux" -version = "0.3.0" +version = "0.4.0" description = "A MLX port of FLUX based on the Huggingface Diffusers implementation." readme = "README.md" keywords = ["diffusers", "flux", "mlx"] @@ -37,7 +37,8 @@ classifiers = [ [project.optional-dependencies] dev = [ - "pytest>=8.0.0,<9.0" + "pytest>=8.3.0,<9.0", + "pytest-timer>=1.0,<2.0", ] [project.urls] @@ -102,7 +103,7 @@ docstring-code-line-length = "dynamic" [tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*.py" -addopts = "-v" +addopts = "-v --exitfirst --failed-first --showlocals --tb=long --full-trace" # https://docs.astral.sh/ruff/settings/#lintisort [tool.ruff.lint.isort] diff --git a/src/mflux/config/runtime_config.py b/src/mflux/config/runtime_config.py index 018f0c5..59d2d58 100644 --- a/src/mflux/config/runtime_config.py +++ b/src/mflux/config/runtime_config.py @@ -39,6 +39,14 @@ def precision(self) -> mx.Dtype: def num_train_steps(self) -> int: return self.model_config.num_train_steps + @property + def init_image_path(self) -> int: + return self.config.init_image_path + + @property + def init_image_strength(self) -> int: + return self.config.init_image_strength + @property def init_time_step(self) -> int: if self.config.init_image_path is None: diff --git a/src/mflux/flux/flux.py b/src/mflux/flux/flux.py index b714d91..0105c21 100644 --- a/src/mflux/flux/flux.py +++ b/src/mflux/flux/flux.py @@ -139,6 +139,8 @@ def generate_image( generation_time=time_steps.format_dict["elapsed"], lora_paths=self.lora_paths, lora_scales=self.lora_scales, + init_image_path=config.init_image_path, + init_image_strength=config.init_image_strength, config=config, ) diff --git a/src/mflux/generate.py b/src/mflux/generate.py index b7c5c17..3334d1d 100644 --- a/src/mflux/generate.py +++ b/src/mflux/generate.py @@ -10,7 +10,7 @@ def main(): parser = CommandLineParser(description="Generate an image based on a prompt.") parser.add_model_arguments() parser.add_lora_arguments() - parser.add_image_generator_arguments() + parser.add_image_generator_arguments(supports_metadata_config=True) parser.add_image_to_image_arguments(required=False) parser.add_output_arguments() args = parser.parse_args() @@ -36,7 +36,7 @@ def main(): width=args.width, guidance=args.guidance, init_image_path=args.init_image_path, - init_image_strength=args.init_image_strength + init_image_strength=args.init_image_strength, ), ) diff --git a/src/mflux/generate_controlnet.py b/src/mflux/generate_controlnet.py index 6495fcf..122ed04 100644 --- a/src/mflux/generate_controlnet.py +++ b/src/mflux/generate_controlnet.py @@ -9,7 +9,7 @@ def main(): parser = CommandLineParser(description="Generate an image based on a prompt and a controlnet reference image.") # fmt: off parser.add_model_arguments() parser.add_lora_arguments() - parser.add_image_generator_arguments() + parser.add_image_generator_arguments(supports_metadata_config=True) parser.add_controlnet_arguments() parser.add_output_arguments() args = parser.parse_args() diff --git a/src/mflux/post_processing/generated_image.py b/src/mflux/post_processing/generated_image.py index 5d85c4f..ca8c21e 100644 --- a/src/mflux/post_processing/generated_image.py +++ b/src/mflux/post_processing/generated_image.py @@ -22,8 +22,10 @@ def __init__( generation_time: float, lora_paths: list[str], lora_scales: list[float], - controlnet_image_path: str | None = None, + controlnet_image_path: str | pathlib.Path | None = None, controlnet_strength: float | None = None, + init_image_path: str | pathlib.Path | None = None, + init_image_strength: float | None = None, ): self.image = image self.model_config = model_config @@ -36,8 +38,10 @@ def __init__( self.generation_time = generation_time self.lora_paths = lora_paths self.lora_scales = lora_scales - self.controlnet_image = controlnet_image_path + self.controlnet_image_path = controlnet_image_path self.controlnet_strength = controlnet_strength + self.init_image_path = init_image_path + self.init_image_strength = init_image_strength def save(self, path: t.Union[str, pathlib.Path], export_json_metadata: bool = False) -> None: from mflux import ImageUtil @@ -45,20 +49,27 @@ def save(self, path: t.Union[str, pathlib.Path], export_json_metadata: bool = Fa ImageUtil.save_image(self.image, path, self._get_metadata(), export_json_metadata) def _get_metadata(self) -> dict: + """Generate metadata for reference as well as input data for + command line --config-from-metadata arg in future generations. + """ return { + # mflux_version is used by future metadata readers + # to determine supportability of metadata-derived workflows "mflux_version": str(GeneratedImage.get_version()), "model": str(self.model_config.alias), - "seed": str(self.seed), - "steps": str(self.steps), - "guidance": "None" if self.model_config == ModelConfig.FLUX1_SCHNELL else str(self.guidance), - "precision": f"{self.precision}", - "quantization": "None" if self.quantization is None else f"{self.quantization} bit", - "generation_time": f"{self.generation_time:.2f} seconds", - "lora_paths": ", ".join(self.lora_paths) if self.lora_paths else "None", - "lora_scales": ", ".join([f"{scale:.2f}" for scale in self.lora_scales]) if self.lora_scales else "None", + "seed": self.seed, + "steps": self.steps, + "guidance": self.guidance if ModelConfig.FLUX1_DEV else None, # only the dev model supports guidance + "precision": str(self.precision), + "quantize": self.quantization, + "generation_time_seconds": round(self.generation_time, 2), + "lora_paths": [str(p) for p in self.lora_paths] if self.lora_paths else None, + "lora_scales": [round(scale, 2) for scale in self.lora_scales] if self.lora_scales else None, + "init_image_path": str(self.init_image_path) if self.init_image_path else None, + "init_image_strength": self.init_image_strength if self.init_image_path else None, + "controlnet_image_path": str(self.controlnet_image_path) if self.controlnet_image_path else None, + "controlnet_strength": round(self.controlnet_strength, 2) if self.controlnet_strength else None, "prompt": self.prompt, - "controlnet_image": "None" if self.controlnet_image is None else self.controlnet_image, - "controlnet_strength": "None" if self.controlnet_strength is None else f"{self.controlnet_strength:.2f}", } @staticmethod diff --git a/src/mflux/post_processing/image_util.py b/src/mflux/post_processing/image_util.py index f8c6af4..832d55f 100644 --- a/src/mflux/post_processing/image_util.py +++ b/src/mflux/post_processing/image_util.py @@ -28,6 +28,8 @@ def to_image( lora_scales: list[float], config: RuntimeConfig, controlnet_image_path: str | None = None, + init_image_path: str | None = None, + init_image_strength: float | None = None, ) -> GeneratedImage: normalized = ImageUtil._denormalize(decoded_latents) normalized_numpy = ImageUtil._to_numpy(normalized) @@ -44,6 +46,8 @@ def to_image( generation_time=generation_time, lora_paths=lora_paths, lora_scales=lora_scales, + init_image_path=init_image_path, + init_image_strength=init_image_strength, controlnet_image_path=controlnet_image_path, controlnet_strength=config.controlnet_strength if isinstance(config.config, ConfigControlnet) else None, ) diff --git a/src/mflux/save.py b/src/mflux/save.py index f26db41..db473ac 100644 --- a/src/mflux/save.py +++ b/src/mflux/save.py @@ -4,7 +4,7 @@ def main(): parser = CommandLineParser(description="Save a quantized version of Flux.1 to disk.") # fmt: off - parser.add_model_arguments() + parser.add_model_arguments(path_type="save") parser.add_lora_arguments() args = parser.parse_args() diff --git a/src/mflux/ui/cli/parsers.py b/src/mflux/ui/cli/parsers.py index 2521a57..3dad873 100644 --- a/src/mflux/ui/cli/parsers.py +++ b/src/mflux/ui/cli/parsers.py @@ -1,4 +1,6 @@ import argparse +import json +import typing as t from pathlib import Path from mflux.ui import defaults as ui_defaults @@ -7,33 +9,51 @@ # fmt: off class CommandLineParser(argparse.ArgumentParser): - def add_model_arguments(self) -> None: + def __init__(self, *pargs, **kwargs): + super().__init__(*pargs, **kwargs) + self.supports_metadata_config = False + self.supports_image_generation = False + self.supports_controlnet = False + self.supports_image_to_image = False + self.supports_lora = False + + def add_model_arguments(self, path_type: t.Literal["load", "save"] = "load") -> None: + self.add_argument("--model", "-m", type=str, required=True, choices=ui_defaults.MODEL_CHOICES, help=f"The model to use ({' or '.join(ui_defaults.MODEL_CHOICES)}).") - self.add_argument("--path", type=str, default=None, help="Local path for loading a model from disk") + + if path_type == "load": + self.add_argument("--path", type=str, default=None, help="Local path for loading a model from disk") + else: + self.add_argument("--path", type=str, required=True, help="Local path for saving a model to disk.") self.add_argument("--quantize", "-q", type=int, choices=ui_defaults.QUANTIZE_CHOICES, default=None, help=f"Quantize the model ({' or '.join(map(str, ui_defaults.QUANTIZE_CHOICES))}, Default is None)") def add_lora_arguments(self) -> None: + self.supports_lora = True self.add_argument("--lora-paths", type=str, nargs="*", default=None, help="Local safetensors for applying LORA from disk") self.add_argument("--lora-scales", type=float, nargs="*", default=None, help="Scaling factor to adjust the impact of LoRA weights on the model. A value of 1.0 applies the LoRA weights as they are.") def _add_image_generator_common_arguments(self) -> None: + self.supports_image_generation = True self.add_argument("--height", type=int, default=ui_defaults.HEIGHT, help=f"Image height (Default is {ui_defaults.HEIGHT})") self.add_argument("--width", type=int, default=ui_defaults.WIDTH, help=f"Image width (Default is {ui_defaults.HEIGHT})") self.add_argument("--steps", type=int, default=None, help="Inference Steps") self.add_argument("--guidance", type=float, default=ui_defaults.GUIDANCE_SCALE, help=f"Guidance Scale (Default is {ui_defaults.GUIDANCE_SCALE})") - def add_image_generator_arguments(self) -> None: - self.add_argument("--prompt", type=str, required=True, help="The textual description of the image to generate.") + def add_image_generator_arguments(self, supports_metadata_config=False) -> None: + self.add_argument("--prompt", type=str, required=(not supports_metadata_config), default=None, help="The textual description of the image to generate.") self.add_argument("--seed", type=int, default=None, help="Entropy Seed (Default is time-based random-seed)") self._add_image_generator_common_arguments() + if supports_metadata_config: + self.add_metadata_config() def add_image_to_image_arguments(self, required=False) -> None: - self.add_argument("--init-image-path", type=Path, required=required, help="Local path to init image") + self.supports_image_to_image = True + self.add_argument("--init-image-path", type=Path, required=required, default=None, help="Local path to init image") self.add_argument("--init-image-strength", type=float, required=False, default=ui_defaults.INIT_IMAGE_STRENGTH, help=f"Controls how strongly the init image influences the output image. A value of 0.0 means no influence. (Default is {ui_defaults.INIT_IMAGE_STRENGTH})") def add_batch_image_generator_arguments(self) -> None: - self.add_argument("--prompts-file", type=Path, required=True, help="Local path for a file that holds a batch of prompts.") - self.add_argument("--global-seed", type=int, default=None, help="Entropy Seed (used for all prompts in the batch)") + self.add_argument("--prompts-file", type=Path, required=True, default=argparse.SUPPRESS, help="Local path for a file that holds a batch of prompts.") + self.add_argument("--global-seed", type=int, default=argparse.SUPPRESS, help="Entropy Seed (used for all prompts in the batch)") self._add_image_generator_common_arguments() def add_output_arguments(self) -> None: @@ -42,14 +62,70 @@ def add_output_arguments(self) -> None: self.add_argument('--stepwise-image-output-dir', type=str, default=None, help='[EXPERIMENTAL] Output dir to write step-wise images and their final composite image to. This feature may change in future versions.') def add_controlnet_arguments(self) -> None: - self.add_argument("--controlnet-image-path", type=str, required=True, help="Local path of the image to use as input for controlnet.") + self.supports_controlnet = True + self.add_argument("--controlnet-image-path", type=str, required=False, help="Local path of the image to use as input for controlnet.") self.add_argument("--controlnet-strength", type=float, default=ui_defaults.CONTROLNET_STRENGTH, help=f"Controls how strongly the control image influences the output image. A value of 0.0 means no influence. (Default is {ui_defaults.CONTROLNET_STRENGTH})") self.add_argument("--controlnet-save-canny", action="store_true", help="If set, save the Canny edge detection reference input image.") + def add_metadata_config(self) -> None: + self.supports_metadata_config = True + self.add_argument("--config-from-metadata", "-C", type=Path, required=False, default=argparse.SUPPRESS, help="Re-use the parameters from prior metadata. Params from metadata are secondary to other args you provide.") + def parse_args(self, **kwargs) -> argparse.Namespace: namespace = super().parse_args() if hasattr(namespace, "path") and namespace.path is not None and namespace.model is None: namespace.error("--model must be specified when using --path") - if hasattr(namespace, "steps") and namespace.steps is None: + + if getattr(namespace, "config_from_metadata", False): + prior_gen_metadata = json.load(namespace.config_from_metadata.open("rt")) + + if namespace.prompt is None: + namespace.prompt = prior_gen_metadata.get("prompt", None) + + # all configs from the metadata config defers to any explicitly defined args + guidance_default = self.get_default("guidance") + guidance_from_metadata = prior_gen_metadata.get("guidance") + if namespace.guidance == guidance_default and guidance_from_metadata: + namespace.guidance = guidance_from_metadata + if namespace.quantize is None: + namespace.quantize = prior_gen_metadata.get("quantize", None) + if namespace.seed is None: + namespace.seed = prior_gen_metadata.get("seed", None) + if namespace.steps is None: + namespace.steps = prior_gen_metadata.get("steps", None) + + if self.supports_lora: + if namespace.lora_paths is None: + namespace.lora_paths = prior_gen_metadata.get("lora_paths", None) + elif namespace.lora_paths: + # merge the loras from cli and config file + namespace.lora_paths = prior_gen_metadata.get("lora_paths", []) + namespace.lora_paths + + if namespace.lora_scales is None: + namespace.lora_scales = prior_gen_metadata.get("lora_scales", None) + elif namespace.lora_scales: + # merge the loras from cli and config file + namespace.lora_scales = prior_gen_metadata.get("lora_scales", []) + namespace.lora_scales + + if self.supports_image_to_image: + if namespace.init_image_path is None: + namespace.init_image_path = prior_gen_metadata.get("init_image_path", None) + if namespace.init_image_strength == self.get_default("init_image_strength") and (init_img_strength_from_metadata := prior_gen_metadata.get("init_image_strength", None)): + namespace.init_image_strength = init_img_strength_from_metadata + + if self.supports_controlnet: + if namespace.controlnet_image_path is None: + namespace.controlnet_image_path = prior_gen_metadata.get("controlnet_image_path", None) + if namespace.controlnet_strength == self.get_default("controlnet_strength") and (cnet_strength_from_metadata := prior_gen_metadata.get("controlnet_strength", None)): + namespace.controlnet_strength = cnet_strength_from_metadata + + + + if self.supports_image_generation and namespace.prompt is None: + # not supplied by CLI and not supplied by metadata config file + self.error("--prompt argument required or 'prompt' required in metadata config file") + + if self.supports_image_generation and namespace.steps is None: namespace.steps = ui_defaults.MODEL_INFERENCE_STEPS.get(namespace.model, None) + return namespace diff --git a/tests/test_cli_argparser.py b/tests/test_cli_argparser.py new file mode 100644 index 0000000..d7a5698 --- /dev/null +++ b/tests/test_cli_argparser.py @@ -0,0 +1,258 @@ +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from mflux.ui.cli.parsers import CommandLineParser + + +def _create_mflux_generate_parser(with_controlnet=False) -> CommandLineParser: + parser = CommandLineParser(description="Generate an image based on a prompt.") + parser.add_model_arguments() + parser.add_image_generator_arguments(supports_metadata_config=True) + parser.add_lora_arguments() + parser.add_image_to_image_arguments(required=False) + if with_controlnet: + parser.add_controlnet_arguments() + parser.add_output_arguments() + return parser + + +@pytest.fixture +def mflux_generate_parser() -> CommandLineParser: + return _create_mflux_generate_parser(with_controlnet=False) + + +@pytest.fixture +def mflux_generate_controlnet_parser() -> CommandLineParser: + return _create_mflux_generate_parser(with_controlnet=True) + + +@pytest.fixture +def mflux_save_parser() -> CommandLineParser: + parser = CommandLineParser(description="Save a quantized version of Flux.1 to disk.") # fmt: off + parser.add_model_arguments(path_type="save") + parser.add_lora_arguments() + return parser + + +@pytest.fixture +def mflux_generate_minimal_argv() -> list[str]: + return ["mflux-generate", "--model", "schnell", "--prompt", "meaning of life"] + + +@pytest.fixture +def mflux_generate_controlnet_minimal_argv() -> list[str]: + return ["mflux-generate-controlnet", "--model", "dev", "--prompt", "meaning of life, imitated"] + + +@pytest.fixture +def temp_dir(tmp_path_factory) -> Path: + # Create a temporary directory for the module + temp_dir = tmp_path_factory.mktemp("mflux_cli_argparser_tests") + return Path(temp_dir) + + +@pytest.fixture +def base_metadata_dict() -> dict: + return { + "mflux_version": "0.4.0", + "model": "schnell", + "seed": 42042, + "steps": 4, + "guidance": None, + "precision": "mlx.core.bfloat16", + "quantize": None, + "generation_time_seconds": 42.0, + "lora_paths": None, + "lora_scales": None, + "init_image": None, + "init_image_strength": None, + "controlnet_image": None, + "controlnet_strength": None, + } + + +def test_model_path_requires_model_arg(mflux_generate_parser): + # when loading a model via --path, the model name still need to be specified + with patch("sys.argv", "mflux-generate", "--path", "/some/saved/model"): + assert pytest.raises(SystemExit, mflux_generate_parser.parse_args) + + +def test_prompt_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): + metadata_file = temp_dir / "prompt.json" + file_prompt = "origin of the universe" + with metadata_file.open("wt") as m: + base_metadata_dict["prompt"] = file_prompt + json.dump(base_metadata_dict, m, indent=4) + # test metadata config accepted, use mflux_generate_minimal_argv without fixture --prompt + with patch('sys.argv', mflux_generate_minimal_argv[:-2] + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.prompt == file_prompt + # test CLI override, use mflux_generate_minimal_argv without fixture --prompt + cli_prompt = "place where monsters come from" + with patch('sys.argv', mflux_generate_minimal_argv[:-2] + ['--prompt', cli_prompt, '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.prompt == cli_prompt + + +def test_guidance_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + metadata_file = temp_dir / "guidance.json" + with metadata_file.open("wt") as m: + base_metadata_dict["guidance"] = 4.2 + json.dump(base_metadata_dict, m, indent=4) + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.guidance == pytest.approx(4.2) + # test CLI override + with patch('sys.argv', mflux_generate_minimal_argv + ['--guidance', '5.0', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.guidance == pytest.approx(5.0) + + +def test_quantize_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + metadata_file = temp_dir / "quantize.json" + with metadata_file.open("wt") as m: + base_metadata_dict["quantize"] = 4 + json.dump(base_metadata_dict, m, indent=4) + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.quantize == 4 + # test CLI override + with patch('sys.argv', mflux_generate_minimal_argv + ['--quantize', '8', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.quantize == 8 + + +def test_seed_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + metadata_file = temp_dir / "seed.json" + with metadata_file.open("wt") as m: + base_metadata_dict["seed"] = 24 + json.dump(base_metadata_dict, m, indent=4) + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.seed == 24 + # test CLI override + with patch('sys.argv', mflux_generate_minimal_argv + ['--seed', '2424', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.seed == 2424 + + +def test_steps_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + metadata_file = temp_dir / "steps.json" + with metadata_file.open("wt") as m: + base_metadata_dict["steps"] = 8 + json.dump(base_metadata_dict, m, indent=4) + + # test user default value + with patch("sys.argv", mflux_generate_minimal_argv): + args = mflux_generate_parser.parse_args() + assert args.steps == 4 + + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.steps == 8 + + # test CLI override + with patch('sys.argv', mflux_generate_minimal_argv + ['--steps', '12', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.steps == 12 + + +def test_lora_args(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + test_paths = ["/some/lora/1.safetensors", "/some/lora/2.safetensors"] + metadata_file = temp_dir / "lora_args.json" + with metadata_file.open("wt") as m: + base_metadata_dict["lora_paths"] = test_paths + base_metadata_dict["lora_scales"] = [0.3, 0.7] + json.dump(base_metadata_dict, m, indent=4) + + # test user default value + with patch("sys.argv", mflux_generate_minimal_argv): + args = mflux_generate_parser.parse_args() + assert args.lora_paths is None + assert args.lora_scales is None + + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.lora_paths == test_paths + assert args.lora_scales == [pytest.approx(0.3), pytest.approx(0.7)] + + # test CLI override + new_loras = ["--lora-paths", "/some/lora/3.safetensors", "/some/lora/4.safetensors", "--lora-scales", "0.1", "0.9"] + with patch('sys.argv', mflux_generate_minimal_argv + new_loras + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert len(args.lora_paths) == 4 + assert args.lora_paths == test_paths + new_loras[1:3] + assert len(args.lora_scales) == 4 + assert args.lora_scales == [pytest.approx(v) for v in [0.3, 0.7, 0.1, 0.9]] + + +def test_image_to_image_args(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + metadata_file = temp_dir / "image_to_image.json" + test_path = "/some/awesome/image.png" + with metadata_file.open("wt") as m: + base_metadata_dict["init_image_path"] = test_path + json.dump(base_metadata_dict, m, indent=4) + + # test user default value + with patch("sys.argv", mflux_generate_minimal_argv): + args = mflux_generate_parser.parse_args() + assert args.init_image_path is None + assert args.init_image_strength == 0.4 # default + + # test metadata config accepted + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.init_image_path == test_path + assert args.init_image_strength == 0.4 # default + + # test strength override + with patch('sys.argv', mflux_generate_minimal_argv + ['--init-image-strength', '0.7', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.init_image_path == test_path + assert args.init_image_strength == 0.7 + + # test image path override + with patch('sys.argv', mflux_generate_minimal_argv + ['--init-image-path', '/some/better/image.png', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.init_image_path == Path("/some/better/image.png") + assert args.init_image_strength == 0.4 # default + + +def test_controlnet_args(mflux_generate_controlnet_parser, mflux_generate_controlnet_minimal_argv, base_metadata_dict, temp_dir): # fmt: off + test_path = "/some/cnet/1.safetensors" + metadata_file = temp_dir / "cnet_args.json" + with metadata_file.open("wt") as m: + base_metadata_dict["controlnet_image_path"] = test_path + base_metadata_dict["controlnet_strength"] = 0.48 + json.dump(base_metadata_dict, m, indent=4) + + # test metadata config accepted + with patch('sys.argv', mflux_generate_controlnet_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_controlnet_parser.parse_args() + assert args.controlnet_image_path == test_path + assert args.controlnet_strength == pytest.approx(0.48) + + # test CLI override + override_cnet = ["--controlnet-image-path", "/some/lora/2.safetensors", "--controlnet-strength", "0.85"] + with patch('sys.argv', mflux_generate_controlnet_minimal_argv + override_cnet + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_controlnet_parser.parse_args() + assert args.controlnet_image_path == "/some/lora/2.safetensors" + assert args.controlnet_strength == pytest.approx(0.85) + + +def test_save_args(mflux_save_parser): + with patch("sys.argv", ["mflux-save", "--model", "dev"]): + # required --path not provided, exits to error + assert pytest.raises(SystemExit, mflux_save_parser.parse_args) + with patch("sys.argv", ["mflux-save", "--model", "dev", "--path", "/some/model/folder"]): + # required --path not provided, exits to error + args = mflux_save_parser.parse_args() + assert args.path == "/some/model/folder" From ffaf562a8cebb6d13f2872f2ff0e5e015e21f6a3 Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Sat, 26 Oct 2024 10:29:33 -0700 Subject: [PATCH 2/4] fix type hint in test file --- src/mflux/config/runtime_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mflux/config/runtime_config.py b/src/mflux/config/runtime_config.py index 59d2d58..490f9ba 100644 --- a/src/mflux/config/runtime_config.py +++ b/src/mflux/config/runtime_config.py @@ -44,7 +44,7 @@ def init_image_path(self) -> int: return self.config.init_image_path @property - def init_image_strength(self) -> int: + def init_image_strength(self) -> float: return self.config.init_image_strength @property From 98d7ef57517392a52b50b7b121fc79299a276f44 Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Sat, 26 Oct 2024 20:22:32 -0700 Subject: [PATCH 3/4] update: do not require model arg, add controlnet_save_canny, fix bugs, add tests --- src/mflux/generate.py | 2 +- src/mflux/generate_controlnet.py | 4 +- src/mflux/save.py | 2 +- src/mflux/ui/cli/parsers.py | 15 ++++-- tests/test_cli_argparser.py | 80 +++++++++++++++++++++++++++----- 5 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/mflux/generate.py b/src/mflux/generate.py index 3334d1d..236b6b0 100644 --- a/src/mflux/generate.py +++ b/src/mflux/generate.py @@ -8,7 +8,7 @@ def main(): # fmt: off parser = CommandLineParser(description="Generate an image based on a prompt.") - parser.add_model_arguments() + parser.add_model_arguments(require_model_arg=False) parser.add_lora_arguments() parser.add_image_generator_arguments(supports_metadata_config=True) parser.add_image_to_image_arguments(required=False) diff --git a/src/mflux/generate_controlnet.py b/src/mflux/generate_controlnet.py index 122ed04..acd7443 100644 --- a/src/mflux/generate_controlnet.py +++ b/src/mflux/generate_controlnet.py @@ -7,9 +7,9 @@ def main(): parser = CommandLineParser(description="Generate an image based on a prompt and a controlnet reference image.") # fmt: off - parser.add_model_arguments() + parser.add_model_arguments(require_model_arg=True) parser.add_lora_arguments() - parser.add_image_generator_arguments(supports_metadata_config=True) + parser.add_image_generator_arguments(supports_metadata_config=False) parser.add_controlnet_arguments() parser.add_output_arguments() args = parser.parse_args() diff --git a/src/mflux/save.py b/src/mflux/save.py index db473ac..21c0b57 100644 --- a/src/mflux/save.py +++ b/src/mflux/save.py @@ -4,7 +4,7 @@ def main(): parser = CommandLineParser(description="Save a quantized version of Flux.1 to disk.") # fmt: off - parser.add_model_arguments(path_type="save") + parser.add_model_arguments(path_type="save", require_model_arg=True) parser.add_lora_arguments() args = parser.parse_args() diff --git a/src/mflux/ui/cli/parsers.py b/src/mflux/ui/cli/parsers.py index 3dad873..c6da341 100644 --- a/src/mflux/ui/cli/parsers.py +++ b/src/mflux/ui/cli/parsers.py @@ -17,9 +17,9 @@ def __init__(self, *pargs, **kwargs): self.supports_image_to_image = False self.supports_lora = False - def add_model_arguments(self, path_type: t.Literal["load", "save"] = "load") -> None: + def add_model_arguments(self, path_type: t.Literal["load", "save"] = "load", require_model_arg: bool = True) -> None: - self.add_argument("--model", "-m", type=str, required=True, choices=ui_defaults.MODEL_CHOICES, help=f"The model to use ({' or '.join(ui_defaults.MODEL_CHOICES)}).") + self.add_argument("--model", "-m", type=str, required=require_model_arg, choices=ui_defaults.MODEL_CHOICES, help=f"The model to use ({' or '.join(ui_defaults.MODEL_CHOICES)}).") if path_type == "load": self.add_argument("--path", type=str, default=None, help="Local path for loading a model from disk") @@ -74,11 +74,15 @@ def add_metadata_config(self) -> None: def parse_args(self, **kwargs) -> argparse.Namespace: namespace = super().parse_args() if hasattr(namespace, "path") and namespace.path is not None and namespace.model is None: - namespace.error("--model must be specified when using --path") + self.error("--model must be specified when using --path") if getattr(namespace, "config_from_metadata", False): prior_gen_metadata = json.load(namespace.config_from_metadata.open("rt")) + if namespace.model is None: + # when not provided by CLI flag, find it in the config file + namespace.model = prior_gen_metadata.get("model", None) + if namespace.prompt is None: namespace.prompt = prior_gen_metadata.get("prompt", None) @@ -118,8 +122,11 @@ def parse_args(self, **kwargs) -> argparse.Namespace: namespace.controlnet_image_path = prior_gen_metadata.get("controlnet_image_path", None) if namespace.controlnet_strength == self.get_default("controlnet_strength") and (cnet_strength_from_metadata := prior_gen_metadata.get("controlnet_strength", None)): namespace.controlnet_strength = cnet_strength_from_metadata + if namespace.controlnet_save_canny == self.get_default("controlnet_save_canny") and (cnet_canny_from_metadata := prior_gen_metadata.get("controlnet_save_canny", None)): + namespace.controlnet_save_canny = cnet_canny_from_metadata - + if namespace.model is None: + self.error("--model / -m must be provided, or 'model' must be specified in the config file.") if self.supports_image_generation and namespace.prompt is None: # not supplied by CLI and not supplied by metadata config file diff --git a/tests/test_cli_argparser.py b/tests/test_cli_argparser.py index d7a5698..0a043e1 100644 --- a/tests/test_cli_argparser.py +++ b/tests/test_cli_argparser.py @@ -9,7 +9,7 @@ def _create_mflux_generate_parser(with_controlnet=False) -> CommandLineParser: parser = CommandLineParser(description="Generate an image based on a prompt.") - parser.add_model_arguments() + parser.add_model_arguments(require_model_arg=False) parser.add_image_generator_arguments(supports_metadata_config=True) parser.add_lora_arguments() parser.add_image_to_image_arguments(required=False) @@ -32,19 +32,19 @@ def mflux_generate_controlnet_parser() -> CommandLineParser: @pytest.fixture def mflux_save_parser() -> CommandLineParser: parser = CommandLineParser(description="Save a quantized version of Flux.1 to disk.") # fmt: off - parser.add_model_arguments(path_type="save") + parser.add_model_arguments(path_type="save", require_model_arg=True) parser.add_lora_arguments() return parser @pytest.fixture def mflux_generate_minimal_argv() -> list[str]: - return ["mflux-generate", "--model", "schnell", "--prompt", "meaning of life"] + return ["mflux-generate", "--prompt", "meaning of life"] @pytest.fixture def mflux_generate_controlnet_minimal_argv() -> list[str]: - return ["mflux-generate-controlnet", "--model", "dev", "--prompt", "meaning of life, imitated"] + return ["mflux-generate-controlnet", "--prompt", "meaning of life, imitated"] @pytest.fixture @@ -58,9 +58,9 @@ def temp_dir(tmp_path_factory) -> Path: def base_metadata_dict() -> dict: return { "mflux_version": "0.4.0", - "model": "schnell", + "model": "dev", "seed": 42042, - "steps": 4, + "steps": 14, "guidance": None, "precision": "mlx.core.bfloat16", "quantize": None, @@ -71,6 +71,7 @@ def base_metadata_dict() -> dict: "init_image_strength": None, "controlnet_image": None, "controlnet_strength": None, + "controlnet_save_canny": False, } @@ -80,6 +81,39 @@ def test_model_path_requires_model_arg(mflux_generate_parser): assert pytest.raises(SystemExit, mflux_generate_parser.parse_args) +def test_model_arg_not_in_file(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): + metadata_file = temp_dir / "model.json" + with metadata_file.open("wt") as m: + del base_metadata_dict["model"] + json.dump(base_metadata_dict, m, indent=4) + # test model arg not provided in either flag or file + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + pytest.raises(SystemExit, mflux_generate_parser.parse_args) + # test value read from flag + with patch('sys.argv', mflux_generate_minimal_argv + ['--model', 'dev', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.model == "dev" + # test value read from flag + with patch('sys.argv', mflux_generate_minimal_argv + ['--model', 'schnell', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.model == "schnell" + + +def test_model_arg_in_file(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): + metadata_file = temp_dir / "model.json" + with metadata_file.open("wt") as m: + base_metadata_dict["model"] = "dev" + json.dump(base_metadata_dict, m, indent=4) + # test value read from file + with patch('sys.argv', mflux_generate_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.model == "dev" + # test value read from flag, overrides value from file + with patch('sys.argv', mflux_generate_minimal_argv + ['--model', 'schnell', '--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_parser.parse_args() + assert args.model == "schnell" + + def test_prompt_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_metadata_dict, temp_dir): metadata_file = temp_dir / "prompt.json" file_prompt = "origin of the universe" @@ -148,8 +182,13 @@ def test_steps_arg(mflux_generate_parser, mflux_generate_minimal_argv, base_meta base_metadata_dict["steps"] = 8 json.dump(base_metadata_dict, m, indent=4) - # test user default value - with patch("sys.argv", mflux_generate_minimal_argv): + # test user default value for dev + with patch("sys.argv", mflux_generate_minimal_argv + ["--model", "dev"]): + args = mflux_generate_parser.parse_args() + assert args.steps == 14 + + # test user default value for schnell + with patch("sys.argv", mflux_generate_minimal_argv + ["--model", "schnell"]): args = mflux_generate_parser.parse_args() assert args.steps == 4 @@ -173,7 +212,7 @@ def test_lora_args(mflux_generate_parser, mflux_generate_minimal_argv, base_meta json.dump(base_metadata_dict, m, indent=4) # test user default value - with patch("sys.argv", mflux_generate_minimal_argv): + with patch("sys.argv", mflux_generate_minimal_argv + ["-m", "schnell"]): args = mflux_generate_parser.parse_args() assert args.lora_paths is None assert args.lora_scales is None @@ -184,7 +223,7 @@ def test_lora_args(mflux_generate_parser, mflux_generate_minimal_argv, base_meta assert args.lora_paths == test_paths assert args.lora_scales == [pytest.approx(0.3), pytest.approx(0.7)] - # test CLI override + # test CLI override that merges CLI loras and config file loras new_loras = ["--lora-paths", "/some/lora/3.safetensors", "/some/lora/4.safetensors", "--lora-scales", "0.1", "0.9"] with patch('sys.argv', mflux_generate_minimal_argv + new_loras + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off args = mflux_generate_parser.parse_args() @@ -202,7 +241,7 @@ def test_image_to_image_args(mflux_generate_parser, mflux_generate_minimal_argv, json.dump(base_metadata_dict, m, indent=4) # test user default value - with patch("sys.argv", mflux_generate_minimal_argv): + with patch("sys.argv", mflux_generate_minimal_argv + ["-m", "dev"]): args = mflux_generate_parser.parse_args() assert args.init_image_path is None assert args.init_image_strength == 0.4 # default @@ -239,13 +278,30 @@ def test_controlnet_args(mflux_generate_controlnet_parser, mflux_generate_contro args = mflux_generate_controlnet_parser.parse_args() assert args.controlnet_image_path == test_path assert args.controlnet_strength == pytest.approx(0.48) + assert args.controlnet_save_canny is False # test CLI override - override_cnet = ["--controlnet-image-path", "/some/lora/2.safetensors", "--controlnet-strength", "0.85"] + override_cnet = [ + "--controlnet-image-path", + "/some/lora/2.safetensors", + "--controlnet-strength", + "0.85", + "--controlnet-save-canny", + ] with patch('sys.argv', mflux_generate_controlnet_minimal_argv + override_cnet + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off args = mflux_generate_controlnet_parser.parse_args() assert args.controlnet_image_path == "/some/lora/2.safetensors" assert args.controlnet_strength == pytest.approx(0.85) + assert args.controlnet_save_canny is True + + # test controlnet_save_canny is False when not specified + with metadata_file.open("wt") as m: + del base_metadata_dict["controlnet_save_canny"] + json.dump(base_metadata_dict, m, indent=4) + + with patch('sys.argv', mflux_generate_controlnet_minimal_argv + ['--config-from-metadata', metadata_file.as_posix()]): # fmt: off + args = mflux_generate_controlnet_parser.parse_args() + assert args.controlnet_save_canny is False def test_save_args(mflux_save_parser): From 960447f92e768abaa49ffdcb5730fef535dc2fd8 Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Sat, 26 Oct 2024 20:24:28 -0700 Subject: [PATCH 4/4] update type hint --- src/mflux/config/runtime_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mflux/config/runtime_config.py b/src/mflux/config/runtime_config.py index 490f9ba..8b65bca 100644 --- a/src/mflux/config/runtime_config.py +++ b/src/mflux/config/runtime_config.py @@ -40,7 +40,7 @@ def num_train_steps(self) -> int: return self.model_config.num_train_steps @property - def init_image_path(self) -> int: + def init_image_path(self) -> str: return self.config.init_image_path @property