//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2025 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------

#import "Renderer.h"
#import "DXRCompiler.h"
#import <simd/simd.h>
#import <metal_irconverter_runtime/metal_irconverter_runtime.h>

static const NSUInteger kMaxBuffersInFlight = 1;
static const NSUInteger kTextureWidth = 800 * 2;
static const NSUInteger kTextureHeight = 600 * 2;

// We'll be drawing triangles in a 3 x 3 grid. So a total of 9 triangles.
static const uint32_t kNumTrianglesX = 3;
static const uint32_t kNumTrianglesY = 3;
static const uint32_t kNumTriangles = kNumTrianglesX * kNumTrianglesY;

// The Shader Binding Table layout (ShaderTable)
typedef struct ShaderRecord {
    IRShaderIdentifier shaderIdentifier;
} ShaderRecord;

typedef struct ShaderRecordTriangle {
    IRShaderIdentifier shaderIdentifier;
    int32_t numHoles; // local root signature data
} ShaderRecordTriangle;

typedef struct ShaderTable {
    ShaderRecord rayGenRecord;
    ShaderRecord missRecord;
    ShaderRecordTriangle hitGroupRecords[kNumTriangles] __attribute__((aligned(64))); // The start of the hitgroup table must be 64-byte aligned
} ShaderTable;

static NSString *MTLFunctionTypeAsString(MTLFunctionType functionType);

@implementation Renderer
{
    dispatch_semaphore_t _inFlightSemaphore;
    id<MTLDevice> _device;
    id<MTLCommandQueue> _commandQueue;

    // Ray tracing (compute)
    id<MTLComputePipelineState> _raytracingPipelineState;
    id<MTLVisibleFunctionTable> _visibleFunctionTable;
    id<MTLIntersectionFunctionTable> _intersectionFunctionTable;
    id<MTLBuffer> _shaderTable;
    id<MTLAccelerationStructure> _accelerationStructureTLAS;
    NSArray<id<MTLAccelerationStructure>> *_accelerationStructureBLAS;
    id<MTLTexture> _resultTexture;
    
    // Present (render)
    id<MTLRenderPipelineState> _presentPipeline;
    id<MTLSamplerState> _sampler;
    id<MTLBuffer> _vertexBuffer;
    id<MTLBuffer> _indexBuffer;
    NSUInteger _indexCount;
    
    // Shader conversion runtime
    id<MTLBuffer> _asSRVBuffer;
    id<MTLBuffer> _textureUAVBuffer;
    id<MTLBuffer> _textureSRVBuffer;
    id<MTLBuffer> _samplerSMPBuffer;
    id<MTLBuffer> _computeTLAB;
    id<MTLBuffer> _renderTLAB;
    
    // Shader conversion compiler
    DXRCompiler *_dxrCompiler;
}

- (nonnull instancetype)initWithMetalKitView:(nonnull MTKView *)view;
{
    self = [super init];
    if (self)
    {
        _device = view.device;
        _dxrCompiler = [[DXRCompiler alloc] initWithDevice:_device];
        _inFlightSemaphore = dispatch_semaphore_create(kMaxBuffersInFlight);
        _commandQueue = [_device newCommandQueue];
        
        [self _createRaytracingComputePipelineWithResources];
        [self _createPresentRenderPipelineWithResources];
        [self _createShaderConverterRuntimeBindings];
    }

    return self;
}

