From cea48f6f61f1a161fd75f652c86791b46f492a54 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 18 Oct 2024 03:40:55 -0700 Subject: [PATCH] [XLA:GPU] Cache the results of the Triton fusion numerics verifier. In some models there are many identical fusions. These are cached to avoid expensive recomputations. PiperOrigin-RevId: 687240882 --- xla/service/gpu/transforms/BUILD | 2 + .../triton_fusion_numerics_verifier.cc | 22 +++++++- .../triton_fusion_numerics_verifier.h | 12 ++++ .../triton_fusion_numerics_verifier_test.cc | 56 +++++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 087913113432c..7fad10686f041 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -3126,6 +3126,7 @@ cc_library( "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:stream", "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", @@ -3149,6 +3150,7 @@ xla_test( "//xla:shape_util", "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/service:backend", "//xla/service:platform_util", "//xla/service/gpu/autotuning:autotuner_compile_util", "//xla/service/gpu/autotuning:autotuner_util", diff --git a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index b39a50bde5020..998834f0dd745 100644 --- a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -177,6 +177,16 @@ absl::Status VerifyTritonFusion(AutotunerCompileUtil& util, return status; } +TritonFusionNumericsVerifier::FusionCacheKey CacheKeyForFusion( + const HloFusionInstruction& fusion) { + std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); + HloPrintOptions print_options = HloPrintOptions::ModuleFingerprint() + .set_print_only_essential_constants(false) + .set_print_backend_config(true) + .set_sort_backend_config(true); + return module->ToString(print_options); +} + } // namespace absl::StatusOr TritonFusionNumericsVerifier::Run( @@ -200,8 +210,16 @@ absl::StatusOr TritonFusionNumericsVerifier::Run( TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions( *module, execution_threads, [&](const HloFusionInstruction& fusion) { - return VerifyTritonFusion(*opt_compile_util, fusion, config_, - debug_options); + auto key = CacheKeyForFusion(fusion); + if (auto it = fusion_result_cache_.find(key); + it != fusion_result_cache_.end()) { + ++cache_hits_; + return it->second; + } + auto result = VerifyTritonFusion(*opt_compile_util, fusion, config_, + debug_options); + fusion_result_cache_[key] = result; + return result; })); return false; } diff --git a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h index f23a90bff8e4b..d5c4d31eb6a0c 100644 --- a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h +++ b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ #define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#include + +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -49,8 +52,17 @@ class TritonFusionNumericsVerifier : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; + using FusionCacheKey = std::string; + + int CacheHitsForTestingOnly() const { return cache_hits_; } + private: AutotuneConfig config_; + + // In some models there are many identical fusions. These are cached to avoid + // expensive recomputations. + absl::flat_hash_map fusion_result_cache_; + int cache_hits_ = 0; // used for testing only. }; namespace triton_fusion_numerics_pass_internal { diff --git a/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index dd562e07d38aa..9c3b8cb35b650 100644 --- a/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" +#include "xla/service/backend.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/platform_util.h" @@ -245,6 +246,61 @@ ENTRY main { ::testing::HasSubstr("Failed to compile Triton fusion")); } +TEST_F(TritonFusionNumericsVerifierTest, CacheIsUsed) { + absl::string_view hlo_text = R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] maximum(p0, p1) +} + +reduce_0 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add +} + +reduce_1 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=max +} + +// Identical to reduce_0. +reduce_2 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add +} + +ENTRY main { + p0 = f32[16,16] parameter(0) + p1 = f32[16,16] parameter(1) + p2 = f32[16,16] parameter(2) + r0 = f32[16] fusion(p0), kind=kCustom, calls=reduce_0, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + r1 = f32[16] fusion(p1), kind=kCustom, calls=reduce_1, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + r2 = f32[16] fusion(p2), kind=kCustom, calls=reduce_2, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + add_0_1 = f32[16] add(r0, r1) + ROOT add_0_2 = f32[16] add(add_0_1, r2) +} + )"; + + std::unique_ptr module = + *ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()); + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), GetAllocator()}, + module->config().debug_options()}; + TritonFusionNumericsVerifier verifier(autotune_config); + TF_EXPECT_OK(RunHloPass(verifier, module.get())); + EXPECT_EQ(verifier.CacheHitsForTestingOnly(), 1); +} + INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite, TritonFusionNumericsVerifierTest, ::testing::Values(F32, F16, BF16));