Skip to content

Commit

Permalink
Add support for 64 bit integer atomic operations in shaders.
Browse files Browse the repository at this point in the history
Add the following flags to `wgpu_types::Features`:

- `SHADER_INT64_ATOMIC_ALL_OPS` enables all atomic operations on `atomic<i64>` and
  `atomic<u64>` values.

- `SHADER_INT64_ATOMIC_MIN_MAX` is a subset of the above, enabling only
  `AtomicFunction::Min` and `AtomicFunction::Max` operations on `atomic<i64>` and
  `atomic<u64>` values in the `Storage` address space. These are the only 64-bit
  atomic operations available on Metal as of 3.1.

Add corresponding flags to `naga::valid::Capabilities`. These are supported by the
WGSL front end, and all Naga backends.

Platform support:

- On Direct3d 12, in `D3D12_FEATURE_DATA_D3D12_OPTIONS9`, if
  `AtomicInt64OnTypedResourceSupported` and `AtomicInt64OnGroupSharedSupported` are
  both available, then both wgpu features described above are available.

- On Metal, `SHADER_INT64_ATOMIC_MIN_MAX` is available on Apple9 hardware, and on
  hardware that advertises both Apple8 and Mac2 support. This also requires Metal
  Shading Language 2.4 or later. Metal does not yet support the more general
  `SHADER_INT64_ATOMIC_ALL_OPS`.

- On Vulkan, if the `VK_KHR_shader_atomic_int64` extension is available with both the
  `shader_buffer_int64_atomics` and `shader_shared_int64_atomics` features, then both
  wgpu features described above are available.
  • Loading branch information
atlv24 authored and jimblandy committed Jun 9, 2024
1 parent 583cc6a commit 9a0a381
Show file tree
Hide file tree
Showing 51 changed files with 1,885 additions and 160 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ Supported on Vulkan, DX12 (requires DXC) and Metal (with MSL 2.3+ support).

By @atlv24 and @cwfitzgerald in [#5154](https://github.com/gfx-rs/wgpu/pull/5154)

64 bit integer atomic support in shaders. By @atlv24 in [#5383](https://github.com/gfx-rs/wgpu/pull/5383)

### New features

#### General
Expand Down
4 changes: 3 additions & 1 deletion naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ impl StatementGraph {
value,
result,
} => {
self.emits.push((id, result));
if let Some(result) = result {
self.emits.push((id, result));
}
self.dependencies.push((id, pointer, "pointer"));
self.dependencies.push((id, value, "value"));
if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
Expand Down
12 changes: 7 additions & 5 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2368,11 +2368,13 @@ impl<'a, W: Write> Writer<'a, W> {
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.resolve_type(result, &self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
if let Some(result) = result {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.resolve_type(result, &self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
}

let fun_str = fun.to_glsl();
write!(self.out, "atomic{fun_str}(")?;
Expand Down
32 changes: 24 additions & 8 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1919,11 +1919,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
let res_name = match result {
None => None,
Some(result) => {
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => {
self.write_type(module, handle)?
}
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name}; ")?;
Some((result, name))
}
};

Expand All @@ -1934,7 +1943,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
.unwrap();

let fun_str = fun.to_hlsl_suffix();
write!(self.out, " {res_name}; ")?;
match pointer_space {
crate::AddressSpace::WorkGroup => {
write!(self.out, "Interlocked{fun_str}(")?;
Expand Down Expand Up @@ -1970,8 +1978,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
_ => {}
}
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ", {res_name});")?;
self.named_expressions.insert(result, res_name);

// The `original_value` out parameter is optional for all the
// `Interlocked` functions we generate other than
// `InterlockedExchange`.
if let Some((result, name)) = res_name {
write!(self.out, ", {name}")?;
self.named_expressions.insert(result, name);
}

writeln!(self.out, ");")?;
}
Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
Expand Down
23 changes: 5 additions & 18 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/*!
Backend for [MSL][msl] (Metal Shading Language).
This backend does not support the [`SHADER_INT64_ATOMIC_ALL_OPS`][all-atom]
capability.
## Binding model
Metal's bindings are flat per resource. Since there isn't an obvious mapping
Expand All @@ -24,6 +27,8 @@ For the result type, if it's a structure, we re-compose it with a temporary valu
holding the result.
[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
[all-atom]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
*/

use crate::{arena::Handle, proc::index, valid::ModuleInfo};
Expand Down Expand Up @@ -661,21 +666,3 @@ fn test_error_size() {
use std::mem::size_of;
assert_eq!(size_of::<Error>(), 32);
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Self::InclusiveOr => "fetch_or",
Self::ExclusiveOr => "fetch_xor",
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
}
}
47 changes: 43 additions & 4 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3058,11 +3058,22 @@ impl<W: Write> Writer<W> {
value,
result,
} => {
// This backend supports `SHADER_INT64_ATOMIC_MIN_MAX` but not
// `SHADER_INT64_ATOMIC_ALL_OPS`, so we can assume that if `result` is
// `Some`, we are not operating on a 64-bit value, and that if we are
// operating on a 64-bit value, `result` is `None`.
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &res_name)?;
self.named_expressions.insert(result, res_name);
let fun_str = fun.to_msl()?;
let fun_str = if let Some(result) = result {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &res_name)?;
self.named_expressions.insert(result, res_name);
fun.to_msl()?
} else if context.expression.resolve_type(value).scalar_width() == Some(8) {
fun.to_msl_64_bit()?
} else {
fun.to_msl()?
};

self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
writeln!(self.out, ";")?;
Expand Down Expand Up @@ -5914,3 +5925,31 @@ fn test_stack_size() {
}
}
}

