Skip to content

Commit

Permalink
[XLA:GPU] Cache the results of the Triton fusion numerics verifier.
Browse files Browse the repository at this point in the history
In some models there are many identical fusions. These are cached to avoid expensive recomputations.

PiperOrigin-RevId: 687240882
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Oct 18, 2024
1 parent 2cee00c commit cea48f6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3126,6 +3126,7 @@ cc_library(
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/stream_executor:stream",
"//xla/tools:hlo_decomposer_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
Expand All @@ -3149,6 +3150,7 @@ xla_test(
"//xla:shape_util",
"//xla:test_helpers",
"//xla/hlo/ir:hlo",
"//xla/service:backend",
"//xla/service:platform_util",
"//xla/service/gpu/autotuning:autotuner_compile_util",
"//xla/service/gpu/autotuning:autotuner_util",
Expand Down
22 changes: 20 additions & 2 deletions xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ absl::Status VerifyTritonFusion(AutotunerCompileUtil& util,
return status;
}

TritonFusionNumericsVerifier::FusionCacheKey CacheKeyForFusion(
const HloFusionInstruction& fusion) {
std::unique_ptr<HloModule> module = ExtractInstructionIntoNewModule(fusion);
HloPrintOptions print_options = HloPrintOptions::ModuleFingerprint()
.set_print_only_essential_constants(false)
.set_print_backend_config(true)
.set_sort_backend_config(true);
return module->ToString(print_options);
}

} // namespace

absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(
Expand All @@ -200,8 +210,16 @@ absl::StatusOr<bool> TritonFusionNumericsVerifier::Run(

TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions(
*module, execution_threads, [&](const HloFusionInstruction& fusion) {
return VerifyTritonFusion(*opt_compile_util, fusion, config_,
debug_options);
auto key = CacheKeyForFusion(fusion);
if (auto it = fusion_result_cache_.find(key);
it != fusion_result_cache_.end()) {
++cache_hits_;
return it->second;
}
auto result = VerifyTritonFusion(*opt_compile_util, fusion, config_,
debug_options);
fusion_result_cache_[key] = result;
return result;
}));
return false;
}
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/transforms/triton_fusion_numerics_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_

#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
Expand Down Expand Up @@ -49,8 +52,17 @@ class TritonFusionNumericsVerifier : public HloModulePass {
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

using FusionCacheKey = std::string;

int CacheHitsForTestingOnly() const { return cache_hits_; }

private:
AutotuneConfig config_;

// In some models there are many identical fusions. These are cached to avoid
// expensive recomputations.
absl::flat_hash_map<FusionCacheKey, absl::Status> fusion_result_cache_;
int cache_hits_ = 0; // used for testing only.
};

namespace triton_fusion_numerics_pass_internal {
Expand Down
56 changes: 56 additions & 0 deletions xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/primitive_util.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/platform_util.h"
Expand Down Expand Up @@ -245,6 +246,61 @@ ENTRY main {
::testing::HasSubstr("Failed to compile Triton fusion"));
}

TEST_F(TritonFusionNumericsVerifierTest, CacheIsUsed) {
absl::string_view hlo_text = R"(
add {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] add(p0, p1)
}
max {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] maximum(p0, p1)
}
reduce_0 {
p = f32[16,16] parameter(0)
c = f32[] constant(0)
ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add
}
reduce_1 {
p = f32[16,16] parameter(0)
c = f32[] constant(0)
ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=max
}
// Identical to reduce_0.
reduce_2 {
p = f32[16,16] parameter(0)
c = f32[] constant(0)
ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add
}
ENTRY main {
p0 = f32[16,16] parameter(0)
p1 = f32[16,16] parameter(1)
p2 = f32[16,16] parameter(2)
r0 = f32[16] fusion(p0), kind=kCustom, calls=reduce_0, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}}
r1 = f32[16] fusion(p1), kind=kCustom, calls=reduce_1, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}}
r2 = f32[16] fusion(p2), kind=kCustom, calls=reduce_2, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}}
add_0_1 = f32[16] add(r0, r1)
ROOT add_0_2 = f32[16] add(add_0_1, r2)
}
)";

std::unique_ptr<HloModule> module =
*ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest());
AutotuneConfig autotune_config{
DeviceConfig{backend().default_stream_executor(), GetAllocator()},
module->config().debug_options()};
TritonFusionNumericsVerifier verifier(autotune_config);
TF_EXPECT_OK(RunHloPass(verifier, module.get()));
EXPECT_EQ(verifier.CacheHitsForTestingOnly(), 1);
}

INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite,
TritonFusionNumericsVerifierTest,
::testing::Values(F32, F16, BF16));
Expand Down

0 comments on commit cea48f6

Please sign in to comment.