I understand we can use MPSImageBatch
as input to
[MPSNNGraph encodeBatchToCommandBuffer: ...]
method.
That being said, all inputs to the MPSNNGraph
need to be encapsulated in a MPSImage
(s).
Suppose I have an machine learning application that trains/infers on thousands of input data where each input has 4 feature channels. Metal Performance Shaders is chosen as the primary AI backbone for real-time use.
Due to the nature of encodeBatchToCommandBuffer
method, I will have to create a MTLTexture
first as a 2D texture array. The texture has pixel width of 1, height of 1 and pixel format being RGBA32f.
The general set up will be:
#define NumInputDims 4
MPSImageBatch * infBatch = @[];
const uint32_t totalFeatureSets = N;
// Each slice is 4 (RGBA) channels.
const uint32_t totalSlices = (totalFeatureSets * NumInputDims + 3) / 4;
MTLTextureDescriptor * descriptor = [MTLTextureDescriptor texture2DDescriptorWithPixelFormat: MTLPixelFormatRGBA32Float
width: 1
height: 1
mipmapped: NO];
descriptor.textureType = MTLTextureType2DArray
descriptor.arrayLength = totalSlices;
id<MTLTexture> texture = [mDevice newTextureWithDescriptor: descriptor];
// bytes per row is `4 * sizeof(float)` since we're doing one pixel of RGBA32F.
[texture replaceRegion: MTLRegionMake3D(0, 0, 0, 1, 1, totalSlices)
mipmapLevel: 0
withBytes: inputFeatureBuffers[0].data()
bytesPerRow: 4 * sizeof(float)];
MPSImage * infQueryImage = [[MPSImage alloc] initWithTexture: texture
featureChannels: NumInputDims];
infBatch = [infBatch arrayByAddingObject: infQueryImage];
The training/inference will be:
MPSNNGraph * mInferenceGraph = /*some MPSNNGraph setup*/;
MPSImageBatch * returnImage = [mInferenceGraph encodeBatchToCommandBuffer: commandBuffer
sourceImages: @[infBatch]
sourceStates: nil
intermediateImages: nil
destinationStates: nil];
// Commit and wait...
// Read the return image for the inferred result.
As you can see, the setup is really ad hoc - a lot of 1x1 pixels just for this sole purpose.
Is there any better way I can achieve the same result while still on Metal Performance Shaders? I guess a further question will be: can MPS handle general machine learning cases other than CNN? I can see the APIs are revolved around convolution network, both from online documentations and header files.
Any response will be helpful, thank you.