Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a very simple GUI using Kivy. #35

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
*.pyc
*.safetensors
*.json

request_history.db
36 changes: 29 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
mlx>=0.16.0
numpy>=2.0.1
certifi>=2024.8.30
charset-normalizer>=3.3.2
docutils>=0.21.2
filelock>=3.15.4
fsspec>=2024.9.0
huggingface-hub>=0.24.6
idna>=3.8
Jinja2>=3.1.4
Kivy>=2.3.0
Kivy-Garden>=0.1.5
MarkupSafe>=2.1.5
mlx>=0.17.2
mpmath>=1.3.0
networkx>=3.3
numpy>=2.1.1
packaging>=24.1
piexif>=1.1.3
pillow>=10.4.0
transformers>=4.44.0
Pygments>=2.18.0
PyYAML>=6.0.2
regex>=2024.7.24
requests>=2.32.3
safetensors>=0.4.5
sentencepiece>=0.2.0
torch>=2.3.1
setuptools>=74.1.2
sympy>=1.13.2
tokenizers==0.19.1
torch>=2.4.1
tqdm>=4.66.5
huggingface-hub>=0.24.5
safetensors>=0.4.4
piexif>=1.1.3
transformers>=4.44.2
typing_extensions>=4.12.2
urllib3>=2.2.2
5 changes: 4 additions & 1 deletion src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
if weights.quantization_level is not None:
self._set_model_weights(weights)

def generate_image(self, seed: int, prompt: str, config: Config = Config()) -> Image:
def generate_image(self, seed: int, prompt: str, config: Config = Config(), progress_callback=None) -> Image:
# Create a new runtime config based on the model type and input parameters
config = RuntimeConfig(config, self.model_config)
time_steps = tqdm(range(config.num_inference_steps))
Expand Down Expand Up @@ -103,6 +103,9 @@ def generate_image(self, seed: int, prompt: str, config: Config = Config()) -> I

# Evaluate to enable progress tracking
mx.eval(latents)
# Call the progress callback if provided
if progress_callback:
progress_callback(t + 1, config.num_inference_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is worth implementing as an interface for all GUIs to hook into. Another consideration can expect progress_callback callable to return a bool to indicate Continue (True) or Abort (False) so that the UI can stop generate_image.

Should be easy to re-cut a branch from newest main and just implement this.


# 5. Decode the latent array and return the image
latents = Flux1._unpack_latents(latents, config.height, config.width)
Expand Down
12 changes: 11 additions & 1 deletion src/mflux/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import time
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

Expand Down Expand Up @@ -49,12 +50,21 @@ def main():
height=args.height,
width=args.width,
guidance=args.guidance,
)
),
progress_callback=tqdm_progress(args.steps,args.output)
)

# Save the image
image.save(path=args.output, export_json_metadata=args.metadata)


def tqdm_progress(total_steps,output_filename):
pbar = tqdm(total=total_steps, desc=output_filename)
def update(step, _):
pbar.update(1)
if step == total_steps:
pbar.close()
return update

if __name__ == '__main__':
main()
Loading