- (void)_createPresentRenderPipelineWithResources
{
    //
    // Create vertex data for present image
    //
    
    typedef struct VertexData
    {
        simd_float4 position;
        simd_float2 texcoord;
    } VertexData;
    
    {
        VertexData vertexData[] = {
            { .position = simd_make_float4(-1.0,  1.0, 0.0, 1.0), .texcoord = simd_make_float2(0.0, 0.0) },
            { .position = simd_make_float4(-1.0, -1.0, 0.0, 1.0), .texcoord = simd_make_float2(0.0, 1.0) },
            { .position = simd_make_float4( 1.0, -1.0, 0.0, 1.0), .texcoord = simd_make_float2(1.0, 1.0) },
            { .position = simd_make_float4( 1.0,  1.0, 0.0, 1.0), .texcoord = simd_make_float2(1.0, 0.0) },
        };
        uint16_t indexData[] = { 0, 1, 2, 2, 3, 0 };
        
        _vertexBuffer = [_device newBufferWithBytes:vertexData length:sizeof(vertexData) options:MTLResourceStorageModeShared];
        _indexBuffer = [_device newBufferWithBytes:indexData length:sizeof(indexData) options:MTLResourceStorageModeShared];
        _indexCount = sizeof(indexData) / sizeof(uint16_t);
    }
    
    //
    // Create sampler
    //
    
    MTLSamplerDescriptor *samplerDescriptor = [[MTLSamplerDescriptor alloc] init];
    {
        samplerDescriptor.supportArgumentBuffers = YES;
        samplerDescriptor.magFilter = MTLSamplerMinMagFilterLinear;
        samplerDescriptor.minFilter = MTLSamplerMinMagFilterLinear;
        samplerDescriptor.rAddressMode = MTLSamplerAddressModeRepeat;
        samplerDescriptor.sAddressMode = MTLSamplerAddressModeRepeat;
    }
    _sampler = [_device newSamplerStateWithDescriptor:samplerDescriptor];
    

    //
    // Create render pipeline for presenting raytraced image
    //

    // Load precompiled metallib created with metal-shaderconverter
    NSError *error;
    NSURL *URL = [[NSBundle mainBundle] URLForResource:@"present" withExtension:@"metallib"];
    if (URL == nil)
    {
        NSLog(@"Check build logs - metal-shaderconverter possibly not found");
        abort();
        return;
    }
    
    id<MTLLibrary> library = [_device newLibraryWithURL:URL error:&error];
    if (library == nil)
    {
        NSLog(@"Failed to create library: %@", error);
        abort();
        return;
    }
    library.label = @"Present";

    NSUInteger stageInAttributePosition = kIRStageInAttributeStartIndex;
    NSUInteger stageInAttributeTexcoord = kIRStageInAttributeStartIndex + 1;
    NSUInteger vertexBufferIndex = kIRVertexBufferBindPoint;
    
    MTLVertexDescriptor *vertexDescriptor = [[MTLVertexDescriptor alloc] init];
    vertexDescriptor.attributes[stageInAttributePosition].format = MTLVertexFormatFloat4;
    vertexDescriptor.attributes[stageInAttributePosition].offset = 0;
    vertexDescriptor.attributes[stageInAttributePosition].bufferIndex = vertexBufferIndex;
    vertexDescriptor.attributes[stageInAttributeTexcoord].format = MTLVertexFormatFloat2;
    vertexDescriptor.attributes[stageInAttributeTexcoord].offset = offsetof(VertexData, texcoord);
    vertexDescriptor.attributes[stageInAttributeTexcoord].bufferIndex = vertexBufferIndex;
    vertexDescriptor.layouts[vertexBufferIndex].stride = sizeof(VertexData);
    vertexDescriptor.layouts[vertexBufferIndex].stepRate = 1;
    vertexDescriptor.layouts[vertexBufferIndex].stepFunction = MTLVertexStepFunctionPerVertex;
    
    id<MTLFunction> vertexFunction = [library newFunctionWithName:@"MainVS"];
    id<MTLFunction> fragmentFunction = [library newFunctionWithName:@"MainFS"];
    
    assert(vertexFunction != nil);
    assert(fragmentFunction != nil);
    
    MTLRenderPipelineDescriptor *pipelineStateDescriptor = [[MTLRenderPipelineDescriptor alloc] init];
    pipelineStateDescriptor.vertexDescriptor = vertexDescriptor;
    pipelineStateDescriptor.vertexFunction = vertexFunction;
    pipelineStateDescriptor.fragmentFunction = fragmentFunction;
    pipelineStateDescriptor.colorAttachments[0].pixelFormat = MTLPixelFormatBGRA8Unorm_sRGB;
    pipelineStateDescriptor.colorAttachments[0].blendingEnabled = YES;
    pipelineStateDescriptor.colorAttachments[0].sourceRGBBlendFactor = MTLBlendFactorSourceAlpha;
    pipelineStateDescriptor.colorAttachments[0].sourceAlphaBlendFactor = MTLBlendFactorSourceAlpha;
    pipelineStateDescriptor.colorAttachments[0].destinationRGBBlendFactor = MTLBlendFactorOneMinusSourceAlpha;
    pipelineStateDescriptor.colorAttachments[0].destinationAlphaBlendFactor = MTLBlendFactorOneMinusSourceAlpha;
    
    id<MTLRenderPipelineState> pipelineState = [_device newRenderPipelineStateWithDescriptor:pipelineStateDescriptor error:&error];
    if (pipelineState == nil)
    {
        NSLog(@"Pipeline state creation failed: %@", error);
        abort();
        return;
    }
    
    _presentPipeline = pipelineState;
}

