Skip to content

Commit fb6c188

Browse files
authored
[SYCL] Optimize context passing in the ProgramManager (#17835)
1 parent 4f3cbba commit fb6c188

File tree

4 files changed

+59
-58
lines changed

4 files changed

+59
-58
lines changed

sycl/source/detail/helpers.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ retrieveKernelBinary(const QueueImplPtr &Queue, const char *KernelName,
6161
return {nullptr, nullptr};
6262
}
6363
auto ContextImpl = Queue->getContextImplPtr();
64-
auto Context = detail::createSyclObjFromImpl<context>(ContextImpl);
6564
auto DeviceImpl = Queue->getDeviceImplPtr();
6665
auto Device = detail::createSyclObjFromImpl<device>(DeviceImpl);
6766
ur_program_handle_t Program =
6867
detail::ProgramManager::getInstance().createURProgram(
69-
**DeviceImage, Context, {std::move(Device)});
68+
**DeviceImage, ContextImpl, {std::move(Device)});
7069
return {*DeviceImage, Program};
7170
}
7271

@@ -85,13 +84,12 @@ retrieveKernelBinary(const QueueImplPtr &Queue, const char *KernelName,
8584
Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref();
8685
} else {
8786
auto ContextImpl = Queue->getContextImplPtr();
88-
auto Context = detail::createSyclObjFromImpl<context>(ContextImpl);
8987
auto DeviceImpl = Queue->getDeviceImplPtr();
9088
auto Device = detail::createSyclObjFromImpl<device>(DeviceImpl);
9189
DeviceImage = &detail::ProgramManager::getInstance().getDeviceImage(
92-
KernelName, Context, Device);
90+
KernelName, ContextImpl, Device);
9391
Program = detail::ProgramManager::getInstance().createURProgram(
94-
*DeviceImage, Context, {std::move(Device)});
92+
*DeviceImage, ContextImpl, {std::move(Device)});
9593
}
9694
return {DeviceImage, Program};
9795
}

sycl/source/detail/memory_manager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ getOrBuildProgramForDeviceGlobal(QueueImplPtr Queue,
12241224
auto Context = createSyclObjFromImpl<context>(ContextImpl);
12251225
ProgramManager &PM = ProgramManager::getInstance();
12261226
RTDeviceBinaryImage &Img =
1227-
PM.getDeviceImage(DeviceGlobalEntry->MImages, Context, Device);
1227+
PM.getDeviceImage(DeviceGlobalEntry->MImages, ContextImpl, Device);
12281228
device_image_plain DeviceImage =
12291229
PM.getDeviceImageFromBinaryImage(&Img, Context, Device);
12301230
device_image_plain BuiltImage =

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr &Context,
115115
}
116116