impl crate::AtomicFunction {
fn to_msl(self) -> Result<&'static str, Error> {
Ok(match self {
Self::Add => "fetch_add",
Self::Subtract => "fetch_sub",
Self::And => "fetch_and",
Self::InclusiveOr => "fetch_or",
Self::ExclusiveOr => "fetch_xor",
Self::Min => "fetch_min",
Self::Max => "fetch_max",
Self::Exchange { compare: None } => "exchange",
Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
))?,
})
}

fn to_msl_64_bit(self) -> Result<&'static str, Error> {
Ok(match self {
Self::Min => "min",
Self::Max => "max",
_ => Err(Error::FeatureNotImplemented(
"64-bit atomic operation other than min/max".to_string(),
))?,
})
}
}
4 changes: 3 additions & 1 deletion naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
} => {
adjust(pointer);
adjust(value);
adjust(result);
if let Some(ref mut result) = *result {
adjust(result);
}
match *fun {
crate::AtomicFunction::Exchange {
compare: Some(ref mut compare),
Expand Down
12 changes: 9 additions & 3 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2423,9 +2423,15 @@ impl<'w> BlockContext<'w> {
result,
} => {
let id = self.gen_id();
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);

self.cached[result] = id;
// Compare-and-exchange operations produce a struct result,
// so use `result`'s type if it is available. For no-result
// operations, fall back to `value`'s type.
let result_type_id =
self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);

if let Some(result) = result {
self.cached[result] = id;
}

let pointer_id =
match self.write_expression_pointer(pointer, &mut block, None)? {
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,9 @@ impl Writer {
crate::TypeInner::RayQuery => {
self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
}
crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
}
_ => {}
}
Ok(())
Expand Down
8 changes: 5 additions & 3 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,11 @@ impl<W: Write> Writer<W> {
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
if let Some(result) = result {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
}

let fun_str = fun.to_wgsl();
write!(self.out, "atomic{fun_str}(")?;
Expand Down
8 changes: 6 additions & 2 deletions naga/src/compact/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ impl FunctionTracer<'_> {
self.expressions_used.insert(pointer);
self.trace_atomic_function(fun);
self.expressions_used.insert(value);
self.expressions_used.insert(result);
if let Some(result) = result {
self.expressions_used.insert(result);
}
}
St::WorkGroupUniformLoad { pointer, result } => {
self.expressions_used.insert(pointer);
Expand Down Expand Up @@ -255,7 +257,9 @@ impl FunctionMap {
adjust(pointer);
self.adjust_atomic_function(fun);
adjust(value);
adjust(result);
if let Some(ref mut result) = *result {
adjust(result);
}
}
St::WorkGroupUniformLoad {
ref mut pointer,
Expand Down
3 changes: 2 additions & 1 deletion naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Int8,
spirv::Capability::Int16,
spirv::Capability::Int64,
spirv::Capability::Int64Atomics,
spirv::Capability::Float16,
spirv::Capability::Float64,
spirv::Capability::Geometry,
Expand Down Expand Up @@ -4028,7 +4029,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
pointer: p_lexp_handle,
fun: crate::AtomicFunction::Add,
value: one_lexp_handle,
result: r_lexp_handle,
result: Some(r_lexp_handle),
};
block.push(stmt, span);
}
Expand Down
4 changes: 2 additions & 2 deletions naga/src/front/type_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ impl crate::Module {
name: Some("exchanged".to_string()),
ty: bool_ty,
binding: None,
offset: 4,
offset: scalar.width as u32,
},
],
span: 8,
span: scalar.width as u32 * 2,
},
}
}
Expand Down
47 changes: 31 additions & 16 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
function,
arguments,
&mut ctx.as_expression(block, &mut emitter),
true,
)?;
block.extend(emitter.finish(&ctx.function.expressions));
return Ok(());
Expand Down Expand Up @@ -1747,7 +1748,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ref arguments,
} => {
let handle = self
.call(span, function, arguments, ctx)?
.call(span, function, arguments, ctx, false)?
.ok_or(Error::FunctionReturnsVoid(function.span))?;
return Ok(Typed::Plain(handle));
}
Expand Down Expand Up @@ -1941,6 +1942,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
function: &ast::Ident<'source>,
arguments: &[Handle<ast::Expression<'source>>],
ctx: &mut ExpressionContext<'source, '_, '_>,
is_statement: bool,
) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
match ctx.globals.get(function.name) {
Some(&LoweredGlobalDecl::Type(ty)) => {
Expand Down Expand Up @@ -2086,7 +2088,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
self.subgroup_gather_helper(span, mode, arguments, ctx)?,
));
} else if let Some(fun) = crate::AtomicFunction::map(function.name) {
return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?));
return self.atomic_helper(span, fun, arguments, is_statement, ctx);
} else {
match function.name {
"select" => {
Expand Down Expand Up @@ -2168,7 +2170,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
compare: Some(compare),
},
value,
result,
result: Some(result),
},
span,
);
Expand Down Expand Up @@ -2459,25 +2461,38 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
span: Span,
fun: crate::AtomicFunction,
args: &[Handle<ast::Expression<'source>>],
is_statement: bool,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
let mut args = ctx.prepare_args(args, 2, span);

let pointer = self.atomic_pointer(args.next()?, ctx)?;

let value = args.next()?;
let value = self.expression(value, ctx)?;
let ty = ctx.register_type(value)?;

let value = self.expression(args.next()?, ctx)?;
let value_inner = resolve_inner!(ctx, value);
args.finish()?;

let result = ctx.interrupt_emitter(
crate::Expression::AtomicResult {
ty,
comparison: false,
},
span,
)?;
// If we don't use the return value of a 64-bit `min` or `max`
// operation, generate a no-result form of the `Atomic` statement, so
// that we can pass validation with only `SHADER_INT64_ATOMIC_MIN_MAX`
// whenever possible.
let is_64_bit_min_max =
matches!(fun, crate::AtomicFunction::Min | crate::AtomicFunction::Max)
&& matches!(
*value_inner,
crate::TypeInner::Scalar(crate::Scalar { width: 8, .. })
);
let result = if is_64_bit_min_max && is_statement {
None
} else {
let ty = ctx.register_type(value)?;
Some(ctx.interrupt_emitter(
crate::Expression::AtomicResult {
ty,
comparison: false,
},
span,
)?)
};
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::Atomic {
Expand Down
Loading

0 comments on commit 9a0a381

Please sign in to comment.