- (void)_createRaytracingComputePipelineWithResources
{
    // Load DXIL file
    NSError *error;
    NSURL *DXILURL = [[NSBundle mainBundle] URLForResource:@"rt_triangle" withExtension:@"dxil"];
    NSData *DXIL = [[NSData alloc] initWithContentsOfURL:DXILURL options:NSDataReadingMappedIfSafe error:&error];
    if (DXIL == nil)
    {
        NSLog(@"Failed to create DXIL data: %@", error);
        abort();
        return;
    }
    
    // Load root signatures with the metal shader converter JSON format
    NSURL *GRSURL = [[NSBundle mainBundle] URLForResource:@"root_sig_global" withExtension:@"json"];
    NSData *GRS = [[NSData alloc] initWithContentsOfURL:GRSURL options:NSDataReadingMappedIfSafe error:&error];
    NSDictionary *globalRootSignatureDescriptor = GRS ? [NSJSONSerialization JSONObjectWithData:GRS options:0 error:&error] : nil;
    if (globalRootSignatureDescriptor == nil)
    {
        NSLog(@"Failed to create global root signature data: %@", error);
        abort();
        return;
    }
    
    NSURL *LRSURL = [[NSBundle mainBundle] URLForResource:@"root_sig_local" withExtension:@"json"];
    NSData *LRS = [[NSData alloc] initWithContentsOfURL:LRSURL options:NSDataReadingMappedIfSafe error:&error];
    NSDictionary *localRootSignatureDescriptor = LRS ? [NSJSONSerialization JSONObjectWithData:LRS options:0 error:&error] : nil;
    if (localRootSignatureDescriptor == nil)
    {
        NSLog(@"Failed to create local root signature data: %@", error);
        abort();
        return;
    }
    
    // Describe the DXR pipeline options
    DXRPipelineOptions *options = [[DXRPipelineOptions alloc] init];
    options.maxRecursiveDepth = 3;
    
    // Describe the DXR raytracing pipeline
    DXRShader *rayGenerationShader = [[DXRShader alloc] initWithName:@"MainRayGen" DXIL:DXIL];
    DXRShader *missShader = [[DXRShader alloc] initWithName:@"MainMiss" DXIL:DXIL];
    DXRHitgroup *hitGroup = [[DXRHitgroup alloc] init]; {
        hitGroup.anyHitShader = [[DXRShader alloc] initWithName:@"TriangleAnyHit" DXIL:DXIL];
        hitGroup.closestHitShader = [[DXRShader alloc] initWithName:@"TriangleClosestHit" DXIL:DXIL];
    }

    // Describe the pipeline
    DXRPipelineDescriptor *dxrDescriptor = [[DXRPipelineDescriptor alloc] init];
    dxrDescriptor.rayGenerationShader = rayGenerationShader;
    dxrDescriptor.hitGroups = @[hitGroup];
    dxrDescriptor.missShaders = @[missShader];
    dxrDescriptor.options = options;
    dxrDescriptor.globalRootSignatureDescriptor = globalRootSignatureDescriptor;
    dxrDescriptor.localRootSignatureDescriptor = localRootSignatureDescriptor;
    
    // Now create a DXRPipelineState object converting input DXIL to Metal IR along the way
    DXRPipelineReflection *dxrPipelineReflection = nil;
    DXRPipelineState *dxrPipelineState = [_dxrCompiler newDXRPipelineWithDescriptor:dxrDescriptor reflection:&dxrPipelineReflection error:&error];
    if (dxrPipelineState == nil)
    {
        NSLog(@"Failed to create ray tracing pipeline: %@", error);
        abort();
        return;
    }
    
    // DXRPipeReflection will contain details about the conversion process as well as the metal compute pipeline reflection
    for (DXRShaderConversion *conversion in dxrPipelineReflection.conversions)
    {
        NSString *functionName = conversion.function.name;
        NSString *functionType = MTLFunctionTypeAsString(conversion.function.functionType);
        if (conversion.isSynthesized)
        {
            NSLog(@"Synthesized MTLFunction \'%@\' (%@)", functionName, functionType);
            continue;
        }

        NSString *shaderName = conversion.name;
        NSLog(@"Converted DXRShader \'%@\' to MTLFunction \'%@\' (%@)", shaderName, functionName, functionType);
    }
    
    _raytracingPipelineState = dxrPipelineState.computePipelineState;
    _visibleFunctionTable = dxrPipelineState.visibleFunctionTable;
    _intersectionFunctionTable = dxrPipelineState.intersectionFunctionTable;
    
    //
    // Create the Shader Binding Table for the scene
    //
    
    {
        // Now use the shader identifiers from the dxr pipeline state and construct the IRShaderIdentifier to assign to the shader binding table
        uint64_t rayGenShaderIdentifier = [dxrPipelineState shaderIdentifierForName:rayGenerationShader.uniqueName];
        uint64_t missShaderIdentifier = [dxrPipelineState shaderIdentifierForName:missShader.uniqueName];
        uint64_t anyHitShaderIdentifier = [dxrPipelineState shaderIdentifierForName:hitGroup.anyHitShader.uniqueName];
        uint64_t closestHitShaderIdentifier = [dxrPipelineState shaderIdentifierForName:hitGroup.closestHitShader.uniqueName];
        
        assert(rayGenShaderIdentifier != DXRShaderIdentifierNull);
        assert(missShaderIdentifier != DXRShaderIdentifierNull);
        assert(anyHitShaderIdentifier != DXRShaderIdentifierNull);
        assert(closestHitShaderIdentifier != DXRShaderIdentifierNull);
        
        id<MTLBuffer> shaderTableBuffer = [_device newBufferWithLength:sizeof(ShaderTable) options:MTLResourceStorageModeShared];
        shaderTableBuffer.label = @"Shader Binding Table";
        ShaderTable *shaderTable = (ShaderTable *)shaderTableBuffer.contents;
        
        // Ray gen shader record
        IRShaderIdentifierInit(&(shaderTable->rayGenRecord.shaderIdentifier), rayGenShaderIdentifier);
        
        // Miss shader record
        IRShaderIdentifierInit(&(shaderTable->missRecord.shaderIdentifier), missShaderIdentifier);
        
        // Hit group records
        for (int i = 0; i < kNumTriangles; ++i)
        {
            IRShaderIdentifierInitWithCustomIntersection(&(shaderTable->hitGroupRecords[i].shaderIdentifier), closestHitShaderIdentifier, anyHitShaderIdentifier);
            shaderTable->hitGroupRecords[i].numHoles = i + 1;
        }
        
        _shaderTable = shaderTableBuffer;
    }
    
    //
    // Create the acceleration structures
    //
    
    {
        assert(_device != nil);
        assert(_commandQueue != nil);
        
        // We need to determine the intersection shader identifier for our acceleration structure
        uint64_t triangleIntersectionShaderIdentifier = [dxrPipelineState shaderIdentifierForSynthesizedTriangleIntersection];
        
        // Geometry descriptor for triangle
        simd_float3 vertexData[] = {
            { -1.0, -1.0, -1.5 },
            { +1.0, -1.0, -1.5 },
            {  0.0, +1.0, -1.5 }
        };
        uint16_t indexData[] = { 0, 1, 2 };
        
        id<MTLBuffer> vertexBuffer = [_device newBufferWithBytes:vertexData length:sizeof(vertexData) options:MTLResourceStorageModeShared];
        id<MTLBuffer> indexBuffer = [_device newBufferWithBytes:indexData length:sizeof(indexData) options:MTLResourceStorageModeShared];
        
        MTLAccelerationStructureTriangleGeometryDescriptor *triGeoDescriptor = [[MTLAccelerationStructureTriangleGeometryDescriptor alloc] init];
        triGeoDescriptor.vertexBuffer = vertexBuffer;
        triGeoDescriptor.vertexStride = sizeof(simd_float3);
        triGeoDescriptor.indexBuffer = indexBuffer;
        triGeoDescriptor.indexType = MTLIndexTypeUInt16;
        triGeoDescriptor.triangleCount = 1;
        triGeoDescriptor.intersectionFunctionTableOffset = triangleIntersectionShaderIdentifier; // Defaults to 0 if _intersectionFunctionTable is nil.
        
        // First, build the primitive acceleration structures - This corresponds to a DXR bottom-level acceleration structures (BLAS)
        NSMutableArray<id<MTLAccelerationStructure>> *bottomLevelAS = [[NSMutableArray alloc] init];
        {
            MTLPrimitiveAccelerationStructureDescriptor *primitiveASDescriptor = [[MTLPrimitiveAccelerationStructureDescriptor alloc] init];
            primitiveASDescriptor.geometryDescriptors = @[triGeoDescriptor];
            
            MTLAccelerationStructureSizes primitiveSizes = [_device accelerationStructureSizesWithDescriptor:primitiveASDescriptor];
            id<MTLAccelerationStructure> primitiveAS = [_device newAccelerationStructureWithSize:primitiveSizes.accelerationStructureSize];
            id<MTLBuffer> scratchBuffer = [_device newBufferWithLength:primitiveSizes.buildScratchBufferSize options:MTLResourceStorageModePrivate];
            id<MTLCommandBuffer> commandBuffer = [_commandQueue commandBuffer];
            id<MTLAccelerationStructureCommandEncoder> encoder = [commandBuffer accelerationStructureCommandEncoder];
            [encoder buildAccelerationStructure:primitiveAS descriptor:primitiveASDescriptor scratchBuffer:scratchBuffer scratchBufferOffset:0];
            [encoder endEncoding];
            [commandBuffer commit];
            [commandBuffer waitUntilCompleted];
            
            [bottomLevelAS addObject:primitiveAS];
        }
        _accelerationStructureBLAS = [bottomLevelAS copy];
        
        // Define triangle instances
        MTLAccelerationStructureInstanceDescriptor triangleInstances[kNumTriangles] = {0};
        
        // MTLAccelerationStructureInstanceDescriptor.intersectionFunctionTableOffset depends on the intersection function compilation mode
        // If we have handles, pipeline was compiled with IRIntersectionFunctionCompilationIntersectionFunctionBufferFunction and
        // in this case, we must set it to the instance index
        const BOOL isIFB = (dxrPipelineState.intersectionFunctionBufferHandles != nil);
        const float posRange = 3.0f;
        uint32_t instanceCount = 0;
        for (int posX = 0; posX < kNumTrianglesX; ++posX)
        {
            for (int posY = 0; posY < kNumTrianglesY; ++posY)
            {
                // Transform each instance
                float x = simd_lerp(-posRange, +posRange, (posX + 0.5f) / kNumTrianglesX);
                float y = simd_lerp(+posRange, -posRange, (posY + 0.5f) / kNumTrianglesY);

                MTLPackedFloat4x3 transformationMatrix;
                transformationMatrix.columns[0] = MTLPackedFloat3Make(1, 0, 0);
                transformationMatrix.columns[1] = MTLPackedFloat3Make(0, 1, 0);
                transformationMatrix.columns[2] = MTLPackedFloat3Make(0, 0, 1);
                transformationMatrix.columns[3] = MTLPackedFloat3Make(x, y, -3);

                uint32_t instanceIndex = instanceCount++;
                MTLAccelerationStructureInstanceDescriptor *instance = &triangleInstances[instanceIndex];
                instance->accelerationStructureIndex = 0;
                instance->intersectionFunctionTableOffset = isIFB ? instanceIndex : 0;
                instance->mask = 0xFF;
                instance->options = MTLAccelerationStructureInstanceOptionDisableTriangleCulling;
                instance->transformationMatrix = transformationMatrix;
            }
        }
        
        // Next, build the instance acceleration structure - This corresponds to a DXR top-level acceleration structure (TLAS)
        id<MTLBuffer> instanceDescBuffer = [_device newBufferWithBytes:triangleInstances length:sizeof(triangleInstances) options:MTLResourceStorageModeShared];
        MTLInstanceAccelerationStructureDescriptor *instanceASDescriptor = [[MTLInstanceAccelerationStructureDescriptor alloc] init];
        instanceASDescriptor.instancedAccelerationStructures = _accelerationStructureBLAS;
        instanceASDescriptor.instanceCount = instanceCount;
        instanceASDescriptor.instanceDescriptorBuffer = instanceDescBuffer;
        
        MTLAccelerationStructureSizes instanceSizes = [_device accelerationStructureSizesWithDescriptor:instanceASDescriptor];
        id<MTLAccelerationStructure> instanceAS = [_device newAccelerationStructureWithSize:instanceSizes.accelerationStructureSize];
        {
            id<MTLBuffer> scratchBuffer = [_device newBufferWithLength:instanceSizes.buildScratchBufferSize options:MTLResourceStorageModePrivate];
            id<MTLCommandBuffer> commandBuffer = [_commandQueue commandBuffer];
            id<MTLAccelerationStructureCommandEncoder> encoder = [commandBuffer accelerationStructureCommandEncoder];
            [encoder buildAccelerationStructure:instanceAS descriptor:instanceASDescriptor scratchBuffer:scratchBuffer scratchBufferOffset:0];
            [encoder endEncoding];
            [commandBuffer commit];
            [commandBuffer waitUntilCompleted];
        }
        
        _accelerationStructureTLAS = instanceAS;
    }
    
    //
    // Finally, create the output result texture
    //
    
    {
        MTLTextureDescriptor *textureDescriptor = [[MTLTextureDescriptor alloc] init];
        textureDescriptor.width = kTextureWidth;
        textureDescriptor.height = kTextureHeight;
        textureDescriptor.depth = 1;
        textureDescriptor.pixelFormat = MTLPixelFormatBGRA8Unorm;
        textureDescriptor.textureType = MTLTextureType2D;
        textureDescriptor.storageMode = MTLStorageModePrivate;
        textureDescriptor.usage = (MTLResourceUsageRead | MTLResourceUsageWrite);
        
        id<MTLTexture> texture = [_device newTextureWithDescriptor:textureDescriptor];
        texture.label = @"Result Texture";
        _resultTexture = texture;
    }
}

