You’re now watching this thread. If you’ve opted in to email or web notifications, you’ll be notified when there’s activity. Click again to stop watching or visit your profile to manage watched threads and notifications.
You’ve stopped watching this thread and will no longer receive emails or web notifications when there’s activity. Click again to start watching.
I have an M1 Pro with a 16-core GPU. When I run a shader with 8193 threads, atomic_thread_fence is violated across the boundary between thread 8191 (the last thread in the 7th threadgroup) and 8192 (the first thread in the 9th threadgroup).
I've attached the Metal and Swift files, but I'll repost the relevant kernel here. It's a function that launches N threads to iterate through a binary tree from the leaves, where the first thread to reach the parent terminates and the second one populates it with the sum of the nodes two children.
// clang-format off
void sum(device const int& size,
device const int* __restrict__ in,
device int* __restrict__ out,
device atomic_int* visited,
uint i [[thread_position_in_grid]]) {
// clang-format on
int val = in[i];
uint cur = (size + i - 1);
out[cur] = val;
atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst);
cur = (cur - 1) / 2;
int proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed);
while (proceed == 1) {
uint left = 2 * cur + 1;
uint right = 2 * cur + 2;
uint val_left = out[left];
uint val_right = out[right];
uint val_cur = val_left + val_right;
out[cur] = val_cur;
if (cur == 0) {
break;
}
cur = (cur - 1) / 2;
atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst);
proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed);
}
}
What I'm observing is that thread 8192 hits the atomic_fetch_add first and terminates, while thread 8191 hits it second (observes that thread 8192 had incremented it by 1) and proceeds into the loop. Thread 8191 reads out[16383] (which it populated with 8191) and out[16384] (which thread 8192 populated with 8192 prior to the atomic_thread_fence). Instead of reading 8192 from out[16384] though, it reads 0.
Maybe I'm missing something but this seems like a pretty clear violation of the atomic_thread_fence which (I thought) was supposed to guarantee that the write from thread 8192 to out[16384] would be visible to any thread observing the effects of the following atomic_fetch_add. Is atomic_fetch_add not a store operation? Modifying it to an atomic_store or atomic_exchange still results in the bug. Adding another atomic_thread_fence between the atomic_fetch_add and reading of out also doesn't change anything.
I only begin to observe this on grid sizes of 8193 and upwards. That's 9 threadgroups per grid, which I assume could be related to my M1 Pro GPU having 16 cores.
Running the same example on an A17 Pro GPU doesn't show any of this behavior up through a tested grid size of 4194303 (2^22-1), at which point testing larger grid sizes starts to run into other issues so I can't test anything larger.
Removing the atomic_thread_fences on both the M1 and A17 cause the test to fail at much smaller grid sizes, as expected.
//
// test.metal
// test
//
// Created by Teguh Hofstee on 12/18/24.
//
#include
using namespace metal;
constant os_log logger(/*subsystem=*/"com.metal.xyz", /*category=*/"abc");
[[kernel]]
// clang-format off
void init(device int& threadgroup_order) {
// clang-format on
threadgroup_order = 0;
}
[[kernel]]
// clang-format off
void sum(device const int& size,
device const int* __restrict__ in,
device int* __restrict__ out,
device atomic_int* visited,
device atomic_int* threadgroup_order,
uint i [[thread_position_in_grid]],
uint k [[thread_position_in_threadgroup]],
uint j [[threadgroup_position_in_grid]]) {
// clang-format on
if (k == 0) {
int order = atomic_fetch_add_explicit(threadgroup_order, 1, memory_order_relaxed);
logger.log_info("threadgroup %d executed %d", j, order);
}
int val = in[i];
uint cur = (size + i - 1);
out[cur] = val;
atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst);
int MAX_ITER = 1000;
cur = (cur - 1) / 2;
int proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed);
if (i == 8191 || i == 8192) {
logger.log_info("[%d] got %d", i, proceed);
}
while (proceed == 1 && MAX_ITER-- > 0) {
uint left = 2 * cur + 1;
uint right = 2 * cur + 2;
uint val_left = out[left];
uint val_right = out[right];
uint val_cur = val_left + val_right;
out[cur] = val_cur;
if (cur == 0) {
break;
}
cur = (cur - 1) / 2;
atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst);
proceed = atomic_fetch_add_explicit(&visited[cur], 1, memory_order_relaxed);
}
}
//
// main.swift
// test
//
// Created by Teguh Hofstee on 12/18/23.
//
import Metal
func makeBuffer(_ items: [T]) -> MTLBuffer? {
if items.count > 0 {
items.withUnsafeBytes { ptr in
device.makeBuffer(bytes: ptr.baseAddress!, length: ptr.count)
}
} else {
device.makeBuffer(length: MemoryLayout.stride)
}
}
let device = MTLCreateSystemDefaultDevice()!
let captureManager = MTLCaptureManager.shared()
let CAPTURE = true
if CAPTURE {
let captureDescriptor = MTLCaptureDescriptor()
captureDescriptor.captureObject = device
do {
try captureManager.startCapture(with: captureDescriptor)
} catch {
fatalError("error when trying to capture: \(error)")
}
}
let library = device.makeDefaultLibrary()!
print(library.functionNames)
let cmdQueue = device.makeCommandQueue()!
let cmdBuf = cmdQueue.makeCommandBuffer()!
let cmdEnc = cmdBuf.makeComputeCommandEncoder()!
let data: [Int32] = Array(1...8193)
var size: Int32 = Int32(data.count)
let input = makeBuffer(data)!
let output = device.makeBuffer(length: MemoryLayout.stride * (2 * data.count - 1))!
let visited = device.makeBuffer(length: MemoryLayout.stride * (data.count - 1))!
var order = device.makeBuffer(length: MemoryLayout.stride)
do {
let fn = library.makeFunction(name: "init")!
let ps = try! device.makeComputePipelineState(function: fn)
cmdEnc.setComputePipelineState(ps)
cmdEnc.setBuffer(order, offset: 0, index: 0)
let gridSize = MTLSizeMake(1, 1, 1)
var threadGroupSize = ps.maxTotalThreadsPerThreadgroup
if threadGroupSize > 1 {
threadGroupSize = 1
}
let threadgroupSize = MTLSizeMake(1, 1, 1)
cmdEnc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
}
do {
let fn = library.makeFunction(name: "sum")!
let ps = try! device.makeComputePipelineState(function: fn)
cmdEnc.setComputePipelineState(ps)
cmdEnc.setBytes(&size, length: MemoryLayout.stride, index: 0)
cmdEnc.setBuffer(input, offset: 0, index: 1)
cmdEnc.setBuffer(output, offset: 0, index: 2)
cmdEnc.setBuffer(visited, offset: 0, index: 3)
cmdEnc.setBuffer(order, offset: 0, index: 4)
let gridSize = MTLSizeMake(data.count, 1, 1)
var threadGroupSize = ps.maxTotalThreadsPerThreadgroup
if threadGroupSize > data.count {
threadGroupSize = data.count
}
let threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1)
cmdEnc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize)
print(gridSize)
}
cmdEnc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
print("done!")
let result = output.contents().bindMemory(to: Int32.self, capacity: 2 * data.count - 1)
for k in 0 ..< data.count - 1 {
if result[k] != result[2*k+1] + result[2*k+2] {
print("\(k) -> (\(2*k+1), \(2*k+2)): expected: \(result[2*k+1] + result[2*k+2]), actual: \(result[k])")
}
}
print("expected: \(data.reduce(0, +))")
print("actual: \(result[0])")
if CAPTURE {
captureManager.stopCapture()
}