117117
// TODO replace this with a new UR API function
118-
static bool isDeviceBinaryTypeSupported(const context &C,
118+
static bool isDeviceBinaryTypeSupported(const ContextImplPtr &ContextImpl,
119119
ur::DeviceBinaryType Format) {
120120
// All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
121121
if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
122122
return true;
123123

124-
const backend ContextBackend = detail::getSyclObjImpl(C)->getBackend();
124+
const backend ContextBackend = ContextImpl->getBackend();
125125

126126
// The CUDA backend cannot use SPIR-V
127127
if (ContextBackend == backend::ext_oneapi_cuda)
128128
return false;
129129

130-
std::vector<device> Devices = C.get_devices();
130+
const std::vector<device> &Devices = ContextImpl->getDevices();
131131

132132
// Program type is SPIR-V, so we need a device compiler to do JIT.
133133
for (const device &D : Devices) {
@@ -137,7 +137,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
137137

138138
// OpenCL 2.1 and greater require clCreateProgramWithIL
139139
if (ContextBackend == backend::opencl) {
140-
std::string ver = C.get_platform().get_info<info::platform::version>();
140+
std::string ver = ContextImpl->get_info<info::context::platform>()
141+
.get_info<info::platform::version>();
141142
if (ver.find("OpenCL 1.0") == std::string::npos &&
142143
ver.find("OpenCL 1.1") == std::string::npos &&
143144
ver.find("OpenCL 1.2") == std::string::npos &&
@@ -187,16 +188,15 @@ static bool isDeviceBinaryTypeSupported(const context &C,
187188

188189
ur_program_handle_t
189190
ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
190-
const context &Context,
191+
const ContextImplPtr &ContextImpl,
191192
const std::vector<device> &Devices) {
192193
if constexpr (DbgProgMgr > 0) {
193194
std::vector<ur_device_handle_t> URDevices;
194195
std::transform(
195196
Devices.begin(), Devices.end(), std::back_inserter(URDevices),
196197
[](const device &Dev) { return getSyclObjImpl(Dev)->getHandleRef(); });
197198
std::cerr << ">>> ProgramManager::createPIProgram(" << &Img << ", "
198-
<< getSyclObjImpl(Context).get() << ", " << VecToString(URDevices)
199-
<< ")\n";
199+
<< ContextImpl.get() << ", " << VecToString(URDevices) << ")\n";
200200
}
201201
const sycl_device_binary_struct &RawImg = Img.getRawData();
202202

@@ -224,7 +224,7 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
224224
// sycl::detail::pi::PiDeviceBinaryType Format = Img->Format;
225225
// assert(Format != SYCL_DEVICE_BINARY_TYPE_NONE && "Image format not set");
226226

227-
if (!isDeviceBinaryTypeSupported(Context, Format))
227+
if (!isDeviceBinaryTypeSupported(ContextImpl, Format))
228228
throw sycl::exception(
229229
sycl::errc::feature_not_supported,
230230
"SPIR-V online compilation is not supported in this context");
@@ -233,23 +233,22 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
233233
const auto &ProgMetadata = Img.getProgramMetadataUR();
234234

235235
// Load the image
236-
const ContextImplPtr &Ctx = getSyclObjImpl(Context);
237236
std::vector<const uint8_t *> Binaries(
238237
Devices.size(), const_cast<uint8_t *>(RawImg.BinaryStart));
239238
std::vector<size_t> Lengths(Devices.size(), ImgSize);
240239
ur_program_handle_t Res =
241240
Format == SYCL_DEVICE_BINARY_TYPE_SPIRV
242-
? createSpirvProgram(Ctx, RawImg.BinaryStart, ImgSize)
243-
: createBinaryProgram(Ctx, Devices, Binaries.data(), Lengths.data(),
244-
ProgMetadata);
241+
? createSpirvProgram(ContextImpl, RawImg.BinaryStart, ImgSize)
242+
: createBinaryProgram(ContextImpl, Devices, Binaries.data(),
243+
Lengths.data(), ProgMetadata);
245244

246245
{
247246
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
248247
// associate the UR program with the image it was created for
249-
NativePrograms.insert({Res, {Ctx, &Img}});
248+
NativePrograms.insert({Res, {ContextImpl, &Img}});
250249
}
251250

252-
Ctx->addDeviceGlobalInitializer(Res, Devices, &Img);
251+
ContextImpl->addDeviceGlobalInitializer(Res, Devices, &Img);
253252

254253
if constexpr (DbgProgMgr > 1)
255254
std::cerr << "created program: " << Res
@@ -518,7 +517,7 @@ static void applyOptionsFromEnvironment(std::string &CompileOpts,
518517
std::pair<ur_program_handle_t, bool> ProgramManager::getOrCreateURProgram(
519518
const RTDeviceBinaryImage &MainImg,
520519
const std::vector<const RTDeviceBinaryImage *> &AllImages,
521-
const context &Context, const std::vector<device> &Devices,
520+
const ContextImplPtr &ContextImpl, const std::vector<device> &Devices,
522521
const std::string &CompileAndLinkOptions, SerializedObj SpecConsts) {
523522
ur_program_handle_t NativePrg;
524523

@@ -540,11 +539,10 @@ std::pair<ur_program_handle_t, bool> ProgramManager::getOrCreateURProgram(
540539
ProgMetadataVector.insert(ProgMetadataVector.end(),
541540
ImgProgMetadata.begin(), ImgProgMetadata.end());
542541
}
543-
NativePrg =
544-
createBinaryProgram(getSyclObjImpl(Context), Devices, BinPtrs.data(),
545-
Lengths.data(), ProgMetadataVector);
542+
NativePrg = createBinaryProgram(ContextImpl, Devices, BinPtrs.data(),
543+
Lengths.data(), ProgMetadataVector);
546544
} else {
547-
NativePrg = createURProgram(MainImg, Context, Devices);
545+
NativePrg = createURProgram(MainImg, ContextImpl, Devices);
548546
}
549547
return {NativePrg, Binaries.size()};
550548
}
@@ -857,10 +855,10 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
857855
sizeof(ur_bool_t), &MustBuildOnSubdevice, nullptr);
858856
}
859857

860-
auto Context = createSyclObjFromImpl<context>(ContextImpl);
861858
auto Device = createSyclObjFromImpl<device>(
862859
MustBuildOnSubdevice == true ? DeviceImpl : RootDevImpl);
863-
const RTDeviceBinaryImage &Img = getDeviceImage(KernelName, Context, Device);
860+
const RTDeviceBinaryImage &Img =
861+
getDeviceImage(KernelName, ContextImpl, Device);
864862

865863
// Check that device supports all aspects used by the kernel
866864
if (auto exception = checkDevSupportDeviceRequirements(Device, Img, NDRDesc))
@@ -879,19 +877,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
879877
std::copy(DeviceImagesToLink.begin(), DeviceImagesToLink.end(),
880878
std::back_inserter(AllImages));
881879

882-
return getBuiltURProgram(std::move(AllImages), Context, {std::move(Device)});
880+
return getBuiltURProgram(std::move(AllImages), ContextImpl,
881+
{std::move(Device)});
883882
}
884883

885884
ur_program_handle_t ProgramManager::getBuiltURProgram(
886-
const BinImgWithDeps &ImgWithDeps, const context &Context,
885+
const BinImgWithDeps &ImgWithDeps, const ContextImplPtr &ContextImpl,
887886
const std::vector<device> &Devs, const DevImgPlainWithDeps *DevImgWithDeps,
888887
const SerializedObj &SpecConsts) {
889888
std::string CompileOpts;
890889
std::string LinkOpts;
891890
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
892-
auto BuildF = [this, &ImgWithDeps, &DevImgWithDeps, &Context, &Devs,
891+
auto BuildF = [this, &ImgWithDeps, &DevImgWithDeps, &ContextImpl, &Devs,
893892
&CompileOpts, &LinkOpts, &SpecConsts] {
894-
const ContextImplPtr &ContextImpl = getSyclObjImpl(Context);
895893
const AdapterPtr &Adapter = ContextImpl->getAdapter();
896894
const RTDeviceBinaryImage &MainImg = *ImgWithDeps.getMain();
897895
applyOptionsFromImage(CompileOpts, LinkOpts, MainImg, Devs, Adapter);
@@ -900,7 +898,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
900898
appendLinkEnvironmentVariablesThatAppend(LinkOpts);
901899

902900
auto [NativePrg, DeviceCodeWasInCache] =
903-
getOrCreateURProgram(MainImg, ImgWithDeps.getAll(), Context, Devs,
901+
getOrCreateURProgram(MainImg, ImgWithDeps.getAll(), ContextImpl, Devs,
904902
CompileOpts + LinkOpts, SpecConsts);
905903

906904
if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants()) {
@@ -940,7 +938,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
940938
if (UseDeviceLibs)
941939
DeviceLibReqMask |= getDeviceLibReqMask(*BinImg);
942940

943-
ur_program_handle_t NativePrg = createURProgram(*BinImg, Context, Devs);
941+
ur_program_handle_t NativePrg =
942+
createURProgram(*BinImg, ContextImpl, Devs);
944943

945944
if (BinImg->supportsSpecConstants()) {
946945
enableITTAnnotationsIfNeeded(NativePrg, Adapter);
@@ -1005,7 +1004,6 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
10051004
auto CacheKey =
10061005
std::make_pair(std::make_pair(SpecConsts, ImgId), URDevicesSet);
10071006

1008-
const ContextImplPtr &ContextImpl = getSyclObjImpl(Context);
10091007
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache();
10101008
auto GetCachedBuildF = [&Cache, &CacheKey]() {
10111009
return Cache.getOrInsertProgram(CacheKey);
@@ -1480,7 +1478,8 @@ sycl_device_binary getRawImg(RTDeviceBinaryImage *Img) {
14801478
template <typename StorageKey>
14811479
RTDeviceBinaryImage *getBinImageFromMultiMap(
14821480
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
1483-
const StorageKey &Key, const context &Context, const device &Device) {
1481+
const StorageKey &Key, const ContextImplPtr &ContextImpl,
1482+
const device &Device) {
14841483
auto [ItBegin, ItEnd] = ImagesSet.equal_range(Key);
14851484
if (ItBegin == ItEnd)
14861485
return nullptr;
@@ -1510,19 +1509,20 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
15101509
uint32_t ImgInd = 0;
15111510
// Ask the native runtime under the given context to choose the device image
15121511
// it prefers.
1513-
getSyclObjImpl(Context)->getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
1512+
ContextImpl->getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
15141513
getSyclObjImpl(Device)->getHandleRef(), UrBinaries.data(),
15151514
UrBinaries.size(), &ImgInd);
15161515
return DeviceFilteredImgs[ImgInd];
15171516
}
15181517

15191518
RTDeviceBinaryImage &
15201519
ProgramManager::getDeviceImage(const std::string &KernelName,
1521-
const context &Context, const device &Device) {
1520+
const ContextImplPtr &ContextImpl,
1521+
const device &Device) {
15221522
if constexpr (DbgProgMgr > 0) {
15231523
std::cerr << ">>> ProgramManager::getDeviceImage(\"" << KernelName << "\", "
1524-
<< getSyclObjImpl(Context).get() << ", "
1525-
<< getSyclObjImpl(Device).get() << ")\n";
1524+
<< ContextImpl.get() << ", " << getSyclObjImpl(Device).get()
1525+
<< ")\n";
15261526

15271527
std::cerr << "available device images:\n";
15281528
debugPrintBinaryImages();
@@ -1532,7 +1532,7 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15321532
assert(m_SpvFileImage);
15331533
return getDeviceImage(
15341534
std::unordered_set<RTDeviceBinaryImage *>({m_SpvFileImage.get()}),
1535-
Context, Device);
1535+
ContextImpl, Device);
15361536
}
15371537

15381538
RTDeviceBinaryImage *Img = nullptr;
@@ -1541,9 +1541,9 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15411541
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
15421542
KernelId != m_KernelName2KernelIDs.end()) {
15431543
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
1544-
Context, Device);
1544+
ContextImpl, Device);
15451545
} else {
1546-
Img = getBinImageFromMultiMap(m_ServiceKernels, KernelName, Context,
1546+
Img = getBinImageFromMultiMap(m_ServiceKernels, KernelName, ContextImpl,
15471547
Device);
15481548
}
15491549
}
@@ -1565,13 +1565,13 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
15651565

15661566
RTDeviceBinaryImage &ProgramManager::getDeviceImage(
15671567
const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1568-
const context &Context, const device &Device) {
1568+
const ContextImplPtr &ContextImpl, const device &Device) {
15691569
assert(ImageSet.size() > 0);
15701570

15711571
if constexpr (DbgProgMgr > 0) {
15721572
std::cerr << ">>> ProgramManager::getDeviceImage(Custom SPV file "
1573-
<< getSyclObjImpl(Context).get() << ", "
1574-
<< getSyclObjImpl(Device).get() << ")\n";
1573+
<< ContextImpl.get() << ", " << getSyclObjImpl(Device).get()
1574+
<< ")\n";
15751575

15761576
std::cerr << "available device images:\n";
15771577
debugPrintBinaryImages();
@@ -1593,7 +1593,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
15931593
getUrDeviceTarget(RawImgs[BinaryCount]->DeviceTargetSpec);
15941594
}
15951595

1596-
getSyclObjImpl(Context)->getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
1596+
ContextImpl->getAdapter()->call<UrApiKind::urDeviceSelectBinary>(
15971597
getSyclObjImpl(Device)->getHandleRef(), UrBinaries.data(),
15981598
UrBinaries.size(), &ImgInd);
15991599

@@ -2888,8 +2888,9 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28882888
const AdapterPtr &Adapter =
28892889
getSyclObjImpl(InputImpl->get_context())->getAdapter();
28902890

2891-
ur_program_handle_t Prog = createURProgram(*InputImpl->get_bin_image_ref(),
2892-
InputImpl->get_context(), Devs);
2891+
ur_program_handle_t Prog =
2892+
createURProgram(*InputImpl->get_bin_image_ref(),
2893+
getSyclObjImpl(InputImpl->get_context()), Devs);
28932894

28942895
if (InputImpl->get_bin_image_ref()->supportsSpecConstants())
28952896
setSpecializationConstants(InputImpl, Prog, Adapter);
@@ -3097,7 +3098,8 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
30973098
const std::shared_ptr<device_image_impl> &MainInputImpl =
30983099
getSyclObjImpl(DevImgWithDeps.getMain());
30993100

3100-
const context Context = MainInputImpl->get_context();
3101+
const context &Context = MainInputImpl->get_context();
3102+
const ContextImplPtr &ContextImpl = detail::getSyclObjImpl(Context);
31013103

31023104
std::vector<const RTDeviceBinaryImage *> BinImgs;
31033105
BinImgs.reserve(DevImgWithDeps.size());
@@ -3138,7 +3140,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31383140
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
31393141

31403142
ur_program_handle_t ResProgram = getBuiltURProgram(
3141-
std::move(BinImgs), Context, Devs, &DevImgWithDeps, SpecConstBlob);
3143+
std::move(BinImgs), ContextImpl, Devs, &DevImgWithDeps, SpecConstBlob);
31423144

31433145
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
31443146
MainInputImpl->get_bin_image_ref(), Context, Devs,
@@ -3259,7 +3261,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
32593261

32603262
if constexpr (DbgProgMgr > 0)
32613263
std::cerr << ">>> Adding the kernel to the cache.\n";
3262-
auto Program = createURProgram(Img, Context, {Device});
3264+
const ContextImplPtr &ContextImpl = detail::getSyclObjImpl(Context);
3265+
auto Program = createURProgram(Img, ContextImpl, {Device});
32633266
auto DeviceImpl = detail::getSyclObjImpl(Device);
32643267
auto &Adapter = DeviceImpl->getAdapter();
32653268
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
@@ -3274,8 +3277,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
32743277
std::vector<ur_program_handle_t> ExtraProgramsToLink;
32753278
std::vector<ur_device_handle_t> Devs = {DeviceImpl->getHandleRef()};
32763279
auto BuildProgram =
3277-
build(std::move(ProgramManaged), detail::getSyclObjImpl(Context),
3278-
CompileOpts, LinkOpts, Devs,
3280+
build(std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts, Devs,
32793281
/*For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0,
32803282
ExtraProgramsToLink);
32813283
ur_kernel_handle_t UrKernel{nullptr};

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,15 @@ class ProgramManager {
135135
static ProgramManager &getInstance();
136136

137137
RTDeviceBinaryImage &getDeviceImage(const std::string &KernelName,
138-
const context &Context,
138+
const ContextImplPtr &ContextImpl,
139139
const device &Device);
140140

141141
RTDeviceBinaryImage &getDeviceImage(
142142
const std::unordered_set<RTDeviceBinaryImage *> &ImagesToVerify,
143-
const context &Context, const device &Device);
143+
const ContextImplPtr &ContextImpl, const device &Device);
144144

145145
ur_program_handle_t createURProgram(const RTDeviceBinaryImage &Img,
146-
const context &Context,
146+
const ContextImplPtr &ContextImpl,
147147
const std::vector<device> &Devices);
148148
/// Creates a UR program using either a cached device code binary if present
149149
/// in the persistent cache or from the supplied device image otherwise.
@@ -167,7 +167,7 @@ class ProgramManager {
167167
std::pair<ur_program_handle_t, bool> getOrCreateURProgram(
168168
const RTDeviceBinaryImage &Img,
169169
const std::vector<const RTDeviceBinaryImage *> &AllImages,
170-
const context &Context, const std::vector<device> &Devices,
170+
const ContextImplPtr &ContextImpl, const std::vector<device> &Devices,
171171
const std::string &CompileAndLinkOptions, SerializedObj SpecConsts);
172172
/// Builds or retrieves from cache a program defining the kernel with given
173173
/// name.
@@ -192,7 +192,8 @@ class ProgramManager {
192192
/// \param SpecConsts is an optional parameter containing spec constant values
193193
/// the program should be built with.
194194
ur_program_handle_t
195-
getBuiltURProgram(const BinImgWithDeps &ImgWithDeps, const context &Context,
195+
getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,
196+
const ContextImplPtr &ContextImpl,
196197
const std::vector<device> &Devs,
197198
const DevImgPlainWithDeps *DevImgWithDeps = nullptr,
198199
const SerializedObj &SpecConsts = {});

0 commit comments

Comments
 (0)