- (void)_createShaderConverterRuntimeBindings
{
    // COMPUTE
    {
        assert(_accelerationStructureTLAS != nil);
        assert(_resultTexture != nil);
        
        // SRV buffer for acceleration structure:
        // This consists of a IRRaytracingAccelerationStructureGPUHeader for locating the acceleration structure on the GPU
        // and a list of all instance contribution indexes to help determine the closest hit.
        size_t asSRVBufferSize = sizeof(IRRaytracingAccelerationStructureGPUHeader) + (sizeof(uint32_t) * kNumTriangles);
        id<MTLBuffer> asSRVBuffer = [_device newBufferWithLength:asSRVBufferSize options:MTLResourceStorageModeShared];
        IRRaytracingAccelerationStructureGPUHeader *header = (IRRaytracingAccelerationStructureGPUHeader *)asSRVBuffer.contents;
        header->accelerationStructureID = _accelerationStructureTLAS.gpuResourceID._impl;
        header->addressOfInstanceContributions = asSRVBuffer.gpuAddress + sizeof(IRRaytracingAccelerationStructureGPUHeader);
        uint32_t *instanceContributions = (uint32_t *)(asSRVBuffer.contents + sizeof(IRRaytracingAccelerationStructureGPUHeader));
        for (uint32_t i = 0; i < kNumTriangles; ++i) { instanceContributions[i] = i; }
        _asSRVBuffer = asSRVBuffer;
        
        // UAV for writing to result texture
        IRDescriptorTableEntry uavEntry;
        IRDescriptorTableSetTexture(&uavEntry, _resultTexture, 0, 0);
        id<MTLBuffer> uavTableBuffer = [_device newBufferWithBytes:&uavEntry length:sizeof(uavEntry) options:MTLResourceStorageModeShared];
        uavTableBuffer.label = @"UAV (Texture)";
        _textureUAVBuffer = uavTableBuffer;
        
        assert(_asSRVBuffer != nil);
        assert(_textureUAVBuffer != nil);
        
        // Create top-level argument buffer for compute pipeline (ray tracing)
        uint64_t TLAB[2] = { _asSRVBuffer.gpuAddress, _textureUAVBuffer.gpuAddress };
        id<MTLBuffer> topLevelArgumentBuffer = [_device newBufferWithBytes:TLAB length:sizeof(TLAB) options:MTLResourceStorageModeShared];
        topLevelArgumentBuffer.label = @"TLAB (Compute)";
        _computeTLAB = topLevelArgumentBuffer;
    }
    
    // RENDER
    {
        assert(_resultTexture != nil);
        assert(_sampler != nil);
        
#if !DEBUG
        _textureSRVBuffer = _textureUAVBuffer;
#else
        // SRV for reading from result texture
        IRDescriptorTableEntry srvEntry;
        IRDescriptorTableSetTexture(&srvEntry, _resultTexture, 0, 0);
        id<MTLBuffer> srvTableBuffer = [_device newBufferWithBytes:&srvEntry length:sizeof(srvEntry) options:MTLResourceStorageModeShared];
        srvTableBuffer.label = @"SRV (Texture)";
        _textureSRVBuffer = srvTableBuffer;
#endif
        
        // SMP for sampling from result texture
        IRDescriptorTableEntry smpEntry;
        IRDescriptorTableSetSampler(&smpEntry, _sampler, 0);
        id<MTLBuffer> smpTableBuffer = [_device newBufferWithBytes:&smpEntry length:sizeof(smpEntry) options:MTLResourceStorageModeShared];
        smpTableBuffer.label = @"SMP (Sampler)";
        _samplerSMPBuffer = smpTableBuffer;
        
        // Create top-level argument buffer for render pipeline (present)
        uint64_t TLAB[2] = { _textureSRVBuffer.gpuAddress, _samplerSMPBuffer.gpuAddress };
        id<MTLBuffer> topLevelArgumentBuffer = [_device newBufferWithBytes:TLAB length:sizeof(TLAB) options:MTLResourceStorageModeShared];
        topLevelArgumentBuffer.label = @"TLAB (Render)";
        _renderTLAB = topLevelArgumentBuffer;
    }
}

