MPSImage image processing
MPSImage image processing
This is fairly easy to do. You'll just need to setup rendering to "upsampled" texture, and then cover it with full viewport quad. Vertex shader would be trivial, and to the fragment shader you'd need to attach our "source" texture. And then, for each fragment do something like (assuming that fragment position is in "position"):
return source.read( uint2( position.xy ) / 2 );
Regards
Michal
Alternatively, you could use a compute kernel that writes all four output texels for a given input texel. The disadvantage of using a fragment shader is that you can only output one value per invocation.
Below are the Metal kernel and host code to dispatch the kernel that should allow you to do an upsample (with nearest filter i.e. just replicate the data) an MPS Image. Hopefully this is of help.
Here is the Metal kernel code:
#include <simd/simd.h>
using namespace metal;
// Upsample kernel for 2D array textures
kernel void upsample_kernel_2darray(texture2d_array<float, access::sample> src [[ texture(0) ]],
texture2d_array<float, access::write> dst [[ texture(1) ]],
uint3 tid [[thread_position_in_grid]])
{
if (tid.x >= dst.get_width() || tid.y >= dst.get_height())
return;
uint2 srcTid = tid.xy / 2;
// Create sampler
constexpr sampler sampler(address::clamp_to_edge, filter::nearest, coord::pixel);
// Read one pixel from src
float4 value = src.sample(sampler, static_cast<float2>(srcTid.xy), tid.z);
// Write value into dst
dst.write(value, tid.xy, tid.z);
}
// Upsample kernel for 2D textures
kernel void upsample_kernel_2d(texture2d<float, access::sample> src [[ texture(0) ]],
texture2d<float, access::write> dst [[ texture(1) ]],
uint2 tid [[thread_position_in_grid]])
{
if (tid.x >= dst.get_width() || tid.y >= dst.get_height())
return;
uint2 srcTid = tid / 2;
// Create sampler
constexpr sampler sampler(address::clamp_to_edge, filter::nearest, coord::pixel);
// Read one pixel from src
float4 value = src.sample(sampler, static_cast<float2>(srcTid));
// Write value into dst
dst.write(value, tid);
}
Here is the main app code:
// Get underlying MTLTextures from MPSImages
id <MTLTexture> underlyingSrcTexture = [srcMPSImage texture];
id <MTLTexture> underlyingDstTexture = [dstMPSImage texture];
NSUInteger srcWidth = [underlyingSrcTexture width];
NSUInteger srcHeight = [underlyingSrcTexture height];
NSUInteger srcArrayLength = [underlyingSrcTexture arrayLength];
// Prepare to dispatch the upsample compute function
static dispatch_once_t onceToken;
static id <MTLFunction> upsampleFunction = nil;
static id <MTLComputePipelineState> computePipelineState = nil;
// Get upsample compute function from default library
dispatch_once (&onceToken, ^{
upsampleFunction = [defaultLibrary newFunctionWithName:@"upsample_kernel_2darray"];
computePipelineState = [device newComputePipelineStateWithFunction:upsampleFunction error:nil];
});
assert(upsampleFunction);
assert(computePipelineState);
assert([underlyingDstTexture width] >= 2 * [underlyingSrcTexture width]);
assert([underlyingDstTexture height] >= 2 * [underlyingSrcTexture height]);
assert([underlyingDstTexture arrayLength] >= [underlyingSrcTexture arrayLength]);
// Set up a compute command encoder
id <MTLComputeCommandEncoder> computeCommandEncoder = [commandBuffer computeCommandEncoder];
[computeCommandEncoder setComputePipelineState:computePipelineState];
[computeCommandEncoder setTexture:underlyingSrcTexture atIndex:0];
[computeCommandEncoder setTexture:underlyingDstTexture atIndex:1];
// Dispatch the upsample compute function
MTLSize numThreadgroups = {(2 * srcWidth + 7) / 8, (2 * srcHeight + 7) / 8, srcArrayLength};
MTLSize numThreadsPerThreadgroup = {8, 8, 1};
printf("Executing with numThreadgroups = {%lu %lu %lu}, numThreadsPerThreadgroup = {%lu %lu %lu}\n",
(unsigned long)numThreadgroups.width, (unsigned long)numThreadgroups.height, (unsigned long)numThreadgroups.depth,
(unsigned long)numThreadsPerThreadgroup.width, (unsigned long)numThreadsPerThreadgroup.height, (unsigned long)numThreadsPerThreadgroup.depth);
[computeCommandEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:numThreadsPerThreadgroup];
[computeCommandEncoder endEncoding];
[commandBuffer commit];
[commandBuffer waitUntilCompleted];
Thank you very much.