diff --git a/lib/API/DX/Device.cpp b/lib/API/DX/Device.cpp index 64282d7..b31e066 100644 --- a/lib/API/DX/Device.cpp +++ b/lib/API/DX/Device.cpp @@ -74,7 +74,7 @@ class DXDevice : public offloadtest::Device { CComPtr Device; Capabilities Caps; - struct UAVResourceSet { + struct ResourceSet { CComPtr Upload; CComPtr Buffer; CComPtr Readback; @@ -89,7 +89,7 @@ class DXDevice : public offloadtest::Device { CComPtr CmdList; CComPtr Fence; HANDLE Event; - llvm::SmallVector Resources; + llvm::SmallVector Resources; }; public: @@ -288,8 +288,79 @@ class DXDevice : public offloadtest::Device { llvm::Error createSRV(Resource &R, InvocationState &IS, const uint32_t HeapIdx) { - return llvm::createStringError(std::errc::not_supported, - "DXDevice::createSRV not supported."); + llvm::outs() << "Creating SRV: { Size = " << R.Size << ", Register = t" + << R.DXBinding.Register << ", Space = " << R.DXBinding.Space + << " }\n"; + CComPtr Buffer; + CComPtr UploadBuffer; + + const D3D12_HEAP_PROPERTIES HeapProp = + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + const D3D12_RESOURCE_DESC ResDesc = { + D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + R.Size, + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + if (auto Err = HR::toError(Device->CreateCommittedResource( + &HeapProp, D3D12_HEAP_FLAG_NONE, &ResDesc, + D3D12_RESOURCE_STATE_COMMON, nullptr, + IID_PPV_ARGS(&Buffer)), + "Failed to create committed resource (buffer).")) + return Err; + + const D3D12_HEAP_PROPERTIES UploadHeapProp = + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); + const D3D12_RESOURCE_DESC UploadResDesc = + CD3DX12_RESOURCE_DESC::Buffer(R.Size); + + if (auto Err = + HR::toError(Device->CreateCommittedResource( + &UploadHeapProp, D3D12_HEAP_FLAG_NONE, + &UploadResDesc, D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, IID_PPV_ARGS(&UploadBuffer)), + "Failed to create committed resource (upload buffer).")) + return Err; + + // Initialize the SRV data + void *ResDataPtr = nullptr; + if (auto Err = HR::toError(UploadBuffer->Map(0, nullptr, &ResDataPtr), + "Failed to acquire UAV data pointer.")) + return Err; + memcpy(ResDataPtr, R.Data.get(), R.Size); + UploadBuffer->Unmap(0, nullptr); + + addResourceUploadCommands(R, IS, Buffer, UploadBuffer); + + const uint32_t EltSize = R.getElementSize(); + const uint32_t NumElts = R.Size / EltSize; + DXGI_FORMAT EltFormat = + R.isRaw() ? DXGI_FORMAT_UNKNOWN : getDXFormat(R.Format, R.Channels); + const D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = { + EltFormat, + D3D12_SRV_DIMENSION_BUFFER, + D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING, + {D3D12_BUFFER_SRV{0, NumElts, static_cast(R.RawSize), + D3D12_BUFFER_SRV_FLAG_NONE}}}; + + llvm::outs() << "SRV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize + << " NumElts = " << NumElts << "\n"; + D3D12_CPU_DESCRIPTOR_HANDLE SRVHandle = + IS.DescHeap->GetCPUDescriptorHandleForHeapStart(); + SRVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize( + D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); + Device->CreateShaderResourceView(Buffer, &SRVDesc, SRVHandle); + + ResourceSet Resources = {UploadBuffer, Buffer, nullptr}; + IS.Resources.push_back(Resources); + + return llvm::Error::success(); } llvm::Error createUAV(Resource &R, InvocationState &IS, @@ -385,7 +456,7 @@ class DXDevice : public offloadtest::Device { D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); Device->CreateUnorderedAccessView(Buffer, nullptr, &UAVDesc, UAVHandle); - UAVResourceSet Resources = {UploadBuffer, Buffer, ReadBackBuffer}; + ResourceSet Resources = {UploadBuffer, Buffer, ReadBackBuffer}; IS.Resources.push_back(Resources); return llvm::Error::success(); @@ -519,11 +590,12 @@ class DXDevice : public offloadtest::Device { IS.CmdList->Dispatch(P.DispatchSize[0], P.DispatchSize[1], P.DispatchSize[2]); - for (auto &Out : IS.Resources) { - addReadbackBeginBarrier(IS, Out.Buffer); - IS.CmdList->CopyResource(Out.Readback, Out.Buffer); - addReadbackEndBarrier(IS, Out.Buffer); - } + for (auto &Out : IS.Resources) + if (Out.Readback != nullptr) { + addReadbackBeginBarrier(IS, Out.Buffer); + IS.CmdList->CopyResource(Out.Readback, Out.Buffer); + addReadbackEndBarrier(IS, Out.Buffer); + } } llvm::Error readBack(Pipeline &P, InvocationState &IS) { diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 1e54394..a485712 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -109,7 +109,8 @@ class MTLDevice : public offloadtest::Device { if (R.isRaw()) { MTL::Buffer *Buf = - Device->newBuffer(R.Data.get(), R.Size, MTL::StorageModeManaged); + Device->newBuffer(R.Data.get(), R.Size, + MTL::ResourceStorageModeManaged); IRBufferView View = {}; View.buffer = Buf; View.bufferSize = R.Size; @@ -137,8 +138,33 @@ class MTLDevice : public offloadtest::Device { llvm::Error createSRV(Resource &R, InvocationState &IS, const uint32_t HeapIdx) { - return llvm::createStringError(std::errc::not_supported, - "MTLDevice::createSRV not supported."); + auto *TablePtr = (IRDescriptorTableEntry *)IS.ArgBuffer->contents(); + + if (R.isRaw()) { + MTL::Buffer *Buf = Device->newBuffer(R.Data.get(), R.Size, + MTL::ResourceStorageModeManaged); + IRBufferView View = {}; + View.buffer = Buf; + View.bufferSize = R.Size; + + IRDescriptorTableSetBufferView(&TablePtr[HeapIdx], &View); + IS.Buffers.push_back(Buf); + } else { + uint64_t Width = R.Size / R.getElementSize(); + MTL::TextureDescriptor *Desc = + MTL::TextureDescriptor::textureBufferDescriptor( + getMTLFormat(R.Format, R.Channels), Width, + MTL::ResourceStorageModeManaged, MTL::ResourceUsageRead); + + MTL::Texture *NewTex = Device->newTexture(Desc); + NewTex->replaceRegion(MTL::Region(0, 0, Width, 1), 0, R.Data.get(), 0); + + IS.Textures.push_back(NewTex); + + IRDescriptorTableSetTexture(&TablePtr[HeapIdx], NewTex, 0, 0); + } + + return llvm::Error::success(); } llvm::Error createCBV(Resource &R, InvocationState &IS, @@ -223,9 +249,16 @@ class MTLDevice : public offloadtest::Device { break; } case DataAccess::ReadOnly: + // Nothing to copy back, just increment the appropriate index. + if (R.isRaw()) + ++BufferIndex; + else + ++TextureIndex; + break; case DataAccess::Constant: - return llvm::createStringError(std::errc::not_supported, - "MTLDevice only supports ReadWrite."); + return llvm::createStringError( + std::errc::not_supported, "MTLDevice does not support Constant."); + break; } } } diff --git a/test/Basic/StructuredBuffer.test b/test/Basic/StructuredBuffer.test index c4f5c35..9fdfb8a 100644 --- a/test/Basic/StructuredBuffer.test +++ b/test/Basic/StructuredBuffer.test @@ -8,8 +8,8 @@ struct S2 { int4 i; }; -RWStructuredBuffer In : register(u0); -RWStructuredBuffer Out : register(u1); +StructuredBuffer In : register(t0); +RWStructuredBuffer Out : register(u0); [numthreads(1,1,1)] void main(uint GI : SV_GroupIndex) { @@ -21,7 +21,7 @@ void main(uint GI : SV_GroupIndex) { DispatchSize: [1, 1, 1] DescriptorSets: - Resources: - - Access: ReadWrite + - Access: ReadOnly Format: Hex32 RawSize: 32 Data: [0x00000000, 0x00000001, 0x00000002, 0x00000003, @@ -34,7 +34,7 @@ DescriptorSets: RawSize: 32 ZeroInitSize: 32 DirectXBinding: - Register: 1 + Register: 0 Space: 0 ... #--- end @@ -50,6 +50,19 @@ DescriptorSets: # RUN: %if Metal %{ metal-shaderconverter %t.dxil -o=%t.metallib %} # RUN: %if Metal %{ %offloader %t/pipeline.yaml %t.metallib | FileCheck %s %} +# CHECK: Access: ReadOnly +# CHECK: Data: [ +# CHECK: 0x0, +# CHECK: 0x1, +# CHECK: 0x2, +# CHECK: 0x3, +# CHECK: 0x0, +# CHECK: 0x3F800000, +# CHECK: 0x40000000, +# CHECK: 0x40400000 +# CHECK: ] + +# CHECK: Access: ReadWrite # CHECK: Data: [ # CHECK: 0x0, # CHECK: 0x3F800000,