@@ -115,19 +115,19 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr &Context,
115
115
}
116
116
117
117
// TODO replace this with a new UR API function
118
- static bool isDeviceBinaryTypeSupported (const context &C ,
118
+ static bool isDeviceBinaryTypeSupported (const ContextImplPtr &ContextImpl ,
119
119
ur::DeviceBinaryType Format) {
120
120
// All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
121
121
if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
122
122
return true ;
123
123
124
- const backend ContextBackend = detail::getSyclObjImpl (C) ->getBackend ();
124
+ const backend ContextBackend = ContextImpl ->getBackend ();
125
125
126
126
// The CUDA backend cannot use SPIR-V
127
127
if (ContextBackend == backend::ext_oneapi_cuda)
128
128
return false ;
129
129
130
- std::vector<device> Devices = C. get_devices ();
130
+ const std::vector<device> & Devices = ContextImpl-> getDevices ();
131
131
132
132
// Program type is SPIR-V, so we need a device compiler to do JIT.
133
133
for (const device &D : Devices) {
@@ -137,7 +137,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
137
137
138
138
// OpenCL 2.1 and greater require clCreateProgramWithIL
139
139
if (ContextBackend == backend::opencl) {
140
- std::string ver = C.get_platform ().get_info <info::platform::version>();
140
+ std::string ver = ContextImpl->get_info <info::context::platform>()
141
+ .get_info <info::platform::version>();
141
142
if (ver.find (" OpenCL 1.0" ) == std::string::npos &&
142
143
ver.find (" OpenCL 1.1" ) == std::string::npos &&
143
144
ver.find (" OpenCL 1.2" ) == std::string::npos &&
@@ -187,16 +188,15 @@ static bool isDeviceBinaryTypeSupported(const context &C,
187
188
188
189
ur_program_handle_t
189
190
ProgramManager::createURProgram (const RTDeviceBinaryImage &Img,
190
- const context &Context ,
191
+ const ContextImplPtr &ContextImpl ,
191
192
const std::vector<device> &Devices) {
192
193
if constexpr (DbgProgMgr > 0 ) {
193
194
std::vector<ur_device_handle_t > URDevices;
194
195
std::transform (
195
196
Devices.begin (), Devices.end (), std::back_inserter (URDevices),
196
197
[](const device &Dev) { return getSyclObjImpl (Dev)->getHandleRef (); });
197
198
std::cerr << " >>> ProgramManager::createPIProgram(" << &Img << " , "
198
- << getSyclObjImpl (Context).get () << " , " << VecToString (URDevices)
199
- << " )\n " ;
199
+ << ContextImpl.get () << " , " << VecToString (URDevices) << " )\n " ;
200
200
}
201
201
const sycl_device_binary_struct &RawImg = Img.getRawData ();
202
202
@@ -224,7 +224,7 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
224
224
// sycl::detail::pi::PiDeviceBinaryType Format = Img->Format;
225
225
// assert(Format != SYCL_DEVICE_BINARY_TYPE_NONE && "Image format not set");
226
226
227
- if (!isDeviceBinaryTypeSupported (Context , Format))
227
+ if (!isDeviceBinaryTypeSupported (ContextImpl , Format))
228
228
throw sycl::exception (
229
229
sycl::errc::feature_not_supported,
230
230
" SPIR-V online compilation is not supported in this context" );
@@ -233,23 +233,22 @@ ProgramManager::createURProgram(const RTDeviceBinaryImage &Img,
233
233
const auto &ProgMetadata = Img.getProgramMetadataUR ();
234
234
235
235
// Load the image
236
- const ContextImplPtr &Ctx = getSyclObjImpl (Context);
237
236
std::vector<const uint8_t *> Binaries (
238
237
Devices.size (), const_cast <uint8_t *>(RawImg.BinaryStart ));
239
238
std::vector<size_t > Lengths (Devices.size (), ImgSize);
240
239
ur_program_handle_t Res =
241
240
Format == SYCL_DEVICE_BINARY_TYPE_SPIRV
242
- ? createSpirvProgram (Ctx , RawImg.BinaryStart , ImgSize)
243
- : createBinaryProgram (Ctx , Devices, Binaries. data (), Lengths .data (),
244
- ProgMetadata);
241
+ ? createSpirvProgram (ContextImpl , RawImg.BinaryStart , ImgSize)
242
+ : createBinaryProgram (ContextImpl , Devices, Binaries.data (),
243
+ Lengths. data (), ProgMetadata);
245
244
246
245
{
247
246
std::lock_guard<std::mutex> Lock (MNativeProgramsMutex);
248
247
// associate the UR program with the image it was created for
249
- NativePrograms.insert ({Res, {Ctx , &Img}});
248
+ NativePrograms.insert ({Res, {ContextImpl , &Img}});
250
249
}
251
250
252
- Ctx ->addDeviceGlobalInitializer (Res, Devices, &Img);
251
+ ContextImpl ->addDeviceGlobalInitializer (Res, Devices, &Img);
253
252
254
253
if constexpr (DbgProgMgr > 1 )
255
254
std::cerr << " created program: " << Res
@@ -518,7 +517,7 @@ static void applyOptionsFromEnvironment(std::string &CompileOpts,
518
517
std::pair<ur_program_handle_t , bool > ProgramManager::getOrCreateURProgram (
519
518
const RTDeviceBinaryImage &MainImg,
520
519
const std::vector<const RTDeviceBinaryImage *> &AllImages,
521
- const context &Context , const std::vector<device> &Devices,
520
+ const ContextImplPtr &ContextImpl , const std::vector<device> &Devices,
522
521
const std::string &CompileAndLinkOptions, SerializedObj SpecConsts) {
523
522
ur_program_handle_t NativePrg;
524
523
@@ -540,11 +539,10 @@ std::pair<ur_program_handle_t, bool> ProgramManager::getOrCreateURProgram(
540
539
ProgMetadataVector.insert (ProgMetadataVector.end (),
541
540
ImgProgMetadata.begin (), ImgProgMetadata.end ());
542
541
}
543
- NativePrg =
544
- createBinaryProgram (getSyclObjImpl (Context), Devices, BinPtrs.data (),
545
- Lengths.data (), ProgMetadataVector);
542
+ NativePrg = createBinaryProgram (ContextImpl, Devices, BinPtrs.data (),
543
+ Lengths.data (), ProgMetadataVector);
546
544
} else {
547
- NativePrg = createURProgram (MainImg, Context , Devices);
545
+ NativePrg = createURProgram (MainImg, ContextImpl , Devices);
548
546
}
549
547
return {NativePrg, Binaries.size ()};
550
548
}
@@ -857,10 +855,10 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
857
855
sizeof (ur_bool_t ), &MustBuildOnSubdevice, nullptr );
858
856
}
859
857
860
- auto Context = createSyclObjFromImpl<context>(ContextImpl);
861
858
auto Device = createSyclObjFromImpl<device>(
862
859
MustBuildOnSubdevice == true ? DeviceImpl : RootDevImpl);
863
- const RTDeviceBinaryImage &Img = getDeviceImage (KernelName, Context, Device);
860
+ const RTDeviceBinaryImage &Img =
861
+ getDeviceImage (KernelName, ContextImpl, Device);
864
862
865
863
// Check that device supports all aspects used by the kernel
866
864
if (auto exception = checkDevSupportDeviceRequirements (Device, Img, NDRDesc))
@@ -879,19 +877,19 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
879
877
std::copy (DeviceImagesToLink.begin (), DeviceImagesToLink.end (),
880
878
std::back_inserter (AllImages));
881
879
882
- return getBuiltURProgram (std::move (AllImages), Context, {std::move (Device)});
880
+ return getBuiltURProgram (std::move (AllImages), ContextImpl,
881
+ {std::move (Device)});
883
882
}
884
883
885
884
ur_program_handle_t ProgramManager::getBuiltURProgram (
886
- const BinImgWithDeps &ImgWithDeps, const context &Context ,
885
+ const BinImgWithDeps &ImgWithDeps, const ContextImplPtr &ContextImpl ,
887
886
const std::vector<device> &Devs, const DevImgPlainWithDeps *DevImgWithDeps,
888
887
const SerializedObj &SpecConsts) {
889
888
std::string CompileOpts;
890
889
std::string LinkOpts;
891
890
applyOptionsFromEnvironment (CompileOpts, LinkOpts);
892
- auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &Context , &Devs,
891
+ auto BuildF = [this , &ImgWithDeps, &DevImgWithDeps, &ContextImpl , &Devs,
893
892
&CompileOpts, &LinkOpts, &SpecConsts] {
894
- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
895
893
const AdapterPtr &Adapter = ContextImpl->getAdapter ();
896
894
const RTDeviceBinaryImage &MainImg = *ImgWithDeps.getMain ();
897
895
applyOptionsFromImage (CompileOpts, LinkOpts, MainImg, Devs, Adapter);
@@ -900,7 +898,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
900
898
appendLinkEnvironmentVariablesThatAppend (LinkOpts);
901
899
902
900
auto [NativePrg, DeviceCodeWasInCache] =
903
- getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), Context , Devs,
901
+ getOrCreateURProgram (MainImg, ImgWithDeps.getAll (), ContextImpl , Devs,
904
902
CompileOpts + LinkOpts, SpecConsts);
905
903
906
904
if (!DeviceCodeWasInCache && MainImg.supportsSpecConstants ()) {
@@ -940,7 +938,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
940
938
if (UseDeviceLibs)
941
939
DeviceLibReqMask |= getDeviceLibReqMask (*BinImg);
942
940
943
- ur_program_handle_t NativePrg = createURProgram (*BinImg, Context, Devs);
941
+ ur_program_handle_t NativePrg =
942
+ createURProgram (*BinImg, ContextImpl, Devs);
944
943
945
944
if (BinImg->supportsSpecConstants ()) {
946
945
enableITTAnnotationsIfNeeded (NativePrg, Adapter);
@@ -1005,7 +1004,6 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
1005
1004
auto CacheKey =
1006
1005
std::make_pair (std::make_pair (SpecConsts, ImgId), URDevicesSet);
1007
1006
1008
- const ContextImplPtr &ContextImpl = getSyclObjImpl (Context);
1009
1007
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache ();
1010
1008
auto GetCachedBuildF = [&Cache, &CacheKey]() {
1011
1009
return Cache.getOrInsertProgram (CacheKey);
@@ -1480,7 +1478,8 @@ sycl_device_binary getRawImg(RTDeviceBinaryImage *Img) {
1480
1478
template <typename StorageKey>
1481
1479
RTDeviceBinaryImage *getBinImageFromMultiMap (
1482
1480
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
1483
- const StorageKey &Key, const context &Context, const device &Device) {
1481
+ const StorageKey &Key, const ContextImplPtr &ContextImpl,
1482
+ const device &Device) {
1484
1483
auto [ItBegin, ItEnd] = ImagesSet.equal_range (Key);
1485
1484
if (ItBegin == ItEnd)
1486
1485
return nullptr ;
@@ -1510,19 +1509,20 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
1510
1509
uint32_t ImgInd = 0 ;
1511
1510
// Ask the native runtime under the given context to choose the device image
1512
1511
// it prefers.
1513
- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1512
+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1514
1513
getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
1515
1514
UrBinaries.size (), &ImgInd);
1516
1515
return DeviceFilteredImgs[ImgInd];
1517
1516
}
1518
1517
1519
1518
RTDeviceBinaryImage &
1520
1519
ProgramManager::getDeviceImage (const std::string &KernelName,
1521
- const context &Context, const device &Device) {
1520
+ const ContextImplPtr &ContextImpl,
1521
+ const device &Device) {
1522
1522
if constexpr (DbgProgMgr > 0 ) {
1523
1523
std::cerr << " >>> ProgramManager::getDeviceImage(\" " << KernelName << " \" , "
1524
- << getSyclObjImpl (Context) .get () << " , "
1525
- << getSyclObjImpl (Device). get () << " )\n " ;
1524
+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1525
+ << " )\n " ;
1526
1526
1527
1527
std::cerr << " available device images:\n " ;
1528
1528
debugPrintBinaryImages ();
@@ -1532,7 +1532,7 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
1532
1532
assert (m_SpvFileImage);
1533
1533
return getDeviceImage (
1534
1534
std::unordered_set<RTDeviceBinaryImage *>({m_SpvFileImage.get ()}),
1535
- Context , Device);
1535
+ ContextImpl , Device);
1536
1536
}
1537
1537
1538
1538
RTDeviceBinaryImage *Img = nullptr ;
@@ -1541,9 +1541,9 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
1541
1541
if (auto KernelId = m_KernelName2KernelIDs.find (KernelName);
1542
1542
KernelId != m_KernelName2KernelIDs.end ()) {
1543
1543
Img = getBinImageFromMultiMap (m_KernelIDs2BinImage, KernelId->second ,
1544
- Context , Device);
1544
+ ContextImpl , Device);
1545
1545
} else {
1546
- Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, Context ,
1546
+ Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, ContextImpl ,
1547
1547
Device);
1548
1548
}
1549
1549
}
@@ -1565,13 +1565,13 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
1565
1565
1566
1566
RTDeviceBinaryImage &ProgramManager::getDeviceImage (
1567
1567
const std::unordered_set<RTDeviceBinaryImage *> &ImageSet,
1568
- const context &Context , const device &Device) {
1568
+ const ContextImplPtr &ContextImpl , const device &Device) {
1569
1569
assert (ImageSet.size () > 0 );
1570
1570
1571
1571
if constexpr (DbgProgMgr > 0 ) {
1572
1572
std::cerr << " >>> ProgramManager::getDeviceImage(Custom SPV file "
1573
- << getSyclObjImpl (Context) .get () << " , "
1574
- << getSyclObjImpl (Device). get () << " )\n " ;
1573
+ << ContextImpl .get () << " , " << getSyclObjImpl (Device). get ()
1574
+ << " )\n " ;
1575
1575
1576
1576
std::cerr << " available device images:\n " ;
1577
1577
debugPrintBinaryImages ();
@@ -1593,7 +1593,7 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(
1593
1593
getUrDeviceTarget (RawImgs[BinaryCount]->DeviceTargetSpec );
1594
1594
}
1595
1595
1596
- getSyclObjImpl (Context) ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1596
+ ContextImpl ->getAdapter ()->call <UrApiKind::urDeviceSelectBinary>(
1597
1597
getSyclObjImpl (Device)->getHandleRef (), UrBinaries.data (),
1598
1598
UrBinaries.size (), &ImgInd);
1599
1599
@@ -2888,8 +2888,9 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
2888
2888
const AdapterPtr &Adapter =
2889
2889
getSyclObjImpl (InputImpl->get_context ())->getAdapter ();
2890
2890
2891
- ur_program_handle_t Prog = createURProgram (*InputImpl->get_bin_image_ref (),
2892
- InputImpl->get_context (), Devs);
2891
+ ur_program_handle_t Prog =
2892
+ createURProgram (*InputImpl->get_bin_image_ref (),
2893
+ getSyclObjImpl (InputImpl->get_context ()), Devs);
2893
2894
2894
2895
if (InputImpl->get_bin_image_ref ()->supportsSpecConstants ())
2895
2896
setSpecializationConstants (InputImpl, Prog, Adapter);
@@ -3097,7 +3098,8 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
3097
3098
const std::shared_ptr<device_image_impl> &MainInputImpl =
3098
3099
getSyclObjImpl (DevImgWithDeps.getMain ());
3099
3100
3100
- const context Context = MainInputImpl->get_context ();
3101
+ const context &Context = MainInputImpl->get_context ();
3102
+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
3101
3103
3102
3104
std::vector<const RTDeviceBinaryImage *> BinImgs;
3103
3105
BinImgs.reserve (DevImgWithDeps.size ());
@@ -3138,7 +3140,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
3138
3140
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge (RTCInfoPtrs);
3139
3141
3140
3142
ur_program_handle_t ResProgram = getBuiltURProgram (
3141
- std::move (BinImgs), Context , Devs, &DevImgWithDeps, SpecConstBlob);
3143
+ std::move (BinImgs), ContextImpl , Devs, &DevImgWithDeps, SpecConstBlob);
3142
3144
3143
3145
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
3144
3146
MainInputImpl->get_bin_image_ref (), Context, Devs,
@@ -3259,7 +3261,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
3259
3261
3260
3262
if constexpr (DbgProgMgr > 0 )
3261
3263
std::cerr << " >>> Adding the kernel to the cache.\n " ;
3262
- auto Program = createURProgram (Img, Context, {Device});
3264
+ const ContextImplPtr &ContextImpl = detail::getSyclObjImpl (Context);
3265
+ auto Program = createURProgram (Img, ContextImpl, {Device});
3263
3266
auto DeviceImpl = detail::getSyclObjImpl (Device);
3264
3267
auto &Adapter = DeviceImpl->getAdapter ();
3265
3268
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
@@ -3274,8 +3277,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
3274
3277
std::vector<ur_program_handle_t > ExtraProgramsToLink;
3275
3278
std::vector<ur_device_handle_t > Devs = {DeviceImpl->getHandleRef ()};
3276
3279
auto BuildProgram =
3277
- build (std::move (ProgramManaged), detail::getSyclObjImpl (Context),
3278
- CompileOpts, LinkOpts, Devs,
3280
+ build (std::move (ProgramManaged), ContextImpl, CompileOpts, LinkOpts, Devs,
3279
3281
/* For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0 ,
3280
3282
ExtraProgramsToLink);
3281
3283
ur_kernel_handle_t UrKernel{nullptr };
0 commit comments