Post not yet marked as solved
Hi there, I'm trying to convert my CoreML model (it's actually .mlpackage) to .mpsgraphpackage so I can test the performance of my model with MPSGraph API. I run the code you provide in terminal but it just does nothing (command execute forever). In Activity Monitor terminal uses 0.0% of CPU. I
My XCode version 15.0 beta 6 (15A5219j) and running in OS Sonoma 14.0 Beta (23A5312d)
Post not yet marked as solved
Hi all,
I am new to the metal Pytorch. I am trying to implement the demo code of customized ops in Pytorch. The demo code
However, I think the torch namespace doesn't have "mps" now? The "torch::mps" cannot be found if I try to compile the .mm file into PyTorch cpp extension.
After some digging, I think everybody is using Aten namespace with "at::"? How can I use functions in mps and make this demo code work?
Thanks in advance.
Error message
In file included from /Users/ethan/Downloads/CustomizingAPyTorchOperation/CustomSoftshrink.mm:10:
/Users/ethan/Downloads/CustomizingAPyTorchOperation/CustomSoftshrink.h:11:30: warning: ISO C++11 does not allow conversion from string literal to 'char *' [-Wwritable-strings]
static char *CUSTOM_KERNEL = R"MPS_SOFTSHRINK(
^
/Users/ethan/Downloads/CustomizingAPyTorchOperation/CustomSoftshrink.mm:43:53: error: no member named 'mps' in namespace 'torch'
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
~~~~~~~^
/Users/ethan/Downloads/CustomizingAPyTorchOperation/CustomSoftshrink.mm:47:47: error: no member named 'mps' in namespace 'torch'
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
~~~~~~~^
/Users/ethan/Downloads/CustomizingAPyTorchOperation/CustomSoftshrink.mm:76:20: error: no member named 'mps' in namespace 'torch'
torch::mps::commit();
~~~~~~~^
1 warning and 3 errors generated.
ninja: build stopped: subcommand failed.
CustomSoftshrink.mm code
/*
See the LICENSE.txt file for this sample’s licensing information.
Abstract:
The code that registers a PyTorch custom operation.
*/
#include <torch/extension.h>
#include "CustomSoftshrink.h"
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
torch::Tensor& dispatchSoftShrinkKernel(const torch::Tensor& input, torch::Tensor& output, float lambda) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSError *error = nil;
// Set the number of threads equal to the number of elements within the input tensor.
int numThreads = input.numel();
// Load the custom soft shrink shader.
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL]
options:nil
error:&error];
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
std::string kernel_name = std::string("softshrink_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
id<MTLFunction> customSoftShrinkFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
TORCH_CHECK(customSoftShrinkFunction, "Failed to create function state object for ", kernel_name.c_str());
// Create a compute pipeline state object for the soft shrink kernel.
id<MTLComputePipelineState> softShrinkPSO = [device newComputePipelineStateWithFunction:customSoftShrinkFunction error:&error];
TORCH_CHECK(softShrinkPSO, error.localizedDescription.UTF8String);
// Get a reference to the command buffer for the MPS stream.
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
dispatch_sync(serialQueue, ^(){
// Start a compute pass.
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
// Encode the pipeline state object and its parameters.
[computeEncoder setComputePipelineState:softShrinkPSO];
[computeEncoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0];
[computeEncoder setBuffer:getMTLBufferStorage(output) offset:output.storage_offset() * output.element_size() atIndex:1];
[computeEncoder setBytes:&lambda length:sizeof(float) atIndex:2];
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
// Calculate a thread group size.
NSUInteger threadGroupSize = softShrinkPSO.maxTotalThreadsPerThreadgroup;
if (threadGroupSize > numThreads) {
threadGroupSize = numThreads;
}
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
// Encode the compute command.
[computeEncoder dispatchThreads:gridSize
threadsPerThreadgroup:threadgroupSize];
[computeEncoder endEncoding];
// Commit the work.
torch::mps::commit();
});
}
return output;
}
// C++ op dispatching the Metal soft shrink shader.
torch::Tensor mps_softshrink(const torch::Tensor &input, float lambda = 0.5) {
// Check whether the input tensor resides on the MPS device and whether it's contiguous.
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
// Check the supported data types for soft shrink.
TORCH_CHECK(input.scalar_type() == torch::kFloat ||
input.scalar_type() == torch::kHalf, "Unsupported data type: ", input.scalar_type());
// Allocate the output, same shape as the input.
torch::Tensor output = torch::empty_like(input);
return dispatchSoftShrinkKernel(input, output, lambda);
}
// Create Python bindings for the Objective-C++ code.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mps_softshrink", &mps_softshrink);
}
Hello,
I'm interested in trying the new JAX Metal plug-in and followed the steps in https://developer.apple.com/metal/jax/. Upon installation, I don't see any difference between the backend device detected by JAX and a pure CPU setup:
>>> import jax
>>> jax.devices()
[CpuDevice(id=0)]
>>> jax.devices()[0].platform
'cpu'
>>> jax.devices()[0].device_kind
'cpu'
>>> jax.devices()[0].client.platform
'cpu'
>>> jax.devices()[0].client.runtime_type
'tfrt'
Is this really using a Metal backend? How can I determine for sure?
Thank you!