- (void)drawInMTKView:(nonnull MTKView *)view
{
    dispatch_semaphore_wait(_inFlightSemaphore, DISPATCH_TIME_FOREVER);

    id <MTLCommandBuffer> commandBuffer = [_commandQueue commandBuffer];
    __block dispatch_semaphore_t block_sema = _inFlightSemaphore;
    [commandBuffer addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
         dispatch_semaphore_signal(block_sema);
     }];
    
    // Ray trace
    {
        // Dispatch compute threads to match the output/result texture
        NSUInteger w = _raytracingPipelineState.threadExecutionWidth;
        NSUInteger h = _raytracingPipelineState.maxTotalThreadsPerThreadgroup / w;
        MTLSize threadsPerThreagroup = MTLSizeMake(w, h, 1);
        MTLSize threads = MTLSizeMake(_resultTexture.width, _resultTexture.height, 1);
        
        // Construct the dispatch rays argument for the shader binding table
        IRDispatchRaysArgument dispatchRaysArg;
        {
            ShaderTable *shaderTable = (ShaderTable *)_shaderTable.contents;
            uint64_t shaderTableGPUAddress = _shaderTable.gpuAddress;
            
            dispatchRaysArg.VisibleFunctionTable = _visibleFunctionTable.gpuResourceID;
            dispatchRaysArg.IntersectionFunctionTable = _intersectionFunctionTable.gpuResourceID;
            dispatchRaysArg.GRS = _computeTLAB.gpuAddress;
            
            IRDispatchRaysDescriptor *dispatchRaysDesc = &dispatchRaysArg.DispatchRaysDesc;
            dispatchRaysDesc->RayGenerationShaderRecord.StartAddress = shaderTableGPUAddress + offsetof(ShaderTable, rayGenRecord);
            dispatchRaysDesc->RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderRecord);
            dispatchRaysDesc->HitGroupTable.StartAddress = shaderTableGPUAddress + offsetof(ShaderTable, hitGroupRecords);
            dispatchRaysDesc->HitGroupTable.SizeInBytes = sizeof(shaderTable->hitGroupRecords);
            dispatchRaysDesc->HitGroupTable.StrideInBytes = sizeof(shaderTable->hitGroupRecords[0]);
            dispatchRaysDesc->MissShaderTable.StartAddress = shaderTableGPUAddress + offsetof(ShaderTable, missRecord);
            dispatchRaysDesc->MissShaderTable.SizeInBytes = sizeof(shaderTable->missRecord);
            dispatchRaysDesc->MissShaderTable.StrideInBytes = sizeof(shaderTable->missRecord);
            dispatchRaysDesc->CallableShaderTable.StartAddress = 0;
            dispatchRaysDesc->CallableShaderTable.SizeInBytes = 0;
            dispatchRaysDesc->CallableShaderTable.StrideInBytes = 0;
            dispatchRaysDesc->Width = (uint)threads.width;
            dispatchRaysDesc->Height = (uint)threads.height;
            dispatchRaysDesc->Depth = (uint)threads.depth;
        }
        
        id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
        [computeEncoder setComputePipelineState:_raytracingPipelineState];
        [computeEncoder useResource:_resultTexture usage:MTLResourceUsageWrite];
        [computeEncoder useResource:_textureUAVBuffer usage:MTLResourceUsageRead];
        [computeEncoder useResource:_asSRVBuffer usage:MTLResourceUsageRead];
        [computeEncoder useResource:_computeTLAB usage:MTLResourceUsageRead];
        [computeEncoder useResource:_accelerationStructureTLAS usage:MTLResourceUsageRead];
        [computeEncoder useResource:_shaderTable usage:MTLResourceUsageRead];
        [computeEncoder useResource:_visibleFunctionTable usage:MTLResourceUsageRead];
        if (_intersectionFunctionTable) {
            [computeEncoder useResource:_intersectionFunctionTable usage:MTLResourceUsageRead];
        }
        [computeEncoder setBytes:&dispatchRaysArg length:sizeof(dispatchRaysArg) atIndex:kIRRayDispatchArgumentsBindPoint];
        [computeEncoder dispatchThreads:threads threadsPerThreadgroup:threadsPerThreagroup];
        [computeEncoder endEncoding];
    }

    // Present result
    {
        MTLRenderPassDescriptor *renderPassDescriptor = view.currentRenderPassDescriptor;
        id<MTLRenderCommandEncoder> renderEncoder = [commandBuffer renderCommandEncoderWithDescriptor:renderPassDescriptor];
        [renderEncoder setRenderPipelineState:_presentPipeline];
        [renderEncoder setVertexBuffer:_vertexBuffer offset:0 atIndex:kIRVertexBufferBindPoint];
        [renderEncoder setFragmentBuffer:_renderTLAB offset:0 atIndex:kIRArgumentBufferBindPoint];
        [renderEncoder useResource:_resultTexture usage:MTLResourceUsageRead stages:MTLRenderStageFragment];
        [renderEncoder useResource:_textureSRVBuffer usage:MTLResourceUsageRead stages:MTLRenderStageFragment];
        [renderEncoder useResource:_samplerSMPBuffer usage:MTLResourceUsageRead stages:MTLRenderStageFragment];
        [renderEncoder drawIndexedPrimitives:MTLPrimitiveTypeTriangle
                                  indexCount:_indexCount
                                   indexType:MTLIndexTypeUInt16
                                 indexBuffer:_indexBuffer
                           indexBufferOffset:0];
        [renderEncoder endEncoding];
    }

    [commandBuffer presentDrawable:view.currentDrawable];
    [commandBuffer commit];
}

- (void)mtkView:(nonnull MTKView *)view drawableSizeWillChange:(CGSize)size
{
}

@end

static NSString *MTLFunctionTypeAsString(MTLFunctionType functionType)
{
    switch (functionType)
    {
        case MTLFunctionTypeVertex:         { return @"Vertex"; }
        case MTLFunctionTypeFragment:       { return @"Fragment"; }
        case MTLFunctionTypeKernel:         { return @"Kernel"; }
        case MTLFunctionTypeVisible:        { return @"Visible"; }
        case MTLFunctionTypeIntersection:   { return @"Intersection"; }
        case MTLFunctionTypeMesh:           { return @"Mesh"; }
        case MTLFunctionTypeObject:         { return @"Object"; }
        default:                            { return @"Unknown"; }
    }
}

