#include 
using namespace metal;


// Dispatch the rectangle generating mesh functions — one generated rectangle per threadgroup
//
// The threadgroup position in the grid is the position of the generated rectangle
[[object, max_total_threadgroups_per_mesh_grid(1024)]]
void objectFunction(mesh_grid_properties meshGridProperties) {
  // some grid oconfigurations work, some don&#39;t
  //
  // 1D configurations work - all of the below works, so we can dispatch 512 threadgroups with no problem
  //
  // meshGridProperties.set_threadgroups_per_grid(uint3(32, 1, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(64, 1, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(128, 1, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(256, 1, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(512, 1, 1));
  
  // the 2D grids up to 16x8 or 32x5 work fine too!
  //
  // meshGridProperties.set_threadgroups_per_grid(uint3(32, 2, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(32, 4, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(32, 5, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(16, 2, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(16, 4, 1));
  // meshGridProperties.set_threadgroups_per_grid(uint3(16, 8, 1));
  
  // the 2D grid from 32x6 or 16x9 hang the GPU+WindowServer with
  //
  //   Execution of the command buffer was aborted due to an error during execution.
  //   Discarded (victim of GPU error/recovery) (00000005:kIOGPUCommandBufferCallbackErrorInnocentVictim)
  //
  //
  // meshGridProperties.set_threadgroups_per_grid(uint3(32, 6, 1));
  meshGridProperties.set_threadgroups_per_grid(uint3(16, 9, 1));
}

struct Vertex {
    float4 position [[position]];
    float3 color;
};

using mesh_t = metal::mesh;


// Generate a simple rectangle with vertices oriented like this:
//
//  2 - 3
//  | \ |
//  0 - 1
//
// The threadgroup position in the grid determines the position of the rectangle
[[mesh, max_total_threads_per_threadgroup(256)]]
void meshFunction(
    // the mesh object to produce
    mesh_t mesh,
//    // payload from the object shader (number of vertices to generate)
//    const object_data Payload&amp; payload [[payload]],
    // vertex id
    uint thread_id [[thread_index_in_threadgroup]],
    // threadgroup id
    uint2 group_id [[threadgroup_position_in_grid]],
    uint2 total_treadgroups [[threadgroups_per_grid]]
) {
  // rect parameters
  float rect_size = 0.05;
  float2 rect_origin = float2(-0.9, -0.9) +  (rect_size*1.1) * float2(group_id);
  float2 threadgroup_pos = float2(group_id + 1)/float2(total_treadgroups);
  
  // everythign is done from the first thread
  if (thread_id == 0) {
    // use the threadgroup position to make pretty colors
    float3 color_u = float3(1, 0, 0)*threadgroup_pos.x + float3(0, 1, 0)*(1 - threadgroup_pos.x);
    float3 color_v = float3(0.5, 0.5, 0)*threadgroup_pos.y + float3(0, 0, 1)*(1 - threadgroup_pos.y);
    float3 color = color_u + color_v;
    
    mesh.set_vertex(0, {float4(rect_origin, 0.0, 1.0), color});
    mesh.set_vertex(1, {float4(rect_origin + float2(rect_size, 0.0), 1.0), color});
    mesh.set_vertex(2, {float4(rect_origin + float2(0.0, rect_size), 1.0), color});
    mesh.set_vertex(3, {float4(rect_origin + float2(rect_size, rect_size), 1.0), color});
    
    // two triangles
    mesh.set_index(0, 0);
    mesh.set_index(1, 2);
    mesh.set_index(2, 1);

    mesh.set_index(3, 1);
    mesh.set_index(4, 2);
    mesh.set_index(5, 3);
    
    mesh.set_primitive_count(2);
  };
}


[[fragment]]
float4 fragmentFunction(Vertex in [[stage_in]]) {
  return float4(in.color, 1.0);
}
