Dear developers, i need support to develop a simple computation on the GPU. I would like to perform matrix multiplication: this will be good with metal-cpp because i need to export as cpp library. Following documentation:
file Multiply.metal :
kernel void multiply(device float *pMatA, device float *pMatB
, device float *pMatC, device float *pMatR)
{
simdgroup_float8x8 sgMatA;
simdgroup_float8x8 sgMatB;
simdgroup_float8x8 sgMatR;
simdgroup_load(sgMatA, pMatA);
simdgroup_load(sgMatB, pMatB);
simdgroup_multiply(sgMatR, sgMatA, sgMatB);
simdgroup_store(sgMatR, pMatR);
}
File Multiply.hpp
#include <Foundation/Foundation.hpp>
#include <Metal/Metal.hpp>
class Multiply {
public:
MTL::Device* m_device;
MTL::ComputePipelineState *m_add_function_pso;
MTL::CommandQueue *m_command_queue;
MTL::Buffer *m_buffer_A;
MTL::Buffer *m_buffer_B;
MTL::Buffer *m_buffer_result;
void init_with_device(MTL::Device*);
void prepare_data();
void send_compute_command();
private:
void generate_random_float_data(MTL::Buffer* buffer);
void encode_dot_command(MTL::ComputeCommandEncoder* compute_encoder);
void verify_results();
};
File Multiply.cpp
#include <iostream>
#include "Multiply.hpp"
const unsigned int array_length = 1 << 5;
const unsigned int buffer_size = array_length * sizeof(float);
void Multiply::init_with_device(MTL::Device* device){ m_device = device;
NS::Error* error;
auto default_library = m_device->newDefaultLibrary(); if(!default_library){
std::cerr << "Failed to load default library."; std::exit(-1);
}
auto function_name = NS::String::string("multiply", NS::ASCIIStringEncoding);
auto dot_function = default_library->newFunction(function_name);
if(!dot_function){
std::cerr << "Failed to find the dot function.";
}
m_dot_function_pso = m_device->newComputePipelineState(dot_function, &error);
m_command_queue = m_device->newCommandQueue(); };
void Multiply::prepare_data(){
m_buffer_A = m_device->newBuffer(buffer_size, MTL::ResourceStorageModeShared);
m_buffer_B = m_device->newBuffer(buffer_size, MTL::ResourceStorageModeShared);
m_buffer_result = m_device->newBuffer(buffer_size, MTL::ResourceStorageModeShared); generate_random_float_data(m_buffer_A); generate_random_float_data(m_buffer_B);
}
void Multiply::generate_random_float_data(MTL::Buffer* buffer)
{
float* data_ptr = (float*)buffer->contents();
for (unsigned long index = 0; index < array_length; index++)
{
for(unsigned long index2 =0; index2 < array_length; index2++)
{
data_ptr[index][index2] = (float)rand() / (float)(RAND_MAX);
}
}
void Multiply::send_compute_command() {
MTL::CommandBuffer* command_buffer = m_command_queue->commandBuffer();
// assert(command_buffer != nullptr); MTL::ComputeCommandEncoder* compute_encoder = command_buffer->computeCommandEncoder(); encode_dot_command(compute_encoder); compute_encoder->endEncoding();// MTL::CommandBufferStatus status = command_buffer->status();
// std::cout << status << std::endl;
command_buffer->commit();
command_buffer->waitUntilCompleted(); verify_results();
}
void Multiply::encode_dot_command(MTL::ComputeCommandEncoder* compute_encoder){
compute_encoder->setComputePipelineState(m_dot_function_pso); compute_encoder->setBuffer(m_buffer_A, 0, 0); compute_encoder->setBuffer(m_buffer_B, 0, 1); compute_encoder->setBuffer(m_buffer_result, 0, 2); MTL::Size grid_size = MTL::Size(array_length, 1, 1); NS::UInteger thread_group_size_ = m_dot_function_pso->maxTotalThreadsPerThreadgroup(); if(thread_group_size_ > array_length){ thread_group_size_ = array_length;
}
MTL::Size thread_group_size = MTL::Size(thread_group_size_, 1, 1); compute_encoder->dispatchThreads(grid_size, thread_group_size);
}
void Multiply::verify_results(){
auto a = (float*) m_buffer_A->contents();
auto b = (float*) m_buffer_B->contents();
auto result = (float*) m_buffer_result->contents();
for (unsigned long index = 0; index < array_length; index++) {
for (unsigned long index2 = 0; index < array_length; index2++) {
if (result[index][index2] != (a[index][index2] * b[index][index2]))
{
std::cout << "Comput ERROR: index=" << index << "result=" << result[index][index2] << "vs " << a[index][index2] + b[index][index2] << "=a*b\n"; assert(result[index][index2] == (a[index][index2] * b[index][index2]));
}
} std::cout << "Compute results as expected\n";}}
Is all this implementation correct? Can someone kindly give suggestions about speed improvement or other solutions? Thank you in advance.