From a3376d9e015fddfddc2e6cccc4a55de639db0511 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 16:31:36 -0500 Subject: [PATCH 01/89] Bump punet revision to d30d6ff --- models/turbine_models/custom_models/sdxl_inference/unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index bd36db763..9e9d17bb7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,11 +92,11 @@ def get_punet_model(hf_model_name, external_weight_path, precision="i8"): if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "942e771bf0c2657a8b33380103d04747a75dfa4a" + revision = "d30d6ff79abb584bf2addc7866738df5242f315a" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" - revision = "76d28af79639c28a79fa5c6c6468febd3490a37e" + revision = "d30d6ff79abb584bf2addc7866738df5242f315a" def download(filename): return hf_hub_download( From 7cabac0597a8b3c5495e9fef17c615e66934ff46 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 16:33:54 -0500 Subject: [PATCH 02/89] Enable punet t2i test. --- models/turbine_models/tests/sdxl_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 216b6ff59..a06ffb657 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -447,7 +447,6 @@ def test05_t2i_generate_images(self): ) assert output is not None - @pytest.mark.skip(reason="Needs sdxl_quantized branch of IREE") def test06_t2i_generate_images_punet(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( From 7dfd4c86a1f50e5acc492c322cea6b0107401695 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 17:09:54 -0500 Subject: [PATCH 03/89] Use formatted strings as input to printer. --- models/turbine_models/custom_models/pipeline_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 5c02649a1..57a990ac2 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -723,7 +723,7 @@ def export_submodel( def load_map(self): for submodel in self.map.keys(): if not self.map[submodel]["load"]: - self.printer.print("Skipping load for ", submodel) + self.printer.print(f"Skipping load for {submodel}") continue self.load_submodel(submodel) From 1cd3ee9a22cd6dc2a3f4ff8b0ea398d48954a708 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 18:29:22 -0500 Subject: [PATCH 04/89] Rework sdxl test to setup with a pipeline, fix unloading submodels, factor out punet setup from pipe init --- .../custom_models/pipeline_base.py | 4 + .../custom_models/sd_inference/sd_pipeline.py | 20 +- models/turbine_models/tests/sdxl_test.py | 220 +++++------------- 3 files changed, 68 insertions(+), 176 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 57a990ac2..22b44e4e0 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -751,6 +751,10 @@ def load_submodel(self, submodel): def unload_submodel(self, submodel): self.map[submodel]["runner"].unload() + self.map[submodel]["vmfb"] = None + self.map[submodel]["mlir"] = None + self.map[submodel]["weights"] = None + self.map[submodel]["export_args"]["input_mlir"] = None setattr(self, submodel, None) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 277f74cb6..79270167d 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -351,13 +351,15 @@ def __init__( self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet + if self.use_punet: + self.setup_punet() + else: + self.map["unet"]["keywords"].append("!punet") + self.map["unet"]["function_name"] = "run_forward" + + def setup_punet(self): if self.use_i8_punet: self.map["unet"]["export_args"]["precision"] = "i8" - self.map["unet"]["export_args"]["use_punet"] = True - self.map["unet"]["use_weights_for_export"] = True - self.map["unet"]["keywords"].append("punet") - self.map["unet"]["module_name"] = "compiled_punet" - self.map["unet"]["function_name"] = "main" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" ) @@ -365,9 +367,11 @@ def __init__( if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" break - else: - self.map["unet"]["keywords"].append("!punet") - self.map["unet"]["function_name"] = "run_forward" + self.map["unet"]["export_args"]["use_punet"] = True + self.map["unet"]["use_weights_for_export"] = True + self.map["unet"]["keywords"].append("punet") + self.map["unet"]["module_name"] = "compiled_punet" + self.map["unet"]["function_name"] = "main" # LOAD diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index a06ffb657..5d16d55d8 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -80,28 +80,46 @@ def command_line_args(request): @pytest.mark.usefixtures("command_line_args") class StableDiffusionXLTest(unittest.TestCase): def setUp(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + decomp_attn = { + "text_encoder": True, + "unet": False, + "vae": True, + } + self.pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, # shift + use_i8_punet=False, + ) + self.pipe.prepare_all() - def test01_ExportPromptEncoder(self): + def test01_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." ) - arguments["external_weight_path"] = ( - "prompt_encoder." + arguments["external_weights"] - ) - prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( - arguments["hf_model_name"], - hf_auth_token=None, - max_length=arguments["max_length"], - batch_size=arguments["batch_size"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights="safetensors", - external_weight_path=arguments["external_weight_path"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ) + arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] tokenizer_1 = CLIPTokenizer.from_pretrained( arguments["hf_model_name"], subfolder="tokenizer", @@ -157,36 +175,10 @@ def test01_ExportPromptEncoder(self): def test02_ExportUnetModel(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") - unet_vmfb = unet.export_unet_model( - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - max_length=arguments["max_length"], - hf_auth_token=None, - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_unet." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=arguments["decomp_attn"], - attn_spec=arguments["attn_spec"], - exit_on_vmfb=False, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_unet." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = unet_vmfb + + arguments["vmfb_path"] = self.pipe.map["unet"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 sample = torch.rand( ( @@ -252,34 +244,9 @@ def test02_ExportUnetModel(self): def test03_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - vae_vmfb = vae.export_vae_model( - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_decode." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=True, - attn_spec=arguments["attn_spec"], - exit_on_vmfb=False, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_decode." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = vae_vmfb + + arguments["vmfb_path"] = self.pipe.map["unet"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] example_input = torch.ones( arguments["batch_size"], 4, @@ -328,35 +295,8 @@ def test04_ExportVaeModelEncode(self): self.skipTest( "Compilation error on cpu, vulkan and rocm; To be tested on cuda." ) - vae_vmfb = vae.export_vae_model( - vae_model=self.vae_model, - # This is a public model, so no auth required - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_encode." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=True, - exit_on_vmfb=True, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_encode." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = vae_vmfb + arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["vae"]["weights"] example_input = torch.ones( arguments["batch_size"], 3, @@ -402,39 +342,9 @@ def test04_ExportVaeModelEncode(self): def test05_t2i_generate_images(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Have issues with submodels on vulkan, cuda") - from turbine_models.custom_models.sd_inference.sd_pipeline import ( - SharkSDPipeline, - ) - decomp_attn = { - "text_encoder": False, - "unet": False, - "vae": True, - } - sd_pipe = SharkSDPipeline( - arguments["hf_model_name"], - arguments["height"], - arguments["width"], - arguments["batch_size"], - arguments["max_length"], - arguments["precision"], - arguments["device"], - arguments["iree_target_triple"], - ireec_flags=None, # ireec_flags - attn_spec=arguments["attn_spec"], - decomp_attn=decomp_attn, - pipeline_dir="test_vmfbs", # pipeline_dir - external_weights_dir="test_weights", # external_weights_dir - external_weights=arguments["external_weights"], - num_inference_steps=arguments["num_inference_steps"], - cpu_scheduling=True, - scheduler_id=arguments["scheduler_id"], - shift=None, # shift - use_i8_punet=False, - ) - sd_pipe.prepare_all() - sd_pipe.load_map() - output = sd_pipe.generate_images( + self.pipe.load_map() + output = self.pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], arguments["num_inference_steps"], @@ -452,39 +362,13 @@ def test06_t2i_generate_images_punet(self): self.skipTest( "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." ) - from turbine_models.custom_models.sd_inference.sd_pipeline import ( - SharkSDPipeline, - ) - - decomp_attn = { - "text_encoder": False, - "unet": False, - "vae": True, - } - sd_pipe = SharkSDPipeline( - arguments["hf_model_name"], - arguments["height"], - arguments["width"], - arguments["batch_size"], - arguments["max_length"], - arguments["precision"], - arguments["device"], - arguments["iree_target_triple"], - ireec_flags=None, # ireec_flags - attn_spec=arguments["attn_spec"], - decomp_attn=decomp_attn, - pipeline_dir="test_vmfbs", # pipeline_dir - external_weights_dir="test_weights", # external_weights_dir - external_weights=arguments["external_weights"], - num_inference_steps=arguments["num_inference_steps"], - cpu_scheduling=True, - scheduler_id=arguments["scheduler_id"], - shift=None, # shift - use_i8_punet=True, - ) - sd_pipe.prepare_all() - sd_pipe.load_map() - output = sd_pipe.generate_images( + self.pipe.unload_submodel("unet") + self.pipe.use_punet = True + self.pipe.use_i8_punet = True + self.pipe.setup_punet() + self.pipe.prepare_all() + self.pipe.load_map() + output = self.pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], arguments["num_inference_steps"], From 1a90abd22d59d1870094d37ab62b761b50490d99 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 19:25:05 -0500 Subject: [PATCH 05/89] Add switch for punet preprocessing flags --- .../custom_models/sd_inference/utils.py | 25 ++++++++++++------- .../custom_models/sdxl_inference/unet.py | 2 ++ models/turbine_models/tests/sdxl_test.py | 8 +++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index cc8591b9e..d3ac79ee0 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -28,6 +28,9 @@ "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], @@ -66,6 +69,9 @@ "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], @@ -153,7 +159,7 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - masked_attention=False, + flagset_keyword="", debug=False, ): flags = [] @@ -235,15 +241,19 @@ def compile_to_vmfb( elif "vae" in safe_name: flags.extend(MI_flags["vae"]) flags.extend(MI_flags["all"]) - if masked_attention: - flags.extend(GFX11_flags["pad_attention"]) + if "masked_attention" in flagset_keyword: + flags.extend(MI_flags["pad_attention"]) + elif "punet" in flagset_keyword: + flags.extend(MI_flags["punet"]) else: - flags.extend(GFX11_flags["preprocess_default"]) + flags.extend(MI_flags["preprocess_default"]) if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) - if masked_attention: + if "masked_attention" in flagset_keyword: flags.extend(GFX11_flags["pad_attention"]) + elif "punet" in flagset_keyword: + flags.extend(GFX11_flags["punet"]) else: flags.extend(GFX11_flags["preprocess_default"]) @@ -257,15 +267,12 @@ def compile_to_vmfb( attn_spec = get_mfma_spec_path( target_triple, os.path.dirname(safe_name), - masked_attention, use_punet=use_punet, ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): - attn_spec = get_wmma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention - ) + attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec and attn_spec != "None": diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 9e9d17bb7..4d8db3ac4 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -208,6 +208,7 @@ def export_unet_model( mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, + flagset_keyword="punet" if use_punet else None, ) return vmfb_path elif use_punet: @@ -340,6 +341,7 @@ class CompiledUnet(CompiledModule): safe_name, return_path=True, attn_spec=attn_spec, + flagset_keyword="punet" if use_punet else None, ) if exit_on_vmfb: exit() diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 5d16d55d8..4e1cdc9f7 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -99,16 +99,16 @@ def setUp(self): arguments["precision"], arguments["device"], arguments["iree_target_triple"], - ireec_flags=None, # ireec_flags + ireec_flags=None, attn_spec=arguments["attn_spec"], decomp_attn=decomp_attn, - pipeline_dir="test_vmfbs", # pipeline_dir - external_weights_dir="test_weights", # external_weights_dir + pipeline_dir="test_vmfbs", + external_weights_dir="test_weights", external_weights=arguments["external_weights"], num_inference_steps=arguments["num_inference_steps"], cpu_scheduling=True, scheduler_id=arguments["scheduler_id"], - shift=None, # shift + shift=None, use_i8_punet=False, ) self.pipe.prepare_all() From b70318dd7404243cce10b4c7cc157a74b07e285c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 12 Jul 2024 23:08:17 -0500 Subject: [PATCH 06/89] Xfail punet e2e test. --- models/turbine_models/tests/sdxl_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 4e1cdc9f7..0959c6ea2 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -357,6 +357,7 @@ def test05_t2i_generate_images(self): ) assert output is not None + @pytest.mark.xfail(reason="compilation issue on gfx90a") def test06_t2i_generate_images_punet(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( From 2d7ebcd5b293437ec9b964e711a8386c56c1e555 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 11:42:29 -0500 Subject: [PATCH 07/89] Fixups to sdxl test arguments --- models/turbine_models/tests/sdxl_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 0959c6ea2..38de26dd7 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -144,7 +144,7 @@ def test01_PromptEncoder(self): turbine_output1, turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( - prompt_encoder_vmfb, + arguments["vmfb_path"], arguments["rt_device"], arguments["external_weight_path"], text_input_ids_list, @@ -161,7 +161,7 @@ def test01_PromptEncoder(self): if arguments["benchmark"] or arguments["tracy_profile"]: run_benchmark( "prompt_encoder", - prompt_encoder_vmfb, + arguments["vmfb_path"], arguments["external_weight_path"], arguments["rt_device"], max_length=arguments["max_length"], @@ -245,7 +245,7 @@ def test03_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - arguments["vmfb_path"] = self.pipe.map["unet"]["vmfb"] + arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] example_input = torch.ones( arguments["batch_size"], From feebc87315f703f16bb5e5fef1a9a126ecb7456f Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 14:10:26 -0500 Subject: [PATCH 08/89] Fix flagset arg and enable vae encode. --- .../custom_models/sd_inference/utils.py | 10 +++++----- .../custom_models/sd_inference/vae.py | 19 +++++++++++++------ .../custom_models/sdxl_inference/unet.py | 4 ++-- models/turbine_models/tests/sdxl_test.py | 3 ++- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index d3ac79ee0..448b67910 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -159,7 +159,7 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - flagset_keyword="", + flagset_keywords=[], debug=False, ): flags = [] @@ -241,18 +241,18 @@ def compile_to_vmfb( elif "vae" in safe_name: flags.extend(MI_flags["vae"]) flags.extend(MI_flags["all"]) - if "masked_attention" in flagset_keyword: + if "masked_attention" in flagset_keywords: flags.extend(MI_flags["pad_attention"]) - elif "punet" in flagset_keyword: + elif "punet" in flagset_keywords: flags.extend(MI_flags["punet"]) else: flags.extend(MI_flags["preprocess_default"]) if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) - if "masked_attention" in flagset_keyword: + if "masked_attention" in flagset_keywords: flags.extend(GFX11_flags["pad_attention"]) - elif "punet" in flagset_keyword: + elif "punet" in flagset_keywords: flags.extend(GFX11_flags["punet"]) else: flags.extend(GFX11_flags["preprocess_default"]) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 7ccd12c48..56e932760 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -65,6 +65,11 @@ def decode(self, inp): return (x / 2 + 0.5).clamp(0, 1) def encode(self, inp): + image_np = inp / 255.0 + image_np = np.moveaxis(image_np, 2, 0) + batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + image_torch = torch.from_numpy(batch_images) + image_torch = 2.0 * image_torch - 1.0 latents = self.vae.encode(inp).latent_dist.sample() return self.vae.config.scaling_factor * latents @@ -97,7 +102,7 @@ def encode(self, inp): latent = self.vae.encode(image_torch) return latent - +@torch.no_grad() def export_vae_model( hf_model_name, batch_size, @@ -167,12 +172,12 @@ def export_vae_model( if weights_only: return external_weight_path - input_image_shape = (height, width, 3) + input_image_shape = (batch_size, 3, height, width) input_latents_shape = (batch_size, num_channels, height // 8, width // 8) encode_args = [ torch.empty( input_image_shape, - dtype=torch.float32, + dtype=dtype, ) ] decode_args = [ @@ -195,9 +200,9 @@ def export_vae_model( fxb = FxProgramsBuilder(vae_model) # TODO: fix issues with exporting the encode function. - # @fxb.export_program(args=(encode_args,)) - # def _encode(module, inputs,): - # return module.encode(*inputs) + @fxb.export_program(args=(encode_args,)) + def _encode(module, inputs,): + return module.encode(*inputs) @fxb.export_program(args=(decode_args,)) def _decode(module, inputs): @@ -205,6 +210,7 @@ def _decode(module, inputs): class CompiledVae(CompiledModule): decode = _decode + encode = _encode if external_weights: externalize_module_parameters(vae_model) @@ -228,6 +234,7 @@ class CompiledVae(CompiledModule): "output_dtypes": [np_dtype], } module = AddMetadataPass(module, model_metadata_decode, "decode").run() + module = AddMetadataPass(module, model_metadata_decode, "encode").run() if compile_to != "vmfb": return str(module) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 4d8db3ac4..a5f9a5b21 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -208,7 +208,7 @@ def export_unet_model( mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, - flagset_keyword="punet" if use_punet else None, + flagset_keyword=["punet"] if use_punet else [], ) return vmfb_path elif use_punet: @@ -341,7 +341,7 @@ class CompiledUnet(CompiledModule): safe_name, return_path=True, attn_spec=attn_spec, - flagset_keyword="punet" if use_punet else None, + flagset_keywords=["punet"] if use_punet else [], ) if exit_on_vmfb: exit() diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 38de26dd7..4a44ecc69 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -88,7 +88,7 @@ def setUp(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": True, + "vae": False, } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -257,6 +257,7 @@ def test03_ExportVaeModelDecode(self): example_input_torch = example_input if arguments["precision"] == "fp16": example_input = example_input.half() + breakpoint() turbine = vae_runner.run_vae( arguments["rt_device"], example_input, From af7782b0066ce052f65ee648b85ec68fa2a095cb Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 15:25:31 -0500 Subject: [PATCH 09/89] Enable VAE encode validation, mark as xfail --- .../custom_models/sd_inference/vae.py | 5 -- .../custom_models/sd_inference/vae_runner.py | 62 +++++-------------- models/turbine_models/tests/sdxl_test.py | 46 +++++--------- 3 files changed, 30 insertions(+), 83 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 56e932760..4c4cd5e84 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -65,11 +65,6 @@ def decode(self, inp): return (x / 2 + 0.5).clamp(0, 1) def encode(self, inp): - image_np = inp / 255.0 - image_np = np.moveaxis(image_np, 2, 0) - batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) - image_torch = torch.from_numpy(batch_images) - image_torch = 2.0 * image_torch - 1.0 latents = self.vae.encode(inp).latent_dist.sample() return self.vae.config.scaling_factor * latents diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index 166021631..f46ec791b 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -16,63 +16,29 @@ def run_vae_decode( return results +def run_vae_encode( + device, example_input, vmfb_path, hf_model_name, external_weight_path +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) -def run_torch_vae_decode(hf_model_name, variant, example_input): - from diffusers import AutoencoderKL + inputs = [ireert.asdevicearray(runner.config.device, example_input)] - class VaeModel(torch.nn.Module): - def __init__( - self, - hf_model_name, - base_vae=False, - custom_vae="", - low_cpu_mem_usage=False, - hf_auth_token="", - ): - super().__init__() - self.vae = None - if custom_vae == "": - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - else: - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - self.vae.load_state_dict(custom_vae) - self.base_vae = base_vae - - def decode_inp(self, input): - with torch.no_grad(): - input = 1 / 0.18215 * input - x = self.vae.decode(input, return_dict=False)[0] - return (x / 2 + 0.5).clamp(0, 1) - - def encode_inp(self, inp): - latents = self.vae.encode(inp).latent_dist.sample() - return 0.18215 * latents + results = runner.ctx.modules.compiled_vae["encode"](*inputs).to_host() + + return results + +def run_torch_vae(hf_model_name, variant, example_input): + from diffusers import AutoencoderKL + from turbine_models.custom_models.sd_inference.vae import VaeModel vae_model = VaeModel( hf_model_name, ) if variant == "decode": - results = vae_model.decode_inp(example_input) + results = vae_model.decode(example_input) elif variant == "encode": - results = vae_model.encode_inp(example_input) + results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 4a44ecc69..81bef87a9 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -10,16 +10,12 @@ import shutil from transformers import CLIPTokenizer from turbine_models.custom_models.sd_inference.utils import create_safe_name -from turbine_models.custom_models.sd_inference import schedulers, vae +from turbine_models.custom_models.sd_inference import schedulers, vae, vae_runner from turbine_models.custom_models.sdxl_inference import ( sdxl_prompt_encoder, sdxl_prompt_encoder_runner, unet, unet_runner, - sdxl_scheduled_unet, - sdxl_scheduled_unet_runner, - vae_runner, - sdxl_compiled_pipeline, ) from turbine_models.utils.sdxl_benchmark import run_benchmark import unittest @@ -88,7 +84,7 @@ def setUp(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": False, + "vae": True, } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -145,7 +141,7 @@ def test01_PromptEncoder(self): turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( arguments["vmfb_path"], - arguments["rt_device"], + self.pipe.map["text_encoder"]["driver"], arguments["external_weight_path"], text_input_ids_list, uncond_input_ids_list, @@ -163,7 +159,7 @@ def test01_PromptEncoder(self): "prompt_encoder", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["text_encoder"]["driver"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) @@ -199,7 +195,7 @@ def test02_ExportUnetModel(self): guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) turbine = unet_runner.run_unet( - arguments["rt_device"], + self.pipe.map["unet"]["driver"], sample, timestep, prompt_embeds, @@ -227,7 +223,7 @@ def test02_ExportUnetModel(self): "unet", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["unet"]["driver"], max_length=arguments["max_length"], height=arguments["height"], width=arguments["width"], @@ -257,21 +253,15 @@ def test03_ExportVaeModelDecode(self): example_input_torch = example_input if arguments["precision"] == "fp16": example_input = example_input.half() - breakpoint() - turbine = vae_runner.run_vae( - arguments["rt_device"], + turbine = vae_runner.run_vae_decode( + self.pipe.map["vae"]["driver"], example_input, arguments["vmfb_path"], arguments["hf_model_name"], - arguments["external_weight_path"], + self.pipe.map["vae"]["weights"], ) torch_output = vae_runner.run_torch_vae( arguments["hf_model_name"], - ( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else "" - ), "decode", example_input_torch, ) @@ -280,7 +270,7 @@ def test03_ExportVaeModelDecode(self): "vae_decode", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["vae"]["driver"], height=arguments["height"], width=arguments["width"], precision=arguments["precision"], @@ -291,6 +281,7 @@ def test03_ExportVaeModelDecode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) + @pytest.xfail(reason="NaN output on rocm, needs triage and file") def test04_ExportVaeModelEncode(self): if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: self.skipTest( @@ -308,20 +299,15 @@ def test04_ExportVaeModelEncode(self): example_input_torch = example_input if arguments["precision"] == "fp16": example_input = example_input.half() - turbine = vae_runner.run_vae( - arguments["rt_device"], + turbine = vae_runner.run_vae_encode( + self.pipe.map["vae"]["driver"], example_input, arguments["vmfb_path"], arguments["hf_model_name"], - arguments["external_weight_path"], + self.pipe.map["vae"]["weights"], ) torch_output = vae_runner.run_torch_vae( arguments["hf_model_name"], - ( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else "" - ), "encode", example_input_torch, ) @@ -329,8 +315,8 @@ def test04_ExportVaeModelEncode(self): run_benchmark( "vae_encode", arguments["vmfb_path"], - arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["vae"]["weights"], + self.pipe.map["vae"]["driver"], height=arguments["height"], width=arguments["width"], precision=arguments["precision"], From eff59a9288b8e3e39202082a6c838c575b54ea09 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 15:28:55 -0500 Subject: [PATCH 10/89] Fix formatting --- models/turbine_models/custom_models/sd_inference/vae.py | 6 +++++- .../turbine_models/custom_models/sd_inference/vae_runner.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 4c4cd5e84..8b18f65fa 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -97,6 +97,7 @@ def encode(self, inp): latent = self.vae.encode(image_torch) return latent + @torch.no_grad() def export_vae_model( hf_model_name, @@ -196,7 +197,10 @@ def export_vae_model( # TODO: fix issues with exporting the encode function. @fxb.export_program(args=(encode_args,)) - def _encode(module, inputs,): + def _encode( + module, + inputs, + ): return module.encode(*inputs) @fxb.export_program(args=(decode_args,)) diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index f46ec791b..81a1735df 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -16,6 +16,7 @@ def run_vae_decode( return results + def run_vae_encode( device, example_input, vmfb_path, hf_model_name, external_weight_path ): @@ -27,6 +28,7 @@ def run_vae_encode( return results + def run_torch_vae(hf_model_name, variant, example_input): from diffusers import AutoencoderKL from turbine_models.custom_models.sd_inference.vae import VaeModel From 63fb0539a4fadb8c2a44d3b44d1b8a9218f8b917 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 15:48:00 -0500 Subject: [PATCH 11/89] fix runner function name in old sd test. --- models/turbine_models/tests/sd_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 674e7d81b..98a3cfca2 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -208,7 +208,7 @@ def testExportVaeModelDecode(self): current_args["hf_model_name"], current_args["external_weight_path"], ) - torch_output = vae_runner.run_torch_vae_decode( + torch_output = vae_runner.run_torch_vae( current_args["hf_model_name"], "decode", example_input, From aff48ab3ff5ee895a5aac8032881fad91e7a817b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 16:09:41 -0500 Subject: [PATCH 12/89] Fix xfail syntax. --- models/turbine_models/tests/sdxl_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 81bef87a9..060e3a138 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -281,7 +281,7 @@ def test03_ExportVaeModelDecode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) - @pytest.xfail(reason="NaN output on rocm, needs triage and file") + @pytest.mark.xfail(reason="NaN output on rocm, needs triage and file") def test04_ExportVaeModelEncode(self): if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: self.skipTest( From b10ad8d414763c8ea17fef1ace3adea4a017a85d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 15 Jul 2024 18:04:35 -0500 Subject: [PATCH 13/89] Update unet script for compile function signature change --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index a5f9a5b21..999599039 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -208,7 +208,7 @@ def export_unet_model( mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, - flagset_keyword=["punet"] if use_punet else [], + flagset_keywords=["punet"] if use_punet else [], ) return vmfb_path elif use_punet: From 321d21d6a5fc5d08b3dfbbd23be00e8cc2450b7d Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Mon, 15 Jul 2024 17:48:50 -0700 Subject: [PATCH 14/89] Update punet to 4d4f955 --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 999599039..1ee531026 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, precision="i8"): if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "d30d6ff79abb584bf2addc7866738df5242f315a" + revision = "4d4f95554bb991b95eb8ae4d57b38eb139b3e23f" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From 2de912e8af6434c4a8210ad13ab9da714473325f Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:50:06 -0500 Subject: [PATCH 15/89] Disable vulkan test on MI250 runner. --- .github/workflows/test_models.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 03872dea3..2b2b6062a 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -70,7 +70,6 @@ jobs: pytest -v models/turbine_models/tests/sd_test.py pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 - pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 - pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 \ No newline at end of file + pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 From 9fdc07fe7c268df0f04466475f3e7ca20536cb1b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 17 Jul 2024 12:17:38 -0500 Subject: [PATCH 16/89] Change tqdm disable conditions and deepcopy model map on init. --- .../custom_models/sd_inference/sd_pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 79270167d..dc743c504 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -247,7 +247,7 @@ def __init__( "external_weights": None, "external_weight_path": None, } - sd_model_map = get_sd_model_map(hf_model_name) + sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name)) for submodel in sd_model_map: if "load" not in sd_model_map[submodel]: sd_model_map[submodel]["load"] = True @@ -569,9 +569,11 @@ def _produce_latents_sdxl( [guidance_scale], dtype=self.map["unet"]["np_dtype"], ) + # Disable progress bar if we aren't in verbose mode or if we're printing + # benchmark latencies for unet. for i, t in tqdm( enumerate(timesteps), - disable=(self.map["unet"].get("benchmark") and self.verbose), + disable=(self.map["unet"].get("benchmark") or not self.verbose), ): if self.cpu_scheduling: latent_model_input, t = self.scheduler.scale_model_input( From b20be32e8a93807737a2960acf1bc3a0d37b1e78 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:47:26 -0500 Subject: [PATCH 17/89] Don't break workarounds for model path Too many regressions based on workarounds for problems that are now fixed. Until submission just accept all cases. --- models/turbine_models/custom_models/sd_inference/sd_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index dc743c504..d249f38c5 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -190,6 +190,7 @@ def get_sd_model_map(hf_model_name): "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-xl-base-1.0", "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe", + "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe", ]: return sdxl_model_map elif "stabilityai/stable-diffusion-3" in name: From 02705a9dffb0d8a20a9e3de1dac1a55ab3bddea4 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 17 Jul 2024 20:44:47 -0500 Subject: [PATCH 18/89] Fix for passing a path as attn_spec. --- .../custom_models/sd_inference/sd_pipeline.py | 2 +- .../custom_models/sdxl_inference/unet.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index d249f38c5..256bcd8ee 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -244,7 +244,7 @@ def __init__( "exit_on_vmfb": False, "pipeline_dir": pipeline_dir, "input_mlir": None, - "attn_spec": None, + "attn_spec": attn_spec, "external_weights": None, "external_weight_path": None, } diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 1ee531026..38f743066 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -182,12 +182,13 @@ def export_unet_model( submodel_name = "punet" else: submodel_name = "unet" - if (not decomp_attn) and use_punet: - attn_spec = "punet" - elif (not decomp_attn) and "gfx9" in target: - attn_spec = "mfma" - elif (not decomp_attn) and "gfx11" in target: - attn_spec = "wmma" + if not attn_spec: + if (not decomp_attn) and use_punet: + attn_spec = "punet" + elif (not decomp_attn) and "gfx9" in target: + attn_spec = "mfma" + elif (not decomp_attn) and "gfx11" in target: + attn_spec = "wmma" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", From 9229aedd01bd828c8a361597e9f93eb8a26eedcc Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 19 Jul 2024 11:14:33 -0500 Subject: [PATCH 19/89] Bump punet revision to defeb489fe2bb17b77d587924db9e58048a8c140 --- models/turbine_models/custom_models/sdxl_inference/unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 38f743066..ec3a81d5a 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,11 +92,11 @@ def get_punet_model(hf_model_name, external_weight_path, precision="i8"): if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "4d4f95554bb991b95eb8ae4d57b38eb139b3e23f" + revision = "defeb489fe2bb17b77d587924db9e58048a8c140" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" - revision = "d30d6ff79abb584bf2addc7866738df5242f315a" + revision = "defeb489fe2bb17b77d587924db9e58048a8c140" def download(filename): return hf_hub_download( From f09ef4a663b04eff471531221f07d930876a60ad Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 19 Jul 2024 12:54:11 -0500 Subject: [PATCH 20/89] Move JIT cpu scheduling load helpers inside conditional. --- .../custom_models/sd_inference/sd_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 256bcd8ee..57610f392 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -381,10 +381,6 @@ def load_scheduler( scheduler_id: str, steps: int = 30, ): - if self.is_sd3: - scheduler_device = self.mmdit.device - else: - scheduler_device = self.unet.device if not self.cpu_scheduling: self.map["scheduler"] = { "module_name": "compiled_scheduler", @@ -431,6 +427,10 @@ def load_scheduler( print("JIT export of scheduler failed. Loading CPU scheduler.") self.cpu_scheduling = True if self.cpu_scheduling: + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) self.scheduler = schedulers.SharkSchedulerCPUWrapper( scheduler, From bbcc4243d76be0753f809b7e9a2ee5fba2cb0bb8 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 19 Jul 2024 12:55:09 -0500 Subject: [PATCH 21/89] formatting --- .../custom_models/sd_inference/sd_pipeline.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 57610f392..d0303a4a0 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -486,9 +486,12 @@ def prepare_latents( elif self.is_sdxl and self.cpu_scheduling: self.scheduler.do_guidance = False self.scheduler.repeat_sample = False - sample, add_time_ids, step_indexes, timesteps = ( - self.scheduler.initialize_sdxl(noise, num_inference_steps) - ) + ( + sample, + add_time_ids, + step_indexes, + timesteps, + ) = self.scheduler.initialize_sdxl(noise, num_inference_steps) return sample, add_time_ids, step_indexes, timesteps elif self.is_sdxl: return self.scheduler("run_initialize", noise) From 1f19c7fef715eb44068a45f740d6e5727c8a04c7 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 19 Jul 2024 15:01:34 -0500 Subject: [PATCH 22/89] Don't pass benchmark as an export arg. --- models/turbine_models/custom_models/pipeline_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 22b44e4e0..b54b158c2 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -391,7 +391,8 @@ def __init__( ) for submodel in self.map.keys(): for key, value in map_arguments.items(): - self.map = merge_export_arg(self.map, value, key) + if key != "benchmark": + self.map = merge_export_arg(self.map, value, key) for key, value in self.map[submodel].get("export_args", {}).items(): if key == "hf_model_name": self.map[submodel]["keywords"].append( From 39c0c00653bd3c19c3cbdd8fa9d687df37cc878f Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:50:34 -0700 Subject: [PATCH 23/89] Changes so no external downloads. (#781) --- .../custom_models/sd_inference/sd_pipeline.py | 5 ++++ .../custom_models/sdxl_inference/unet.py | 26 ++++++++++++++----- .../custom_models/sdxl_inference/vae.py | 7 ++--- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index d0303a4a0..f5e58fca0 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -234,6 +234,8 @@ def __init__( benchmark: bool | dict[bool] = False, verbose: bool = False, batch_prompts: bool = False, + punet_quant_paths: dict[str] = None, + vae_weight_path: str = None, ): common_export_args = { "hf_model_name": None, @@ -304,6 +306,7 @@ def __init__( self.cpu_scheduling = cpu_scheduling self.scheduler_id = scheduler_id self.num_inference_steps = num_inference_steps + self.punet_quant_paths = punet_quant_paths self.text_encoder = None self.unet = None @@ -340,6 +343,7 @@ def __init__( self.scheduler_device = self.map["unet"]["device"] self.scheduler_driver = self.map["unet"]["driver"] self.scheduler_target = self.map["unet"]["target"] + self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -364,6 +368,7 @@ def setup_punet(self): self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" ) + self.map["unet"]["export_args"]["quant_paths"] = self.punet_quant_paths for idx, word in enumerate(self.map["unet"]["keywords"]): if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index ec3a81d5a..95b072e68 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -82,7 +82,7 @@ def forward( return noise_pred -def get_punet_model(hf_model_name, external_weight_path, precision="i8"): +def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision="i8"): from sharktank.models.punet.model import ( Unet2DConditionModel as sharktank_unet2d, ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel, @@ -103,14 +103,23 @@ def download(filename): repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision ) - results = { - "config.json": download("config.json"), - "params.safetensors": download("params.safetensors"), - } + if not quant_paths: + results = { + "config.json": download("config.json"), + "params.safetensors": download("params.safetensors"), + } + else: + results = { + "config.json": quant_paths["config"], + "params.safetensors": quant_paths["params"], + } output_dir = os.path.dirname(external_weight_path) if precision == "i8": - results["quant_params.json"] = download("quant_params.json") + if quant_paths: + results["quant_params.json"] = quant_paths["quant_params"] + else: + results["quant_params.json"] = download("quant_params.json") ds_filename = os.path.basename(external_weight_path) output_path = os.path.join(output_dir, ds_filename) ds = get_punet_dataset( @@ -177,6 +186,7 @@ def export_unet_model( input_mlir=None, weights_only=False, use_punet=False, + quant_paths=None, ): if use_punet: submodel_name = "punet" @@ -213,7 +223,9 @@ def export_unet_model( ) return vmfb_path elif use_punet: - unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + unet_model = get_punet_model( + hf_model_name, external_weight_path, quant_paths, precision + ) else: unet_model = UnetModel(hf_model_name, hf_auth_token, precision) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 753cbb9e7..8a02dc192 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -119,9 +119,10 @@ def export_vae_model( mapper = {} - utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path - ) + if not os.path.exists(external_weight_path): + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_path + ) if weights_only: return external_weight_path From 3c59b25375dd00ad3ae5302a5de4c1c6a8e41347 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:58:16 -0700 Subject: [PATCH 24/89] fix so that we check exact paths as well for is_prepared (#782) --- models/turbine_models/custom_models/pipeline_base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index b54b158c2..b4755283e 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -540,7 +540,11 @@ def is_prepared(self, vmfbs, weights): avail_files = os.listdir(self.external_weights_dir) candidates = [] for filename in avail_files: - if all(str(x) in filename for x in w_keywords): + if all( + str(x) in filename + or str(x) == os.path.join(self.external_weights_dir, filename) + for x in w_keywords + ): candidates.append( os.path.join(self.external_weights_dir, filename) ) From 2e9de46a1bc97ed85a8ded34308c7323858b2be7 Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:22:59 -0700 Subject: [PATCH 25/89] Update punet to 60edc91 --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 95b072e68..0ce7a808d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "defeb489fe2bb17b77d587924db9e58048a8c140" + revision = "60edc91ded9f0f59a6ea37aa18e8ea774d9fd3f0" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From aa0ac2bbdb04717e5e7823c73a01363b4be1ccff Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Sun, 21 Jul 2024 11:51:11 -0700 Subject: [PATCH 26/89] Vae weight path none check (#784) --- .../turbine_models/custom_models/sd_inference/sd_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index f5e58fca0..5d8408605 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -343,7 +343,8 @@ def __init__( self.scheduler_device = self.map["unet"]["device"] self.scheduler_driver = self.map["unet"]["driver"] self.scheduler_target = self.map["unet"]["target"] - self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path + if vae_weight_path is not None: + self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" From 6556a360c24cb2f78aee0b4ac8d1a1dcd0fb9332 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:58:58 -0500 Subject: [PATCH 27/89] Bump punet to mi300_all_sym_8_step10 (62785ea) --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 0ce7a808d..83e33f6b8 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "60edc91ded9f0f59a6ea37aa18e8ea774d9fd3f0" + revision = "62785eafa1d40fed9a34cd748d2c3a5b3b299204" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From 2c49cb66cdbed95288206c3bb8a6380a663b46f7 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:41:13 -0700 Subject: [PATCH 28/89] Changes so that the default run without quant docker will work as well with all the new changes. (#785) This PR makes the necessary changes so that you don't have to run the quantization docker to run the normal harness. With stricter checks for the quant paths and adding a special switch to the save_external_weights so that it is compliant with this change that we need for the docker generated weights https://github.com/iree-org/iree-turbine/commit/cd916ec7aec3e9fcc91898a3b6b93e575148e3f3. --- .../custom_models/sd_inference/sd_pipeline.py | 2 ++ .../custom_models/sd_inference/utils.py | 6 ++++++ .../custom_models/sd_inference/vae.py | 12 ++++++++--- .../custom_models/sdxl_inference/unet.py | 20 +++++++++++++------ 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 5d8408605..60225ff60 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -236,6 +236,7 @@ def __init__( batch_prompts: bool = False, punet_quant_paths: dict[str] = None, vae_weight_path: str = None, + vae_harness: bool = False, ): common_export_args = { "hf_model_name": None, @@ -345,6 +346,7 @@ def __init__( self.scheduler_target = self.map["unet"]["target"] if vae_weight_path is not None: self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path + self.map["vae"]["export_args"]["vae_harness"] = vae_harness elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 448b67910..7f2f336e9 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -387,14 +387,20 @@ def save_external_weights( external_weights=None, external_weight_file=None, force_format=False, + vae_harness=False, ): if external_weights is not None: if external_weights in ["safetensors", "irpa"]: mod_params = dict(model.named_parameters()) mod_buffers = dict(model.named_buffers()) mod_params.update(mod_buffers) + vae_params = {} for name in mod_params: + if vae_harness: + vae_params[name.replace("vae.", "")] = mod_params[name] mapper["params." + name] = name + if vae_harness: + mod_params = vae_params if external_weight_file and not os.path.isfile(external_weight_file): if not force_format: safetensors.torch.save_file(mod_params, external_weight_file) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 8b18f65fa..493c8bb79 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -119,6 +119,7 @@ def export_vae_model( input_mlir=None, weights_only=False, upload_ir=False, + vae_harness=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 np_dtype = "float16" if precision == "fp16" else "float32" @@ -162,9 +163,14 @@ def export_vae_model( if dtype == torch.float16: vae_model = vae_model.half() mapper = {} - utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path - ) + if not os.path.exists(external_weight_path): + utils.save_external_weights( + mapper, + vae_model, + external_weights, + external_weight_path, + vae_harness=vae_harness, + ) if weights_only: return external_weight_path diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 83e33f6b8..f51ae1fa1 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -103,20 +103,28 @@ def download(filename): repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision ) - if not quant_paths: + if quant_paths and quant_paths["config"] and os.path.exists(quant_paths["config"]): results = { - "config.json": download("config.json"), - "params.safetensors": download("params.safetensors"), + "config.json": quant_paths["config"], } else: results = { - "config.json": quant_paths["config"], - "params.safetensors": quant_paths["params"], + "config.json": download("config.json"), } + + if quant_paths and quant_paths["params"] and os.path.exists(quant_paths["params"]): + results["params.safetensors"] = quant_paths["params"] + else: + results["params.safetensors"] = download("params.safetensors") + output_dir = os.path.dirname(external_weight_path) if precision == "i8": - if quant_paths: + if ( + quant_paths + and quant_paths["quant_params"] + and os.path.exists(quant_paths["quant_params"]) + ): results["quant_params.json"] = quant_paths["quant_params"] else: results["quant_params.json"] = download("quant_params.json") From cb911b143ff2726a04dd998e0de31cc8357e7348 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 22 Jul 2024 13:40:09 -0500 Subject: [PATCH 29/89] Bump punet to 361df65844e0a7c766484707c57f6248cea9587f --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index f51ae1fa1..3b4f8554f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "62785eafa1d40fed9a34cd748d2c3a5b3b299204" + revision = "361df65844e0a7c766484707c57f6248cea9587f" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From d857f7768944064440760666e69948892cf6ca74 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:06:37 -0700 Subject: [PATCH 30/89] Sync flags to sdxl-scripts repo (#786) Sync flags to sdxl-scripts repo --- .../custom_models/sd_inference/utils.py | 47 ++++++++++++------- .../custom_models/sd_inference/vae.py | 4 +- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 7f2f336e9..1caba490e 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -17,13 +17,8 @@ "all": [ "--iree-global-opt-propagate-transposes=true", "--iree-opt-const-eval=false", - "--iree-opt-outer-dim-concat=true", - "--iree-vm-target-truncate-unsupported-floats", "--iree-llvmgpu-enable-prefetch=true", - "--iree-opt-data-tiling=false", - "--iree-codegen-gpu-native-math-precision=true", - "--iree-rocm-waves-per-eu=2", - "--iree-flow-inline-constants-max-byte-length=1", + "--iree-execution-model=async-external", ], "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", @@ -31,23 +26,37 @@ "punet": [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], + "vae_preprocess": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-opt-outer-dim-concat=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-vm-target-truncate-unsupported-floats", ], "clip": [ "--iree-flow-enable-aggressive-fusion", "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-rocm-waves-per-eu=2", + "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "vae": [ "--iree-flow-enable-aggressive-fusion", + "--iree-flow-enable-fuse-horizontal-contractions", + "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-vm-target-truncate-unsupported-floats", ], "winograd": [""], } @@ -210,8 +219,6 @@ def compile_to_vmfb( "--iree-vm-bytecode-module-output-format=flatbuffer-binary", ] ) - if target_triple == "gfx942": - flags.extend(["--iree-rocm-waves-per-eu=2"]) elif device == "cuda": flags.extend( [ @@ -245,6 +252,8 @@ def compile_to_vmfb( flags.extend(MI_flags["pad_attention"]) elif "punet" in flagset_keywords: flags.extend(MI_flags["punet"]) + elif "vae" in safe_name: + flags.extend(MI_flags["vae_preprocess"]) else: flags.extend(MI_flags["preprocess_default"]) @@ -263,20 +272,22 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: - use_punet = True if attn_spec in ["punet", "i8"] else False - attn_spec = get_mfma_spec_path( - target_triple, - os.path.dirname(safe_name), - use_punet=use_punet, - ) - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + use_punet=use_punet, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec and attn_spec != "None": - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) for i, flag in enumerate(ireec_flags): k = flag.strip().split("=")[0] diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 493c8bb79..c18fb6da5 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -163,7 +163,9 @@ def export_vae_model( if dtype == torch.float16: vae_model = vae_model.half() mapper = {} - if not os.path.exists(external_weight_path): + if (external_weight_path is not None) and ( + not os.path.exists(external_weight_path) + ): utils.save_external_weights( mapper, vae_model, From 37548f259928304b962ca447794b4225a4d0246d Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:47:19 -0700 Subject: [PATCH 31/89] Integrate int8 tk kernels (#783) --- .../custom_models/sd_inference/utils.py | 97 ++++++++++++++++++- .../sdxl_inference/sdxl_cmd_opts.py | 6 ++ .../custom_models/sdxl_inference/unet.py | 8 ++ 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 1caba490e..0300c7907 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -155,6 +155,70 @@ def iree_backend_map(device): return iree_device +def replace_with_tk_kernels( + flow_dialect_ir, +): + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + ] + + # Replace all calls to old kernel with new kernel + print("Inserting kernels and updating calls to kernels...") + kernel_name = {} + for kernel in kernels: + kernel_name[kernel] = kernel.split("/")[-1].split(".")[0] + kernel_map = {} + prefix_map = {} + + base = flow_dialect_ir.split("\n") + new_base = [] + for line in base: + for kernel in kernels: + suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] + bias_explicit = False + if "bias" in suffix: + bias_explicit = True + kernel_args = 3 + int(suffix[4:]) + suffix = kernel.split(".")[0].split("_")[-2] + B, M, N, K = suffix.split("x") + old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" + if not old_kernel in line: + continue + if old_kernel in line and "func.func" in line: + if bias_explicit: + num_args = line.count("arg") + if num_args != kernel_args: + continue + kernel_map[kernel] = line.strip().split(" ")[1][1:-7] + prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] + if ( + old_kernel in line + and "flow.dispatch" in line + and not "func.func" in line + ): + line = line.replace(kernel_map[kernel], kernel_name[kernel]) + line = line.replace(prefix_map[kernel], kernel_name[kernel]) + new_base.append(line) + # Insert kernels in appropriate locations + final_ir = [] + for line in new_base: + for kernel in kernels: + if ( + prefix_map[kernel] + " {" in line + and "flow.executable" in line + and "private" in line + ): + data = urlopen(kernel).read().decode("utf-8") + data = data.split("\n") + translation_info = data[0].split("#translation = ")[1].strip() + data[10] = data[10].replace("#translation", translation_info) + final_ir.append("\n".join(data[2:-3])) + final_ir.append(line) + + print("tk kernels added") + return final_ir + + def compile_to_vmfb( module_str, device, @@ -170,6 +234,7 @@ def compile_to_vmfb( winograd=False, flagset_keywords=[], debug=False, + add_tk_kernels=False, ): flags = [] if mlir_source == "file" and not isinstance(module_str, str): @@ -307,6 +372,34 @@ def compile_to_vmfb( for idx, flag in enumerate(flags): if flag is None: flags.pop(idx) + input_ir_type = "torch" + if add_tk_kernels: + print("Adding tk kernels") + flags.extend(["--compile-to=flow"]) + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + elif mlir_source == "str": + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + + flow_ir = flatbuffer_blob.decode("utf-8") + + flow_ir_tk = replace_with_tk_kernels(flow_ir) + module_str = "\n".join(flow_ir_tk) + flags.pop() + flags.extend(["--compile-from=flow"]) + mlir_source = "str" + input_ir_type = "auto" + print("Compiling to", device, "with flags:", flags) # Forces a standard for naming files: @@ -323,7 +416,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_file( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) elif mlir_source == "str": @@ -334,7 +427,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_str( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) else: diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 368fb0d74..626df59cc 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -369,5 +369,11 @@ def is_valid_file(arg): help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", ) +p.add_argument( + "--add_tk_kernels", + default=False, + action="store_true", + help="Flag to add compiled tk kernels.", +) args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 3b4f8554f..c6808dcbd 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -195,6 +195,7 @@ def export_unet_model( weights_only=False, use_punet=False, quant_paths=None, + add_tk_kernels=False, ): if use_punet: submodel_name = "punet" @@ -217,6 +218,10 @@ def export_unet_model( if decomp_attn == True: ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" + # Currently, only int8 tk kernels are integrated + if add_tk_kernels and precision != "i8": + add_tk_kernels = False + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, @@ -228,6 +233,7 @@ def export_unet_model( return_path=not exit_on_vmfb, attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, ) return vmfb_path elif use_punet: @@ -363,6 +369,7 @@ class CompiledUnet(CompiledModule): return_path=True, attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, ) if exit_on_vmfb: exit() @@ -401,6 +408,7 @@ class CompiledUnet(CompiledModule): args.decomp_attn, attn_spec=args.attn_spec, input_mlir=args.input_mlir, + add_tk_kernels=args.add_tk_kernels, ) if args.input_mlir: exit() From 25b24629f65b59bdd770d014cdf732c19094ebf6 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 23 Jul 2024 12:11:09 -0500 Subject: [PATCH 32/89] Update punet revision to deterministic version (42e9407) Bumps punet revision -- https://huggingface.co/amd-shark/sdxl-quant-models/commit/42e94070478ed0599c0225a4879b69b253206eb6 --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index c6808dcbd..acccc3919 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" subfolder = "unet/int8" - revision = "361df65844e0a7c766484707c57f6248cea9587f" + revision = "42e94070478ed0599c0225a4879b69b253206eb6" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From 0e57b4edc51aec9468c88e19d9478f54c296a36a Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:39:03 -0700 Subject: [PATCH 33/89] Integration of tk kernels into pipeline (#789) Currently using a link, but Nithin will be pushing the fix to use a file name asap --- .../custom_models/sd_inference/sd_pipeline.py | 4 ++++ .../custom_models/sd_inference/utils.py | 20 ++++++++++++------- .../custom_models/sdxl_inference/unet.py | 1 + 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 60225ff60..2456ae4bf 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -237,6 +237,7 @@ def __init__( punet_quant_paths: dict[str] = None, vae_weight_path: str = None, vae_harness: bool = False, + add_tk_kernels: bool = False, ): common_export_args = { "hf_model_name": None, @@ -316,6 +317,7 @@ def __init__( self.scheduler = None self.split_scheduler = True + self.add_tk_kernels = add_tk_kernels self.base_model_name = ( hf_model_name @@ -367,6 +369,8 @@ def __init__( def setup_punet(self): if self.use_i8_punet: + if self.add_tk_kernels: + self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 0300c7907..9d5c149aa 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -155,12 +155,15 @@ def iree_backend_map(device): return iree_device -def replace_with_tk_kernels( - flow_dialect_ir, -): - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" - ] +def replace_with_tk_kernels(flow_dialect_ir, batch_size): + if batch_size == 8: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + ] + if batch_size == 1: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + ] # Replace all calls to old kernel with new kernel print("Inserting kernels and updating calls to kernels...") @@ -235,7 +238,10 @@ def compile_to_vmfb( flagset_keywords=[], debug=False, add_tk_kernels=False, + batch_size=1, ): + if batch_size != 1 and batch_size != 8: + add_tk_kernels = False flags = [] if mlir_source == "file" and not isinstance(module_str, str): module_str = str(module_str) @@ -393,7 +399,7 @@ def compile_to_vmfb( flow_ir = flatbuffer_blob.decode("utf-8") - flow_ir_tk = replace_with_tk_kernels(flow_ir) + flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size) module_str = "\n".join(flow_ir_tk) flags.pop() flags.extend(["--compile-from=flow"]) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index acccc3919..4ed874e25 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -370,6 +370,7 @@ class CompiledUnet(CompiledModule): attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], add_tk_kernels=add_tk_kernels, + batch_size=batch_size, ) if exit_on_vmfb: exit() From 920dbf5d53057ed04d08c515f2bd9db70559fcd6 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Wed, 24 Jul 2024 17:07:37 -0700 Subject: [PATCH 34/89] Update unet horizontal fusion flag (#790) --- models/turbine_models/custom_models/sd_inference/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9d5c149aa..2b6a164c1 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -40,6 +40,7 @@ "--iree-opt-data-tiling=false", "--iree-codegen-gpu-native-math-precision=true", "--iree-vm-target-truncate-unsupported-floats", + "--iree-flow-enable-fuse-horizontal-contractions=true", ], "clip": [ "--iree-flow-enable-aggressive-fusion", From 6f167312efa659f2e9bde5a498a94f95eacb005e Mon Sep 17 00:00:00 2001 From: saienduri Date: Wed, 24 Jul 2024 18:43:39 -0700 Subject: [PATCH 35/89] Revert "Update unet horizontal fusion flag (#790)" This reverts commit 920dbf5d53057ed04d08c515f2bd9db70559fcd6. --- models/turbine_models/custom_models/sd_inference/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2b6a164c1..9d5c149aa 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -40,7 +40,6 @@ "--iree-opt-data-tiling=false", "--iree-codegen-gpu-native-math-precision=true", "--iree-vm-target-truncate-unsupported-floats", - "--iree-flow-enable-fuse-horizontal-contractions=true", ], "clip": [ "--iree-flow-enable-aggressive-fusion", From 15dbd93ed7174a70c98ed4be05dd5bf9473369ba Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Wed, 24 Jul 2024 22:01:15 -0700 Subject: [PATCH 36/89] [tk kernel] Add support to match kernel with number of arguments and update kernel links (#791) --- .../custom_models/sd_inference/utils.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9d5c149aa..5bee5a097 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -158,11 +158,12 @@ def iree_backend_map(device): def replace_with_tk_kernels(flow_dialect_ir, batch_size): if batch_size == 8: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir" ] if batch_size == 1: kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", ] # Replace all calls to old kernel with new kernel @@ -178,20 +179,26 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): for line in base: for kernel in kernels: suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] - bias_explicit = False - if "bias" in suffix: - bias_explicit = True - kernel_args = 3 + int(suffix[4:]) - suffix = kernel.split(".")[0].split("_")[-2] + # Uncomment/rework when a kernel with bias comes in + # bias_explicit = False + # if "bias" in suffix: + # bias_explicit = True + # kernel_args = 3 + int(suffix[4:]) + # suffix = kernel.split(".")[0].split("_")[-2] B, M, N, K = suffix.split("x") old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" if not old_kernel in line: continue if old_kernel in line and "func.func" in line: - if bias_explicit: - num_args = line.count("arg") - if num_args != kernel_args: - continue + data = urlopen(kernel).read().decode("utf-8") + data = data.split("\n") + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") + num_args = line.count("arg") + if num_args != kernel_args: + continue kernel_map[kernel] = line.strip().split(" ")[1][1:-7] prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] if ( @@ -547,11 +554,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id, From 0c02652c7337f7ef686c45e7a31e072d5966c6eb Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:18:23 -0500 Subject: [PATCH 37/89] Add functionality to SD pipeline and abstracted components for saving output .npys (#792) --- .../custom_models/pipeline_base.py | 26 +++++++++++++++++-- .../custom_models/sd_inference/sd_cmd_opts.py | 6 +++++ .../custom_models/sd_inference/sd_pipeline.py | 13 +++++++++- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index b4755283e..3102ac3e3 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -84,7 +84,12 @@ class PipelineComponent: """ def __init__( - self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False + self, + printer, + dest_type="devicearray", + dest_dtype="float16", + benchmark=False, + save_outputs=False, ): self.runner = None self.module_name = None @@ -92,6 +97,8 @@ def __init__( self.metadata = None self.printer = printer self.benchmark = benchmark + self.save_outputs = save_outputs + self.output_counter = 0 self.dest_type = dest_type self.dest_dtype = dest_dtype @@ -218,6 +225,16 @@ def _output_cast(self, output): case _: return output + def save_output(self, function_name, output): + if isinstance(output, tuple) or isinstance(output, list): + for i in output: + self.save_output(function_name, i) + else: + np.save( + f"{function_name}_output_{self.output_counter}.npy", output.to_host() + ) + self.output_counter += 1 + def _run(self, function_name, inputs: list): return self.module[function_name](*inputs) @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list): output = self._run_and_benchmark(function_name, inputs) else: output = self._run(function_name, inputs) + if self.save_outputs: + self.save_output(function_name, output) output = self._output_cast(output) return output @@ -340,6 +359,7 @@ def __init__( hf_model_name: str | dict[str] = None, benchmark: bool | dict[bool] = False, verbose: bool = False, + save_outputs: bool | dict[bool] = False, common_export_args: dict = {}, ): self.map = model_map @@ -374,6 +394,7 @@ def __init__( "external_weights": external_weights, "hf_model_name": hf_model_name, "benchmark": benchmark, + "save_outputs": save_outputs, } for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) @@ -391,7 +412,7 @@ def __init__( ) for submodel in self.map.keys(): for key, value in map_arguments.items(): - if key != "benchmark": + if key not in ["benchmark", "save_outputs"]: self.map = merge_export_arg(self.map, value, key) for key, value in self.map[submodel].get("export_args", {}).items(): if key == "hf_model_name": @@ -744,6 +765,7 @@ def load_submodel(self, submodel): printer=self.printer, dest_type=dest_type, benchmark=self.map[submodel].get("benchmark", False), + save_outputs=self.map[submodel].get("save_outputs", False), ) self.map[submodel]["runner"].load( self.map[submodel]["driver"], diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 5e025a4d5..a852bf464 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -151,6 +151,12 @@ def is_valid_file(arg): help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.", ) +p.add_argument( + "--save_outputs", + type=str, + default=None, + help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.", +) ############################################################################## # SDXL Modelling Options # These options are used to control model defining parameters for SDXL. diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 2456ae4bf..97369c538 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -236,8 +236,9 @@ def __init__( batch_prompts: bool = False, punet_quant_paths: dict[str] = None, vae_weight_path: str = None, - vae_harness: bool = False, + vae_harness: bool = True, add_tk_kernels: bool = False, + save_outputs: bool | dict[bool] = False, ): common_export_args = { "hf_model_name": None, @@ -286,6 +287,7 @@ def __init__( hf_model_name, benchmark, verbose, + save_outputs, common_export_args, ) for submodel in sd_model_map: @@ -742,6 +744,14 @@ def numpy_to_pil_image(images): benchmark[i] = True else: benchmark = False + if args.save_outputs: + if args.save_outputs.lower() == "all": + save_outputs = True + else: + for i in args.save_outputs.split(","): + save_outputs[i] = True + else: + save_outputs = False if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): args.decomp_attn = { "text_encoder": args.decomp_attn, @@ -772,6 +782,7 @@ def numpy_to_pil_image(images): args.use_i8_punet, benchmark, args.verbose, + save_outputs=save_outputs, ) sd_pipe.prepare_all() sd_pipe.load_map() From 3fd954b7219a19752c9e201d6e659d225b902451 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:05:10 -0700 Subject: [PATCH 38/89] Remove download links for tk kernels and instead specify kernel directory as an argument (#793) --- .../custom_models/sd_inference/sd_pipeline.py | 3 ++ .../custom_models/sd_inference/utils.py | 51 ++++++++----------- .../sdxl_inference/sdxl_cmd_opts.py | 7 +++ 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 97369c538..a322cb083 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -238,6 +238,7 @@ def __init__( vae_weight_path: str = None, vae_harness: bool = True, add_tk_kernels: bool = False, + tk_kernels_dir: str | dict[str] = None, save_outputs: bool | dict[bool] = False, ): common_export_args = { @@ -320,6 +321,7 @@ def __init__( self.split_scheduler = True self.add_tk_kernels = add_tk_kernels + self.tk_kernels_dir = tk_kernels_dir self.base_model_name = ( hf_model_name @@ -373,6 +375,7 @@ def setup_punet(self): if self.use_i8_punet: if self.add_tk_kernels: self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels + self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 5bee5a097..84d9cb3b4 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -5,6 +5,7 @@ import safetensors import safetensors.numpy as safe_numpy import re +import glob from diffusers import ( PNDMScheduler, EulerDiscreteScheduler, @@ -155,16 +156,8 @@ def iree_backend_map(device): return iree_device -def replace_with_tk_kernels(flow_dialect_ir, batch_size): - if batch_size == 8: - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir" - ] - if batch_size == 1: - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", - ] +def replace_with_tk_kernels(tk_kernels_dir, flow_dialect_ir, batch_size): + kernels = glob.glob(tk_kernels_dir + "/bs" + str(batch_size) + "/*") # Replace all calls to old kernel with new kernel print("Inserting kernels and updating calls to kernels...") @@ -178,25 +171,21 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): new_base = [] for line in base: for kernel in kernels: - suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] - # Uncomment/rework when a kernel with bias comes in - # bias_explicit = False - # if "bias" in suffix: - # bias_explicit = True - # kernel_args = 3 + int(suffix[4:]) - # suffix = kernel.split(".")[0].split("_")[-2] + suffix = kernel.split(".")[0].split("_")[-1] + if "bias" in suffix: + suffix = kernel.split(".")[0].split("_")[-2] B, M, N, K = suffix.split("x") old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" if not old_kernel in line: continue if old_kernel in line and "func.func" in line: - data = urlopen(kernel).read().decode("utf-8") - data = data.split("\n") + num_args = line.count("arg") + with open(kernel, "r") as f: + data = f.readlines() idx_with_kernel_args = [ idx for idx, s in enumerate(data) if "func.func" in s ][0] kernel_args = data[idx_with_kernel_args].count("arg") - num_args = line.count("arg") if num_args != kernel_args: continue kernel_map[kernel] = line.strip().split(" ")[1][1:-7] @@ -218,11 +207,12 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): and "flow.executable" in line and "private" in line ): - data = urlopen(kernel).read().decode("utf-8") - data = data.split("\n") + with open(kernel, "r") as f: + data = f.readlines() translation_info = data[0].split("#translation = ")[1].strip() - data[10] = data[10].replace("#translation", translation_info) - final_ir.append("\n".join(data[2:-3])) + extract = "".join(data[2:-2]) + extract = extract.replace("#translation", translation_info) + final_ir += extract final_ir.append(line) print("tk kernels added") @@ -245,6 +235,7 @@ def compile_to_vmfb( flagset_keywords=[], debug=False, add_tk_kernels=False, + tk_kernels_dir=None, batch_size=1, ): if batch_size != 1 and batch_size != 8: @@ -406,7 +397,7 @@ def compile_to_vmfb( flow_ir = flatbuffer_blob.decode("utf-8") - flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size) + flow_ir_tk = replace_with_tk_kernels(tk_kernels_dir, flow_ir, batch_size) module_str = "\n".join(flow_ir_tk) flags.pop() flags.extend(["--compile-from=flow"]) @@ -554,11 +545,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers[ - "EulerAncestralDiscrete" - ] = EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id, diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 626df59cc..017244f64 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -376,4 +376,11 @@ def is_valid_file(arg): help="Flag to add compiled tk kernels.", ) +p.add_argument( + "--tk_kernels_dir", + default=False, + action="store_true", + help="Path to directory containing tk kernels sorted by batch size.", +) + args, unknown = p.parse_known_args() From 7f8a2b02de14f86275f23ac77444edb59f92ce04 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 25 Jul 2024 14:52:31 -0700 Subject: [PATCH 39/89] Update to best iteration on unet weights (#794) --- models/turbine_models/custom_models/sdxl_inference/unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 4ed874e25..945a3fb7d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -90,9 +90,9 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= from sharktank.utils import cli if precision == "i8": - repo_id = "amd-shark/sdxl-quant-models" - subfolder = "unet/int8" - revision = "42e94070478ed0599c0225a4879b69b253206eb6" + repo_id = "amd-shark/sdxl-quant-int8" + subfolder = "mi300_all_sym_8_step14_fp32" + revision = "2e416a4205c519f5e62ba707ddf4f5022b6276c8" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From bf63aec50878dc431f2356e52c0c6c178ddf614d Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:09:38 -0700 Subject: [PATCH 40/89] Add missing tk_kernel_args arg in function calls (#795) --- models/turbine_models/custom_models/sdxl_inference/unet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 945a3fb7d..f2de082d4 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -196,6 +196,7 @@ def export_unet_model( use_punet=False, quant_paths=None, add_tk_kernels=False, + tk_kernels_dir=None, ): if use_punet: submodel_name = "punet" @@ -234,6 +235,7 @@ def export_unet_model( attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], add_tk_kernels=add_tk_kernels, + tk_kernels_dir=tk_kernels_dir, ) return vmfb_path elif use_punet: @@ -371,6 +373,7 @@ class CompiledUnet(CompiledModule): flagset_keywords=["punet"] if use_punet else [], add_tk_kernels=add_tk_kernels, batch_size=batch_size, + tk_kernels_dir=tk_kernels_dir, ) if exit_on_vmfb: exit() @@ -410,6 +413,7 @@ class CompiledUnet(CompiledModule): attn_spec=args.attn_spec, input_mlir=args.input_mlir, add_tk_kernels=args.add_tk_kernels, + tk_kernels_dir=args.tk_kernels_dir ) if args.input_mlir: exit() From a74d98ecfe4932d1b1bdbcfac0265a966a409ce8 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:36:33 -0700 Subject: [PATCH 41/89] update hash for config file --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index f2de082d4..70a1a043b 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -92,7 +92,7 @@ def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision= if precision == "i8": repo_id = "amd-shark/sdxl-quant-int8" subfolder = "mi300_all_sym_8_step14_fp32" - revision = "2e416a4205c519f5e62ba707ddf4f5022b6276c8" + revision = "efda8afb35fd72c1769e02370b320b1011622958" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" From 925cd0c6d944c61d23d4e0fb6230b41985e255f4 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 29 Jul 2024 11:40:38 -0500 Subject: [PATCH 42/89] Fix formatting --- models/turbine_models/custom_models/sdxl_inference/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 70a1a043b..73b32cf58 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -413,7 +413,7 @@ class CompiledUnet(CompiledModule): attn_spec=args.attn_spec, input_mlir=args.input_mlir, add_tk_kernels=args.add_tk_kernels, - tk_kernels_dir=args.tk_kernels_dir + tk_kernels_dir=args.tk_kernels_dir, ) if args.input_mlir: exit() From 7715fd0c31a979ecd038cf90965224d2b9ac297a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 30 Jul 2024 10:24:43 -0500 Subject: [PATCH 43/89] Point to sdxl-vae-fix branch of iree-turbine. --- models/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/requirements.txt b/models/requirements.txt index 0aed40159..9cee9146a 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -12,5 +12,5 @@ azure-storage-blob einops pytest scipy -shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main +shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@sdxl-vae-fix -e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank From e276c782755cf27ac9706593c97434159180341e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 30 Jul 2024 16:04:48 -0500 Subject: [PATCH 44/89] Add SD3 to sd_pipeline --- .../custom_models/sd3_inference/sd3_mmdit.py | 22 ++- .../sd3_inference/sd3_schedulers.py | 66 ++++--- .../sd3_inference/sd3_text_encoders.py | 8 +- .../custom_models/sd3_inference/sd3_vae.py | 2 +- .../custom_models/sd_inference/sd_cmd_opts.py | 31 +++- .../custom_models/sd_inference/sd_pipeline.py | 162 ++++++++++++++---- .../custom_models/sd_inference/utils.py | 27 ++- .../custom_models/sd_inference/vae.py | 51 +++--- 8 files changed, 260 insertions(+), 109 deletions(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index b71d3129e..aed7839dd 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -45,6 +45,7 @@ def forward( pooled_projections, timestep, ): + timestep.expand(hidden_states.shape[0]) noise_pred = self.mmdit( hidden_states, encoder_hidden_states, @@ -71,7 +72,7 @@ def forward(self, q, k, v): def export_attn( precision="fp16", device="cpu", - target_triple="x86_64-unknown-linux-gnu", + target="x86_64-unknown-linux-gnu", ireec_flags="", compile_to="torch", decomp_attn=False, @@ -128,7 +129,7 @@ class CompiledAttn(CompiledModule): vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, return_path=True, @@ -139,7 +140,6 @@ class CompiledAttn(CompiledModule): @torch.no_grad() def export_mmdit_model( - mmdit_model, hf_model_name, batch_size, height, @@ -151,8 +151,8 @@ def export_mmdit_model( external_weights=None, external_weight_path=None, device=None, - target_triple=None, - ireec_flags=None, + target=None, + ireec_flags="pad_attention", decomp_attn=False, exit_on_vmfb=False, pipeline_dir=None, @@ -161,6 +161,9 @@ def export_mmdit_model( weights_only=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 + mmdit_model = MMDiTModel( + dtype=dtype, + ) np_dtype = "float16" if precision == "fp16" else "float32" safe_name = utils.create_safe_name( hf_model_name, @@ -169,13 +172,14 @@ def export_mmdit_model( if pipeline_dir: safe_name = os.path.join(pipeline_dir, safe_name) if decomp_attn == True: + safe_name += "_decomp_attn" ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -208,7 +212,7 @@ def export_mmdit_model( torch.empty(hidden_states_shape, dtype=dtype), torch.empty(encoder_hidden_states_shape, dtype=dtype), torch.empty(pooled_projections_shape, dtype=dtype), - torch.empty(init_batch_dim, dtype=dtype), + torch.empty(1, dtype=dtype), ] decomp_list = [] @@ -249,7 +253,7 @@ class CompiledMmdit(CompiledModule): hidden_states_shape, encoder_hidden_states_shape, pooled_projections_shape, - init_batch_dim, + (1,), ], "input_dtypes": [np_dtype for x in range(4)], "output_shapes": [hidden_states_shape], @@ -263,7 +267,7 @@ class CompiledMmdit(CompiledModule): vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, return_path=True, diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 2c1d04cf1..075b29017 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -77,17 +77,14 @@ def initialize(self, sample): ) def prepare_model_input(self, sample, t, timesteps): - t = timesteps[t] - if self.do_classifier_free_guidance: latent_model_input = torch.cat([sample] * 2) else: latent_model_input = sample - t = t.expand(latent_model_input.shape[0]) return latent_model_input.type(self.dtype), t.type(self.dtype) - def step(self, noise_pred, t, sample, guidance_scale, i): - self.model._step_index = i + def step(self, noise_pred, t, sample, guidance_scale): + self.model._step_index = self.index_for_timestep(t) if self.do_classifier_free_guidance: noise_preds = noise_pred.chunk(2) @@ -97,6 +94,30 @@ def step(self, noise_pred, t, sample, guidance_scale, i): sample = self.model.step(noise_pred, t, sample, return_dict=False)[0] return sample.type(self.dtype) + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.model.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + eq = torch.eq(schedule_timesteps, timestep) + eq = eq.int() + index_candidates = torch.argmax(eq) + index_candidates = index_candidates.unsqueeze(0) + + a = torch.numel(index_candidates) + cond = torch.scalar_tensor(a) + one = torch.scalar_tensor(1, dtype=torch.int64) + zero = torch.scalar_tensor(0, dtype=torch.int64) + index = torch.where(cond > 1, one, zero) + index = index.unsqueeze(0) + step_index = index_candidates.index_select(0, index) + return step_index + # Wraps a diffusers scheduler running on native pytorch+cpu. # This allows us to use it interchangeably with compiled schedulers in our pipeline(s). @@ -151,6 +172,9 @@ def step(self, noise_pred, t, latents, guidance_scale, i): )[0] +# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete and adapted for dynamo compile. + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Only used for cpu scheduling. def retrieve_timesteps( @@ -198,6 +222,7 @@ def retrieve_timesteps( @torch.no_grad() def export_scheduler_model( hf_model_name: str, + scheduler_id: str = "FlowEulerDiscrete", batch_size: int = 1, height: int = 512, width: int = 512, @@ -206,7 +231,7 @@ def export_scheduler_model( precision: str = "fp16", compile_to: str = "torch", device: str = None, - target_triple: str = None, + target: str = None, ireec_flags: str = None, exit_on_vmfb: bool = False, pipeline_dir: str = None, @@ -221,7 +246,7 @@ def export_scheduler_model( f"bs{batch_size}_{height}x{width}", precision, str(num_inference_steps), - target_triple, + target, ] vmfb_name = "_".join(vmfb_names) safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) @@ -231,9 +256,9 @@ def export_scheduler_model( vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, mlir_source="file", return_path=not exit_on_vmfb, ) @@ -260,7 +285,7 @@ def export_scheduler_model( example_init_args = [torch.empty(sample, dtype=dtype)] example_prep_args = ( torch.empty(sample, dtype=dtype), - torch.empty(1, dtype=torch.int64), + torch.empty(1, dtype=torch.float32), torch.empty([19], dtype=torch.float32), ) timesteps = torch.export.Dim("timesteps") @@ -274,7 +299,6 @@ def export_scheduler_model( torch.empty(1, dtype=dtype), torch.empty(sample, dtype=dtype), torch.empty(1, dtype=dtype), - torch.empty(1, dtype=torch.int64), ] fxb = FxProgramsBuilder(scheduler_module) @@ -312,8 +336,8 @@ def _step(module, inputs): ): class CompiledScheduler(CompiledModule): - run_init = _initialize - run_prep = _prep + run_initialize = _initialize + run_scale = _prep run_step = _step import_to = "INPUT" if compile_to == "linalg" else "IMPORT" @@ -330,20 +354,20 @@ class CompiledScheduler(CompiledModule): } model_metadata_run_prep = { "model_name": "sd3_scheduler_FlowEulerDiscrete", - "input_shapes": [sample, 1, [19]], + "input_shapes": [sample, (1,), ("?",)], "input_dtypes": [np_dtype, "float32", "float32"], - "output_shapes": [noise_pred_shape, noise_pred_shape[0]], + "output_shapes": [noise_pred_shape, (1,)], "output_dtypes": [np_dtype, "float32"], } model_metadata_run_step = { "model_name": "sd3_scheduler_FlowEulerDiscrete", - "input_shapes": [noise_pred_shape, 1, sample, 1, 1], - "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype, "int64"], + "input_shapes": [noise_pred_shape, (1,), sample, (1,)], + "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype], "output_shapes": [sample], "output_dtypes": [np_dtype], } - module = AddMetadataPass(module, model_metadata_run_init, "run_init").run() - module = AddMetadataPass(module, model_metadata_run_prep, "run_prep").run() + module = AddMetadataPass(module, model_metadata_run_init, "run_initialize").run() + module = AddMetadataPass(module, model_metadata_run_prep, "run_scale").run() module = AddMetadataPass(module, model_metadata_run_step, "run_step").run() module_str = str(module) @@ -353,9 +377,9 @@ class CompiledScheduler(CompiledModule): vmfb = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, return_path=True, ) if exit_on_vmfb: diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index d3e4ecb54..6487be7ee 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -121,7 +121,7 @@ def export_text_encoders( external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, exit_on_vmfb=False, pipeline_dir=None, @@ -134,6 +134,8 @@ def export_text_encoders( hf_model_name, f"_bs{batch_size}_{str(max_length)}_{precision}_text_encoders", ) + if decomp_attn: + safe_name += "_decomp_attn" if pipeline_dir: safe_name = os.path.join(pipeline_dir, safe_name) @@ -141,7 +143,7 @@ def export_text_encoders( vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -212,7 +214,7 @@ class CompiledTextEncoder(CompiledModule): vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, return_path=not exit_on_vmfb, diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py index ff24864a6..56300d5f3 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -98,7 +98,7 @@ def export_vae_model( vae_model = vae_model.half() mapper = {} utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path + mapper, vae_model, external_weights, external_weight_path, vae_harness=True ) if weights_only: return external_weight_path diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index a852bf464..7acf0ef4f 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -180,11 +180,11 @@ def is_valid_file(arg): p.add_argument( "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" ) -p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") + p.add_argument( - "--return_index", + "--clip_decomp_attn", action="store_true", - help="Make scheduled unet compiled module return the step index.", + help="Decompose attention for text_encoder only at fx graph level", ) p.add_argument( @@ -199,6 +199,19 @@ def is_valid_file(arg): help="Decompose attention for unet only at fx graph level", ) +p.add_argument( + "--mmdit_decomp_attn", + action="store_true", + help="Decompose attention for unet only at fx graph level", +) + +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) + p.add_argument( "--use_i8_punet", action="store_true", @@ -227,12 +240,6 @@ def is_valid_file(arg): action="store_true", help="Runs both turbine vmfb and a torch model to compare results", ) -p.add_argument( - "--decomp_attn", - default=False, - action="store_true", - help="Decompose attention at fx graph level", -) p.add_argument( "--exit_on_vmfb", default=True, @@ -317,5 +324,11 @@ def is_valid_file(arg): help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", ) +p.add_argument( + "--mmdit_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling mmdit. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index a322cb083..2cd0a5408 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -24,7 +24,9 @@ from turbine_models.custom_models.sd3_inference import ( sd3_text_encoders, sd3_mmdit, + sd3_schedulers, ) +from turbine_models.custom_models.sd3_inference.text_encoder_impls import SD3Tokenizer from turbine_models.custom_models.pipeline_base import ( TurbinePipelineBase, merge_arg_into_map, @@ -250,6 +252,7 @@ def __init__( "exit_on_vmfb": False, "pipeline_dir": pipeline_dir, "input_mlir": None, + "ireec_flags": None, "attn_spec": attn_spec, "external_weights": None, "external_weight_path": None, @@ -346,28 +349,36 @@ def __init__( ), ] self.map["text_encoder"]["export_args"]["batch_input"] = batch_prompts - self.latents_precision = self.map["unet"]["precision"] - self.scheduler_device = self.map["unet"]["device"] - self.scheduler_driver = self.map["unet"]["driver"] - self.scheduler_target = self.map["unet"]["target"] + self.diffusion_model = self.map["unet"] if vae_weight_path is not None: self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path self.map["vae"]["export_args"]["vae_harness"] = vae_harness - elif not self.is_sd3: + elif self.is_sd3: + self.tokenizer = SD3Tokenizer() + self.scheduler_id = "EulerFlowDiscrete" + self.map["text_encoder"]["export_args"]["external_weights"] = "irpa" + self.map["text_encoder"]["export_args"][ + "external_weight_path" + ] = "stable_diffusion_3_medium_text_encoder_fp16.irpa" + self.diffusion_model = self.map["mmdit"] + else: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" ) - self.latents_precision = self.map["unet"]["precision"] - self.scheduler_device = self.map["unet"]["device"] - self.scheduler_driver = self.map["unet"]["driver"] - self.scheduler_target = self.map["unet"]["target"] - # TODO: Add SD3 init + self.diffusion_model = self.map["unet"] + + self.latents_precision = self.diffusion_model["precision"] + self.latents_channels = self.map["vae"]["export_args"]["num_channels"] + self.scheduler_device = self.diffusion_model["device"] + self.scheduler_driver = self.diffusion_model["driver"] + self.scheduler_target = self.diffusion_model["target"] self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet + self.map["vae"]["export_args"]["vae_harness"] = True if self.use_punet: self.setup_punet() - else: + elif not self.is_sd3: self.map["unet"]["keywords"].append("!punet") self.map["unet"]["function_name"] = "run_forward" @@ -395,13 +406,17 @@ def setup_punet(self): def load_scheduler( self, - scheduler_id: str, + scheduler_id: str = None, steps: int = 30, ): if not self.cpu_scheduling: + if self.is_sd3: + export_fn = sd3_schedulers.export_scheduler_model + else: + export_fn = scheduler.export_scheduler_model self.map["scheduler"] = { "module_name": "compiled_scheduler", - "export_fn": schedulers.export_scheduler_model, + "export_fn": export_fn, "driver": self.scheduler_driver, "export_args": { "hf_model_name": self.base_model_name, @@ -419,10 +434,11 @@ def load_scheduler( } self.scheduler = None self.num_inference_steps = steps - self.scheduler_id = scheduler_id + if scheduler_id: + self.scheduler_id = scheduler_id scheduler_uid = "_".join( [ - f"{scheduler_id}Scheduler", + f"{self.scheduler_id}Scheduler", f"bs{self.batch_size}", "x".join([str(self.width), str(self.height)]), self.latents_precision, @@ -445,10 +461,12 @@ def load_scheduler( self.cpu_scheduling = True if self.cpu_scheduling: if self.is_sd3: - scheduler_device = self.mmdit.device + raise AssertionError("CPU scheduling not yet supported for SD3") else: scheduler_device = self.unet.device - scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) + scheduler = schedulers.get_scheduler( + self.base_model_name, self.scheduler_id + ) self.scheduler = schedulers.SharkSchedulerCPUWrapper( scheduler, self.batch_size, @@ -491,6 +509,21 @@ def encode_prompts_sdxl(self, prompt, negative_prompt): ) return prompt_embeds, add_text_embeds + def encode_prompts_sd3(self, prompt, negative_prompt): + text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt) + uncond_input_ids_dict = self.tokenizer.tokenize_with_weights(negative_prompt) + text_input_ids_list = list(text_input_ids_dict.values()) + uncond_input_ids_list = list(uncond_input_ids_dict.values()) + text_encoders_inputs = [ + text_input_ids_list[0], + text_input_ids_list[1], + text_input_ids_list[2], + uncond_input_ids_list[0], + uncond_input_ids_list[1], + uncond_input_ids_list[2], + ] + return self.text_encoder("encode_tokens", text_encoders_inputs) + def prepare_latents( self, noise, @@ -510,10 +543,8 @@ def prepare_latents( timesteps, ) = self.scheduler.initialize_sdxl(noise, num_inference_steps) return sample, add_time_ids, step_indexes, timesteps - elif self.is_sdxl: + elif self.is_sdxl or self.is_sd3: return self.scheduler("run_initialize", noise) - elif self.is_sd3: - raise NotImplementedError("Stable Diffusion 3 not supported yet.") else: sample, timesteps = self.scheduler.initialize_sd(noise, num_inference_steps) return sample, timesteps @@ -529,7 +560,7 @@ def get_rand_latents(self, seed, batch_count): rand_sample = torch.randn( ( self.batch_size, - 4, + self.latents_channels, self.height // 8, self.width // 8, ), @@ -635,6 +666,50 @@ def _produce_latents_sdxl( latents = self.scheduler("run_step", [noise_pred, t, latents]) return latents + def _produce_latents_sd3( + self, + sample, + prompt_embeds, + pooled_prompt_embeds, + steps, + guidance_scale, + ): + image = None + strength = 0 + latents, steps, timesteps = self.scheduler( + "run_initialize", + sample, + ) + guidance_scale = ireert.asdevicearray( + self.mmdit.device, + [guidance_scale], + dtype=self.map["mmdit"]["np_dtype"], + ) + # Disable progress bar if we aren't in verbose mode or if we're printing + # benchmark latencies for unet. + for i, t in tqdm( + enumerate(timesteps), + disable=(self.map["mmdit"].get("benchmark") or not self.verbose), + ): + step = torch.tensor([i], dtype=torch.float32) + latent_model_input, t = self.scheduler( + "run_scale", [latents, step, timesteps] + ) + mmdit_inputs = [ + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + t, + ] + noise_pred = self.mmdit( + "run_forward", + mmdit_inputs, + ) + latents = self.scheduler( + "run_step", [noise_pred, t, latents, guidance_scale] + ) + return latents + def generate_images( self, prompt: str, @@ -674,6 +749,10 @@ def generate_images( prompt_embeds, negative_embeds = self.encode_prompts_sdxl( prompt, negative_prompt ) + elif self.is_sd3: + prompt_embeds, negative_embeds = self.encode_prompts_sd3( + prompt, negative_prompt + ) else: prompt_embeds, negative_embeds = encode_prompt( self, prompt, negative_prompt @@ -689,6 +768,8 @@ def generate_images( ] if self.is_sdxl: latents = self._produce_latents_sdxl(*produce_latents_input) + elif self.is_sd3: + latents = self._produce_latents_sd3(*produce_latents_input) else: latents = self._produce_latents_sd(*produce_latents_input) image = self.vae("decode", [latents]) @@ -699,13 +780,23 @@ def generate_images( timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") images = [] for idx, image in enumerate(numpy_images): - image = torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy() - image = numpy_to_pil_image(image) - images.append(image[0]) + if self.is_sd3: + if image.ndim == 4: + image = image[0] + image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() + image = (image * 255).round().astype("uint8") + out_image = Image.fromarray(image) + images.extend([out_image]) + else: + image = ( + torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy() + ) + image = numpy_to_pil_image(image) + images.append(image[0]) if return_imgs: return images for idx, image in enumerate(images): - img_path = "sdxl_output_" + timestamp + "_" + str(idx) + ".png" + img_path = "sd_output_" + timestamp + "_" + str(idx) + ".png" image.save(img_path) print(img_path, "saved") return @@ -731,9 +822,10 @@ def numpy_to_pil_image(images): from turbine_models.custom_models.sd_inference.sd_cmd_opts import args ireec_flags = { - "clip": args.ireec_flags + args.clip_flags, + "text_encoder": args.ireec_flags + args.clip_flags, "scheduler": args.ireec_flags, "unet": args.ireec_flags + args.unet_flags, + "mmdit": args.ireec_flags + args.mmdit_flags, "vae_decode": args.ireec_flags + args.vae_flags, } if not args.pipeline_dir: @@ -755,14 +847,16 @@ def numpy_to_pil_image(images): save_outputs[i] = True else: save_outputs = False - if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): - args.decomp_attn = { - "text_encoder": args.decomp_attn, - "unet": ( - args.unet_decomp_attn if args.unet_decomp_attn else args.decomp_attn - ), - "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, - } + args.decomp_attn = { + "text_encoder": ( + args.clip_decomp_attn if args.clip_decomp_attn else args.decomp_attn + ), + "unet": (args.unet_decomp_attn if args.unet_decomp_attn else args.decomp_attn), + "mmdit": ( + args.mmdit_decomp_attn if args.mmdit_decomp_attn else args.decomp_attn + ), + "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, + } sd_pipe = SharkSDPipeline( args.hf_model_name, args.height, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 84d9cb3b4..6c33ee4a8 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -21,8 +21,8 @@ "--iree-llvmgpu-enable-prefetch=true", "--iree-execution-model=async-external", ], - "pad_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", + "masked_attention": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" @@ -44,7 +44,7 @@ ], "clip": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-opt-outer-dim-concat=true", "--iree-rocm-waves-per-eu=2", @@ -52,7 +52,7 @@ ], "vae": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions", + "--iree-global-opt-enable-fuse-horizontal-contractions", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-opt-data-tiling=false", @@ -71,12 +71,12 @@ "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], - "pad_attention": [ + "masked_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ @@ -238,6 +238,12 @@ def compile_to_vmfb( tk_kernels_dir=None, batch_size=1, ): + if ireec_flags is not None and "masked_attention" in ireec_flags: + flagset_keywords = ["masked_attention"] + ireec_flags = "".join(ireec_flags.split("masked_attention")) + masked_attention = True + else: + masked_attention = False if batch_size != 1 and batch_size != 8: add_tk_kernels = False flags = [] @@ -318,7 +324,7 @@ def compile_to_vmfb( flags.extend(MI_flags["vae"]) flags.extend(MI_flags["all"]) if "masked_attention" in flagset_keywords: - flags.extend(MI_flags["pad_attention"]) + flags.extend(MI_flags["masked_attention"]) elif "punet" in flagset_keywords: flags.extend(MI_flags["punet"]) elif "vae" in safe_name: @@ -329,7 +335,7 @@ def compile_to_vmfb( if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) if "masked_attention" in flagset_keywords: - flags.extend(GFX11_flags["pad_attention"]) + flags.extend(GFX11_flags["masked_attention"]) elif "punet" in flagset_keywords: flags.extend(GFX11_flags["punet"]) else: @@ -347,11 +353,14 @@ def compile_to_vmfb( target_triple, os.path.dirname(safe_name), use_punet=use_punet, + masked_attention=masked_attention, ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): - attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) + attn_spec = get_wmma_spec_path( + target_triple, os.path.dirname(safe_name), masked_attention=masked_attention + ) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec and attn_spec != "None": diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index c18fb6da5..14422ae15 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -176,14 +176,19 @@ def export_vae_model( if weights_only: return external_weight_path - input_image_shape = (batch_size, 3, height, width) - input_latents_shape = (batch_size, num_channels, height // 8, width // 8) - encode_args = [ - torch.empty( - input_image_shape, - dtype=dtype, - ) - ] + if "stable-diffusion-3" in hf_model_name: + input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, 16, height // 8, width // 8) + else: + input_image_shape = (batch_size, 3, height, width) + input_latents_shape = (batch_size, num_channels, height // 8, width // 8) + + # encode_args = [ + # torch.empty( + # input_image_shape, + # dtype=dtype, + # ) + # ] decode_args = [ torch.empty( input_latents_shape, @@ -204,12 +209,12 @@ def export_vae_model( fxb = FxProgramsBuilder(vae_model) # TODO: fix issues with exporting the encode function. - @fxb.export_program(args=(encode_args,)) - def _encode( - module, - inputs, - ): - return module.encode(*inputs) + # @fxb.export_program(args=(encode_args,)) + # def _encode( + # module, + # inputs, + # ): + # return module.encode(*inputs) @fxb.export_program(args=(decode_args,)) def _decode(module, inputs): @@ -217,7 +222,7 @@ def _decode(module, inputs): class CompiledVae(CompiledModule): decode = _decode - encode = _encode + # encode = _encode if external_weights: externalize_module_parameters(vae_model) @@ -233,15 +238,15 @@ class CompiledVae(CompiledModule): "output_shapes": [(3, width, height) * batch_size], "output_dtypes": ["float32"], } - model_metadata_encode = { - "model_name": "vae_encode", - "input_shapes": [input_image_shape], - "input_dtypes": [np_dtype], - "output_shapes": [input_latents_shape], - "output_dtypes": [np_dtype], - } + # model_metadata_encode = { + # "model_name": "vae_encode", + # "input_shapes": [input_image_shape], + # "input_dtypes": [np_dtype], + # "output_shapes": [input_latents_shape], + # "output_dtypes": [np_dtype], + # } module = AddMetadataPass(module, model_metadata_decode, "decode").run() - module = AddMetadataPass(module, model_metadata_decode, "encode").run() + # module = AddMetadataPass(module, model_metadata_decode, "encode").run() if compile_to != "vmfb": return str(module) From de5d3debded338694842b344f00484ef681af93d Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:07:30 -0500 Subject: [PATCH 45/89] Update test_models.yml CPU runs take up to a minute, so only do two steps to validate in CI. --- .github/workflows/test_models.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 2b2b6062a..cd3351dba 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -69,7 +69,7 @@ jobs: source turbine_venv/bin/activate pytest -v models/turbine_models/tests/sd_test.py - pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 + pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 - pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 + pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 From d0d3ae6e62026c639c6065fa0765d66604a46ceb Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 30 Jul 2024 16:08:29 -0500 Subject: [PATCH 46/89] Remove default in mmdit export args. --- models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index aed7839dd..40e0f18c4 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -152,7 +152,7 @@ def export_mmdit_model( external_weight_path=None, device=None, target=None, - ireec_flags="pad_attention", + ireec_flags="", decomp_attn=False, exit_on_vmfb=False, pipeline_dir=None, From 403fe47e9eaa0c022cddb51db4cea0dd42c64c6b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 30 Jul 2024 16:49:10 -0500 Subject: [PATCH 47/89] set vae_harness to False in sdxl test. --- models/turbine_models/custom_models/sd_inference/sd_pipeline.py | 1 - models/turbine_models/tests/sdxl_test.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 2cd0a5408..969a077da 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -375,7 +375,6 @@ def __init__( self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet - self.map["vae"]["export_args"]["vae_harness"] = True if self.use_punet: self.setup_punet() elif not self.is_sd3: diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 060e3a138..96b90b55d 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -106,6 +106,7 @@ def setUp(self): scheduler_id=arguments["scheduler_id"], shift=None, use_i8_punet=False, + vae_harness=False, ) self.pipe.prepare_all() From 0ac6b6417ef599a512d9be55d7c849d858219f85 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 30 Jul 2024 17:46:14 -0500 Subject: [PATCH 48/89] Switch to main branch of iree-turbine --- models/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/requirements.txt b/models/requirements.txt index 9cee9146a..0aed40159 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -12,5 +12,5 @@ azure-storage-blob einops pytest scipy -shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@sdxl-vae-fix +shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank From 1a41394487ab39e6a27e4e75b8b527cee55d4f63 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:43:28 -0500 Subject: [PATCH 49/89] Update sd3_vae.py --- models/turbine_models/custom_models/sd3_inference/sd3_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py index 56300d5f3..ff24864a6 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -98,7 +98,7 @@ def export_vae_model( vae_model = vae_model.half() mapper = {} utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path, vae_harness=True + mapper, vae_model, external_weights, external_weight_path ) if weights_only: return external_weight_path From 493f2606956f30a7b72319b7d660d35e864993a4 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:44:51 -0500 Subject: [PATCH 50/89] Remove preprocess arg that fails to parse. --- models/turbine_models/custom_models/sd_inference/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 6c33ee4a8..0e5507efc 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -31,7 +31,7 @@ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", From 711403ce89ccb36771f00f51dfb3e6a17ff48aed Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 1 Aug 2024 20:48:24 -0500 Subject: [PATCH 51/89] SD3 updates, CLI arguments for multi-device --- .../custom_models/pipeline_base.py | 7 + .../sd3_inference/diffusers_ref.py | 49 +++++++ .../sd3_inference/sd3_schedulers.py | 9 +- .../sd3_inference/sd3_text_encoders.py | 6 +- .../custom_models/sd_inference/sd_cmd_opts.py | 122 +++++++++++++++--- .../custom_models/sd_inference/sd_pipeline.py | 64 +++++++-- .../custom_models/sd_inference/utils.py | 3 + 7 files changed, 226 insertions(+), 34 deletions(-) create mode 100644 models/turbine_models/custom_models/sd3_inference/diffusers_ref.py diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 3102ac3e3..7d23ef6d3 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -101,6 +101,7 @@ def __init__( self.output_counter = 0 self.dest_type = dest_type self.dest_dtype = dest_dtype + self.validate = False def load( self, @@ -252,6 +253,10 @@ def __call__(self, function_name, inputs: list): if not isinstance(inputs, list): inputs = [inputs] inputs = self._validate_or_convert_inputs(function_name, inputs) + + if self.validate: + self.save_torch_inputs(inputs) + if self.benchmark: output = self._run_and_benchmark(function_name, inputs) else: @@ -261,6 +266,8 @@ def __call__(self, function_name, inputs: list): output = self._output_cast(output) return output + # def _run_and_validate(self, iree_fn, torch_fn, inputs: list) + class Printer: def __init__(self, verbose, start_time, print_time): diff --git a/models/turbine_models/custom_models/sd3_inference/diffusers_ref.py b/models/turbine_models/custom_models/sd3_inference/diffusers_ref.py new file mode 100644 index 000000000..efea91c40 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/diffusers_ref.py @@ -0,0 +1,49 @@ +from diffusers import StableDiffusion3Pipeline +import torch +from datetime import datetime as dt + + +def run_diffusers_cpu( + hf_model_name, + prompt, + negative_prompt, + guidance_scale, + seed, + height, + width, + num_inference_steps, +): + from diffusers import StableDiffusion3Pipeline + + pipe = StableDiffusion3Pipeline.from_pretrained( + hf_model_name, torch_dtype=torch.float32 + ) + pipe = pipe.to("cpu") + generator = torch.Generator().manual_seed(int(seed)) + + image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + ).images[0] + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image.save(f"diffusers_reference_output_{timestamp}.png") + + +if __name__ == "__main__": + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + + run_diffusers_cpu( + args.hf_model_name, + args.prompt, + args.negative_prompt, + args.guidance_scale, + args.seed, + args.height, + args.width, + args.num_inference_steps, + ) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 075b29017..d05ff278d 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -83,8 +83,8 @@ def prepare_model_input(self, sample, t, timesteps): latent_model_input = sample return latent_model_input.type(self.dtype), t.type(self.dtype) - def step(self, noise_pred, t, sample, guidance_scale): - self.model._step_index = self.index_for_timestep(t) + def step(self, noise_pred, t, sample, guidance_scale, i): + self.model._step_index = i if self.do_classifier_free_guidance: noise_preds = noise_pred.chunk(2) @@ -299,6 +299,7 @@ def export_scheduler_model( torch.empty(1, dtype=dtype), torch.empty(sample, dtype=dtype), torch.empty(1, dtype=dtype), + torch.empty([1], dtype=torch.int64), ] fxb = FxProgramsBuilder(scheduler_module) @@ -361,8 +362,8 @@ class CompiledScheduler(CompiledModule): } model_metadata_run_step = { "model_name": "sd3_scheduler_FlowEulerDiscrete", - "input_shapes": [noise_pred_shape, (1,), sample, (1,)], - "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype], + "input_shapes": [noise_pred_shape, (1,), sample, (1,), (1,)], + "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype, "int64"], "output_shapes": [sample], "output_dtypes": [np_dtype], } diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 6487be7ee..67484d70d 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -54,7 +54,6 @@ class TextEncoderModule(torch.nn.Module): @torch.no_grad() def __init__( self, - batch_size=1, ): super().__init__() self.dtype = torch.float16 @@ -89,7 +88,6 @@ def __init__( load_into(f, self.t5xxl.transformer, "", "cpu", self.dtype) self.do_classifier_free_guidance = True - self.batch_size = batch_size def get_cond(self, tokens_l, tokens_g, tokens_t5xxl): l_out, l_pooled = self.clip_l.forward(tokens_l) @@ -152,9 +150,7 @@ def export_text_encoders( attn_spec=attn_spec, ) return vmfb_path - model = TextEncoderModule( - batch_size=batch_size, - ) + model = TextEncoderModule(hf_model_name) mapper = {} assert ( diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 7acf0ef4f..6d0c20379 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -177,10 +177,43 @@ def is_valid_file(arg): default="fp16", help="Precision of Stable Diffusion weights and graph.", ) + +p.add_argument( + "--clip_precision", + type=str, + default=None, + help="Precision of CLIP weights and graph.", +) +p.add_argument( + "--unet_precision", + type=str, + default=None, + help="Precision of CLIP weights and graph.", +) +p.add_argument( + "--mmdit_precision", + type=str, + default=None, + help="Precision of CLIP weights and graph.", +) +p.add_argument( + "--vae_precision", + type=str, + default=None, + help="Precision of CLIP weights and graph.", +) + p.add_argument( "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" ) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) + p.add_argument( "--clip_decomp_attn", action="store_true", @@ -205,12 +238,6 @@ def is_valid_file(arg): help="Decompose attention for unet only at fx graph level", ) -p.add_argument( - "--decomp_attn", - default=False, - action="store_true", - help="Decompose attention at fx graph level", -) p.add_argument( "--use_i8_punet", @@ -270,21 +297,81 @@ def is_valid_file(arg): # IREE Compiler Options ############################################################################## -p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") - p.add_argument( - "--rt_device", + "--device", type=str, default="local-task", help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", ) +p.add_argument( + "--clip_device", + type=str, + default=None, + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) +p.add_argument( + "--unet_device", + type=str, + default=None, + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) +p.add_argument( + "--mmdit_device", + type=str, + default=None, + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) +p.add_argument( + "--vae_device", + type=str, + default=None, + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) +p.add_argument( + "--scheduler_device", + type=str, + default=None, + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + # TODO: Bring in detection for target triple p.add_argument( "--iree_target_triple", type=str, default="x86_64-linux-gnu", - help="Specify vulkan target triple or rocm/cuda target device.", + help="Specify vulkan target triple or rocm/cuda target chip.", +) + +p.add_argument( + "--clip_target", + type=str, + default=None, + help="Specify vulkan target triple or rocm/cuda target chip.", +) +p.add_argument( + "--unet_target", + type=str, + default=None, + help="Specify vulkan target triple or rocm/cuda target chip.", +) +p.add_argument( + "--mmdit_target", + type=str, + default=None, + help="Specify vulkan target triple or rocm/cuda target chip.", +) +p.add_argument( + "--vae_target", + type=str, + default=None, + help="Specify vulkan target triple or rocm/cuda target chip.", +) +p.add_argument( + "--scheduler_target", + type=str, + default=None, + help="Specify vulkan target triple or rocm/cuda target chip.", ) p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") @@ -296,13 +383,6 @@ def is_valid_file(arg): help="extra iree-compile options for models with iree_linalg_ext.attention ops.", ) -p.add_argument( - "--attn_spec", - type=str, - default=None, - help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", -) - p.add_argument( "--clip_flags", type=str, @@ -331,4 +411,12 @@ def is_valid_file(arg): help="extra iree-compile options to send for compiling mmdit. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", ) +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + + args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 969a077da..3d2c50238 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -150,6 +150,7 @@ "module_name": "compiled_text_encoder", "keywords": ["text_encoder"], "export_fn": sd3_text_encoders.export_text_encoders, + "torch_module": sd3_text_encoders.TextEncoderModule, "export_args": { "batch_size": 1, "max_length": 64, @@ -159,6 +160,7 @@ "module_name": "compiled_mmdit", "keywords": ["mmdit"], "export_fn": sd3_mmdit.export_mmdit_model, + "torch_module": sd3_mmdit.MMDiTModel, "export_args": { "batch_size": 1, "height": 1024, @@ -172,6 +174,7 @@ "keywords": ["vae"], "dest_type": "numpy", "export_fn": vae.export_vae_model, + "torch_module": vae.SD3VaeModel, "export_args": { "batch_size": 1, "height": 1024, @@ -353,6 +356,7 @@ def __init__( if vae_weight_path is not None: self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path self.map["vae"]["export_args"]["vae_harness"] = vae_harness + elif self.is_sd3: self.tokenizer = SD3Tokenizer() self.scheduler_id = "EulerFlowDiscrete" @@ -372,7 +376,9 @@ def __init__( self.scheduler_device = self.diffusion_model["device"] self.scheduler_driver = self.diffusion_model["driver"] self.scheduler_target = self.diffusion_model["target"] - + self.cast_latents_to_vae = False + if self.diffusion_model["driver"] != self.map["vae"]["driver"]: + self.cast_latents_to_vae = True self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet if self.use_punet: @@ -675,7 +681,7 @@ def _produce_latents_sd3( ): image = None strength = 0 - latents, steps, timesteps = self.scheduler( + latents, indexes, timesteps = self.scheduler( "run_initialize", sample, ) @@ -684,15 +690,26 @@ def _produce_latents_sd3( [guidance_scale], dtype=self.map["mmdit"]["np_dtype"], ) + steps_list_gpu = [ + ireert.asdevicearray(self.scheduler.device, [i], dtype="int64") + for i in range(steps) + ] + timesteps_cpu = timesteps + timesteps_list_gpu = [ + ireert.asdevicearray( + self.scheduler.device, [timesteps_cpu[i]], dtype="float32" + ) + for i in range(steps) + ] + # Disable progress bar if we aren't in verbose mode or if we're printing # benchmark latencies for unet. for i, t in tqdm( enumerate(timesteps), disable=(self.map["mmdit"].get("benchmark") or not self.verbose), ): - step = torch.tensor([i], dtype=torch.float32) latent_model_input, t = self.scheduler( - "run_scale", [latents, step, timesteps] + "run_scale", [latents, timesteps_list_gpu[i], timesteps] ) mmdit_inputs = [ latent_model_input, @@ -705,7 +722,7 @@ def _produce_latents_sd3( mmdit_inputs, ) latents = self.scheduler( - "run_step", [noise_pred, t, latents, guidance_scale] + "run_step", [noise_pred, t, latents, guidance_scale, steps_list_gpu[i]] ) return latents @@ -771,6 +788,13 @@ def generate_images( latents = self._produce_latents_sd3(*produce_latents_input) else: latents = self._produce_latents_sd(*produce_latents_input) + + if self.cast_latents_to_vae: + latents = ireert.asdevicearray( + self.vae.device, + latents.to_host(), + dtype=self.map["vae"]["np_dtype"], + ) image = self.vae("decode", [latents]) numpy_images.append(image) pipe_end = time.time() @@ -825,7 +849,31 @@ def numpy_to_pil_image(images): "scheduler": args.ireec_flags, "unet": args.ireec_flags + args.unet_flags, "mmdit": args.ireec_flags + args.mmdit_flags, - "vae_decode": args.ireec_flags + args.vae_flags, + "vae": args.ireec_flags + args.vae_flags, + } + devices = { + "text_encoder": args.clip_device if args.clip_device else args.device, + "scheduler": args.scheduler_device if args.scheduler_device else args.device, + "unet": args.unet_device if args.unet_device else args.device, + "mmdit": args.mmdit_device if args.mmdit_device else args.device, + "vae": args.vae_device if args.vae_device else args.device, + } + targets = { + "text_encoder": ( + args.clip_target if args.clip_target else args.iree_target_triple + ), + "scheduler": ( + args.scheduler_target if args.scheduler_target else args.iree_target_triple + ), + "unet": args.unet_target if args.unet_target else args.iree_target_triple, + "mmdit": args.mmdit_target if args.mmdit_target else args.iree_target_triple, + "vae": args.vae_target if args.vae_target else args.iree_target_triple, + } + precisions = { + "text_encoder": args.clip_precision if args.clip_precision else args.precision, + "unet": args.unet_precision if args.unet_precision else args.precision, + "mmdit": args.mmdit_precision if args.mmdit_precision else args.precision, + "vae": args.vae_precision if args.vae_precision else args.precision, } if not args.pipeline_dir: args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") @@ -863,8 +911,8 @@ def numpy_to_pil_image(images): args.batch_size, args.max_length, args.precision, - args.device, - args.iree_target_triple, + devices, + targets, ireec_flags, args.attn_spec, args.decomp_attn, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 0e5507efc..362aeccaf 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -244,6 +244,9 @@ def compile_to_vmfb( masked_attention = True else: masked_attention = False + if ireec_flags is not None and "winograd" in ireec_flags: + winograd = True + ireec_flags = "".join(ireec_flags.split("winograd")) if batch_size != 1 and batch_size != 8: add_tk_kernels = False flags = [] From e554da803c27f435198c0be91958208b52f392b1 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 2 Aug 2024 12:27:17 -0500 Subject: [PATCH 52/89] Tweaks to requirements, scheduler filenames --- models/README.md | 18 ++++-------------- models/requirements.txt | 4 ++-- models/setup.py | 2 +- .../sd3_inference/sd3_schedulers.py | 4 +--- .../custom_models/sd_inference/sd_pipeline.py | 2 -- 5 files changed, 8 insertions(+), 22 deletions(-) diff --git a/models/README.md b/models/README.md index 4fe6ea1b2..c917d03ee 100644 --- a/models/README.md +++ b/models/README.md @@ -1,26 +1,19 @@ -# LLAMA 2 Inference +# Turbine-Models setup (source) -This example require some extra dependencies. Here's an easy way to get it running on a fresh server. - -Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens +For private/gated models, make sure you have run `huggingface-cli login`. ```bash #!/bin/bash - -# if you don't insert it, you will be prompted to log in later; -# you may need to rerun this script after logging in -YOUR_HF_TOKEN="insert token for headless" - # clone and install dependencies sudo apt install -y git git clone https://github.com/nod-ai/SHARK-Turbine.git cd SHARK-Turbine -pip install -r core/requirements.txt +pip install torch==2.5.0.dev20240801 torchvision --index-url https://download.pytorch.org/whl/nightly/cpu pip install -r models/requirements.txt # do an editable install from the cloned SHARK-Turbine -pip install --editable core models +pip install --editable models # Log in with Hugging Face CLI if token setup is required if [[ $YOUR_HF_TOKEN == hf_* ]]; then @@ -42,6 +35,3 @@ else huggingface-cli login fi -# Step 7: Run the Python script -python .\python\turbine_models\custom_models\stateless_llama.py --compile_to=torch --external_weights=safetensors --external_weight_file=llama_f32.safetensors -``` diff --git a/models/requirements.txt b/models/requirements.txt index 0aed40159..72b8a398a 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -1,6 +1,6 @@ protobuf gguf -transformers==4.37.1 +transformers==4.43.3 torchsde accelerate peft @@ -13,4 +13,4 @@ einops pytest scipy shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main --e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank +-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank \ No newline at end of file diff --git a/models/setup.py b/models/setup.py index 09d60cfe3..ae737a9fc 100644 --- a/models/setup.py +++ b/models/setup.py @@ -57,7 +57,7 @@ def load_version_info(): "Shark-Turbine", "protobuf", "sentencepiece", - "transformers>=4.37.1", + "transformers>=4.43.3", "accelerate", "diffusers==0.29.0.dev0", "azure-storage-blob", diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index d05ff278d..676717e23 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -141,8 +141,7 @@ def initialize(self, sample): if isinstance(sample, ireert.DeviceArray): sample = torch.tensor(sample.to_host(), dtype=torch.float32) step_indexes = torch.tensor(len(self.module.timesteps)) - timesteps = self.timesteps - return sample, step_indexes, timesteps + return sample, step_indexes, self.module.timesteps def scale_model_input(self, sample, t, timesteps): if self.do_classifier_free_guidance: @@ -246,7 +245,6 @@ def export_scheduler_model( f"bs{batch_size}_{height}x{width}", precision, str(num_inference_steps), - target, ] vmfb_name = "_".join(vmfb_names) safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 3d2c50238..18924468f 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -679,8 +679,6 @@ def _produce_latents_sd3( steps, guidance_scale, ): - image = None - strength = 0 latents, indexes, timesteps = self.scheduler( "run_initialize", sample, From cdd2f66ac0f4ef5edea3bd232d56ac2b16df2803 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:47:35 -0500 Subject: [PATCH 53/89] xfail stateless llama test --- models/turbine_models/tests/stateless_llama_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index 4b1ffef73..203eef9a5 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -196,6 +196,7 @@ def test_streaming_vmfb_comparison(self): # See: https://github.com/nod-ai/SHARK-Turbine/issues/560 # Developed issues related to the pytorch 2.3 upgrade. + @unittest.expectedFailure def test_rerotated_torch_comparison(self): torch_str = llm_runner.run_torch_llm( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", From d23a45b1799bc2c6b92e4e40d654bdd09c4bf3d5 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 13 Aug 2024 18:48:00 -0500 Subject: [PATCH 54/89] Flag updates and parametrize a few more args. --- models/requirements.txt | 1 - .../custom_models/pipeline_base.py | 4 +-- .../sd3_inference/sd3_text_encoders.py | 15 ++++++--- .../custom_models/sd_inference/sd_cmd_opts.py | 32 +++++++++++++++++-- .../custom_models/sd_inference/sd_pipeline.py | 10 ++++-- .../custom_models/sd_inference/utils.py | 18 +++++------ 6 files changed, 58 insertions(+), 22 deletions(-) diff --git a/models/requirements.txt b/models/requirements.txt index 72b8a398a..06283efd5 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -8,7 +8,6 @@ diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading azure-storage-blob -# microsoft/phi model einops pytest scipy diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 7d23ef6d3..c5f550e5d 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -358,7 +358,7 @@ def __init__( target: str | dict[str], ireec_flags: str | dict[str] = None, precision: str | dict[str] = "fp16", - td_spec: str | dict[str] = None, + attn_spec: str | dict[str] = None, decomp_attn: bool | dict[bool] = False, external_weights: str | dict[str] = None, pipeline_dir: str = "./shark_vmfbs", @@ -396,7 +396,7 @@ def __init__( map_arguments = { "ireec_flags": ireec_flags, "precision": precision, - "td_spec": td_spec, + "attn_spec": attn_spec, "decomp_attn": decomp_attn, "external_weights": external_weights, "hf_model_name": hf_model_name, diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 67484d70d..3edf6b402 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -54,9 +54,10 @@ class TextEncoderModule(torch.nn.Module): @torch.no_grad() def __init__( self, + precision, ): super().__init__() - self.dtype = torch.float16 + self.dtype = torch.float16 if precision == "fp16" else torch.float32 self.clip_l = SDClipModel( layer="hidden", layer_idx=-2, @@ -65,21 +66,25 @@ def __init__( layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, - ).half() + ) + if precision == "fp16": + self.clip_l = self.clip_l.half() clip_l_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_l.safetensors", ) with safe_open(clip_l_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype) + if precision == "fp16": + self.clip_l = self.clip_g.half() clip_g_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_g.safetensors", ) with safe_open(clip_g_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float16) t5_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/t5xxl_fp16.safetensors", @@ -150,7 +155,7 @@ def export_text_encoders( attn_spec=attn_spec, ) return vmfb_path - model = TextEncoderModule(hf_model_name) + model = TextEncoderModule(precision) mapper = {} assert ( diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 6d0c20379..87e01467a 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -188,21 +188,47 @@ def is_valid_file(arg): "--unet_precision", type=str, default=None, - help="Precision of CLIP weights and graph.", + help="Precision of UNet weights and graph.", ) p.add_argument( "--mmdit_precision", type=str, default=None, - help="Precision of CLIP weights and graph.", + help="Precision of mmdit weights and graph.", ) p.add_argument( "--vae_precision", type=str, default=None, - help="Precision of CLIP weights and graph.", + help="Precision of vae weights and graph.", ) +p.add_argument( + "--clip_spec", + type=str, + default=None, + help="transform dialect spec for the given submodel.", +) +p.add_argument( + "--unet_spec", + type=str, + default=None, + help="transform dialect spec for the given submodel.", +) +p.add_argument( + "--mmdit_spec", + type=str, + default=None, + help="transform dialect spec for the given submodel.", +) +p.add_argument( + "--vae_spec", + type=str, + default=None, + help="transform dialect spec for the given submodel.", +) + + p.add_argument( "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" ) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 18924468f..03c4f64f0 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -873,6 +873,12 @@ def numpy_to_pil_image(images): "mmdit": args.mmdit_precision if args.mmdit_precision else args.precision, "vae": args.vae_precision if args.vae_precision else args.precision, } + specs = { + "text_encoder": args.clip_spec if args.clip_spec else args.attn_spec, + "unet": args.unet_spec if args.unet_spec else args.attn_spec, + "mmdit": args.mmdit_spec if args.mmdit_spec else args.attn_spec, + "vae": args.vae_spec if args.vae_spec else args.attn_spec, + } if not args.pipeline_dir: args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") benchmark = {} @@ -908,11 +914,11 @@ def numpy_to_pil_image(images): args.width, args.batch_size, args.max_length, - args.precision, + precisions, devices, targets, ireec_flags, - args.attn_spec, + specs, args.decomp_attn, args.pipeline_dir, args.external_weights_dir, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 362aeccaf..8db8806f0 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -22,16 +22,16 @@ "--iree-execution-model=async-external", ], "masked_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "vae_preprocess": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", ], "unet": [ "--iree-flow-enable-aggressive-fusion", @@ -44,7 +44,7 @@ ], "clip": [ "--iree-flow-enable-aggressive-fusion", - "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-opt-outer-dim-concat=true", "--iree-rocm-waves-per-eu=2", @@ -71,19 +71,19 @@ "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", "--iree-flow-enable-aggressive-fusion", - "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "masked_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)", ], "unet": [""], "clip": [""], From 2d7a92e360eb24397f381985b16fba1673e2816c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 15 Aug 2024 13:18:53 -0500 Subject: [PATCH 55/89] Update SDXL tests, README for running on GFX942 --- models/README.md | 40 +++++++++++++++++++ models/requirements.txt | 1 + .../custom_models/sd_inference/sd_pipeline.py | 5 +++ .../custom_models/sd_inference/utils.py | 28 ++++++------- .../sdxl_inference/sdxl_prompt_encoder.py | 3 +- models/turbine_models/tests/conftest.py | 3 ++ models/turbine_models/tests/sdxl_test.py | 26 +++++++----- 7 files changed, 81 insertions(+), 25 deletions(-) diff --git a/models/README.md b/models/README.md index c917d03ee..96f3dadd7 100644 --- a/models/README.md +++ b/models/README.md @@ -2,6 +2,46 @@ For private/gated models, make sure you have run `huggingface-cli login`. +For MI Instinct: +```bash +#!/bin/bash +sudo apt install -y git + +# Clone and build IREE at the shared/sdxl_quantized branch +git clone https://github.com/iree-org/iree && cd iree +git checkout shared/sdxl_quantized +git submodule update --init +python -m venv iree.venv +pip install pybind11 numpy nanobind +cmake -S . -B build-release \ + -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=`which clang` -DCMAKE_CXX_COMPILER=`which clang++` \ + -DIREE_HAL_DRIVER_CUDA=OFF \ + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python3)" && \ + cmake --build build-release/ + +export PYTHONPATH=/path/to/iree/build-release/compiler/bindings/python:/path/to/iree/build-release/runtime/bindings/python + +# Clone and setup turbine-models +cd .. +git clone https://github.com/nod-ai/SHARK-Turbine.git && cd SHARK-Turbine +git checkout merge_punet_sdxl +pip install torch==2.5.0.dev20240801 torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +pip install -r models/requirements.txt +pip uninstall -y iree-compiler iree-runtime + +pip install -e models + +# Run sdxl tests. +python models/turbine_models/tests/sdxl_test.py pytest --device=rocm --rt_device=hip --iree_target_triple=gfx942 --external_weights=safetensors --precision=fp16 --clip_spec=mfma --unet_spec=mfma --vae_spec=mfma + +# Generate an image. +# To reuse test artifacts/weights, add: --pipeline_dir=./test_vmfbs --external_weights_dir=./test_weights +python models/turbine_models/custom_models/sd_inference/sd_pipeline.py --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --device=hip://0 --precision=fp16 --external_weights=safetensors --iree_target_triple=gfx942 --vae_decomp_attn --clip_decomp_attn --use_i8_punet --width=1024 --height=1024 --num_inference_steps=20 --benchmark=all --verbose + +``` +For GFX11 (RDNA3 Discrete GPUs/Ryzen laptops) the following setup is validated: ```bash #!/bin/bash diff --git a/models/requirements.txt b/models/requirements.txt index 06283efd5..ec153a02a 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -4,6 +4,7 @@ transformers==4.43.3 torchsde accelerate peft +safetensors==0.4.0 diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 03c4f64f0..e37c095c4 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -388,6 +388,11 @@ def __init__( self.map["unet"]["function_name"] = "run_forward" def setup_punet(self): + self.map["unet"]["mlir"] = None + self.map["unet"]["vmfb"] = None + self.map["unet"]["weights"] = None + self.map["unet"]["keywords"] = [i for i in self.map["unet"]["keywords"] if i != "!punet"] + self.map["unet"]["keywords"] += "punet" if self.use_i8_punet: if self.add_tk_kernels: self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 8db8806f0..07e609989 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -22,16 +22,16 @@ "--iree-execution-model=async-external", ], "masked_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" ], "vae_preprocess": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", @@ -52,7 +52,7 @@ ], "vae": [ "--iree-flow-enable-aggressive-fusion", - "--iree-global-opt-enable-fuse-horizontal-contractions", + "--iree-flow-enable-fuse-horizontal-contractions", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-opt-data-tiling=false", @@ -350,15 +350,15 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: - if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: - use_punet = True if attn_spec in ["punet", "i8"] else False - attn_spec = get_mfma_spec_path( - target_triple, - os.path.dirname(safe_name), - use_punet=use_punet, - masked_attention=masked_attention, - ) - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) +# if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + use_punet=use_punet, + masked_attention=masked_attention, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 40ce6c2e5..d547cadf7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -177,7 +177,7 @@ def export_prompt_encoder( safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batch_size}_{str(max_length)}-{precision}-prompt-encoder-{device}", + f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder-{device}", ) if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, safe_name) @@ -275,7 +275,6 @@ def encode_prompts_turbo( } module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() module_str = str(module) - if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py index 4292c7390..bcd8ba91a 100644 --- a/models/turbine_models/tests/conftest.py +++ b/models/turbine_models/tests/conftest.py @@ -39,6 +39,9 @@ def pytest_addoption(parser): parser.addoption("--decomp_attn", action="store", default=False) parser.addoption("--vae_decomp_attn", action="store", default=False) parser.addoption("--attn_spec", action="store", default="") + parser.addoption("--clip_spec", action="store", default="") + parser.addoption("--unet_spec", action="store", default="") + parser.addoption("--vae_spec", action="store", default="") # Compiler Options parser.addoption("--device", action="store", default="cpu") parser.addoption("--rt_device", action="store", default="local-task") diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 96b90b55d..7cec4a661 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -24,6 +24,7 @@ import os import numpy as np import time +import gc torch.random.manual_seed(0) @@ -61,7 +62,11 @@ def command_line_args(request): arguments["compile_to"] = request.config.getoption("--compile_to") arguments["external_weights"] = request.config.getoption("--external_weights") arguments["decomp_attn"] = request.config.getoption("--decomp_attn") - arguments["attn_spec"] = request.config.getoption("--attn_spec") + arguments["attn_spec"] = request.config.getoption("--attn_spec") if request.config.getoption("attn_spec") else { + "text_encoder": request.config.getoption("clip_spec"), + "unet": request.config.getoption("unet_spec"), + "vae": request.config.getoption("vae_spec"), + } arguments["device"] = request.config.getoption("--device") arguments["rt_device"] = request.config.getoption("--rt_device") arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") @@ -111,9 +116,9 @@ def setUp(self): self.pipe.prepare_all() def test01_PromptEncoder(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( - "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." + "Compilation error on vulkan; To be tested on cuda." ) arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] @@ -235,7 +240,6 @@ def test02_ExportUnetModel(self): ) rtol = 4e-2 atol = 4e-1 - np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test03_ExportVaeModelDecode(self): @@ -279,7 +283,6 @@ def test03_ExportVaeModelDecode(self): ) rtol = 4e-2 atol = 4e-1 - np.testing.assert_allclose(torch_output, turbine, rtol, atol) @pytest.mark.xfail(reason="NaN output on rocm, needs triage and file") @@ -345,13 +348,13 @@ def test05_t2i_generate_images(self): ) assert output is not None - @pytest.mark.xfail(reason="compilation issue on gfx90a") def test06_t2i_generate_images_punet(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( - "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." + "Have issues with submodels on vulkan, cuda" ) - self.pipe.unload_submodel("unet") + if getattr(self.pipe, "unet"): + self.pipe.unload_submodel("unet") self.pipe.use_punet = True self.pipe.use_i8_punet = True self.pipe.setup_punet() @@ -369,6 +372,11 @@ def test06_t2i_generate_images_punet(self): True, # return_img ) assert output is not None + + def tearDown(self): + del self.pipe + gc.collect() + if __name__ == "__main__": From 18bffdbc61f3379d24b01306033a3d46545259bc Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Sat, 17 Aug 2024 11:37:19 -0500 Subject: [PATCH 56/89] Fix vae script CLI and revert precision changes to sd3 text encoders export --- .../custom_models/pipeline_base.py | 1 + .../sd3_inference/sd3_text_encoders.py | 17 ++++++----------- .../custom_models/sd_inference/vae.py | 10 ---------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index c5f550e5d..eaf90a560 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -758,6 +758,7 @@ def load_map(self): if not self.map[submodel]["load"]: self.printer.print(f"Skipping load for {submodel}") continue + breakpoint() self.load_submodel(submodel) def load_submodel(self, submodel): diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 3edf6b402..8acf4fe3f 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -13,6 +13,7 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.ops.iree import trace_tensor from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch @@ -54,10 +55,9 @@ class TextEncoderModule(torch.nn.Module): @torch.no_grad() def __init__( self, - precision, ): super().__init__() - self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.dtype = torch.float16 self.clip_l = SDClipModel( layer="hidden", layer_idx=-2, @@ -66,25 +66,21 @@ def __init__( layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, - ) - if precision == "fp16": - self.clip_l = self.clip_l.half() + ).half() clip_l_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_l.safetensors", ) with safe_open(clip_l_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype) - if precision == "fp16": - self.clip_l = self.clip_g.half() + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() clip_g_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/clip_g.safetensors", ) with safe_open(clip_g_weights, framework="pt", device="cpu") as f: load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float16) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() t5_weights = hf_hub_download( repo_id="stabilityai/stable-diffusion-3-medium", filename="text_encoders/t5xxl_fp16.safetensors", @@ -155,8 +151,7 @@ def export_text_encoders( attn_spec=attn_spec, ) return vmfb_path - model = TextEncoderModule(precision) - mapper = {} + model = TextEncoderModule() assert ( ".safetensors" not in external_weight_path diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 14422ae15..add39b353 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -265,16 +265,7 @@ class CompiledVae(CompiledModule): if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args - - if args.input_mlir: - vae_model = None - else: - vae_model = VaeModel( - args.hf_model_name, - custom_vae=None, - ) mod_str = export_vae_model( - vae_model, args.hf_model_name, args.batch_size, height=args.height, @@ -286,7 +277,6 @@ class CompiledVae(CompiledModule): device=args.device, target=args.iree_target_triple, ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, - variant=args.vae_variant, decomp_attn=args.decomp_attn, attn_spec=args.attn_spec, input_mlir=args.input_mlir, From 674128e0fc31f27147548da28d3d86f6ecd80521 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 26 Aug 2024 19:34:38 -0500 Subject: [PATCH 57/89] Small fixes to compile modes and requirements --- models/requirements.txt | 2 +- models/turbine_models/custom_models/pipeline_base.py | 1 - .../custom_models/sd_inference/utils.py | 11 +++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/models/requirements.txt b/models/requirements.txt index ec153a02a..dcccbcdac 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -4,7 +4,7 @@ transformers==4.43.3 torchsde accelerate peft -safetensors==0.4.0 +safetensors>=0.4.0 diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index eaf90a560..c5f550e5d 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -758,7 +758,6 @@ def load_map(self): if not self.map[submodel]["load"]: self.printer.print(f"Skipping load for {submodel}") continue - breakpoint() self.load_submodel(submodel) def load_submodel(self, submodel): diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 07e609989..72c8f93f2 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -71,19 +71,22 @@ "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", - "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "masked_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "punet": [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "unet": [""], "clip": [""], From 4d6198bd8b82c01d37a759dc3acd0aaa5dccd64e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 28 Aug 2024 11:57:17 -0500 Subject: [PATCH 58/89] Adds explicit model arch flag, remove commented code --- .../custom_models/pipeline_base.py | 9 --------- .../custom_models/sd_inference/sd_cmd_opts.py | 6 ++++++ .../custom_models/sd_inference/sd_pipeline.py | 15 ++++++++++++--- .../custom_models/sd_inference/vae.py | 19 ------------------- 4 files changed, 18 insertions(+), 31 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index c5f550e5d..837b95e64 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -68,13 +68,6 @@ def merge_export_arg(model_map, arg, arg_name): return model_map -# def str_to_list(string): -# out = string.strip("[]").replace(" ", "").split(";") -# for item in out: -# item = ast.literal_eval(item) -# return out - - class PipelineComponent: """ Wraps a VMFB runner with attributes for embedded metadata, device info, utilities and @@ -266,8 +259,6 @@ def __call__(self, function_name, inputs: list): output = self._output_cast(output) return output - # def _run_and_validate(self, iree_fn, torch_fn, inputs: list) - class Printer: def __init__(self, verbose, start_time, print_time): diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 87e01467a..cfc46ed88 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -43,6 +43,12 @@ def is_valid_file(arg): help="HF model name", default="stabilityai/stable-diffusion-2-1", ) +p.add_argument( + "--model_arch", + type=str, + help="SD pipeline/model architecture. Choices are [sd, sdxl, sd3]." + default=None, +) p.add_argument( "--scheduler_id", type=str, diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index e37c095c4..ca3db41bc 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -185,9 +185,16 @@ }, } +arch_mappings = { + "sd": sd1_sd2_model_map, + "sdxl": sdxl_model_map, + "sd3": sd3_model_map, +} -def get_sd_model_map(hf_model_name): - if isinstance(hf_model_name, dict): +def get_sd_model_map(hf_model_name, model_arch=None): + if model_arch: + return arch_mappings[model_arch] + elif isinstance(hf_model_name, dict): name = hf_model_name["text_encoder"] else: name = hf_model_name @@ -245,6 +252,7 @@ def __init__( add_tk_kernels: bool = False, tk_kernels_dir: str | dict[str] = None, save_outputs: bool | dict[bool] = False, + model_arch: str = None, ): common_export_args = { "hf_model_name": None, @@ -260,7 +268,7 @@ def __init__( "external_weights": None, "external_weight_path": None, } - sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name)) + sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name, model_arch)) for submodel in sd_model_map: if "load" not in sd_model_map[submodel]: sd_model_map[submodel]["load"] = True @@ -936,6 +944,7 @@ def numpy_to_pil_image(images): benchmark, args.verbose, save_outputs=save_outputs, + model_arch=args.model_arch, ) sd_pipe.prepare_all() sd_pipe.load_map() diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index add39b353..485c2c221 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -183,12 +183,6 @@ def export_vae_model( input_image_shape = (batch_size, 3, height, width) input_latents_shape = (batch_size, num_channels, height // 8, width // 8) - # encode_args = [ - # torch.empty( - # input_image_shape, - # dtype=dtype, - # ) - # ] decode_args = [ torch.empty( input_latents_shape, @@ -209,12 +203,6 @@ def export_vae_model( fxb = FxProgramsBuilder(vae_model) # TODO: fix issues with exporting the encode function. - # @fxb.export_program(args=(encode_args,)) - # def _encode( - # module, - # inputs, - # ): - # return module.encode(*inputs) @fxb.export_program(args=(decode_args,)) def _decode(module, inputs): @@ -238,13 +226,6 @@ class CompiledVae(CompiledModule): "output_shapes": [(3, width, height) * batch_size], "output_dtypes": ["float32"], } - # model_metadata_encode = { - # "model_name": "vae_encode", - # "input_shapes": [input_image_shape], - # "input_dtypes": [np_dtype], - # "output_shapes": [input_latents_shape], - # "output_dtypes": [np_dtype], - # } module = AddMetadataPass(module, model_metadata_decode, "decode").run() # module = AddMetadataPass(module, model_metadata_decode, "encode").run() From f3e3fe3d5af9c58a97f5afb70a35779d62d000a6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 28 Aug 2024 11:58:12 -0500 Subject: [PATCH 59/89] Fix formatting --- .../custom_models/sd_inference/sd_pipeline.py | 5 +++- .../custom_models/sd_inference/utils.py | 2 +- .../custom_models/sd_inference/vae.py | 1 + models/turbine_models/tests/sdxl_test.py | 25 +++++++++---------- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index ca3db41bc..4cd1e4b86 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -191,6 +191,7 @@ "sd3": sd3_model_map, } + def get_sd_model_map(hf_model_name, model_arch=None): if model_arch: return arch_mappings[model_arch] @@ -399,7 +400,9 @@ def setup_punet(self): self.map["unet"]["mlir"] = None self.map["unet"]["vmfb"] = None self.map["unet"]["weights"] = None - self.map["unet"]["keywords"] = [i for i in self.map["unet"]["keywords"] if i != "!punet"] + self.map["unet"]["keywords"] = [ + i for i in self.map["unet"]["keywords"] if i != "!punet" + ] self.map["unet"]["keywords"] += "punet" if self.use_i8_punet: if self.add_tk_kernels: diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 72c8f93f2..ec70141d0 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -353,7 +353,7 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: -# if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + # if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: use_punet = True if attn_spec in ["punet", "i8"] else False attn_spec = get_mfma_spec_path( target_triple, diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 485c2c221..9d5c0c6f4 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -246,6 +246,7 @@ class CompiledVae(CompiledModule): if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + mod_str = export_vae_model( args.hf_model_name, args.batch_size, diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 7cec4a661..cf0a05f5e 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -62,11 +62,15 @@ def command_line_args(request): arguments["compile_to"] = request.config.getoption("--compile_to") arguments["external_weights"] = request.config.getoption("--external_weights") arguments["decomp_attn"] = request.config.getoption("--decomp_attn") - arguments["attn_spec"] = request.config.getoption("--attn_spec") if request.config.getoption("attn_spec") else { - "text_encoder": request.config.getoption("clip_spec"), - "unet": request.config.getoption("unet_spec"), - "vae": request.config.getoption("vae_spec"), - } + arguments["attn_spec"] = ( + request.config.getoption("--attn_spec") + if request.config.getoption("attn_spec") + else { + "text_encoder": request.config.getoption("clip_spec"), + "unet": request.config.getoption("unet_spec"), + "vae": request.config.getoption("vae_spec"), + } + ) arguments["device"] = request.config.getoption("--device") arguments["rt_device"] = request.config.getoption("--rt_device") arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") @@ -117,9 +121,7 @@ def setUp(self): def test01_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest( - "Compilation error on vulkan; To be tested on cuda." - ) + self.skipTest("Compilation error on vulkan; To be tested on cuda.") arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] tokenizer_1 = CLIPTokenizer.from_pretrained( @@ -350,9 +352,7 @@ def test05_t2i_generate_images(self): def test06_t2i_generate_images_punet(self): if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest( - "Have issues with submodels on vulkan, cuda" - ) + self.skipTest("Have issues with submodels on vulkan, cuda") if getattr(self.pipe, "unet"): self.pipe.unload_submodel("unet") self.pipe.use_punet = True @@ -372,13 +372,12 @@ def test06_t2i_generate_images_punet(self): True, # return_img ) assert output is not None - + def tearDown(self): del self.pipe gc.collect() - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From 7adfc7a2a0d4ee72cde240f19c0ae4455370b6f5 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 28 Aug 2024 11:59:58 -0500 Subject: [PATCH 60/89] Fix formatting --- models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index cfc46ed88..1571fd2e5 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -46,7 +46,7 @@ def is_valid_file(arg): p.add_argument( "--model_arch", type=str, - help="SD pipeline/model architecture. Choices are [sd, sdxl, sd3]." + help="SD pipeline/model architecture. Choices are [sd, sdxl, sd3].", default=None, ) p.add_argument( From ff2c3c9911ba87e9980681eaf426a668638c23ef Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:13:27 -0500 Subject: [PATCH 61/89] Update test_models.yml --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 0ee55d187..88218690e 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -57,7 +57,7 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt + pip install --pre torch==2.5.0.dev20240801+cpu torchvision==0.20.0.dev20240801+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html From afdb8d6910b3520bea3b5ee8c347a260c1ed1b9d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 10 Sep 2024 14:42:52 -0500 Subject: [PATCH 62/89] Decompose CLIP attention --- models/turbine_models/custom_models/sd_inference/clip.py | 2 +- models/turbine_models/tests/sd_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 11705a916..5182e6e0f 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -34,7 +34,7 @@ def export_clip_model( attn_spec: str = None, weights_only: bool = False, upload_ir: bool = False, - decomp_attn: bool = False, + decomp_attn: bool = True, ): input_len = max_length safe_name = utils.create_safe_name( diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 98a3cfca2..9b67d2ecf 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -78,6 +78,7 @@ def testExportClipModel(self): target=current_args["iree_target_triple"], exit_on_vmfb=False, upload_ir=UPLOAD_IR, + decomp_attn=True, ) current_args["external_weight_path"] = safe_prefix + ".safetensors" current_args["vmfb_path"] = blob_name From a4e67e87d919ee72b7847289c89c8226dca8bcf6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 10 Sep 2024 19:07:27 -0500 Subject: [PATCH 63/89] decompose implementation for clip --- .../custom_models/sd_inference/clip.py | 176 ++++++++++-------- 1 file changed, 95 insertions(+), 81 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 5182e6e0f..d5927f322 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -55,95 +55,109 @@ def export_clip_model( attn_spec=attn_spec, ) return vmfb_path - if "google/t5" in hf_model_name: - from transformers import T5Tokenizer, T5Model - tokenizer = T5Tokenizer.from_pretrained(hf_model_name) - text_encoder_model = T5Model.from_pretrained(hf_model_name) - input_len = 512 + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + if "google/t5" in hf_model_name: + from transformers import T5Tokenizer, T5Model + + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) + text_encoder_model = T5Model.from_pretrained(hf_model_name) + input_len = 512 - else: - # TODO: Add better filtering mechanism for things that require CLIPProcessor - if "openai" in hf_model_name: - tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") - hf_subfolder = "" # CLIPProcessor does not have a subfolder - input_len = 10 else: - # Load the tokenizer and text encoder to tokenize and encode the text. - tokenizer = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer", - ) - hf_subfolder = "text_encoder" - - text_encoder_model = CLIPTextModel.from_pretrained( - hf_model_name, - subfolder=hf_subfolder, - ) - if precision == "fp16": - text_encoder_model = text_encoder_model.half() - mapper = {} - utils.save_external_weights( - mapper, text_encoder_model, external_weights, external_weight_path - ) - if weights_only: - return external_weight_path - - if "google/t5" in hf_model_name: - input_shapes = [(batch_size, input_len), (batch_size, input_len)] - - class CompiledTextEncoder(CompiledModule): - if external_weights: - params = export_parameters( - text_encoder_model, - external=True, - external_scope="", - name_mapper=mapper.get, + # TODO: Add better filtering mechanism for things that require CLIPProcessor + if "openai" in hf_model_name: + tokenizer = CLIPProcessor.from_pretrained( + "openai/clip-vit-large-patch14" ) + hf_subfolder = "" # CLIPProcessor does not have a subfolder + input_len = 10 else: - params = export_parameters(text_encoder_model) - - def encode_tokens( - self, - inp=AbstractTensor(1, input_len, dtype=torch.int64), - decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), - ): - return jittable(text_encoder_model.forward)( - input_ids=inp, decoder_input_ids=decoder_input_ids + # Load the tokenizer and text encoder to tokenize and encode the text. + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", ) + hf_subfolder = "text_encoder" - else: - input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] - - class CompiledTextEncoder(CompiledModule): - if external_weights: - params = export_parameters( - text_encoder_model, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(text_encoder_model) - - def encode_tokens_attn_mask( - self, - inp=AbstractTensor(1, input_len, dtype=torch.int64), - attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), - ): - return jittable(text_encoder_model.forward)( - input_ids=inp, attention_mask=attn_mask - ) - - def encode_tokens( - self, - inp=AbstractTensor(1, input_len, dtype=torch.int64), - ): - return jittable(text_encoder_model.forward)(input_ids=inp) + text_encoder_model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder=hf_subfolder, + ) + if precision == "fp16": + text_encoder_model = text_encoder_model.half() + mapper = {} + utils.save_external_weights( + mapper, text_encoder_model, external_weights, external_weight_path + ) + if weights_only: + return external_weight_path + + if "google/t5" in hf_model_name: + input_shapes = [(batch_size, input_len), (batch_size, input_len)] + + class CompiledTextEncoder(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def encode_tokens( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)( + input_ids=inp, decoder_input_ids=decoder_input_ids + ) - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledTextEncoder(context=Context(), import_to=import_to) - module = CompiledModule.get_mlir_module(inst) + else: + input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] + + class CompiledTextEncoder(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def encode_tokens_attn_mask( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)( + input_ids=inp, attention_mask=attn_mask + ) + + def encode_tokens( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)(input_ids=inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledTextEncoder(context=Context(), import_to=import_to) + module = CompiledModule.get_mlir_module(inst) model_metadata_attn_mask = { "model_name": hf_model_name + "_text_encoder", From 35517d956dc4087a3dda9531a06ad56be4431998 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 10 Sep 2024 21:02:43 -0500 Subject: [PATCH 64/89] Add decompose clip flag to pipe e2e test --- models/turbine_models/tests/sd_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 9b67d2ecf..48372927c 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -233,7 +233,7 @@ def testSDPipeline(self): current_args = copy.deepcopy(default_arguments) decomp_attn = { - "text_encoder": False, + "text_encoder": True, "unet": False, "vae": current_args["vae_decomp_attn"], } From 6ca109a9c96e5a038decf3fb7ebfe994572d5e78 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 10 Sep 2024 23:15:14 -0500 Subject: [PATCH 65/89] Add attention decomposition mechanism to sdxl clip exports. --- .../custom_models/sdxl_inference/clip.py | 50 ++++++++----- .../sdxl_inference/sdxl_prompt_encoder.py | 74 +++++++++++-------- 2 files changed, 74 insertions(+), 50 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/clip.py b/models/turbine_models/custom_models/sdxl_inference/clip.py index 2740745ed..269e87d57 100644 --- a/models/turbine_models/custom_models/sdxl_inference/clip.py +++ b/models/turbine_models/custom_models/sdxl_inference/clip.py @@ -62,6 +62,7 @@ def export_clip_model( input_mlir=None, attn_spec=None, weights_only=False, + decomp_attn=True, ): if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, "clip_" + str(index)) @@ -118,25 +119,36 @@ def export_clip_model( if weights_only: return weights_path - - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - text_encoder_model, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(text_encoder_model) - - def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): - return jittable(text_encoder_model.forward)(inp) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) if compile_to != "vmfb": return module_str, tokenizer diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index d547cadf7..3b9fb102f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -171,7 +171,7 @@ def export_prompt_encoder( attn_spec=None, weights_only=False, batch_input=False, - decomp_attn=False, # Compatibility + decomp_attn=True, ): do_classifier_free_guidance = True @@ -233,39 +233,51 @@ def export_prompt_encoder( if weights_only: return None, external_weight_path - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - prompt_encoder_module, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(prompt_encoder_module) - - def encode_prompts( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward)( - t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 - ) + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + prompt_encoder_module, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(prompt_encoder_module) + + def encode_prompts( + self, + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward)( + t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 + ) - def encode_prompts_turbo( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) + def encode_prompts_turbo( + self, + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) - module = CompiledModule.get_mlir_module(inst) + module = CompiledModule.get_mlir_module(inst) model_metadata_encode = { "model_name": hf_model_name + "_text_encoder", From 453fb38f02282ddd4d1c561b0d0428d084a60095 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 11 Sep 2024 13:57:35 -0500 Subject: [PATCH 66/89] Update compile options for sdxl --- .../custom_models/sd_inference/utils.py | 32 +++++++++---------- models/turbine_models/tests/sdxl_test.py | 4 ++- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index ec70141d0..d0235d1f9 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -25,16 +25,16 @@ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "vae_preprocess": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", ], "unet": [ - "--iree-flow-enable-aggressive-fusion", + "--iree-dispatch-creation-enable-aggressive-fusion", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-opt-outer-dim-concat=true", @@ -43,16 +43,16 @@ "--iree-vm-target-truncate-unsupported-floats", ], "clip": [ - "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-dispatch-creation-enable-aggressive-fusion", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-opt-outer-dim-concat=true", "--iree-rocm-waves-per-eu=2", "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "vae": [ - "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions", + "--iree-dispatch-creation-enable-aggressive-fusion", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-opt-data-tiling=false", @@ -70,7 +70,7 @@ "--iree-opt-data-tiling=false", "--iree-opt-const-eval=false", "--iree-opt-aggressively-propagate-transposes=true", - "--iree-flow-enable-aggressive-fusion", + "--iree-dispatch-creation-enable-aggressive-fusion", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", ], @@ -80,12 +80,12 @@ "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" - "--iree-flow-enable-fuse-horizontal-contractions=true", + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "unet": [""], @@ -98,9 +98,9 @@ "--iree-llvmcpu-target-cpu=znver4", "--iree-opt-const-eval=false", "--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack", - "--iree-flow-collapse-reduction-dims", + "--iree-dispatch-creation-collapse-reduction-dims", "--iree-opt-const-expr-max-size-increase-threshold=1000000000000000", - "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + "--iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops", ], "bf16": [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))", @@ -288,7 +288,7 @@ def compile_to_vmfb( "--iree-stream-resource-max-allocation-size=" + max_alloc, "--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64", - "--iree-flow-inline-constants-max-byte-length=1", + "--iree-dispatch-creation-inline-constants-max-byte-length=1", ] ) device = "vulkan-spirv" @@ -296,7 +296,7 @@ def compile_to_vmfb( flags.extend( [ "--iree-hal-target-backends=rocm", - "--iree-rocm-target-chip=" + target_triple, + "--iree-hip-target=" + target_triple, "--iree-vm-bytecode-module-output-format=flatbuffer-binary", ] ) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index cf0a05f5e..faef67417 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,7 +93,7 @@ def setUp(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": True, + "vae": False, } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -358,6 +358,8 @@ def test06_t2i_generate_images_punet(self): self.pipe.use_punet = True self.pipe.use_i8_punet = True self.pipe.setup_punet() + if arguments["iree_target_triple"] != "gfx942": + self.pipe.map["unet"]["export_args"]["attn_spec"] = None self.pipe.prepare_all() self.pipe.load_map() output = self.pipe.generate_images( From c0be575fc6e4e5ef6ee6c6103c5de0f21caf85c6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 12 Sep 2024 11:07:54 -0500 Subject: [PATCH 67/89] Decompose VAE for cpu --- models/turbine_models/tests/sdxl_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index faef67417..b69b0c187 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,7 +93,11 @@ def setUp(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": False, + "vae": ( + False + if any(x in arguments["device"] for x in ["hip", "rocm"]) + else True + ), } self.pipe = SharkSDPipeline( arguments["hf_model_name"], From e3cd69d8385794149c10688c36ebabdc5953aa16 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 12 Sep 2024 12:25:24 -0500 Subject: [PATCH 68/89] skip i8 punet test on cpu --- models/turbine_models/tests/sdxl_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index b69b0c187..ef6dd892d 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -355,8 +355,8 @@ def test05_t2i_generate_images(self): assert output is not None def test06_t2i_generate_images_punet(self): - if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest("Have issues with submodels on vulkan, cuda") + if arguments["device"] in ["vulkan", "cuda", "cpu"]: + self.skipTest("Have issues with submodels on vulkan, cuda, cpu") if getattr(self.pipe, "unet"): self.pipe.unload_submodel("unet") self.pipe.use_punet = True From e3e1dcbaab282c09ea307afec3e8318a4e0b7b4c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 12 Sep 2024 17:58:11 -0500 Subject: [PATCH 69/89] Don't use spec for clip by default --- models/turbine_models/tests/sdxl_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index ef6dd892d..62b934c1a 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -99,6 +99,11 @@ def setUp(self): else True ), } + attn_spec = { + "text_encoder": None, + "unet": arguments["attn_spec"], + "vae": arguments["attn_spec"], + } self.pipe = SharkSDPipeline( arguments["hf_model_name"], arguments["height"], @@ -109,7 +114,7 @@ def setUp(self): arguments["device"], arguments["iree_target_triple"], ireec_flags=None, - attn_spec=arguments["attn_spec"], + attn_spec=attn_spec, decomp_attn=decomp_attn, pipeline_dir="test_vmfbs", external_weights_dir="test_weights", From 56d6ee772f53adfccf243a1c3f0721b00c7820f6 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:01:51 -0500 Subject: [PATCH 70/89] Revert change to attention spec handling in sdxl test --- models/turbine_models/tests/sdxl_test.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 62b934c1a..ef6dd892d 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -99,11 +99,6 @@ def setUp(self): else True ), } - attn_spec = { - "text_encoder": None, - "unet": arguments["attn_spec"], - "vae": arguments["attn_spec"], - } self.pipe = SharkSDPipeline( arguments["hf_model_name"], arguments["height"], @@ -114,7 +109,7 @@ def setUp(self): arguments["device"], arguments["iree_target_triple"], ireec_flags=None, - attn_spec=attn_spec, + attn_spec=arguments["attn_spec"], decomp_attn=decomp_attn, pipeline_dir="test_vmfbs", external_weights_dir="test_weights", From d330564d9a54fa9e5621816a8921bce715cea5d8 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 12 Sep 2024 21:03:29 -0500 Subject: [PATCH 71/89] Don't use td spec for clip bs2 export test --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 88218690e..dc8176b55 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -79,6 +79,6 @@ jobs: pytest -v models/turbine_models/tests/sd_test.py pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --clip_spec None --unet_spec default --vae_spec default --batch_size 2 pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 From ffba3ea295139809e4405811295843c0ddd75a14 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Fri, 13 Sep 2024 10:56:55 -0500 Subject: [PATCH 72/89] disable attn spec usage for sdxl bs2 on mi250 tests --- .github/workflows/test_models.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index dc8176b55..c7e7c2100 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -77,8 +77,8 @@ jobs: source turbine_venv/bin/activate pytest -v models/turbine_models/tests/sd_test.py - pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --clip_spec None --unet_spec default --vae_spec default --batch_size 2 - pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 + pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default -x + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec None --batch_size 2 -x + pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x From fad7e6ef23231795520d8674abc066b826b69674 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:37:13 -0500 Subject: [PATCH 73/89] Update test_models.yml --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index c7e7c2100..550f0f8b9 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -79,6 +79,6 @@ jobs: pytest -v models/turbine_models/tests/sd_test.py pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default -x - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec None --batch_size 2 -x + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --batch_size 2 -x pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x From 05fa32d74f533122cfd66c7ee9ae3ed4f06a75a2 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:28:55 -0500 Subject: [PATCH 74/89] Update test_models.yml --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 550f0f8b9..086e49bce 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -79,6 +79,6 @@ jobs: pytest -v models/turbine_models/tests/sd_test.py pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default -x - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --batch_size 2 -x + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --clip_spec None --unet_spec None --vae_spec None --batch_size 2 -x pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x From 0291d43ef515e40240734ac58edef73eaad2583e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 24 Sep 2024 11:30:59 -0500 Subject: [PATCH 75/89] Small fixes to SDXL inference pipeline/exports/compile --- models/turbine_models/custom_models/pipeline_base.py | 2 ++ .../turbine_models/custom_models/sd_inference/sd_pipeline.py | 4 +++- models/turbine_models/custom_models/sd_inference/utils.py | 2 +- models/turbine_models/custom_models/sd_inference/vae.py | 1 - models/turbine_models/custom_models/sdxl_inference/unet.py | 4 ---- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 837b95e64..33bc425b9 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -368,6 +368,8 @@ def __init__( target, dict ), "Device and target triple must be both dicts or both strings." for submodel in self.map.keys(): + if self.map[submodel].get("load") == False: + continue assert submodel in device.keys(), f"Device for {submodel} not found." assert ( submodel in target.keys() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 4cd1e4b86..61bb66d82 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -120,6 +120,8 @@ "decomp_attn": None, }, }, +} +sdxl_compiled_pipeline_map = { "unetloop": { "module_name": "sdxl_compiled_pipeline", "load": False, @@ -434,7 +436,7 @@ def load_scheduler( if self.is_sd3: export_fn = sd3_schedulers.export_scheduler_model else: - export_fn = scheduler.export_scheduler_model + export_fn = schedulers.export_scheduler_model self.map["scheduler"] = { "module_name": "compiled_scheduler", "export_fn": export_fn, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index d0235d1f9..aee316cdf 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -476,7 +476,7 @@ def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet= url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" elif not masked_attention: suffix = "" - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir" + url = "https://raw.githubusercontent.com/iree-org/iree/refs/heads/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" else: suffix = "_pad" url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 9d5c0c6f4..732ccc5f0 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -171,7 +171,6 @@ def export_vae_model( vae_model, external_weights, external_weight_path, - vae_harness=vae_harness, ) if weights_only: return external_weight_path diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 73b32cf58..2d96f2e6c 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -205,10 +205,6 @@ def export_unet_model( if not attn_spec: if (not decomp_attn) and use_punet: attn_spec = "punet" - elif (not decomp_attn) and "gfx9" in target: - attn_spec = "mfma" - elif (not decomp_attn) and "gfx11" in target: - attn_spec = "wmma" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", From e337f2a2f39ba829a034ef98f55703ae2a888618 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 2 Oct 2024 10:05:42 -0500 Subject: [PATCH 76/89] Pin torch to 2.4.1 --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 086e49bce..67c7054ac 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -57,7 +57,7 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --pre torch==2.5.0.dev20240801+cpu torchvision==0.20.0.dev20240801+cpu --index-url https://download.pytorch.org/whl/nightly/cpu + pip install --pre torch==2.4.1+cpu torchvision --index-url https://download.pytorch.org/whl/cpu pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html From 0fd8ad04dd8ebe390b6285ba19a573621946580b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 3 Oct 2024 11:16:25 -0500 Subject: [PATCH 77/89] Largely disables attn spec usage. --- .../turbine_models/custom_models/sd_inference/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index aee316cdf..9b926a586 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -361,7 +361,8 @@ def compile_to_vmfb( use_punet=use_punet, masked_attention=masked_attention, ) - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if attn_spec: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( @@ -474,12 +475,8 @@ def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet= if use_punet: suffix = "_punet" url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" - elif not masked_attention: - suffix = "" - url = "https://raw.githubusercontent.com/iree-org/iree/refs/heads/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" else: - suffix = "_pad" - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" + return None attn_spec = urlopen(url).read().decode("utf-8") spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") with open(spec_path, "w") as f: From e1c4ac287935f5aa21ac94f9876a3fae00f9d6ef Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 3 Oct 2024 13:22:59 -0500 Subject: [PATCH 78/89] Update canonicalization pass name, decouple model validation from pipeline one-shot/exprt --- .../custom_models/sd_inference/utils.py | 9 +- .../custom_models/sd_inference/vae.py | 5 - models/turbine_models/tests/sdxl_test.py | 309 ++++++++++++------ 3 files changed, 207 insertions(+), 116 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9b926a586..2e5ef57dd 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -25,10 +25,7 @@ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" - ], - "vae_preprocess": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", @@ -80,7 +77,7 @@ "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-dispatch-creation-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", ], @@ -333,8 +330,6 @@ def compile_to_vmfb( flags.extend(MI_flags["masked_attention"]) elif "punet" in flagset_keywords: flags.extend(MI_flags["punet"]) - elif "vae" in safe_name: - flags.extend(MI_flags["vae_preprocess"]) else: flags.extend(MI_flags["preprocess_default"]) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 732ccc5f0..12a29ace1 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -129,11 +129,6 @@ def export_vae_model( ) if decomp_attn: safe_name += "_decomp_attn" - elif not attn_spec: - if "gfx9" in target: - attn_spec = "mfma" - elif "gfx11" in target: - attn_spec = "wmma" if pipeline_dir: safe_name = os.path.join(pipeline_dir, safe_name) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index ef6dd892d..192b1ce83 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -84,7 +84,7 @@ def command_line_args(request): @pytest.mark.usefixtures("command_line_args") class StableDiffusionXLTest(unittest.TestCase): - def setUp(self): + def test00_compile_pipe(self): from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -122,12 +122,70 @@ def setUp(self): vae_harness=False, ) self.pipe.prepare_all() + self.pipe.load_map() + output = self.pipe.generate_images( + arguments["prompt"], + arguments["negative_prompt"], + arguments["num_inference_steps"], + 1, # batch count + arguments["guidance_scale"], + arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img + ) + assert output is not None + + # Switch to punet. + self.pipe.unload_submodel("unet") + self.pipe.use_punet = True + self.pipe.use_i8_punet = True + self.pipe.setup_punet() + if arguments["iree_target_triple"] != "gfx942": + self.pipe.map["unet"]["export_args"]["attn_spec"] = None + self.pipe.prepare_all() + self.pipe.load_map() + output = self.pipe.generate_images( + arguments["prompt"], + arguments["negative_prompt"], + arguments["num_inference_steps"], + 1, # batch count + arguments["guidance_scale"], + arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img + ) + assert output is not None def test01_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] - arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] + clip_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["max_length"]), + arguments["precision"], + "text_encoder", + arguments["device"], + arguments["iree_target_triple"], + ) + + ".vmfb" + ) + arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename) + clip_w_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "text_encoder", + arguments["precision"], + ) + + ".safetensors" + ) + arguments["external_weight_path"] = os.path.join( + "test_weights", + clip_w_filename, + ) tokenizer_1 = CLIPTokenizer.from_pretrained( arguments["hf_model_name"], subfolder="tokenizer", @@ -153,7 +211,7 @@ def test01_PromptEncoder(self): turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( arguments["vmfb_path"], - self.pipe.map["text_encoder"]["driver"], + arguments["rt_driver"], arguments["external_weight_path"], text_input_ids_list, uncond_input_ids_list, @@ -171,7 +229,7 @@ def test01_PromptEncoder(self): "prompt_encoder", arguments["vmfb_path"], arguments["external_weight_path"], - self.pipe.map["text_encoder"]["driver"], + arguments["rt_driver"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) @@ -180,13 +238,35 @@ def test01_PromptEncoder(self): np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol) np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol) - def test02_ExportUnetModel(self): + def test02_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") - - arguments["vmfb_path"] = self.pipe.map["unet"]["vmfb"] - arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] - + unet_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["max_length"]), + str(arguments["height"]) + "x" + str(arguments["width"]), + arguments["precision"], + "unet", + arguments["device"], + arguments["iree_target_triple"], + ) + + ".vmfb" + ) + arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename) + unet_w_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "unet", + arguments["precision"], + ) + + ".safetensors" + ) + arguments["external_weight_path"] = os.path.join( + "test_weights", + unet_w_filename, + ) dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 sample = torch.rand( ( @@ -207,7 +287,7 @@ def test02_ExportUnetModel(self): guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) turbine = unet_runner.run_unet( - self.pipe.map["unet"]["driver"], + arguments["rt_device"], sample, timestep, prompt_embeds, @@ -235,7 +315,7 @@ def test02_ExportUnetModel(self): "unet", arguments["vmfb_path"], arguments["external_weight_path"], - self.pipe.map["unet"]["driver"], + arguments["rt_device"], max_length=arguments["max_length"], height=arguments["height"], width=arguments["width"], @@ -252,8 +332,31 @@ def test03_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] - arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] + vae_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["height"]) + "x" + str(arguments["width"]), + arguments["precision"], + "vae", + arguments["device"], + arguments["iree_target_triple"], + ) + + ".vmfb" + ) + arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename) + vae_w_filename = ( + "_".join( + create_safe_name(arguments["hf_model_name"], ""), + "vae", + arguments["precision"], + ) + + ".safetensors" + ) + arguments["external_weight_path"] = os.path.join( + "test_weights", + vae_w_filename, + ) example_input = torch.ones( arguments["batch_size"], 4, @@ -265,11 +368,11 @@ def test03_ExportVaeModelDecode(self): if arguments["precision"] == "fp16": example_input = example_input.half() turbine = vae_runner.run_vae_decode( - self.pipe.map["vae"]["driver"], + arguments["rt_device"], example_input, arguments["vmfb_path"], arguments["hf_model_name"], - self.pipe.map["vae"]["weights"], + arguments["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( arguments["hf_model_name"], @@ -281,7 +384,7 @@ def test03_ExportVaeModelDecode(self): "vae_decode", arguments["vmfb_path"], arguments["external_weight_path"], - self.pipe.map["vae"]["driver"], + arguments["rt_device"], height=arguments["height"], width=arguments["width"], precision=arguments["precision"], @@ -291,96 +394,94 @@ def test03_ExportVaeModelDecode(self): atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) - @pytest.mark.xfail(reason="NaN output on rocm, needs triage and file") - def test04_ExportVaeModelEncode(self): - if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: - self.skipTest( - "Compilation error on cpu, vulkan and rocm; To be tested on cuda." - ) - arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] - arguments["external_weight_path"] = self.pipe.map["vae"]["weights"] - example_input = torch.ones( - arguments["batch_size"], - 3, - arguments["height"], - arguments["width"], - dtype=torch.float32, - ) - example_input_torch = example_input - if arguments["precision"] == "fp16": - example_input = example_input.half() - turbine = vae_runner.run_vae_encode( - self.pipe.map["vae"]["driver"], - example_input, - arguments["vmfb_path"], - arguments["hf_model_name"], - self.pipe.map["vae"]["weights"], - ) - torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], - "encode", - example_input_torch, - ) - if arguments["benchmark"] or arguments["tracy_profile"]: - run_benchmark( - "vae_encode", - arguments["vmfb_path"], - self.pipe.map["vae"]["weights"], - self.pipe.map["vae"]["driver"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - tracy_profile=arguments["tracy_profile"], - ) - rtol = 4e-2 - atol = 4e-2 - np.testing.assert_allclose(torch_output, turbine, rtol, atol) - - def test05_t2i_generate_images(self): - if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest("Have issues with submodels on vulkan, cuda") - - self.pipe.load_map() - output = self.pipe.generate_images( - arguments["prompt"], - arguments["negative_prompt"], - arguments["num_inference_steps"], - 1, # batch count - arguments["guidance_scale"], - arguments["seed"], - True, - arguments["scheduler_id"], - True, # return_img - ) - assert output is not None + # def test04_punet(self): + # if arguments["device"] in ["vulkan", "cuda"]: + # self.skipTest("Unknown error on vulkan; To be tested on cuda.") + # unet_filename = "_".join( + # create_safe_name(arguments["hf_model_name"], ""), + # "bs" + str(arguments["batch_size"]), + # str(arguments["max_length"]), + # str(arguments["height"]) + "x" + str(arguments["width"]), + # arguments["precision"], + # "unet", + # arguments["device"], + # arguments["iree_target_triple"], + # ) + ".vmfb" + # arguments["vmfb_path"] = os.path.join( + # "test_vmfbs", + # unet_filename + # ) + # unet_w_filename = "_".join( + # create_safe_name(arguments["hf_model_name"], ""), + # "unet", + # arguments["precision"], + # ) + ".safetensors" + # arguments["external_weight_path"] = os.path.join( + # "test_weights", + # unet_w_filename, + # ) + # dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 + # sample = torch.rand( + # ( + # arguments["batch_size"], + # arguments["in_channels"], + # arguments["height"] // 8, + # arguments["width"] // 8, + # ), + # dtype=dtype, + # ) + # timestep = torch.zeros(1, dtype=dtype) + # prompt_embeds = torch.rand( + # (2 * arguments["batch_size"], arguments["max_length"], 2048), + # dtype=dtype, + # ) + # text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype) + # time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype) + # guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) - def test06_t2i_generate_images_punet(self): - if arguments["device"] in ["vulkan", "cuda", "cpu"]: - self.skipTest("Have issues with submodels on vulkan, cuda, cpu") - if getattr(self.pipe, "unet"): - self.pipe.unload_submodel("unet") - self.pipe.use_punet = True - self.pipe.use_i8_punet = True - self.pipe.setup_punet() - if arguments["iree_target_triple"] != "gfx942": - self.pipe.map["unet"]["export_args"]["attn_spec"] = None - self.pipe.prepare_all() - self.pipe.load_map() - output = self.pipe.generate_images( - arguments["prompt"], - arguments["negative_prompt"], - arguments["num_inference_steps"], - 1, # batch count - arguments["guidance_scale"], - arguments["seed"], - True, - arguments["scheduler_id"], - True, # return_img - ) - assert output is not None + # turbine = unet_runner.run_punet( + # arguments["rt_device"], + # sample, + # timestep, + # prompt_embeds, + # text_embeds, + # time_ids, + # guidance_scale, + # arguments["vmfb_path"], + # arguments["hf_model_name"], + # arguments["hf_auth_token"], + # arguments["external_weight_path"], + # ) + # torch_output = unet_runner.run_torch_unet( + # arguments["hf_model_name"], + # arguments["hf_auth_token"], + # sample.float(), + # timestep, + # prompt_embeds.float(), + # text_embeds.float(), + # time_ids.float(), + # guidance_scale.float(), + # precision=arguments["precision"], + # ) + # if arguments["benchmark"] or arguments["tracy_profile"]: + # run_benchmark( + # "unet", + # arguments["vmfb_path"], + # arguments["external_weight_path"], + # arguments["rt_device"], + # max_length=arguments["max_length"], + # height=arguments["height"], + # width=arguments["width"], + # batch_size=arguments["batch_size"], + # in_channels=arguments["in_channels"], + # precision=arguments["precision"], + # tracy_profile=arguments["tracy_profile"], + # ) + # rtol = 4e-2 + # atol = 4e-1 + # np.testing.assert_allclose(torch_output, turbine, rtol, atol) def tearDown(self): - del self.pipe gc.collect() From 61bb4ef24eb7afea0e5159857a79c7ca2ec67d54 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 3 Oct 2024 13:53:59 -0500 Subject: [PATCH 79/89] Don't use punet spec. --- models/turbine_models/tests/sdxl_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 192b1ce83..2b6eab9b8 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -141,8 +141,7 @@ def test00_compile_pipe(self): self.pipe.use_punet = True self.pipe.use_i8_punet = True self.pipe.setup_punet() - if arguments["iree_target_triple"] != "gfx942": - self.pipe.map["unet"]["export_args"]["attn_spec"] = None + self.pipe.map["unet"]["export_args"]["attn_spec"] = None self.pipe.prepare_all() self.pipe.load_map() output = self.pipe.generate_images( From dfb9474e0f0b5955b338e22dee7d2c6da1ad7a7f Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 3 Oct 2024 13:57:48 -0500 Subject: [PATCH 80/89] Remove default/mfma/wmma specs from sd compile utils. --- .../custom_models/sd_inference/utils.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2e5ef57dd..aabbdb65f 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -347,27 +347,8 @@ def compile_to_vmfb( # This is a temporary solution, and should be removed or largely disabled once the functionality of # the TD spec is implemented in C++. - if attn_spec in ["default", "mfma", "punet"]: - # if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: - use_punet = True if attn_spec in ["punet", "i8"] else False - attn_spec = get_mfma_spec_path( - target_triple, - os.path.dirname(safe_name), - use_punet=use_punet, - masked_attention=masked_attention, - ) - if attn_spec: - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) - - elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): - attn_spec = get_wmma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention=masked_attention - ) - if attn_spec: - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) - elif attn_spec and attn_spec != "None": - if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if os.path.exists(attn_spec): + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) for i, flag in enumerate(ireec_flags): k = flag.strip().split("=")[0] From 9fe20a62101f5fbfea4c2840d98ea43d023e10c1 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 09:45:06 -0500 Subject: [PATCH 81/89] Guard path check for attn spec --- .../custom_models/sd_inference/utils.py | 32 +------------------ 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index aabbdb65f..66e2159d0 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -347,7 +347,7 @@ def compile_to_vmfb( # This is a temporary solution, and should be removed or largely disabled once the functionality of # the TD spec is implemented in C++. - if os.path.exists(attn_spec): + if attn_spec and os.path.exists(attn_spec): flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) for i, flag in enumerate(ireec_flags): @@ -447,36 +447,6 @@ def create_safe_name(hf_model_name, model_name_str=""): return safe_name -def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False): - if use_punet: - suffix = "_punet" - url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" - else: - return None - attn_spec = urlopen(url).read().decode("utf-8") - spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") - with open(spec_path, "w") as f: - f.write(attn_spec) - return spec_path - - -def get_wmma_spec_path(target_chip, save_dir, masked_attention=False): - if not masked_attention: - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_wmma.mlir" - elif target_chip == "gfx1100": - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir" - elif target_chip in ["gfx1103", "gfx1150"]: - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir" - else: - return None - attn_spec = urlopen(url).read().decode("utf-8") - suffix = "masked" if masked_attention else "" - spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_wmma{suffix}.mlir") - with open(spec_path, "w") as f: - f.write(attn_spec) - return spec_path - - def save_external_weights( mapper, model, From f39b2d2278a0a3d2cea34bd932d1c016b82bafd9 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 10:26:53 -0500 Subject: [PATCH 82/89] Separate punet run --- models/turbine_models/tests/sdxl_test.py | 138 +++++++---------------- 1 file changed, 41 insertions(+), 97 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 2b6eab9b8..f0cf63d4f 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -84,7 +84,7 @@ def command_line_args(request): @pytest.mark.usefixtures("command_line_args") class StableDiffusionXLTest(unittest.TestCase): - def test00_compile_pipe(self): + def test00_sdxl_pipe(self): from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -136,12 +136,43 @@ def test00_compile_pipe(self): ) assert output is not None - # Switch to punet. - self.pipe.unload_submodel("unet") - self.pipe.use_punet = True - self.pipe.use_i8_punet = True - self.pipe.setup_punet() - self.pipe.map["unet"]["export_args"]["attn_spec"] = None + def test01_sdxl_pipe_i8(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + decomp_attn = { + "text_encoder": True, + "unet": False, + "vae": ( + False + if any(x in arguments["device"] for x in ["hip", "rocm"]) + else True + ), + } + self.pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", + external_weights_dir="test_weights", + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, + use_i8_punet=True, + vae_harness=False, + ) self.pipe.prepare_all() self.pipe.load_map() output = self.pipe.generate_images( @@ -157,7 +188,7 @@ def test00_compile_pipe(self): ) assert output is not None - def test01_PromptEncoder(self): + def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") clip_filename = ( @@ -237,7 +268,7 @@ def test01_PromptEncoder(self): np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol) np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol) - def test02_unet(self): + def test03_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") unet_filename = ( @@ -327,7 +358,7 @@ def test02_unet(self): atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) - def test03_ExportVaeModelDecode(self): + def test04_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") @@ -393,93 +424,6 @@ def test03_ExportVaeModelDecode(self): atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) - # def test04_punet(self): - # if arguments["device"] in ["vulkan", "cuda"]: - # self.skipTest("Unknown error on vulkan; To be tested on cuda.") - # unet_filename = "_".join( - # create_safe_name(arguments["hf_model_name"], ""), - # "bs" + str(arguments["batch_size"]), - # str(arguments["max_length"]), - # str(arguments["height"]) + "x" + str(arguments["width"]), - # arguments["precision"], - # "unet", - # arguments["device"], - # arguments["iree_target_triple"], - # ) + ".vmfb" - # arguments["vmfb_path"] = os.path.join( - # "test_vmfbs", - # unet_filename - # ) - # unet_w_filename = "_".join( - # create_safe_name(arguments["hf_model_name"], ""), - # "unet", - # arguments["precision"], - # ) + ".safetensors" - # arguments["external_weight_path"] = os.path.join( - # "test_weights", - # unet_w_filename, - # ) - # dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 - # sample = torch.rand( - # ( - # arguments["batch_size"], - # arguments["in_channels"], - # arguments["height"] // 8, - # arguments["width"] // 8, - # ), - # dtype=dtype, - # ) - # timestep = torch.zeros(1, dtype=dtype) - # prompt_embeds = torch.rand( - # (2 * arguments["batch_size"], arguments["max_length"], 2048), - # dtype=dtype, - # ) - # text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype) - # time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype) - # guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) - - # turbine = unet_runner.run_punet( - # arguments["rt_device"], - # sample, - # timestep, - # prompt_embeds, - # text_embeds, - # time_ids, - # guidance_scale, - # arguments["vmfb_path"], - # arguments["hf_model_name"], - # arguments["hf_auth_token"], - # arguments["external_weight_path"], - # ) - # torch_output = unet_runner.run_torch_unet( - # arguments["hf_model_name"], - # arguments["hf_auth_token"], - # sample.float(), - # timestep, - # prompt_embeds.float(), - # text_embeds.float(), - # time_ids.float(), - # guidance_scale.float(), - # precision=arguments["precision"], - # ) - # if arguments["benchmark"] or arguments["tracy_profile"]: - # run_benchmark( - # "unet", - # arguments["vmfb_path"], - # arguments["external_weight_path"], - # arguments["rt_device"], - # max_length=arguments["max_length"], - # height=arguments["height"], - # width=arguments["width"], - # batch_size=arguments["batch_size"], - # in_channels=arguments["in_channels"], - # precision=arguments["precision"], - # tracy_profile=arguments["tracy_profile"], - # ) - # rtol = 4e-2 - # atol = 4e-1 - # np.testing.assert_allclose(torch_output, turbine, rtol, atol) - def tearDown(self): gc.collect() From d3c8e8016ccf0eb6f5548e3189c144cdae795099 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 11:57:34 -0500 Subject: [PATCH 83/89] typo fixes --- models/turbine_models/tests/sdxl_test.py | 30 ++++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index f0cf63d4f..a8bc702aa 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -136,7 +136,7 @@ def test00_sdxl_pipe(self): ) assert output is not None - def test01_sdxl_pipe_i8(self): + def test01_sdxl_pipe_i8_punet(self): from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -192,7 +192,7 @@ def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") clip_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["max_length"]), @@ -200,16 +200,16 @@ def test02_PromptEncoder(self): "text_encoder", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename) clip_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "text_encoder", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -241,7 +241,7 @@ def test02_PromptEncoder(self): turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( arguments["vmfb_path"], - arguments["rt_driver"], + arguments["rt_device"], arguments["external_weight_path"], text_input_ids_list, uncond_input_ids_list, @@ -259,7 +259,7 @@ def test02_PromptEncoder(self): "prompt_encoder", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_driver"], + arguments["rt_device"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) @@ -272,7 +272,7 @@ def test03_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") unet_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["max_length"]), @@ -281,16 +281,16 @@ def test03_unet(self): "unet", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename) unet_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "unet", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -363,7 +363,7 @@ def test04_ExportVaeModelDecode(self): self.skipTest("Compilation error on vulkan; To be tested on cuda.") vae_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["height"]) + "x" + str(arguments["width"]), @@ -371,16 +371,16 @@ def test04_ExportVaeModelDecode(self): "vae", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename) vae_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "vae", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( From 40808dbf0e4041dba1ff265dfc9e963f7bc0556d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 12:13:44 -0500 Subject: [PATCH 84/89] Filename fixes, explicit input dtypes for i8 punet --- .../custom_models/sd_inference/sd_pipeline.py | 2 + .../sdxl_inference/sdxl_prompt_encoder.py | 2 +- models/turbine_models/tests/sdxl_test.py | 95 ++++++++++--------- 3 files changed, 55 insertions(+), 44 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 61bb66d82..a6652a234 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -407,6 +407,8 @@ def setup_punet(self): ] self.map["unet"]["keywords"] += "punet" if self.use_i8_punet: + self.map["unet"]["np_dtype"] = "int8" + self.map["unet"]["torch_dtype"] = torch.int8 if self.add_tk_kernels: self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 3b9fb102f..17328f6e2 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -177,7 +177,7 @@ def export_prompt_encoder( safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder-{device}", + f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder", ) if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, safe_name) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index a8bc702aa..6eba5b8e4 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -192,24 +192,27 @@ def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") clip_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "bs" + str(arguments["batch_size"]), - str(arguments["max_length"]), - arguments["precision"], - "text_encoder", - arguments["device"], - arguments["iree_target_triple"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["max_length"]), + arguments["precision"], + "text_encoder", + arguments["iree_target_triple"], + ] + ) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename) clip_w_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "text_encoder", - arguments["precision"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "text_encoder", + arguments["precision"], + ] + ) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -272,25 +275,28 @@ def test03_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") unet_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "bs" + str(arguments["batch_size"]), - str(arguments["max_length"]), - str(arguments["height"]) + "x" + str(arguments["width"]), - arguments["precision"], - "unet", - arguments["device"], - arguments["iree_target_triple"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["max_length"]), + str(arguments["height"]) + "x" + str(arguments["width"]), + arguments["precision"], + "unet", + arguments["iree_target_triple"], + ] + ) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename) unet_w_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "unet", - arguments["precision"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "unet", + arguments["precision"], + ] + ) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -363,24 +369,27 @@ def test04_ExportVaeModelDecode(self): self.skipTest("Compilation error on vulkan; To be tested on cuda.") vae_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "bs" + str(arguments["batch_size"]), - str(arguments["height"]) + "x" + str(arguments["width"]), - arguments["precision"], - "vae", - arguments["device"], - arguments["iree_target_triple"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "bs" + str(arguments["batch_size"]), + str(arguments["height"]) + "x" + str(arguments["width"]), + arguments["precision"], + "vae", + arguments["iree_target_triple"], + ] + ) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename) vae_w_filename = ( - "_".join([ - create_safe_name(arguments["hf_model_name"], ""), - "vae", - arguments["precision"], - ]) + "_".join( + [ + create_safe_name(arguments["hf_model_name"], ""), + "vae", + arguments["precision"], + ] + ) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( From e630d39d2f55b1bf211f8cb7513b5383193110d4 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 13:01:01 -0500 Subject: [PATCH 85/89] Update CPU test configuration. --- .github/workflows/test_models.yml | 6 +++--- models/turbine_models/tests/sdxl_test.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 67c7054ac..a46fc9fe8 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -77,8 +77,8 @@ jobs: source turbine_venv/bin/activate pytest -v models/turbine_models/tests/sd_test.py - pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default -x - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --clip_spec None --unet_spec None --vae_spec None --batch_size 2 -x + pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x -s + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 -x -s + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --batch_size 2 -x pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2 -x diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 6eba5b8e4..3868f9192 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,11 +93,7 @@ def test00_sdxl_pipe(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": ( - False - if any(x in arguments["device"] for x in ["hip", "rocm"]) - else True - ), + "vae": False, } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -135,8 +131,12 @@ def test00_sdxl_pipe(self): True, # return_img ) assert output is not None + del output + del self.pipe def test01_sdxl_pipe_i8_punet(self): + if arguments["device"] not in ["rocm", "hip"]: + self.skipTest("Currently unimplemented/pending validation") from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -187,6 +187,8 @@ def test01_sdxl_pipe_i8_punet(self): True, # return_img ) assert output is not None + del output + del self.pipe def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: From fc6d018f54f6fc816b2379ae1a62e16f27994d09 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 13:20:51 -0500 Subject: [PATCH 86/89] Decompose VAE for cpu --- models/turbine_models/tests/sdxl_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 3868f9192..05abc70f0 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -93,7 +93,11 @@ def test00_sdxl_pipe(self): decomp_attn = { "text_encoder": True, "unet": False, - "vae": False, + "vae": ( + False + if any(x in arguments["device"] for x in ["hip", "rocm"]) + else True + ), } self.pipe = SharkSDPipeline( arguments["hf_model_name"], @@ -377,7 +381,7 @@ def test04_ExportVaeModelDecode(self): "bs" + str(arguments["batch_size"]), str(arguments["height"]) + "x" + str(arguments["width"]), arguments["precision"], - "vae", + "vae" if arguments["device"] != "cpu" else "vae_decomp_attn", arguments["iree_target_triple"], ] ) From 7d50dc8e7ea738fd73e1f441b6982934409f6261 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 13:56:38 -0500 Subject: [PATCH 87/89] Change compile flag reporting to CLI input --- .../custom_models/sd_inference/utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 66e2159d0..4dea7bfd3 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -396,7 +396,6 @@ def compile_to_vmfb( mlir_source = "str" input_ir_type = "auto" - print("Compiling to", device, "with flags:", flags) # Forces a standard for naming files: # If safe_name has target triple in it, get rid of target triple in mlir name @@ -408,6 +407,24 @@ def compile_to_vmfb( safe_vmfb_name = safe_name safe_mlir_name = "".join(safe_name.split(target_triple)) + if os.path.exists(module_str): + in_file = module_str + else: + in_file = "" + + out_file = f"{safe_vmfb_name}.vmfb" + iree_repro_cli_list = [ + "iree_compile", + f"--iree-hal-target-backends={device}", + f"--iree-input-type={input_ir_type}", + in_file, + out_file, + ] + iree_repro_cli_list.extend(flags) + iree_repro_cli = " ".join(iree_repro_cli_list) + + print("Compiling to target:", device, " \nCLI equivalent:", iree_repro_cli) + if mlir_source == "file": flatbuffer_blob = ireec.compile_file( module_str, From f14092656a19d0e9d504f7a73adf03770ddb31b9 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 4 Oct 2024 13:59:47 -0500 Subject: [PATCH 88/89] formatting --- models/turbine_models/custom_models/sd_inference/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 4dea7bfd3..6dc68324a 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -396,7 +396,6 @@ def compile_to_vmfb( mlir_source = "str" input_ir_type = "auto" - # Forces a standard for naming files: # If safe_name has target triple in it, get rid of target triple in mlir name # From 67e6558ea0a043901ee2b6193167f57d58abc69f Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 21 Oct 2024 09:36:14 -0500 Subject: [PATCH 89/89] Rework prompt encoder export on aot.export API --- .../sdxl_inference/sdxl_prompt_encoder.py | 243 +++++++++--------- 1 file changed, 120 insertions(+), 123 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 17328f6e2..1ae7ea1f0 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -48,111 +48,106 @@ def __init__( def forward( self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 ): - with torch.no_grad(): - prompt_embeds_1 = self.text_encoder_model_1( - text_input_ids_1, - output_hidden_states=True, - ) - prompt_embeds_2 = self.text_encoder_model_2( - text_input_ids_2, - output_hidden_states=True, - ) - neg_prompt_embeds_1 = self.text_encoder_model_1( - uncond_input_ids_1, - output_hidden_states=True, - ) - neg_prompt_embeds_2 = self.text_encoder_model_2( - uncond_input_ids_2, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds_2[0] - neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + neg_prompt_embeds_1 = self.text_encoder_model_1( + uncond_input_ids_1, + output_hidden_states=True, + ) + neg_prompt_embeds_2 = self.text_encoder_model_2( + uncond_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] + neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] - prompt_embeds_list = [ - prompt_embeds_1.hidden_states[-2], - prompt_embeds_2.hidden_states[-2], - ] - neg_prompt_embeds_list = [ - neg_prompt_embeds_1.hidden_states[-2], - neg_prompt_embeds_2.hidden_states[-2], - ] + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + neg_prompt_embeds_list = [ + neg_prompt_embeds_1.hidden_states[-2], + neg_prompt_embeds_2.hidden_states[-2], + ] - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, 1, 1) - prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) + if not self.batch_input: + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + add_text_embeds = pooled_prompt_embeds + if not self.batch_input: + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) + if self.do_classifier_free_guidance: if not self.batch_input: - prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) - add_text_embeds = pooled_prompt_embeds + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view( + 1, -1 + ) + neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) + neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) if not self.batch_input: - add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) - if self.do_classifier_free_guidance: - if not self.batch_input: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( - 1, 1 - ).view(1, -1) - neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) - neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) - if not self.batch_input: - neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) - prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) - if not self.batch_input: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( - self.batch_size, 1 - ) - add_text_embeds = torch.cat( - [neg_pooled_prompt_embeds, add_text_embeds], dim=0 + neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) + prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + self.batch_size, 1 ) - add_text_embeds = add_text_embeds.to(self.torch_dtype) - prompt_embeds = prompt_embeds.to(self.torch_dtype) - return prompt_embeds, add_text_embeds + add_text_embeds = torch.cat( + [neg_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds def forward_turbo(self, text_input_ids_1, text_input_ids_2): - with torch.no_grad(): - prompt_embeds_1 = self.text_encoder_model_1( - text_input_ids_1, - output_hidden_states=True, - ) - prompt_embeds_2 = self.text_encoder_model_2( - text_input_ids_2, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds_2[0] + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] - prompt_embeds_list = [ - prompt_embeds_1.hidden_states[-2], - prompt_embeds_2.hidden_states[-2], - ] - # neg_prompt_embeds_list = [ - # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor - # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor - # ] + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + # neg_prompt_embeds_list = [ + # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor + # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor + # ] - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - bs_embed, seq_len, _ = prompt_embeds.shape + bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, 1, 1) - prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( - bs_embed * 1, -1 - ) - prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) - add_text_embeds = pooled_prompt_embeds - add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + add_text_embeds = pooled_prompt_embeds + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) - add_text_embeds = add_text_embeds.to(self.torch_dtype) - prompt_embeds = prompt_embeds.to(self.torch_dtype) - return prompt_embeds, add_text_embeds + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds +@torch.no_grad() def export_prompt_encoder( hf_model_name, hf_auth_token=None, @@ -233,6 +228,20 @@ def export_prompt_encoder( if weights_only: return None, external_weight_path + example_inputs = { + "text_input_ids_1": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "text_input_ids_2": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "uncond_input_ids_1": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + "uncond_input_ids_2": torch.empty( + (input_batchsize, max_length), dtype=torch.int64 + ), + } decomp_list = [] if decomp_attn == True: decomp_list = [ @@ -244,40 +253,27 @@ def export_prompt_encoder( from_current=True, add_ops=decomp_list, ): - - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - prompt_encoder_module, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(prompt_encoder_module) - - def encode_prompts( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward)( - t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 - ) - - def encode_prompts_turbo( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + if external_weights: + # Transformers (model source) registers position ids as non-persistent. + # This causes externalization to think it's a user input, and since it's not, + # we end up trying to do ops on a !torch.None instead of a tensor. + for buffer_name, buffer in prompt_encoder_module.named_buffers( + recurse=True ): - return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module = CompiledModule.get_mlir_module(inst) + mod_name_list = buffer_name.split(".") + buffer_id = mod_name_list.pop() + parent = prompt_encoder_module + for i in mod_name_list: + parent = getattr(parent, i) + parent.register_buffer(buffer_id, buffer, persistent=True) + externalize_module_parameters(prompt_encoder_module) + output = export( + prompt_encoder_module, + kwargs=example_inputs, + module_name="compiled_clip", + function_name="encode_prompts", + ) + module = output.mlir_module model_metadata_encode = { "model_name": hf_model_name + "_text_encoder", @@ -285,6 +281,7 @@ def encode_prompts_turbo( "input_dtypes": ["int64" for i in range(4)], "use_attention_mask": False, } + module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() module_str = str(module) if compile_to != "vmfb":