diff --git a/include/simsycl/sycl/accessor.hh b/include/simsycl/sycl/accessor.hh index 3d623a3..b79d59c 100644 --- a/include/simsycl/sycl/accessor.hh +++ b/include/simsycl/sycl/accessor.hh @@ -148,7 +148,7 @@ class host_access_guard { template explicit host_access_guard( const sycl::buffer &buf, const accessed_range &range) - : m_validator(&detail::get_buffer_access_validator(buf, m_lock)), m_range(range) // + : m_validator(&detail::get_buffer_state(buf).validator.with(m_lock)), m_range(range) // { m_validator->begin_host_access(m_range); } @@ -169,9 +169,9 @@ class host_access_guard { template class command_group_access_guard { public: - template - explicit command_group_access_guard(const sycl::buffer &buf) - : m_validator(&detail::get_buffer_access_validator(buf, m_lock)) {} + template + explicit command_group_access_guard(const buffer_state &state) + : m_validator(&state.validator.with(m_lock)) {} void check_access_from_command_group(const accessed_range &range) { m_validator->check_access_from_command_group(range); @@ -335,7 +335,7 @@ class accessor : public simsycl::detail::property_interface { { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer[detail::get_linear_index(m_buffer_range, index)]; + return m_buffer->data[detail::get_linear_index(m_buffer->range, index)]; } SIMSYCL_DETAIL_DEPRECATED_IN_SYCL atomic operator[]( @@ -353,7 +353,7 @@ class accessor : public simsycl::detail::property_interface { { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } std::add_pointer_t get_pointer() const noexcept @@ -361,7 +361,7 @@ class accessor : public simsycl::detail::property_interface { { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } template @@ -370,7 +370,7 @@ class accessor : public simsycl::detail::property_interface { { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return accessor_ptr(m_buffer); + return accessor_ptr(m_buffer->data); } iterator begin() const noexcept { return iterator(this, iterator::begin); } @@ -398,9 +398,8 @@ class accessor : public simsycl::detail::property_interface { struct internal_t { } constexpr inline static internal{}; - DataT *m_buffer = nullptr; + const detail::buffer_state, Dimensions> *m_buffer = nullptr; std::shared_ptr> m_guard; // shared_ptr: accessors must be copyable - range m_buffer_range; id m_access_offset; range m_access_range; // shared: require() on a copy is equivalent to require() on the original instance @@ -408,11 +407,11 @@ class accessor : public simsycl::detail::property_interface { template void init(buffer &buffer_ref) { - m_buffer = detail::get_buffer_data(buffer_ref); - m_guard = std::make_shared>(buffer_ref); - m_buffer_range = buffer_ref.get_range(); - m_access_range = m_buffer_range; + m_buffer = &detail::get_buffer_state(buffer_ref); + m_guard = std::make_shared>(*m_buffer); + m_access_range = m_buffer->range; } + void init(const id &access_offset) { m_access_offset = access_offset; } void init(const range &access_range) { m_access_range = access_range; } @@ -437,8 +436,6 @@ class accessor : public simsycl::detail::property_interface { m_guard->check_access_from_command_group({m_access_offset, m_access_range, AccessMode}); *m_required = true; } - - const range &get_buffer_range() const { return m_buffer_range; } }; template @@ -525,8 +522,8 @@ class accessor : public simsy template accessor(buffer &buffer_ref, const property_list &prop_list = {}) : simsycl::detail::property_interface(prop_list, property_compatibility()), - m_buffer(detail::get_buffer_data(buffer_ref)), - m_guard(std::make_shared>(buffer_ref)) {} + m_buffer(&detail::get_buffer_state(buffer_ref)), + m_guard(std::make_shared>(*m_buffer)) {} template accessor(buffer &buffer_ref, handler &command_group_handler_ref, @@ -559,7 +556,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return *m_buffer; + return *m_buffer->data; } const accessor &operator=(const value_type &other) const @@ -567,7 +564,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - *m_buffer = other; + *m_buffer->data = other; return *this; } @@ -576,7 +573,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - *m_buffer = std::move(other); + *m_buffer->data = std::move(other); return *this; } @@ -588,7 +585,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } std::add_pointer_t get_pointer() const noexcept @@ -596,7 +593,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } template @@ -605,7 +602,7 @@ class accessor : public simsy { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return accessor_ptr(m_buffer); + return accessor_ptr(m_buffer->data); } @@ -631,7 +628,7 @@ class accessor : public simsy template friend struct std::hash; - DataT *m_buffer = nullptr; + const detail::buffer_state, 1> *m_buffer = nullptr; std::shared_ptr> m_guard; // shared_ptr: accessors must be copyable // shared: require() on a copy is equivalent to require() on the original instance std::shared_ptr m_required = std::make_shared(false); @@ -736,8 +733,6 @@ class local_accessor final : public simsycl::detail::property_interface { void **m_allocation_ptr = nullptr; sycl::range m_range; - const range &get_buffer_range() const { return get_range(); } - id get_offset() const { return {}; } inline DataT *get_allocation() const { @@ -916,7 +911,7 @@ class host_accessor : public simsycl::detail::property_interface { requires(AccessMode != access_mode::atomic) { SIMSYCL_CHECK(m_buffer != nullptr); - return m_buffer[detail::get_linear_index(m_buffer_range, index)]; + return m_buffer->data[detail::get_linear_index(m_buffer->range, index)]; } decltype(auto) operator[](size_t index) const @@ -925,7 +920,10 @@ class host_accessor : public simsycl::detail::property_interface { return detail::subscript(*this, index); } - std::add_pointer_t get_pointer() const noexcept { return m_buffer; } + std::add_pointer_t get_pointer() const noexcept { + SIMSYCL_CHECK(m_buffer != nullptr); + return m_buffer->data; + } iterator begin() const noexcept { return iterator(this, iterator::begin); } @@ -952,8 +950,7 @@ class host_accessor : public simsycl::detail::property_interface { struct internal_t { } constexpr inline static internal{}; - DataT *m_buffer = nullptr; - range m_buffer_range; + const detail::buffer_state, Dimensions> *m_buffer = nullptr; id m_access_offset; range m_access_range; // guard is a shared_ptr because accessors need to be copyable @@ -961,9 +958,8 @@ class host_accessor : public simsycl::detail::property_interface { template void init(buffer &buffer_ref) { - m_buffer = detail::get_buffer_data(buffer_ref); - m_buffer_range = buffer_ref.get_range(); - m_access_range = m_buffer_range; + m_buffer = &detail::get_buffer_state(buffer_ref); + m_access_range = m_buffer->range; m_access_guard = std::make_shared>( buffer_ref, detail::accessed_range(m_access_offset, m_access_range, AccessMode)); } @@ -983,8 +979,6 @@ class host_accessor : public simsycl::detail::property_interface { explicit host_accessor(internal_t /* tag */, Params &&...args) { (init(args), ...); } - - const range &get_buffer_range() const { return m_buffer_range; } }; template @@ -1025,8 +1019,9 @@ class host_accessor : public simsycl::detail::property_int template host_accessor(buffer &buffer_ref, const property_list &prop_list = {}) : detail::property_interface(prop_list, property_compatibility()), - m_buffer(detail::get_buffer_data(buffer_ref)), m_access_guard(std::make_shared>( - buffer_ref, detail::accessed_range<1>(0, 1, AccessMode))) { + m_buffer(&detail::get_buffer_state(buffer_ref)), + m_access_guard( + std::make_shared>(buffer_ref, detail::accessed_range<1>(0, 1, AccessMode))) { } friend bool operator==(const host_accessor &lhs, const host_accessor &rhs) = default; @@ -1049,14 +1044,14 @@ class host_accessor : public simsycl::detail::property_int requires(AccessMode != access_mode::atomic) { SIMSYCL_CHECK(m_buffer != nullptr); - return *m_buffer; + return *m_buffer->data; } const host_accessor &operator=(const value_type &other) const requires(AccessMode != access_mode::atomic && AccessMode != access_mode::read) { SIMSYCL_CHECK(m_buffer != nullptr); - *m_buffer = other; + *m_buffer->data = other; return *this; } @@ -1064,11 +1059,14 @@ class host_accessor : public simsycl::detail::property_int requires(AccessMode != access_mode::atomic && AccessMode != access_mode::read) { SIMSYCL_CHECK(m_buffer != nullptr); - *m_buffer = std::move(other); + *m_buffer->data = std::move(other); return *this; } - std::add_pointer_t get_pointer() const noexcept { return m_buffer; } + std::add_pointer_t get_pointer() const noexcept { + SIMSYCL_CHECK(m_buffer != nullptr); + return m_buffer->data; + } iterator begin() const noexcept { return iterator(this, iterator::begin); } @@ -1092,7 +1090,7 @@ class host_accessor : public simsycl::detail::property_int template friend struct std::hash; - DataT *m_buffer = nullptr; + const detail::buffer_state, 1> *m_buffer = nullptr; std::shared_ptr> m_access_guard; }; @@ -1165,7 +1163,7 @@ class accessordata[detail::get_linear_index(m_buffer->range, index)]; } decltype(auto) operator[](size_t index) const @@ -1177,7 +1175,7 @@ class accessor get_pointer() const noexcept { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } private: @@ -1189,9 +1187,8 @@ class accessor, Dimensions> *m_buffer = nullptr; std::shared_ptr> m_guard; // shared_ptr: accessors must be copyable - range m_buffer_range; id m_access_offset; range m_access_range; // shared: require() on a copy is equivalent to require() on the original instance @@ -1199,11 +1196,11 @@ class accessor void init(buffer &buffer_ref) { - m_buffer = detail::get_buffer_data(buffer_ref); - m_guard = std::make_shared>(buffer_ref); - m_buffer_range = buffer_ref.get_range(); - m_access_range = m_buffer_range; + m_buffer = &detail::get_buffer_state(buffer_ref); + m_guard = std::make_shared>(*m_buffer); + m_access_range = m_buffer->range; } + void init(const id &access_offset) { m_access_offset = access_offset; } void init(const range &access_range) { m_access_range = access_range; } @@ -1226,8 +1223,6 @@ class accessorcheck_access_from_command_group({m_access_offset, m_access_range, AccessMode}); *m_required = true; } - - const range &get_buffer_range() const { return m_buffer_range; } }; template @@ -1247,8 +1242,8 @@ class accessor fin template accessor(buffer &buffer_ref, const property_list &prop_list = {}) : simsycl::detail::property_interface(prop_list, property_compatibility()), - m_buffer(detail::get_buffer_data(buffer_ref)), - m_guard(std::make_shared>(buffer_ref)) {} + m_buffer(&detail::get_buffer_state(buffer_ref)), + m_guard(std::make_shared>(*m_buffer)) {} template accessor(buffer &buffer_ref, handler &command_group_handler_ref, @@ -1269,13 +1264,13 @@ class accessor fin operator reference() const { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return *m_buffer; + return *m_buffer->data; } global_ptr get_pointer() const noexcept { SIMSYCL_CHECK(m_buffer != nullptr); SIMSYCL_CHECK(*m_required); - return m_buffer; + return m_buffer->data; } private: @@ -1284,7 +1279,7 @@ class accessor fin template friend struct std::hash; - DataT *m_buffer = nullptr; + const detail::buffer_state, 1> *m_buffer = nullptr; std::shared_ptr> m_guard; // shared: require() on a copy is equivalent to require() on the original instance std::shared_ptr m_required = std::make_shared(false); @@ -1325,8 +1320,7 @@ class accessor &buffer_ref, range access_range, id access_offset, const property_list &prop_list = {}) : detail::property_interface(prop_list, property_compatibility()), - m_buffer(detail::get_buffer_data(buffer_ref)), m_buffer_range(buffer_ref.get_range()), - m_access_offset(access_offset), m_access_range(access_range), + m_buffer(&detail::get_buffer_state(buffer_ref)), m_access_offset(access_offset), m_access_range(access_range), m_access_guard(std::make_shared>( buffer_ref, detail::accessed_range(m_access_offset, m_access_range, AccessMode))) {} @@ -1344,7 +1338,7 @@ class accessor index) const { SIMSYCL_CHECK(m_buffer != nullptr); - return m_buffer[detail::get_linear_index(m_buffer_range, index)]; + return m_buffer->data[detail::get_linear_index(m_buffer->range, index)]; } decltype(auto) operator[](size_t index) const @@ -1353,7 +1347,10 @@ class accessor(*this, index); } - std::add_pointer_t get_pointer() const noexcept { return m_buffer; } + std::add_pointer_t get_pointer() const noexcept { + SIMSYCL_CHECK(m_buffer != nullptr); + return m_buffer->data; + } friend bool operator==(const accessor &lhs, const accessor &rhs) = default; @@ -1361,8 +1358,7 @@ class accessor friend struct std::hash; - DataT *m_buffer = nullptr; - range m_buffer_range; + const detail::buffer_state, Dimensions> *m_buffer = nullptr; id m_access_offset; range m_access_range; // guard is a shared_ptr because accessors need to be copyable @@ -1386,8 +1382,9 @@ class accessor : publi template accessor(buffer &buffer_ref, const property_list &prop_list = {}) : detail::property_interface(prop_list, property_compatibility()), - m_buffer(detail::get_buffer_data(buffer_ref)), m_access_guard(std::make_shared>( - buffer_ref, detail::accessed_range<1>(0, 1, AccessMode))) { + m_buffer(&detail::get_buffer_state(buffer_ref)), + m_access_guard( + std::make_shared>(buffer_ref, detail::accessed_range<1>(0, 1, AccessMode))) { } // non-copyable and immovable, because it holds a system_lock @@ -1400,10 +1397,13 @@ class accessor : publi operator reference() const { SIMSYCL_CHECK(m_buffer != nullptr); - return *m_buffer; + return *m_buffer->data; } - std::add_pointer_t get_pointer() const noexcept { return m_buffer; } + std::add_pointer_t get_pointer() const noexcept { + SIMSYCL_CHECK(m_buffer != nullptr); + return m_buffer->data; + } friend bool operator==(const accessor &lhs, const accessor &rhs) = default; @@ -1412,7 +1412,7 @@ class accessor : publi friend struct std::hash; detail::system_lock m_lock; // active host accessors must block command-group submission on other threads - DataT *m_buffer = nullptr; + const detail::buffer_state, 1> *m_buffer = nullptr; std::shared_ptr> m_access_guard; }; @@ -1473,8 +1473,6 @@ class accessor fina void **m_allocation_ptr; sycl::range m_range; - const range &get_buffer_range() const { return get_range(); } - inline DataT *get_allocation() const { return static_cast(*m_allocation_ptr); } }; diff --git a/include/simsycl/sycl/buffer.hh b/include/simsycl/sycl/buffer.hh index c394d94..0b7ae19 100644 --- a/include/simsycl/sycl/buffer.hh +++ b/include/simsycl/sycl/buffer.hh @@ -9,6 +9,7 @@ #include "../detail/reference_type.hh" #include +#include #include #include #include @@ -128,34 +129,46 @@ struct buffer_access_validator { } }; -template +template struct buffer_state { using write_back_fn = std::function; + using deallocate_fn = std::function; + struct raw_tag {}; sycl::range range; - AllocatorT allocator; - T *buffer; - write_back_fn write_back; - bool write_back_enabled; - std::shared_ptr shared_host_ptr; // keep the std::shared_ptr host pointer alive + T *data = nullptr; + // buffer_state must not be dependent on AllocatorT because it's used in accessor<>, so we type-erase allocation + deallocate_fn deallocate; + mutable shared_value write_back_on_destruction = false; + mutable shared_value write_back; + std::shared_ptr host_ptr_lifetime_extender; mutable shared_value> validator; - buffer_state(sycl::range range, const AllocatorT &allocator = {}, const T *init_from = nullptr, - write_back_fn write_back = {}, const std::shared_ptr &shared_host_ptr = nullptr) - : range(range), allocator(allocator), buffer(this->allocator.allocate(range.size())), write_back(write_back), - write_back_enabled(static_cast(write_back)), shared_host_ptr(shared_host_ptr) { + template + buffer_state(raw_tag /* tag */, sycl::range range, AllocatorT allocator, write_back_fn write_back = {}, + std::shared_ptr lifetime_extend_host_ptr = nullptr) + : range(range), data(allocator.allocate(range.size())), + // AllocatorT::deallocate isn't const-qualified, and std::function doesn't take mutable lambdas, so we copy + // the allocator (which is always trivial) + deallocate([allocator](T *ptr, size_t n) { AllocatorT(allocator).deallocate(ptr, n); }), + write_back_on_destruction(static_cast(write_back)), write_back(std::move(write_back)), + host_ptr_lifetime_extender(std::move(lifetime_extend_host_ptr)) {} + + template + buffer_state(sycl::range range, AllocatorT allocator = {}, const T *init_from = nullptr, + write_back_fn write_back = {}, std::shared_ptr lifetime_extend_host_ptr = nullptr) + : buffer_state(raw_tag{}, range, allocator, std::move(write_back), std::move(lifetime_extend_host_ptr)) { if(init_from) { - memcpy(buffer, init_from, range.size() * sizeof(T)); + memcpy(data, init_from, range.size() * sizeof(T)); } else { - memset(buffer, static_cast(detail::uninitialized_memory_pattern), range.size() * sizeof(T)); + memset(data, static_cast(detail::uninitialized_memory_pattern), range.size() * sizeof(T)); } } - template + template buffer_state(InputIterator first, InputIterator last, const AllocatorT &allocator) - : range(static_cast(std::distance(first, last))), allocator(allocator), - buffer(this->allocator.allocate(range.size())), write_back_enabled(false) { - std::copy(first, last, buffer); + : buffer_state(raw_tag{}, static_cast(std::distance(first, last)), allocator) { + std::copy(first, last, data); } buffer_state(const buffer_state &) = delete; @@ -164,11 +177,9 @@ struct buffer_state { buffer_state &operator=(buffer_state &&) = delete; ~buffer_state() { - if(write_back_enabled) { - system_lock lock; // writeback must not overlap with command groups in other threads - write_back(buffer, range.size()); - } - allocator.deallocate(buffer, range.size()); + system_lock lock; // writeback must not overlap with command groups in other threads + if(write_back_on_destruction.with(lock)) { write_back.with(lock)(data, range.size()); } + deallocate(data, range.size()); } }; @@ -178,10 +189,10 @@ namespace simsycl::sycl { template class buffer final : public detail::reference_type, - detail::buffer_state, Dimensions, AllocatorT>>, + detail::buffer_state, Dimensions>>, public detail::property_interface { private: - using state_type = detail::buffer_state, Dimensions, AllocatorT>; + using state_type = detail::buffer_state, Dimensions>; using reference_type = detail::reference_type, state_type>; using write_back_fn = typename state_type::write_back_fn; using property_compatibility = detail::property_compatibility_with, @@ -198,7 +209,7 @@ class buffer final : public detail::reference_type &buffer_range, AllocatorT allocator, const property_list &prop_list = {}) : reference_type(std::in_place, buffer_range, allocator), - property_interface(prop_list, property_compatibility()) {} + property_interface(prop_list, property_compatibility()), m_allocator(allocator) {} buffer(T *host_data, const range &buffer_range, const property_list &prop_list = {}) requires(!std::is_const_v) @@ -208,7 +219,8 @@ class buffer final : public detail::reference_type &buffer_range, AllocatorT allocator, const property_list &prop_list = {}) requires(!std::is_const_v) : property_interface(prop_list, property_compatibility()), - reference_type(std::in_place, buffer_range, allocator, host_data, write_back_to(host_data)) {} + reference_type(std::in_place, buffer_range, allocator, host_data, write_back_to(host_data)), + m_allocator(allocator) {} buffer(const T *host_data, const range &buffer_range, const property_list &prop_list = {}) : buffer(host_data, buffer_range, AllocatorT(), prop_list) {} @@ -216,7 +228,7 @@ class buffer final : public detail::reference_type &buffer_range, AllocatorT allocator, const property_list &prop_list = {}) : property_interface(prop_list, property_compatibility()), - reference_type(std::in_place, buffer_range, allocator, host_data) {} + reference_type(std::in_place, buffer_range, allocator, host_data), m_allocator(allocator) {} template Container> requires(Dimensions == 1) @@ -232,7 +244,8 @@ class buffer final : public detail::reference_type) : property_interface(prop_list, property_compatibility()), reference_type(std::in_place, buffer_range, allocator, host_data.get(), - write_back_to_if_non_const(host_data.get()), host_data) {} + write_back_to_if_non_const(host_data.get()), host_data), + m_allocator(allocator) {} buffer( const std::shared_ptr &host_data, const range &buffer_range, const property_list &prop_list = {}) @@ -242,7 +255,8 @@ class buffer final : public detail::reference_type &host_data, const range &buffer_range, const property_list &prop_list = {}) @@ -252,7 +266,7 @@ class buffer final : public detail::reference_type requires(Dimensions == 1) @@ -273,7 +287,7 @@ class buffer final : public detail::reference_type accessor get_access(handler &command_group_handler) { @@ -313,16 +327,20 @@ class buffer final : public detail::reference_type void set_final_data(Destination final_data = nullptr) { + detail::system_lock lock; if constexpr(std::is_same_v) { - state().write_back = {}; - state().write_back_enabled = false; + state().write_back.with(lock) = {}; + state().write_back_on_destruction.with(lock) = false; } else { - state().write_back = write_back_to(final_data); - state().write_back_enabled = true; + state().write_back.with(lock) = write_back_to(final_data); + state().write_back_on_destruction.with(lock) = true; } } - void set_write_back(bool flag = true) { state().write_back_enabled = state().write_back && flag; } + void set_write_back(bool flag = true) { + detail::system_lock lock; + state().write_back_on_destruction.with(lock) = state().write_back.with(lock) && flag; + } bool is_sub_buffer() const { // sub-buffers are unimplemented @@ -345,14 +363,13 @@ class buffer final : public detail::reference_type - friend U *simsycl::detail::get_buffer_data(sycl::buffer &buf); - - template - friend detail::buffer_access_validator &detail::get_buffer_access_validator( - const sycl::buffer &buf, detail::system_lock &lock); + friend const detail::buffer_state, D> &detail::get_buffer_state( + const sycl::buffer &buf); using reference_type::state; + AllocatorT m_allocator; // required purely for get_allocator() - buffer_state type-erases a copy of this + static write_back_fn write_back_to(T out) { return [out](const T *buffer, size_t size) { return memcpy(out, buffer, size * sizeof(T)); }; } @@ -363,7 +380,7 @@ class buffer final : public detail::reference_type + template OutputIterator> static write_back_fn write_back_to(OutputIterator out) { return [out](const T *buffer, size_t size) { return std::copy_n(buffer, size, out); }; } @@ -381,16 +398,16 @@ class buffer final : public detail::reference_type -buffer(InputIterator, InputIterator, AllocatorT, const property_list & = {}) - -> buffer::value_type, 1, AllocatorT>; +buffer(InputIterator, InputIterator, AllocatorT, + const property_list & = {}) -> buffer::value_type, 1, AllocatorT>; template -buffer(InputIterator, InputIterator, const property_list & = {}) - -> buffer::value_type, 1>; +buffer(InputIterator, InputIterator, + const property_list & = {}) -> buffer::value_type, 1>; template -buffer(const T *, const range &, AllocatorT, const property_list & = {}) - -> buffer; +buffer( + const T *, const range &, AllocatorT, const property_list & = {}) -> buffer; template buffer(const T *, const range &, const property_list & = {}) -> buffer; @@ -408,19 +425,14 @@ buffer(Container &, const property_list & = {}) -> buffer struct std::hash> : std::hash, - simsycl::detail::buffer_state, Dimensions, AllocatorT>>> {}; + simsycl::detail::buffer_state, Dimensions>>> {}; namespace simsycl::detail { template -T *get_buffer_data(sycl::buffer &buf) { - return buf.state().buffer; -} - -template -buffer_access_validator &get_buffer_access_validator( - const sycl::buffer &buf, system_lock &lock) { - return buf.state().validator.with(lock); +const buffer_state, Dimensions> &get_buffer_state( + const sycl::buffer &buf) { + return buf.state(); } } // namespace simsycl::detail diff --git a/include/simsycl/sycl/forward.hh b/include/simsycl/sycl/forward.hh index e4bc77d..44fd9f7 100644 --- a/include/simsycl/sycl/forward.hh +++ b/include/simsycl/sycl/forward.hh @@ -163,15 +163,15 @@ concurrent_sub_group &get_concurrent_group(const sycl::sub_group &g); template concurrent_group &get_concurrent_group(const sycl::group &g); +template +struct buffer_state; + template struct buffer_access_validator; template -T *get_buffer_data(sycl::buffer &buf); - -template -buffer_access_validator &get_buffer_access_validator( - const sycl::buffer &buf, system_lock &lock); +const buffer_state, Dimensions> &get_buffer_state( + const sycl::buffer &buf); sycl::handler make_handler(const sycl::device &device); diff --git a/include/simsycl/sycl/handler.hh b/include/simsycl/sycl/handler.hh index 499cf94..7b14e24 100644 --- a/include/simsycl/sycl/handler.hh +++ b/include/simsycl/sycl/handler.hh @@ -171,8 +171,8 @@ class handler { typename DestT> void copy(accessor src, DestT *dest) { static_assert(sizeof(SrcT) == sizeof(DestT)); - detail::memcpy_strided_host(src.get_pointer(), dest, sizeof(SrcT), src.get_buffer_range(), src.get_offset(), - src.get_range(), sycl::id(), src.get_range()); + detail::memcpy_strided_host(src.get_pointer(), dest, sizeof(SrcT), get_buffer_state(src).range, + src.get_offset(), src.get_range(), sycl::id(), src.get_range()); } template dest) { static_assert(sizeof(SrcT) == sizeof(DestT)); detail::memcpy_strided_host(src, dest.get_pointer(), sizeof(SrcT), dest.get_range(), sycl::id(), - dest.get_buffer_range(), dest.get_offset(), dest.get_range()); + get_buffer_state(dest).range, dest.get_offset(), dest.get_range()); } template void update_host(accessor acc) { - acc.update_host(); + const auto &buffer = get_buffer_state(acc); + detail::system_lock lock; // writeback must not overlap with command groups in other threads + const auto &write_back = buffer.write_back.with(lock); + SIMSYCL_CHECK_MSG(static_cast(write_back), + "Cannot update_host on an buffer that was not constructed with a host pointer"); + write_back(buffer.data, buffer.range.size()); } template @@ -248,6 +254,12 @@ class handler { } return nullptr; } + + template + const auto &get_buffer_state(const Accessor &acc) { + SIMSYCL_CHECK(acc.m_buffer); + return *acc.m_buffer; + } }; template diff --git a/include/simsycl/sycl/reduction.hh b/include/simsycl/sycl/reduction.hh index dc3b69b..f54aa74 100644 --- a/include/simsycl/sycl/reduction.hh +++ b/include/simsycl/sycl/reduction.hh @@ -162,7 +162,7 @@ auto reduction(buffer &vars, handler &cgh, BinaryOper const property_list &prop_list = {}) { (void)cgh; SIMSYCL_CHECK(vars.get_range().size() == 1); - T *value = detail::get_buffer_data(vars); + T *value = detail::get_buffer_state(vars).data; detail::begin_reduction(value, combiner, nullptr, prop_list); return detail::reducer(value, combiner); } @@ -183,7 +183,7 @@ auto reduction(buffer &vars, handler &cgh, const T &i const property_list &prop_list = {}) { (void)cgh; SIMSYCL_CHECK(vars.get_range().size() == 1); - T *value = detail::get_buffer_data(vars); + T *value = detail::get_buffer_state(vars).data; detail::begin_reduction(value, combiner, &identity, prop_list); return detail::reducer(value, combiner); } diff --git a/test/reduction_tests.cc b/test/reduction_tests.cc index 30559b8..438da12 100644 --- a/test/reduction_tests.cc +++ b/test/reduction_tests.cc @@ -58,11 +58,11 @@ TEMPLATE_TEST_CASE( }) .wait(); - CHECK(*detail::get_buffer_data(plus_buf) == 100 + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9); + CHECK(*detail::get_buffer_state(plus_buf).data == 100 + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9); CHECK(mult_var == 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10); - CHECK(*detail::get_buffer_data(bit_and_buf) == 16); + CHECK(*detail::get_buffer_state(bit_and_buf).data == 16); CHECK(bit_or_var == (128 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9)); - CHECK(*detail::get_buffer_data(bit_xor_buf) == (0 ^ 1 ^ 2 ^ 3 ^ 4 ^ 5 ^ 6 ^ 7 ^ 8 ^ 9)); + CHECK(*detail::get_buffer_state(bit_xor_buf).data == (0 ^ 1 ^ 2 ^ 3 ^ 4 ^ 5 ^ 6 ^ 7 ^ 8 ^ 9)); CHECK(min_var == -4.0f); CHECK(max_var == 9.0f); }