metal encoder dispatch based on value of an atomic counter

I'm trying to find a way to reduce synchronization time between two compute shader calls, where one dispatch depends on an atomic counter from the other.

Example: If I have two metal kernels, select and execute, select is looking through the numbers-buffer and stores the index of all numbers < 10 in a new buffer selectedNumberIndices by using an atomic counter. execute is then run counter number of times do do something with those selected indices.

kernel void select (device atomic_uint &counter,
                    device uint *numbers,
                    device uint *selectedNumberIndices,
                    uint id [[thread_position_in_grid]]) {
    if(numbers[id] < 10) {
        uint idx = atomic_fetch_add_explicit(&counter, 1, memory_order_relaxed);
        selectedNumbers[idx] = id;
    }
}

kernel void execute (device uint *selectedNumberIndices,
                     uint id [[thread_position_in_grid]]) {
    // do something #counter number of times
}

currently I can do this by using .waitUntilCompleted() between the dispatches to ensure I get accurate results, something like:

// select
buffer = queue.makeCommandBuffer()!
encoder = buffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(selectState)
encoder.setBuffer(counterBuffer, offset: 0, index: 0)
encoder.setBuffer(numbersBuffer, offset: 0, index: 1)
encoder.setBuffer(selectedNumberIndicesBuffer, index: 2)
encoder.dispatchThreads(.init(width: Int(numbersCount), height: 1, depth: 1),
                        threadsPerThreadgroup: .init(width: selectState.threadExecutionWidth, height: 1, depth: 1))
encoder.endEncoding()
buffer.commit()

// wait
buffer.waitUntilCompleted()

// execute
buffer = queue.makeCommandBuffer()!
encoder = buffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(executeState)
encoder.setBuffer(selectedNumberIndicesBuffer, index: 0)
var counterValue: uint = 0 // extract the value of the atomic counter
counterBuffer.contents().copyBytes(to: &counterValue, count: MemoryLayout<UInt32>.stride)
encoder.dispatchThreads(.init(width: Int(counterValue), height: 1, depth: 1),
                        threadsPerThreadgroup: .init(width: executeState.threadExecutionWidth, height: 1, depth: 1))
encoder.endEncoding()
buffer.commit()

My question is if there is any way I can have this same functionality without the costly buffer.waitUntilCompleted() call? Or am I going about this in completely the wrong way, or missing something else?

Answered by gopatrik in 762836022
dispatchThreadgroups(indirectBuffer:indirectBufferOffset:threadsPerThreadgroup:)

solves this.

does anyone know why threadsPerThreadgroup: executeState.threadExecutionWidth literally freezes my mac if I use it with the indirectBuffer dispatch?

(using MTLSize(1,1,1) works fine); while executeState.threadExecutionWidth works well if I dispatch normally.

Trying to understand how to encode this into an indirect command buffer...

Accepted Answer
dispatchThreadgroups(indirectBuffer:indirectBufferOffset:threadsPerThreadgroup:)

solves this.

does anyone know why threadsPerThreadgroup: executeState.threadExecutionWidth literally freezes my mac if I use it with the indirectBuffer dispatch?

(using MTLSize(1,1,1) works fine); while executeState.threadExecutionWidth works well if I dispatch normally.

metal encoder dispatch based on value of an atomic counter
 
 
Q