Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subgroup Operations #2523

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f1ccbe
subgroup: Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hl…
exrook Sep 30, 2023
d8325a4
subgroup: subgroupBallot metal out
exrook Sep 30, 2023
a5e9a9e
subgroup: require GroupNonUnifomBallot capability
exrook Sep 30, 2023
a915b14
subgroup: Add subgroup invocation id and subgroup size builtins
exrook Oct 1, 2023
de6d371
subgroup: SubgroupInvocationId is only valid in compute stages
exrook Oct 2, 2023
820c72b
subgroup: expierment with subgroupBarrier() based on OpControlbarrier
exrook Oct 3, 2023
522e4db
subgroup: add statement for rest of subgroup ops
exrook Oct 14, 2023
789776d
subgroup: fix doc error on SubgroupBroadcast
exrook Oct 5, 2023
ad99c49
subgroup: wgsl-in and spv-out for subgroup operations
exrook Oct 5, 2023
4e818f4
subgroup: add optional predicate for subgroupBallot
exrook Oct 5, 2023
f2dd590
Renames SubgroupBroadcast => SubgroupGather and BroadcastMode => Gath…
Lichtso Oct 11, 2023
35d7105
General fixes.
Lichtso Oct 14, 2023
fc93567
Adds BuiltIn::NumSubgroups, BuiltIn::SubgroupId.
Lichtso Oct 14, 2023
06ed9ab
Adds GatherMode::Shuffle, GatherMode::ShuffleDown, GatherMode::Shuffl…
Lichtso Oct 14, 2023
325908d
Implements all frontends and backends.
Lichtso Oct 14, 2023
6d3a9e8
Adjusts metal backend test_stack_size().
Lichtso Oct 14, 2023
9995399
Adds test and snapshots.
Lichtso Oct 20, 2023
cca0b24
subgroup: fix spv-in
exrook Oct 21, 2023
caae497
subgroup: Add 0 arg subgroupBallot to tests and fix wgsl-out whitespace
exrook Oct 21, 2023
f7fcf32
subgroup: Add spv-in test
exrook Oct 21, 2023
0bdd3b5
subgroup: resolve fixmes & fix typo
exrook Oct 21, 2023
698789e
subgroup: Add subgroup capability
exrook Oct 21, 2023
c9a70af
subgroup: Treat subgroup operation results as non-uniform
exrook Oct 18, 2023
406a7dc
subgroup: refactor wgsl subgroup gather parsing
exrook Oct 21, 2023
36844c8
subgroup: doc comments for subgroup `Statement`s and `Expression`s
exrook Oct 21, 2023
63e2da9
subgroup: add validation for each subgroup operation type
exrook Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {

// Validate the IR before compaction.
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
.subgroup_stages(naga::valid::ShaderStages::all())
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
.validate(&module)
{
Ok(info) => Some(info),
Expand Down
90 changes: 90 additions & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
}
self.emits.push((id, result));
"SubgroupBallot"
}
S::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
"SubgroupAll"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
"SubgroupAny"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
"SubgroupAdd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
"SubgroupMul"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
"SubgroupMax"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
"SubgroupMin"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
"SubgroupAnd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
"SubgroupOr"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
"SubgroupXor"
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixExclusiveAdd",
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixExclusiveMul",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixInclusiveAdd",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixInclusiveMul",
_ => unimplemented!(),
}
}
S::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.dependencies.push((id, index, "index"))
}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match mode {
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
}
}
};
// Set the last node to the merge node
last_node = merge_id;
Expand Down Expand Up @@ -586,6 +674,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};

// give uniform expressions an outline
Expand Down
23 changes: 23 additions & 0 deletions src/back/glsl/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ bitflags::bitflags! {
const IMAGE_SIZE = 1 << 20;
/// Dual source blending
const DUAL_SOURCE_BLENDING = 1 << 21;
/// Subgroup operations
const SUBGROUP_OPERATIONS = 1 << 22;
}
}

Expand Down Expand Up @@ -106,6 +108,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
Expand Down Expand Up @@ -235,6 +238,22 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_blend_func_extended : require")?;
}

if self.0.contains(Features::SUBGROUP_OPERATIONS) {
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_arithmetic : require"
)?;
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
)?;
}

Ok(())
}
}
Expand Down Expand Up @@ -455,6 +474,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
Expression::SubgroupBallotResult |
Expression::SubgroupOperationResult { .. } => {
features.request(Features::SUBGROUP_OPERATIONS)
}
_ => {}
}
}
Expand Down
131 changes: 130 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

write!(self.out, "subgroupBallot(")?;
match predicate {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
}

Ok(())
Expand Down Expand Up @@ -3423,7 +3542,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::SubgroupOperationResult { .. }
| Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down Expand Up @@ -3980,6 +4101,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
}
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
Expand Down Expand Up @@ -4156,6 +4280,11 @@ const fn glsl_built_in(
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
// subgroup
Bi::NumSubgroups => "gl_NumSubgroups",
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
Self::SubgroupSize
| Self::SubgroupInvocationId
| Self::NumSubgroups
| Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
Expand Down
Loading
Loading