Skip to content

Commit 9aefea0

Browse files
[SYCL] Fix memory leak in reduction resources (#5653)
Reductions that require additional resources, such as buffers, can currently create a circular dependency between the resources and the commands issued by the reductions. These changes clear up this dependence in a similar way to how streams are transferred by transferring ownership of the resources to the commands and ensuring release when cleaning up the commands.
1 parent 191a62a commit 9aefea0

16 files changed

+129
-28
lines changed

sycl/include/CL/sycl/detail/cg.hpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ class CGExecKernel : public CG {
248248
std::string MKernelName;
249249
detail::OSModuleHandle MOSModuleHandle;
250250
std::vector<std::shared_ptr<detail::stream_impl>> MStreams;
251+
std::vector<std::shared_ptr<const void>> MAuxiliaryResources;
251252

252253
CGExecKernel(NDRDescT NDRDesc, std::unique_ptr<HostKernelBase> HKernel,
253254
std::shared_ptr<detail::kernel_impl> SyclKernel,
@@ -259,14 +260,16 @@ class CGExecKernel : public CG {
259260
std::vector<ArgDesc> Args, std::string KernelName,
260261
detail::OSModuleHandle OSModuleHandle,
261262
std::vector<std::shared_ptr<detail::stream_impl>> Streams,
263+
std::vector<std::shared_ptr<const void>> AuxiliaryResources,
262264
CGTYPE Type, detail::code_location loc = {})
263265
: CG(Type, std::move(ArgsStorage), std::move(AccStorage),
264266
std::move(SharedPtrStorage), std::move(Requirements),
265267
std::move(Events), std::move(loc)),
266268
MNDRDesc(std::move(NDRDesc)), MHostKernel(std::move(HKernel)),
267269
MSyclKernel(std::move(SyclKernel)), MArgs(std::move(Args)),
268270
MKernelName(std::move(KernelName)), MOSModuleHandle(OSModuleHandle),
269-
MStreams(std::move(Streams)) {
271+
MStreams(std::move(Streams)),
272+
MAuxiliaryResources(std::move(AuxiliaryResources)) {
270273
assert((getType() == RunOnHostIntel || getType() == Kernel) &&
271274
"Wrong type of exec kernel CG.");
272275
}
@@ -277,6 +280,10 @@ class CGExecKernel : public CG {
277280
return MStreams;
278281
}
279282

283+
std::vector<std::shared_ptr<const void>> getAuxiliaryResources() const {
284+
return MAuxiliaryResources;
285+
}
286+
280287
std::shared_ptr<detail::kernel_bundle_impl> getKernelBundle() {
281288
const std::shared_ptr<std::vector<ExtendedMemberT>> &ExtendedMembers =
282289
getExtendedMembers();
@@ -291,6 +298,9 @@ class CGExecKernel : public CG {
291298

292299
void clearStreams() { MStreams.clear(); }
293300
bool hasStreams() { return !MStreams.empty(); }
301+
302+
void clearAuxiliaryResources() { MAuxiliaryResources.clear(); }
303+
bool hasAuxiliaryResources() { return !MAuxiliaryResources.empty(); }
294304
};
295305

296306
/// "Copy memory" command group class.

sycl/include/CL/sycl/handler.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,9 @@ class __SYCL_EXPORT handler {
472472
/// Saves buffers created by handling reduction feature in handler.
473473
/// They are then forwarded to command group and destroyed only after
474474
/// the command group finishes the work on device/host.
475-
/// The 'MSharedPtrStorage' suits that need.
476475
///
477476
/// @param ReduObj is a pointer to object that must be stored.
478-
void addReduction(const std::shared_ptr<const void> &ReduObj) {
479-
MSharedPtrStorage.push_back(ReduObj);
480-
}
477+
void addReduction(const std::shared_ptr<const void> &ReduObj);
481478

482479
~handler() = default;
483480

@@ -1280,6 +1277,7 @@ class __SYCL_EXPORT handler {
12801277
}
12811278

12821279
std::shared_ptr<detail::handler_impl> getHandlerImpl() const;
1280+
std::shared_ptr<detail::handler_impl> evictHandlerImpl() const;
12831281

12841282
void setStateExplicitKernelBundle();
12851283
void setStateSpecConstSet();

sycl/include/sycl/ext/oneapi/reduction.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ class reduction_impl : private reduction_impl_base {
718718
auto RWReduVal = std::make_shared<T>(MIdentity);
719719
CGH.addReduction(RWReduVal);
720720
MOutBufPtr = std::make_shared<buffer<T, 1>>(RWReduVal.get(), range<1>(1));
721+
MOutBufPtr->set_final_data();
721722
CGH.addReduction(MOutBufPtr);
722723
return createHandlerWiredReadWriteAccessor(CGH, *MOutBufPtr);
723724
}
@@ -728,6 +729,7 @@ class reduction_impl : private reduction_impl_base {
728729
auto CounterMem = std::make_shared<int>(0);
729730
CGH.addReduction(CounterMem);
730731
auto CounterBuf = std::make_shared<buffer<int, 1>>(CounterMem.get(), 1);
732+
CounterBuf->set_final_data();
731733
CGH.addReduction(CounterBuf);
732734
return {*CounterBuf, CGH};
733735
}

sycl/source/detail/handler_impl.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class handler_impl {
6565
/// equal to the queue associated with the handler if the corresponding
6666
/// submission is a fallback from a previous submission.
6767
std::shared_ptr<queue_impl> MSubmissionSecondaryQueue;
68+
69+
// Stores auxiliary resources used by internal operations.
70+
std::vector<std::shared_ptr<const void>> MAuxiliaryResources;
6871
};
6972

7073
} // namespace detail

sycl/source/detail/scheduler/commands.cpp

+18-2
Original file line numberDiff line numberDiff line change
@@ -1378,11 +1378,23 @@ std::vector<StreamImplPtr> ExecCGCommand::getStreams() const {
13781378
return {};
13791379
}
13801380

1381+
std::vector<std::shared_ptr<const void>>
1382+
ExecCGCommand::getAuxiliaryResources() const {
1383+
if (MCommandGroup->getType() == CG::Kernel)
1384+
return ((CGExecKernel *)MCommandGroup.get())->getAuxiliaryResources();
1385+
return {};
1386+
}
1387+
13811388
void ExecCGCommand::clearStreams() {
13821389
if (MCommandGroup->getType() == CG::Kernel)
13831390
((CGExecKernel *)MCommandGroup.get())->clearStreams();
13841391
}
13851392

1393+
void ExecCGCommand::clearAuxiliaryResources() {
1394+
if (MCommandGroup->getType() == CG::Kernel)
1395+
((CGExecKernel *)MCommandGroup.get())->clearAuxiliaryResources();
1396+
}
1397+
13861398
cl_int UpdateHostRequirementCommand::enqueueImp() {
13871399
waitForPreparedHostEvents();
13881400
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
@@ -1673,7 +1685,9 @@ ExecCGCommand::ExecCGCommand(std::unique_ptr<detail::CG> CommandGroup,
16731685
static_cast<detail::CGHostTask *>(MCommandGroup.get())->MQueue;
16741686
MEvent->setNeedsCleanupAfterWait(true);
16751687
} else if (MCommandGroup->getType() == CG::CGTYPE::Kernel &&
1676-
(static_cast<CGExecKernel *>(MCommandGroup.get()))->hasStreams())
1688+
(static_cast<CGExecKernel *>(MCommandGroup.get())->hasStreams() ||
1689+
static_cast<CGExecKernel *>(MCommandGroup.get())
1690+
->hasAuxiliaryResources()))
16771691
MEvent->setNeedsCleanupAfterWait(true);
16781692

16791693
emitInstrumentationDataProxy();
@@ -2481,7 +2495,9 @@ bool ExecCGCommand::supportsPostEnqueueCleanup() const {
24812495
return Command::supportsPostEnqueueCleanup() &&
24822496
(MCommandGroup->getType() != CG::CGTYPE::CodeplayHostTask) &&
24832497
(MCommandGroup->getType() != CG::CGTYPE::Kernel ||
2484-
!(static_cast<CGExecKernel *>(MCommandGroup.get()))->hasStreams());
2498+
(!static_cast<CGExecKernel *>(MCommandGroup.get())->hasStreams() &&
2499+
!static_cast<CGExecKernel *>(MCommandGroup.get())
2500+
->hasAuxiliaryResources()));
24852501
}
24862502

24872503
} // namespace detail

sycl/source/detail/scheduler/commands.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,10 @@ class ExecCGCommand : public Command {
541541
ExecCGCommand(std::unique_ptr<detail::CG> CommandGroup, QueueImplPtr Queue);
542542

543543
std::vector<StreamImplPtr> getStreams() const;
544+
std::vector<std::shared_ptr<const void>> getAuxiliaryResources() const;
544545

545546
void clearStreams();
547+
void clearAuxiliaryResources();
546548

547549
void printDot(std::ostream &Stream) const final;
548550
void emitInstrumentationData() final;

sycl/source/detail/scheduler/graph_builder.cpp

+23-2
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,8 @@ void Scheduler::GraphBuilder::decrementLeafCountersForRecord(
10451045

10461046
void Scheduler::GraphBuilder::cleanupCommandsForRecord(
10471047
MemObjRecord *Record,
1048-
std::vector<std::shared_ptr<stream_impl>> &StreamsToDeallocate) {
1048+
std::vector<std::shared_ptr<stream_impl>> &StreamsToDeallocate,
1049+
std::vector<std::shared_ptr<const void>> &AuxResourcesToDeallocate) {
10491050
std::vector<AllocaCommandBase *> &AllocaCommands = Record->MAllocaCommands;
10501051
if (AllocaCommands.empty())
10511052
return;
@@ -1097,10 +1098,19 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(
10971098
// Collect stream objects for a visited command.
10981099
if (Cmd->getType() == Command::CommandType::RUN_CG) {
10991100
auto ExecCmd = static_cast<ExecCGCommand *>(Cmd);
1101+
1102+
// Transfer ownership of stream implementations.
11001103
std::vector<std::shared_ptr<stream_impl>> Streams = ExecCmd->getStreams();
11011104
ExecCmd->clearStreams();
11021105
StreamsToDeallocate.insert(StreamsToDeallocate.end(), Streams.begin(),
11031106
Streams.end());
1107+
1108+
// Transfer ownership of auxiliary resources.
1109+
std::vector<std::shared_ptr<const void>> AuxResources =
1110+
ExecCmd->getAuxiliaryResources();
1111+
ExecCmd->clearAuxiliaryResources();
1112+
AuxResourcesToDeallocate.insert(AuxResourcesToDeallocate.end(),
1113+
AuxResources.begin(), AuxResources.end());
11041114
}
11051115

11061116
for (Command *UserCmd : Cmd->MUsers)
@@ -1160,6 +1170,7 @@ void Scheduler::GraphBuilder::cleanupCommand(Command *Cmd) {
11601170
if (ExecCGCmd->getCG().getType() == CG::CGTYPE::Kernel) {
11611171
auto *ExecKernelCG = static_cast<CGExecKernel *>(&ExecCGCmd->getCG());
11621172
assert(!ExecKernelCG->hasStreams());
1173+
assert(!ExecKernelCG->hasAuxiliaryResources());
11631174
}
11641175
}
11651176
#endif
@@ -1191,7 +1202,8 @@ void Scheduler::GraphBuilder::cleanupCommand(Command *Cmd) {
11911202

11921203
void Scheduler::GraphBuilder::cleanupFinishedCommands(
11931204
Command *FinishedCmd,
1194-
std::vector<std::shared_ptr<stream_impl>> &StreamsToDeallocate) {
1205+
std::vector<std::shared_ptr<stream_impl>> &StreamsToDeallocate,
1206+
std::vector<std::shared_ptr<const void>> &AuxResourcesToDeallocate) {
11951207
assert(MCmdsToVisit.empty());
11961208
MCmdsToVisit.push(FinishedCmd);
11971209
MVisitedCmds.clear();
@@ -1207,10 +1219,19 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands(
12071219
// Collect stream objects for a visited command.
12081220
if (Cmd->getType() == Command::CommandType::RUN_CG) {
12091221
auto ExecCmd = static_cast<ExecCGCommand *>(Cmd);
1222+
1223+
// Transfer ownership of stream implementations.
12101224
std::vector<std::shared_ptr<stream_impl>> Streams = ExecCmd->getStreams();
12111225
ExecCmd->clearStreams();
12121226
StreamsToDeallocate.insert(StreamsToDeallocate.end(), Streams.begin(),
12131227
Streams.end());
1228+
1229+
// Transfer ownership of auxiliary resources.
1230+
std::vector<std::shared_ptr<const void>> AuxResources =
1231+
ExecCmd->getAuxiliaryResources();
1232+
ExecCmd->clearAuxiliaryResources();
1233+
AuxResourcesToDeallocate.insert(AuxResourcesToDeallocate.end(),
1234+
AuxResources.begin(), AuxResources.end());
12141235
}
12151236

12161237
for (const DepDesc &Dep : Cmd->MDeps) {

sycl/source/detail/scheduler/scheduler.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) {
239239
// objects, this is needed to guarantee that streamed data is printed and
240240
// resources are released.
241241
std::vector<std::shared_ptr<stream_impl>> StreamsToDeallocate;
242+
// Similar to streams, we also collect the auxiliary resources used by the
243+
// commands. Cleanup will make sure the commands do not own the resources
244+
// anymore, so we just need them to survive the graph lock then they can die
245+
// as they go out of scope.
246+
std::vector<std::shared_ptr<const void>> AuxResourcesToDeallocate;
242247
{
243248
// Avoiding deadlock situation, where one thread is in the process of
244249
// enqueueing (with a locked mutex) a currently blocked task that waits for
@@ -249,7 +254,8 @@ void Scheduler::cleanupFinishedCommands(EventImplPtr FinishedEvent) {
249254
// The command might have been cleaned up (and set to nullptr) by another
250255
// thread
251256
if (FinishedCmd)
252-
MGraphBuilder.cleanupFinishedCommands(FinishedCmd, StreamsToDeallocate);
257+
MGraphBuilder.cleanupFinishedCommands(FinishedCmd, StreamsToDeallocate,
258+
AuxResourcesToDeallocate);
253259
}
254260
}
255261
deallocateStreams(StreamsToDeallocate);
@@ -261,6 +267,11 @@ void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) {
261267
// objects, this is needed to guarantee that streamed data is printed and
262268
// resources are released.
263269
std::vector<std::shared_ptr<stream_impl>> StreamsToDeallocate;
270+
// Similar to streams, we also collect the auxiliary resources used by the
271+
// commands. Cleanup will make sure the commands do not own the resources
272+
// anymore, so we just need them to survive the graph lock then they can die
273+
// as they go out of scope.
274+
std::vector<std::shared_ptr<const void>> AuxResourcesToDeallocate;
264275

265276
{
266277
MemObjRecord *Record = nullptr;
@@ -282,7 +293,8 @@ void Scheduler::removeMemoryObject(detail::SYCLMemObjI *MemObj) {
282293
WriteLockT Lock(MGraphLock, std::defer_lock);
283294
acquireWriteLock(Lock);
284295
MGraphBuilder.decrementLeafCountersForRecord(Record);
285-
MGraphBuilder.cleanupCommandsForRecord(Record, StreamsToDeallocate);
296+
MGraphBuilder.cleanupCommandsForRecord(Record, StreamsToDeallocate,
297+
AuxResourcesToDeallocate);
286298
MGraphBuilder.removeRecordForMemObj(MemObj);
287299
}
288300
}

sycl/source/detail/scheduler/scheduler.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ class Scheduler {
514514
/// (assuming that all its commands have been waited for).
515515
void cleanupFinishedCommands(
516516
Command *FinishedCmd,
517-
std::vector<std::shared_ptr<cl::sycl::detail::stream_impl>> &);
517+
std::vector<std::shared_ptr<cl::sycl::detail::stream_impl>> &,
518+
std::vector<std::shared_ptr<const void>> &);
518519

519520
/// Reschedules the command passed using Queue provided.
520521
///
@@ -540,7 +541,8 @@ class Scheduler {
540541
/// Removes commands that use the given MemObjRecord from the graph.
541542
void cleanupCommandsForRecord(
542543
MemObjRecord *Record,
543-
std::vector<std::shared_ptr<cl::sycl::detail::stream_impl>> &);
544+
std::vector<std::shared_ptr<cl::sycl::detail::stream_impl>> &,
545+
std::vector<std::shared_ptr<const void>> &);
544546

545547
/// Removes the MemObjRecord for the memory object passed.
546548
void removeRecordForMemObj(SYCLMemObjI *MemObject);

sycl/source/handler.cpp

+37-12
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,40 @@ handler::handler(std::shared_ptr<detail::queue_impl> Queue,
4949
MSharedPtrStorage.push_back(std::move(ExtendedMembers));
5050
}
5151

52+
static detail::ExtendedMemberT &getHandlerImplMember(
53+
std::vector<std::shared_ptr<const void>> &SharedPtrStorage) {
54+
assert(!SharedPtrStorage.empty());
55+
std::shared_ptr<std::vector<detail::ExtendedMemberT>> ExtendedMembersVec =
56+
detail::convertToExtendedMembers(SharedPtrStorage[0]);
57+
assert(ExtendedMembersVec->size() > 0);
58+
auto &HandlerImplMember = (*ExtendedMembersVec)[0];
59+
assert(detail::ExtendedMembersType::HANDLER_IMPL == HandlerImplMember.MType);
60+
return HandlerImplMember;
61+
}
62+
5263
/// Gets the handler_impl at the start of the extended members.
5364
std::shared_ptr<detail::handler_impl> handler::getHandlerImpl() const {
5465
std::lock_guard<std::mutex> Lock(
5566
detail::GlobalHandler::instance().getHandlerExtendedMembersMutex());
67+
return std::static_pointer_cast<detail::handler_impl>(
68+
getHandlerImplMember(MSharedPtrStorage).MData);
69+
}
5670

57-
assert(!MSharedPtrStorage.empty());
58-
59-
std::shared_ptr<std::vector<detail::ExtendedMemberT>> ExtendedMembersVec =
60-
detail::convertToExtendedMembers(MSharedPtrStorage[0]);
61-
62-
assert(ExtendedMembersVec->size() > 0);
63-
64-
auto HandlerImplMember = (*ExtendedMembersVec)[0];
71+
/// Gets the handler_impl at the start of the extended members and removes it.
72+
std::shared_ptr<detail::handler_impl> handler::evictHandlerImpl() const {
73+
std::lock_guard<std::mutex> Lock(
74+
detail::GlobalHandler::instance().getHandlerExtendedMembersMutex());
75+
auto &HandlerImplMember = getHandlerImplMember(MSharedPtrStorage);
76+
auto Impl =
77+
std::static_pointer_cast<detail::handler_impl>(HandlerImplMember.MData);
6578

66-
assert(detail::ExtendedMembersType::HANDLER_IMPL == HandlerImplMember.MType);
79+
// Reset the data of the member.
80+
// NOTE: We let it stay because removing the front can be expensive. This will
81+
// be improved when the impl is made a member of handler. In fact eviction is
82+
// likely to not be needed when that happens.
83+
HandlerImplMember.MData.reset();
6784

68-
return std::static_pointer_cast<detail::handler_impl>(
69-
HandlerImplMember.MData);
85+
return Impl;
7086
}
7187

7288
// Sets the submission state to indicate that an explicit kernel bundle has been
@@ -281,6 +297,10 @@ event handler::finalize() {
281297
return MLastEvent;
282298
}
283299

300+
// Evict handler_impl from extended members to make sure the command group
301+
// does not keep it alive.
302+
std::shared_ptr<detail::handler_impl> Impl = evictHandlerImpl();
303+
284304
std::unique_ptr<detail::CG> CommandGroup;
285305
switch (type) {
286306
case detail::CG::Kernel:
@@ -293,7 +313,8 @@ event handler::finalize() {
293313
std::move(MArgsStorage), std::move(MAccStorage),
294314
std::move(MSharedPtrStorage), std::move(MRequirements),
295315
std::move(MEvents), std::move(MArgs), MKernelName, MOSModuleHandle,
296-
std::move(MStreamStorage), MCGType, MCodeLoc));
316+
std::move(MStreamStorage), std::move(Impl->MAuxiliaryResources),
317+
MCGType, MCodeLoc));
297318
break;
298319
}
299320
case detail::CG::CodeplayInteropTask:
@@ -382,6 +403,10 @@ event handler::finalize() {
382403
return MLastEvent;
383404
}
384405

406+
void handler::addReduction(const std::shared_ptr<const void> &ReduObj) {
407+
getHandlerImpl()->MAuxiliaryResources.push_back(ReduObj);
408+
}
409+
385410
void handler::associateWithHandler(detail::AccessorBaseHost *AccBase,
386411
access::target AccTarget) {
387412
detail::AccessorImplPtr AccImpl = detail::getSyclObjImpl(*AccBase);

sycl/test/abi/sycl_symbols_linux.dump

+2
Original file line numberDiff line numberDiff line change
@@ -3994,6 +3994,7 @@ _ZN2cl4sycl7handler10depends_onERKSt6vectorINS0_5eventESaIS3_EE
39943994
_ZN2cl4sycl7handler10mem_adviseEPKvmi
39953995
_ZN2cl4sycl7handler10processArgEPvRKNS0_6detail19kernel_param_kind_tEimRmb
39963996
_ZN2cl4sycl7handler10processArgEPvRKNS0_6detail19kernel_param_kind_tEimRmbb
3997+
_ZN2cl4sycl7handler12addReductionERKSt10shared_ptrIKvE
39973998
_ZN2cl4sycl7handler13getKernelNameB5cxx11Ev
39983999
_ZN2cl4sycl7handler17use_kernel_bundleERKNS0_13kernel_bundleILNS0_12bundle_stateE2EEE
39994000
_ZN2cl4sycl7handler18RangeRoundingTraceEv
@@ -4391,6 +4392,7 @@ _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE65552EEENS3_12param_traitsIS4_X
43914392
_ZNK2cl4sycl7context8get_infoILNS0_4info7contextE65553EEENS3_12param_traitsIS4_XT_EE11return_typeEv
43924393
_ZNK2cl4sycl7context9getNativeEv
43934394
_ZNK2cl4sycl7handler14getHandlerImplEv
4395+
_ZNK2cl4sycl7handler16evictHandlerImplEv
43944396
_ZNK2cl4sycl7handler27isStateExplicitKernelBundleEv
43954397
_ZNK2cl4sycl7handler30getOrInsertHandlerKernelBundleEb
43964398
_ZNK2cl4sycl7program10get_kernelENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE

sycl/test/abi/sycl_symbols_windows.dump

+1
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,7 @@
17571757
?erfc@__host_std@cl@@YA?AVhalf@half_impl@detail@sycl@2@V34562@@Z
17581758
?erfc@__host_std@cl@@YAMM@Z
17591759
?erfc@__host_std@cl@@YANN@Z
1760+
?evictHandlerImpl@handler@sycl@cl@@AEBA?AV?$shared_ptr@Vhandler_impl@detail@sycl@cl@@@std@@XZ
17601761
?exp10@__host_std@cl@@YA?AV?$vec@M$00@sycl@2@V342@@Z
17611762
?exp10@__host_std@cl@@YA?AV?$vec@M$01@sycl@2@V342@@Z
17621763
?exp10@__host_std@cl@@YA?AV?$vec@M$02@sycl@2@V342@@Z

0 commit comments

Comments
 (0)