Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use other model file based on Flux-schnell? #102

Open
puppyapple opened this issue Nov 22, 2024 · 4 comments
Open

Use other model file based on Flux-schnell? #102

puppyapple opened this issue Nov 22, 2024 · 4 comments

Comments

@puppyapple
Copy link

Thanks for the great work!
Recently I see new models like shuttleai/shuttle-3-diffusion · Hugging Face, which is finetuned on flux-schnell.
Is there any way to use mflux on these models?

@anthonywu
Copy link
Collaborator

@puppyapple if you're up for installing a dev build, try:

pip install git+https://github.com/anthonywu/mflux.git@support-hf-models

mflux-generate \
  --base-model schnell \
  --steps 4 \
  --model shuttleai/shuttle-3-diffusion \
  --prompt "A dog holding a sign that says mflux rocks"

make sure to accept the repo terms of service first at: https://huggingface.co/shuttleai/shuttle-3-diffusion before trying to use mflux / huggingface-hub to download the models to disk

sample from my local gen using the branch in #103

image

@tariks
Copy link

tariks commented Nov 23, 2024

@anthonywu for this syntax, how does quant work? for shuttle-3-diffusion, what's the difference between --quant 8 on full model or cloning the fp8 version?

@anthonywu
Copy link
Collaborator

anthonywu commented Nov 23, 2024

Here's my investigation.

TL;DR

  • The mflux-converted q8 model works as I confirm in this reply.
  • the upstream fp8 doesn't work because that HF repo does not contain the full assets. You may be able to get it to work by locally stitching the repo assets together.

download the official models

# official bfloat16 model
huggingface-cli download shuttleai/shuttle-3-diffusion 

# official fp8 model
huggingface-cli download shuttleai/shuttle-3-diffusion-fp8

observe the model size on disk

du -sh ~/.cache/huggingface/hub/models--shuttleai--shuttle-3-diffusion*
 54G	~/.cache/huggingface/hub/models--shuttleai--shuttle-3-diffusion
 11G	~/.cache/huggingface/hub/models--shuttleai--shuttle-3-diffusion-fp8

use mflux to save a q8 model

this should work in the latest commit of PR #103

mflux-save \
  -m shuttleai/shuttle-3-diffusion \
  --base-model schnell \
  -q 8 \
  --path /tmp/shuttle-3-diffusion-q8

this produces a local 17G converted model (compared to the 11G fp8 produced by their util: https://huggingface.co/shuttleai/shuttle-3-diffusion-fp8/blob/main/convert.py)

du -sh /tmp/shuttle-3-diffusion-q8
 17G	/tmp/shuttle-3-diffusion-q8

In mflux's model_saver.py, the conversion is passed to mlx.core. save_safetensors, this is its help():

save_safetensors = <nanobind.nb_func object>
    save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)

    Save array(s) to a binary file in ``.safetensors`` format.

    See the `Safetensors documentation
    <https://huggingface.co/docs/safetensors/index>`_ for more
    information on the format.

    Args:
        file (file, str): File in which the array is saved.
        arrays (dict(str, array)): The dictionary of names to arrays to
        be saved. metadata (dict(str, str), optional): The dictionary of
        metadata to be saved.

the caller is:

    def save_weights(base_path: str, bits: int, model: nn.Module, subdir: str):
        path = Path(base_path) / subdir
        path.mkdir(parents=True, exist_ok=True)
        weights = ModelSaver._split_weights(base_path, dict(tree_flatten(model.parameters())))
        for i, weight in enumerate(weights):
            mx.save_safetensors(
                str(path / f"{i}.safetensors"),
                weight,
                {"quantization_level": str(bits)},
            )

so the difference between the official fp8 17G version and the mflux converted 11G version is a matter of mlx vs PyTorch low level details of the quantization

compare outputs

generating with official bfloat16

mflux-generate \
  --seed 42 \
  --base-model schnell \
  --model shuttleai/shuttle-3-diffusion \
  --prompt "a cat holding a sign saying meow" \
  --steps 4

image

generating with mflux-converted q8

mflux-generate \
  --seed 42 \
  --model schnell \
  --path /tmp/shuttle-3-diffusion-q8 \
  --prompt "a cat holding a sign saying meow" \
  --steps 4 

seed42-mflux-shuttle-q8

the two images look almost identical - watch for a diff in the lower right corner of the sign.

On my M1 Max 64GB, generation time was 1m24 with q8 and 1m28 on bfloat16, so that's almost identical so my conclusion is that q8 isn't worth the perf improvement for now. I presume q8 would use less RAM but I don't have time to instrument it for now.

generating with upstream fp8

mflux-generate \
  --seed 42 \
  --base-model schnell \
  --model shuttleai/shuttle-3-diffusion-fp8 \
  --prompt "a cat holding a sign saying meow" \
  --steps 4

this doesn't actually work out of the box because the HF repo for fp8 does not contain all the assets from the bfloat16 version! And that's probably why it weighs in a 11G rather than closer to 17G, so a full repo may not be much diff in size after all.

image

the error from mflux is OSError: Incorrect path_or_model_id: '/Users/anthonywu/.cache/huggingface/hub/models--shuttleai--shuttle-3-diffusion-fp8/snapshots/c91030ae2199a2a57f9e0814f6f427692f15aabd/tokenizer'. Please provide either the path to a local folder or the repo_id of a model on the Hub. and the reason is the tokenizer and adjacent assets are not in the fp8 repo.

if the official fp8 repo is laid out like the full version then I think the HF model would just work

image

so as of this writing - the only way to get out of box q8 functionality is to use mflux to convert to q8 locally

this does create a todo to better display the errors for HF repos that do not present its full assets

@tariks
Copy link

tariks commented Nov 23, 2024

Thank you! Happy to confirm all this works as you say with the pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants