Skip to content

Commit

Permalink
Support for SRVs
Browse files Browse the repository at this point in the history
This is mostly cargo-culting the logic from UAVs and adjusting for the
simpler read-only case of SRVs.

TODO: vulkan support
  • Loading branch information
bogner committed Jan 21, 2025
1 parent 0b8a107 commit 2fad19b
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 19 deletions.
92 changes: 82 additions & 10 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DXDevice : public offloadtest::Device {
CComPtr<ID3D12Device> Device;
Capabilities Caps;

struct UAVResourceSet {
struct ResourceSet {
CComPtr<ID3D12Resource> Upload;
CComPtr<ID3D12Resource> Buffer;
CComPtr<ID3D12Resource> Readback;
Expand All @@ -89,7 +89,7 @@ class DXDevice : public offloadtest::Device {
CComPtr<ID3D12GraphicsCommandList> CmdList;
CComPtr<ID3D12Fence> Fence;
HANDLE Event;
llvm::SmallVector<UAVResourceSet> Resources;
llvm::SmallVector<ResourceSet> Resources;
};

public:
Expand Down Expand Up @@ -288,8 +288,79 @@ class DXDevice : public offloadtest::Device {

llvm::Error createSRV(Resource &R, InvocationState &IS,
const uint32_t HeapIdx) {
return llvm::createStringError(std::errc::not_supported,
"DXDevice::createSRV not supported.");
llvm::outs() << "Creating SRV: { Size = " << R.Size << ", Register = t"
<< R.DXBinding.Register << ", Space = " << R.DXBinding.Space
<< " }\n";
CComPtr<ID3D12Resource> Buffer;
CComPtr<ID3D12Resource> UploadBuffer;

const D3D12_HEAP_PROPERTIES HeapProp =
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
const D3D12_RESOURCE_DESC ResDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
R.Size,
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{1, 0},
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS};

if (auto Err = HR::toError(Device->CreateCommittedResource(
&HeapProp, D3D12_HEAP_FLAG_NONE, &ResDesc,
D3D12_RESOURCE_STATE_COMMON, nullptr,
IID_PPV_ARGS(&Buffer)),
"Failed to create committed resource (buffer)."))
return Err;

const D3D12_HEAP_PROPERTIES UploadHeapProp =
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
const D3D12_RESOURCE_DESC UploadResDesc =
CD3DX12_RESOURCE_DESC::Buffer(R.Size);

if (auto Err =
HR::toError(Device->CreateCommittedResource(
&UploadHeapProp, D3D12_HEAP_FLAG_NONE,
&UploadResDesc, D3D12_RESOURCE_STATE_GENERIC_READ,
nullptr, IID_PPV_ARGS(&UploadBuffer)),
"Failed to create committed resource (upload buffer)."))
return Err;

// Initialize the SRV data
void *ResDataPtr = nullptr;
if (auto Err = HR::toError(UploadBuffer->Map(0, nullptr, &ResDataPtr),
"Failed to acquire UAV data pointer."))
return Err;
memcpy(ResDataPtr, R.Data.get(), R.Size);
UploadBuffer->Unmap(0, nullptr);

addResourceUploadCommands(R, IS, Buffer, UploadBuffer);

const uint32_t EltSize = R.getElementSize();
const uint32_t NumElts = R.Size / EltSize;
DXGI_FORMAT EltFormat =
R.isRaw() ? DXGI_FORMAT_UNKNOWN : getDXFormat(R.Format, R.Channels);
const D3D12_SHADER_RESOURCE_VIEW_DESC SRVDesc = {
EltFormat,
D3D12_SRV_DIMENSION_BUFFER,
D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
{D3D12_BUFFER_SRV{0, NumElts, static_cast<uint32_t>(R.RawSize),
D3D12_BUFFER_SRV_FLAG_NONE}}};

llvm::outs() << "SRV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
<< " NumElts = " << NumElts << "\n";
D3D12_CPU_DESCRIPTOR_HANDLE SRVHandle =
IS.DescHeap->GetCPUDescriptorHandleForHeapStart();
SRVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize(
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
Device->CreateShaderResourceView(Buffer, &SRVDesc, SRVHandle);

ResourceSet Resources = {UploadBuffer, Buffer, nullptr};
IS.Resources.push_back(Resources);

return llvm::Error::success();
}

llvm::Error createUAV(Resource &R, InvocationState &IS,
Expand Down Expand Up @@ -385,7 +456,7 @@ class DXDevice : public offloadtest::Device {
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
Device->CreateUnorderedAccessView(Buffer, nullptr, &UAVDesc, UAVHandle);

UAVResourceSet Resources = {UploadBuffer, Buffer, ReadBackBuffer};
ResourceSet Resources = {UploadBuffer, Buffer, ReadBackBuffer};
IS.Resources.push_back(Resources);

return llvm::Error::success();
Expand Down Expand Up @@ -519,11 +590,12 @@ class DXDevice : public offloadtest::Device {
IS.CmdList->Dispatch(P.DispatchSize[0], P.DispatchSize[1],
P.DispatchSize[2]);

for (auto &Out : IS.Resources) {
addReadbackBeginBarrier(IS, Out.Buffer);
IS.CmdList->CopyResource(Out.Readback, Out.Buffer);
addReadbackEndBarrier(IS, Out.Buffer);
}
for (auto &Out : IS.Resources)
if (Out.Readback != nullptr) {
addReadbackBeginBarrier(IS, Out.Buffer);
IS.CmdList->CopyResource(Out.Readback, Out.Buffer);
addReadbackEndBarrier(IS, Out.Buffer);
}
}

llvm::Error readBack(Pipeline &P, InvocationState &IS) {
Expand Down
43 changes: 38 additions & 5 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class MTLDevice : public offloadtest::Device {

if (R.isRaw()) {
MTL::Buffer *Buf =
Device->newBuffer(R.Data.get(), R.Size, MTL::StorageModeManaged);
Device->newBuffer(R.Data.get(), R.Size,
MTL::ResourceStorageModeManaged);
IRBufferView View = {};
View.buffer = Buf;
View.bufferSize = R.Size;
Expand Down Expand Up @@ -137,8 +138,33 @@ class MTLDevice : public offloadtest::Device {

llvm::Error createSRV(Resource &R, InvocationState &IS,
const uint32_t HeapIdx) {
return llvm::createStringError(std::errc::not_supported,
"MTLDevice::createSRV not supported.");
auto *TablePtr = (IRDescriptorTableEntry *)IS.ArgBuffer->contents();

if (R.isRaw()) {
MTL::Buffer *Buf = Device->newBuffer(R.Data.get(), R.Size,
MTL::ResourceStorageModeManaged);
IRBufferView View = {};
View.buffer = Buf;
View.bufferSize = R.Size;

IRDescriptorTableSetBufferView(&TablePtr[HeapIdx], &View);
IS.Buffers.push_back(Buf);
} else {
uint64_t Width = R.Size / R.getElementSize();
MTL::TextureDescriptor *Desc =
MTL::TextureDescriptor::textureBufferDescriptor(
getMTLFormat(R.Format, R.Channels), Width,
MTL::ResourceStorageModeManaged, MTL::ResourceUsageRead);

MTL::Texture *NewTex = Device->newTexture(Desc);
NewTex->replaceRegion(MTL::Region(0, 0, Width, 1), 0, R.Data.get(), 0);

IS.Textures.push_back(NewTex);

IRDescriptorTableSetTexture(&TablePtr[HeapIdx], NewTex, 0, 0);
}

return llvm::Error::success();
}

llvm::Error createCBV(Resource &R, InvocationState &IS,
Expand Down Expand Up @@ -223,9 +249,16 @@ class MTLDevice : public offloadtest::Device {
break;
}
case DataAccess::ReadOnly:
// Nothing to copy back, just increment the appropriate index.
if (R.isRaw())
++BufferIndex;
else
++TextureIndex;
break;
case DataAccess::Constant:
return llvm::createStringError(std::errc::not_supported,
"MTLDevice only supports ReadWrite.");
return llvm::createStringError(
std::errc::not_supported, "MTLDevice does not support Constant.");
break;
}
}
}
Expand Down
21 changes: 17 additions & 4 deletions test/Basic/StructuredBuffer.test
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ struct S2 {
int4 i;
};

RWStructuredBuffer<S1> In : register(u0);
RWStructuredBuffer<S2> Out : register(u1);
StructuredBuffer<S1> In : register(t0);
RWStructuredBuffer<S2> Out : register(u0);

[numthreads(1,1,1)]
void main(uint GI : SV_GroupIndex) {
Expand All @@ -21,7 +21,7 @@ void main(uint GI : SV_GroupIndex) {
DispatchSize: [1, 1, 1]
DescriptorSets:
- Resources:
- Access: ReadWrite
- Access: ReadOnly
Format: Hex32
RawSize: 32
Data: [0x00000000, 0x00000001, 0x00000002, 0x00000003,
Expand All @@ -34,7 +34,7 @@ DescriptorSets:
RawSize: 32
ZeroInitSize: 32
DirectXBinding:
Register: 1
Register: 0
Space: 0
...
#--- end
Expand All @@ -50,6 +50,19 @@ DescriptorSets:
# RUN: %if Metal %{ metal-shaderconverter %t.dxil -o=%t.metallib %}
# RUN: %if Metal %{ %offloader %t/pipeline.yaml %t.metallib | FileCheck %s %}

# CHECK: Access: ReadOnly
# CHECK: Data: [
# CHECK: 0x0,
# CHECK: 0x1,
# CHECK: 0x2,
# CHECK: 0x3,
# CHECK: 0x0,
# CHECK: 0x3F800000,
# CHECK: 0x40000000,
# CHECK: 0x40400000
# CHECK: ]

# CHECK: Access: ReadWrite
# CHECK: Data: [
# CHECK: 0x0,
# CHECK: 0x3F800000,
Expand Down

0 comments on commit 2fad19b

Please sign in to comment.