diff --git a/configure.ac b/configure.ac index 1fa8dd8ff..64ba63bfd 100644 --- a/configure.ac +++ b/configure.ac @@ -112,6 +112,14 @@ AC_ARG_ENABLE([Werror], AS_HELP_STRING([--enable-Werror], [Make all compiler warnings fatal]), [enable_Werror="$enableval"], [enable_Werror=no]) +AC_ARG_ENABLE([WP4-16-stacks], + AS_HELP_STRING([--enable-WP4-16-stacks], + [Implement stacks strictly as per the P4_16 specification instead of legacy behavior]), + [enable_WP4_16_stacks="$enableval"], [enable_WP4_16_stacks=no]) + +AS_IF([test "$enable_WP4_16_stacks" = "yes"], + [MY_CPPFLAGS="$MY_CPPFLAGS -DBM_WP4_16_STACKS"]) + # Checks for programs. AC_PROG_CXX AC_PROG_CC diff --git a/include/bm/bm_sim/stacks.h b/include/bm/bm_sim/stacks.h index 6ea929bbb..e99fe472f 100644 --- a/include/bm/bm_sim/stacks.h +++ b/include/bm/bm_sim/stacks.h @@ -47,13 +47,13 @@ class StackIface { virtual ~StackIface() { } //! Removes the first element of the stack. Returns the number of elements - //! removed, which is `0` if the stack is empty and `1` otherwise. The second - //! element of the stack becomes the first element, and so on... + //! removed. The second element of the stack becomes the first element, and so + //! on... virtual size_t pop_front() = 0; //! Removes the first \p num element of the stack. Returns the number of - //! elements removed, which is `0` if the stack is empty. Calling this - //! function is more efficient than calling pop_front() multiple times. + //! elements removed. Calling this function is more efficient than calling + //! pop_front() multiple times. virtual size_t pop_front(size_t num) = 0; //! Pushes an element to the front of the stack. If the stack is already full, @@ -88,12 +88,16 @@ class StackIface { virtual void reset() = 0; }; +// We use CRTP for Stack class to implement either legacy behavior or strict +// P4_16 behavior by providing different implementations of push_front and +// pop_front. + //! Stack is used to represent header and union stacks in P4. The Stack class //! itself does not store any union / header / field data itself, but stores //! references to the HeaderUnion / Header instances which constitute the stack, //! as well as the stack internal state (e.g. number of valid headers in the //! stack). -template +template class Stack : public StackIface, public NamedP4Object { public: friend class PHV; @@ -125,7 +129,7 @@ class Stack : public StackIface, public NamedP4Object { T &at(size_t idx); const T &at(size_t idx) const; - private: + protected: using TRef = std::reference_wrapper; // To be called by PHV class @@ -138,6 +142,48 @@ class Stack : public StackIface, public NamedP4Object { size_t next{0}; }; +namespace detail { + +// Implements legacy behavior for stacks. push_front and pop_front only shift +// the portion of the stack up to the next index. Pushed elements are marked as +// valid. +template +class StackLegacy : public Stack > { + public: + StackLegacy(const std::string &name, p4object_id_t id); + + size_t pop_front(); + size_t pop_front(size_t num); + + size_t push_front(); + size_t push_front(size_t num); +}; + +// Implements strict P4_16 behavior for stacks. push_front and pop_front shift +// the entire stack. Pushed elements are marked as invalid. +template +class StackP4_16 : public Stack > { + public: + StackP4_16(const std::string &name, p4object_id_t id); + + size_t pop_front(); + size_t pop_front(size_t num); + + size_t push_front(); + size_t push_front(size_t num); +}; + +#ifdef BM_WP4_16_STACKS +template +// using MyStack = StackP4_16; +using MyStack = StackP4_16; +#else +template +using MyStack = StackLegacy; +#endif // BM_WP4_16_STACKS + +} // namespace detail + using header_stack_id_t = p4object_id_t; using header_union_stack_id_t = p4object_id_t; @@ -149,7 +195,7 @@ using header_union_stack_id_t = p4object_id_t; //! // ... //! }; //! @endcode -using HeaderStack = Stack
; +using HeaderStack = detail::MyStack
; //! Convenience alias for stacks of header unions //! A HeaderUnionStack reference can be used in an action primitive. For //! example: @@ -160,7 +206,9 @@ using HeaderStack = Stack
; //! // ... //! }; //! @endcode -using HeaderUnionStack = Stack; +using HeaderUnionStack = detail::MyStack; + +#undef _P4_16_STACKS } // namespace bm diff --git a/src/bm_sim/stacks.cpp b/src/bm_sim/stacks.cpp index eb84fcba6..4f1cd14d6 100644 --- a/src/bm_sim/stacks.cpp +++ b/src/bm_sim/stacks.cpp @@ -27,155 +27,257 @@ namespace bm { +namespace detail { + +// legacy implementation of push_front and pop_front + template -Stack::Stack(const std::string &name, p4object_id_t id) - : StackIface(), NamedP4Object(name, id) { } +StackLegacy::StackLegacy(const std::string &name, p4object_id_t id) + : Stack >(name, id) { } template size_t -Stack::pop_front() { - if (next == 0) return 0u; - next--; - for (size_t i = 0; i < next; i++) { - elements[i].get().swap_values(&elements[i + 1].get()); +StackLegacy::pop_front() { + if (this->next == 0) return 0u; + this->next--; + for (size_t i = 0; i < this->next; i++) { + this->elements[i].get().swap_values(&this->elements[i + 1].get()); } - elements[next].get().mark_invalid(); + this->elements[this->next].get().mark_invalid(); return 1u; } template size_t -Stack::pop_front(size_t num) { +StackLegacy::pop_front(size_t num) { if (num == 0) return 0; - size_t popped = std::min(next, num); - next -= popped; - for (size_t i = 0; i < next; i++) { - elements[i].get().swap_values(&elements[i + num].get()); + size_t popped = std::min(this->next, num); + this->next -= popped; + for (size_t i = 0; i < this->next; i++) { + this->elements[i].get().swap_values(&this->elements[i + num].get()); } - for (size_t i = next; i < next + popped; i++) { - elements[i].get().mark_invalid(); + for (size_t i = this->next; i < this->next + popped; i++) { + this->elements[i].get().mark_invalid(); } return popped; } template size_t -Stack::push_front() { - if (next < elements.size()) next++; - for (size_t i = next - 1; i > 0; i--) { - elements[i].get().swap_values(&elements[i - 1].get()); +StackLegacy::push_front() { + if (this->next < this->elements.size()) this->next++; + for (size_t i = this->next - 1; i > 0; i--) { + this->elements[i].get().swap_values(&this->elements[i - 1].get()); } - // TODO(antonin): do I want to reset the element as well? - // this may be complicated given the design - elements[0].get().mark_valid(); + this->elements[0].get().mark_valid(); return 1u; } template size_t -Stack::push_front(size_t num) { +StackLegacy::push_front(size_t num) { if (num == 0) return 0; - next = std::min(elements.size(), next + num); - for (size_t i = next - 1; i > num - 1; i--) { - elements[i].get().swap_values(&elements[i - num].get()); + this->next = std::min(this->elements.size(), this->next + num); + for (size_t i = this->next - 1; i > num - 1; i--) { + this->elements[i].get().swap_values(&this->elements[i - num].get()); } - size_t pushed = std::min(elements.size(), num); + size_t pushed = std::min(this->elements.size(), num); for (size_t i = 0; i < pushed; i++) { - elements[i].get().mark_valid(); + this->elements[i].get().mark_valid(); } return pushed; } +// P4_16-conformant implementation of push_front and pop_front + +template +StackP4_16::StackP4_16(const std::string &name, p4object_id_t id) + : Stack >(name, id) { } + template size_t -Stack::pop_back() { +StackP4_16::pop_front() { + if (this->next > 0) this->next--; + auto size = this->elements.size(); + for (size_t i = 0; i < size - 1; i++) { + this->elements[i].get().swap_values(&this->elements[i + 1].get()); + } + this->elements[size - 1].get().mark_invalid(); + return 1u; +} + +template +size_t +StackP4_16::pop_front(size_t num) { + if (num == 0) return 0; + auto size = this->elements.size(); + this->next -= std::min(this->next, num); + size_t i = 0; + for (; i < size - num; i++) { + this->elements[i].get().swap_values(&this->elements[i + num].get()); + } + for (; i < size; i++) { + this->elements[i].get().mark_invalid(); + } + return std::min(num, size); +} + +template +size_t +StackP4_16::push_front() { + auto size = this->elements.size(); + if (this->next < size) this->next++; + for (size_t i = size - 1; i > 0; i--) { + this->elements[i].get().swap_values(&this->elements[i - 1].get()); + } + this->elements[0].get().mark_invalid(); + return 1u; +} + +template +size_t +StackP4_16::push_front(size_t num) { + if (num == 0) return 0; + auto size = this->elements.size(); + this->next = std::min(size, this->next + num); + for (size_t i = size - 1; i > num - 1; i--) { + this->elements[i].get().swap_values(&this->elements[i - num].get()); + } + size_t pushed = std::min(size, num); + for (size_t i = 0; i < pushed; i++) { + this->elements[i].get().mark_invalid(); + } + return pushed; +} + +// explicit instantiation +template class StackLegacy
; +template class StackLegacy; +template class StackP4_16
; +template class StackP4_16; + +} // namespace detail + +template +Stack::Stack(const std::string &name, p4object_id_t id) + : StackIface(), NamedP4Object(name, id) { } + +template +size_t +Stack::pop_front() { + return static_cast(this)->pop_front(); +} + +template +size_t +Stack::pop_front(size_t num) { + return static_cast(this)->pop_front(num); +} + +template +size_t +Stack::push_front() { + return static_cast(this)->push_front(); +} + +template +size_t +Stack::push_front(size_t num) { + return static_cast(this)->push_front(num); +} + +template +size_t +Stack::pop_back() { if (next == 0) return 0u; next--; elements[next].get().mark_invalid(); return 1u; } -template +template size_t -Stack::push_back() { +Stack::push_back() { if (next == elements.size()) return 0u; elements[next].get().mark_valid(); next++; return 1u; } -template +template size_t -Stack::get_depth() const { +Stack::get_depth() const { return elements.size(); } -template +template size_t -Stack::get_count() const { +Stack::get_count() const { return next; } -template +template bool -Stack::is_full() const { +Stack::is_full() const { return (next >= elements.size()); } -template +template void -Stack::reset() { +Stack::reset() { next = 0; } -template +template T & -Stack::get_last() { +Stack::get_last() { assert(next > 0 && "stack empty"); return elements[next - 1]; } -template +template const T & -Stack::get_last() const { +Stack::get_last() const { assert(next > 0 && "stack empty"); return elements[next - 1]; } -template +template T & -Stack::get_next() { +Stack::get_next() { assert(next < elements.size() && "stack full"); return elements[next]; } -template +template const T & -Stack::get_next() const { +Stack::get_next() const { assert(next < elements.size() && "stack full"); return elements[next]; } -template +template T & -Stack::at(size_t idx) { +Stack::at(size_t idx) { return elements.at(idx); } -template +template const T & -Stack::at(size_t idx) const { +Stack::at(size_t idx) const { return elements.at(idx); } -template +template void -Stack::set_next_element(T &e) { // NOLINT(runtime/references) +Stack::set_next_element(T &e) { // NOLINT(runtime/references) elements.emplace_back(e); } // explicit instantiation -template class Stack
; -template class Stack; +template class Stack >; +template class Stack >; +template class Stack >; +template class Stack >; } // namespace bm diff --git a/tests/test_header_stacks.cpp b/tests/test_header_stacks.cpp index 13d74e1cf..d01fb55f1 100644 --- a/tests/test_header_stacks.cpp +++ b/tests/test_header_stacks.cpp @@ -23,11 +23,66 @@ #include #include +#include #include -using namespace bm; +namespace bm { + +namespace testing { + +// Because we want to test both implementations, we cannot use +// PHVFactory::push_back_header_stack. We therefore subclass +// bm::detail::StackLegacy to access the set_next_element method and +// build the header stacks ourselves. + +class HeaderStackLegacy : public detail::StackLegacy
{ + public: + HeaderStackLegacy(const std::string &name, p4object_id_t id) + : detail::StackLegacy(name, id) { } + + void set_next_header(Header &hdr) { // NOLINT(runtime/references) + this->set_next_element(hdr); + } +}; + +// if PushValid is true, emulate legacy behavior by making pushed headers valid +template +class HeaderStackP4_16 : public detail::StackP4_16
{ + public: + HeaderStackP4_16(const std::string &name, p4object_id_t id) + : detail::StackP4_16(name, id) { } + + void set_next_header(Header &hdr) { // NOLINT(runtime/references) + this->set_next_element(hdr); + } + + template + typename std::enable_if::type push_front() { + auto s = detail::StackP4_16
::push_front(); + this->at(0).mark_valid(); + return s; + } + + template + typename std::enable_if::type push_front(size_t num) { + auto s = detail::StackP4_16
::push_front(num); + for (size_t i = 0; i < s; i++) this->at(i).mark_valid(); + return s; + } + + template + typename std::enable_if::type push_front() { + return detail::StackP4_16
::push_front(); + } + + template + typename std::enable_if::type push_front(size_t num) { + return detail::StackP4_16
::push_front(num); + } +}; // Google Test fixture for header stack tests +template class HeaderStackTest : public ::testing::Test { protected: PHVFactory phv_factory; @@ -37,44 +92,57 @@ class HeaderStackTest : public ::testing::Test { header_id_t testHeader_0{0}, testHeader_1{1}, testHeader_2{2}; header_stack_id_t testHeaderStack{0}; + HSType stack; size_t stack_depth{3}; HeaderStackTest() - : testHeaderType("test_t", 0) { + : testHeaderType("test_t", 0), + stack("test_stack", testHeaderStack) { testHeaderType.push_back_field("f16", 16); testHeaderType.push_back_field("f48", 48); phv_factory.push_back_header("test_0", testHeader_0, testHeaderType); phv_factory.push_back_header("test_1", testHeader_1, testHeaderType); phv_factory.push_back_header("test_2", testHeader_2, testHeaderType); - const std::vector headers = - {testHeader_0, testHeader_1, testHeader_2}; - phv_factory.push_back_header_stack("test_stack", testHeaderStack, - testHeaderType, headers); + // Cannot use this if we want to test both stack implementations + // independently of whether preprocessor flag BM_WP4_16_STACKS is used. + + // const std::vector headers = + // {testHeader_0, testHeader_1, testHeader_2}; + // phv_factory.push_back_header_stack("test_stack", testHeaderStack, + // testHeaderType, headers); } virtual void SetUp() { phv = phv_factory.create(); + for (auto header_id : {testHeader_0, testHeader_1, testHeader_2}) + stack.set_next_header(phv->get_header(header_id)); } // virtual void TearDown() {} }; -TEST_F(HeaderStackTest, Basic) { - const HeaderStack &stack = phv->get_header_stack(testHeaderStack); +using HeaderStackTypes = + ::testing::Types >; + +TYPED_TEST_CASE(HeaderStackTest, HeaderStackTypes); + +TYPED_TEST(HeaderStackTest, Basic) { + auto &stack = this->stack; - ASSERT_EQ(stack_depth, stack.get_depth()); + ASSERT_EQ(this->stack_depth, stack.get_depth()); ASSERT_EQ(0u, stack.get_count()); } -TEST_F(HeaderStackTest, PushBack) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PushBack) { + auto &stack = this->stack; + auto *phv = this->phv.get(); - const Header &h0 = phv->get_header(testHeader_0); - const Header &h1 = phv->get_header(testHeader_1); - const Header &h2 = phv->get_header(testHeader_2); + const auto &h0 = phv->get_header(this->testHeader_0); + const auto &h1 = phv->get_header(this->testHeader_1); + const auto &h2 = phv->get_header(this->testHeader_2); ASSERT_FALSE(h0.is_valid()); ASSERT_FALSE(h1.is_valid()); @@ -82,23 +150,24 @@ TEST_F(HeaderStackTest, PushBack) { ASSERT_EQ(0u, stack.get_count()); - for (size_t i = 0; i < stack_depth; i++) { + for (size_t i = 0; i < this->stack_depth; i++) { ASSERT_EQ(1u, stack.push_back()); ASSERT_EQ(i + 1, stack.get_count()); } ASSERT_EQ(0u, stack.push_back()); - ASSERT_EQ(stack_depth, stack.get_count()); + ASSERT_EQ(this->stack_depth, stack.get_count()); ASSERT_TRUE(h0.is_valid()); ASSERT_TRUE(h1.is_valid()); ASSERT_TRUE(h2.is_valid()); } -TEST_F(HeaderStackTest, PopBack) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PopBack) { + auto &stack = this->stack; + auto *phv = this->phv.get(); - const Header &h0 = phv->get_header(testHeader_0); + const auto &h0 = phv->get_header(this->testHeader_0); ASSERT_FALSE(h0.is_valid()); ASSERT_EQ(1u, stack.push_back()); @@ -110,18 +179,19 @@ TEST_F(HeaderStackTest, PopBack) { ASSERT_EQ(0u, stack.pop_back()); // empty so nothing to pop } -TEST_F(HeaderStackTest, PushFront) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PushFront) { + auto &stack = this->stack; + auto *phv = this->phv.get(); - Header &h0 = phv->get_header(testHeader_0); - Header &h1 = phv->get_header(testHeader_1); - Header &h2 = phv->get_header(testHeader_2); + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); ASSERT_FALSE(h0.is_valid()); - Field &f0_0 = h0.get_field(0); - Field &f1_0 = h1.get_field(0); - Field &f2_0 = h2.get_field(0); + auto &f0_0 = h0.get_field(0); + auto &f1_0 = h1.get_field(0); + auto &f2_0 = h2.get_field(0); unsigned int v0 = 10u; unsigned int v1 = 11u; unsigned int v2 = 12u; @@ -158,15 +228,16 @@ TEST_F(HeaderStackTest, PushFront) { ASSERT_EQ(v2, f1_0.get_uint()); } -TEST_F(HeaderStackTest, PushFrontNum) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PushFrontNum) { + auto &stack = this->stack; + auto *phv = this->phv.get(); - Header &h0 = phv->get_header(testHeader_0); - Header &h1 = phv->get_header(testHeader_1); - Header &h2 = phv->get_header(testHeader_2); + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); - Field &f0_0 = h0.get_field(0); - Field &f2_0 = h2.get_field(0); + auto &f0_0 = h0.get_field(0); + auto &f2_0 = h2.get_field(0); unsigned int v0 = 10u; @@ -185,21 +256,22 @@ TEST_F(HeaderStackTest, PushFrontNum) { ASSERT_EQ(v0, f2_0.get_uint()); } -TEST_F(HeaderStackTest, PopFront) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PopFront) { + auto &stack = this->stack; + auto *phv = this->phv.get(); ASSERT_EQ(2u, stack.push_front(2)); // add 2 headers - Header &h0 = phv->get_header(testHeader_0); - Header &h1 = phv->get_header(testHeader_1); - Header &h2 = phv->get_header(testHeader_2); + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); ASSERT_TRUE(h0.is_valid()); ASSERT_TRUE(h1.is_valid()); ASSERT_FALSE(h2.is_valid()); - Field &f0_0 = h0.get_field(0); - Field &f1_0 = h1.get_field(0); + auto &f0_0 = h0.get_field(0); + auto &f1_0 = h1.get_field(0); const unsigned int v0 = 10u; const unsigned int v1 = 11u; const std::string v1_hex("0x000b"); @@ -220,26 +292,30 @@ TEST_F(HeaderStackTest, PopFront) { ASSERT_FALSE(h1.is_valid()); ASSERT_FALSE(h0.is_valid()); - ASSERT_EQ(0u, stack.pop_front()); // empty so nothing popped + // not true for P4_16 stacks for which we always shift independently of the + // value of next + // ASSERT_EQ(0u, stack.pop_front()); // empty so nothing popped + stack.pop_front(); ASSERT_EQ(0u, stack.get_count()); } -TEST_F(HeaderStackTest, PopFrontNum) { - HeaderStack &stack = phv->get_header_stack(testHeaderStack); +TYPED_TEST(HeaderStackTest, PopFrontNum) { + auto &stack = this->stack; + auto *phv = this->phv.get(); ASSERT_EQ(3u, stack.push_front(3)); // add 3 headers - Header &h0 = phv->get_header(testHeader_0); - Header &h1 = phv->get_header(testHeader_1); - Header &h2 = phv->get_header(testHeader_2); + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); ASSERT_TRUE(h0.is_valid()); ASSERT_TRUE(h1.is_valid()); ASSERT_TRUE(h2.is_valid()); - Field &f0_0 = h0.get_field(0); - Field &f1_0 = h1.get_field(0); - Field &f2_0 = h2.get_field(0); + auto &f0_0 = h0.get_field(0); + auto &f1_0 = h1.get_field(0); + auto &f2_0 = h2.get_field(0); unsigned int v0 = 10u; unsigned int v1 = 11u; unsigned int v2 = 12u; f0_0.set(v0); f1_0.set(v1); f2_0.set(v2); @@ -259,9 +335,62 @@ TEST_F(HeaderStackTest, PopFrontNum) { ASSERT_FALSE(h2.is_valid()); ASSERT_EQ(v2, f0_0.get_uint()); - ASSERT_EQ(1u, stack.pop_front(2)); + // not true for P4_16 stacks for which we always shift independently of the + // value of next + // ASSERT_EQ(1u, stack.pop_front(2)); + stack.pop_front(2); ASSERT_EQ(0u, stack.get_count()); ASSERT_FALSE(h0.is_valid()); ASSERT_FALSE(h1.is_valid()); ASSERT_FALSE(h2.is_valid()); } + +// inheritance to reuse members, constructor, etc +class HeaderStackP4_16Test + : public HeaderStackTest > { }; + +TEST_F(HeaderStackP4_16Test, PushFront) { + auto &stack = this->stack; + auto *phv = this->phv.get(); + + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); + + h1.mark_valid(); + EXPECT_EQ(1u, stack.push_front()); + EXPECT_FALSE(h0.is_valid()); // push doesn't make valid + EXPECT_FALSE(h1.is_valid()); + EXPECT_TRUE(h2.is_valid()); + + h0.mark_valid(); + EXPECT_EQ(2u, stack.push_front(2)); + EXPECT_FALSE(h0.is_valid()); + EXPECT_FALSE(h1.is_valid()); + EXPECT_TRUE(h2.is_valid()); +} + +TEST_F(HeaderStackP4_16Test, PopFront) { + auto &stack = this->stack; + auto *phv = this->phv.get(); + + auto &h0 = phv->get_header(this->testHeader_0); + auto &h1 = phv->get_header(this->testHeader_1); + auto &h2 = phv->get_header(this->testHeader_2); + + h0.mark_valid(); h2.mark_valid(); + EXPECT_EQ(1u, stack.pop_front()); + EXPECT_FALSE(h0.is_valid()); + EXPECT_TRUE(h1.is_valid()); + EXPECT_FALSE(h2.is_valid()); + + h2.mark_valid(); + EXPECT_EQ(2u, stack.pop_front(2)); + EXPECT_TRUE(h0.is_valid()); + EXPECT_FALSE(h1.is_valid()); + EXPECT_FALSE(h2.is_valid()); +} + +} // namespace testing + +} // namespace bm diff --git a/tests/test_phv.cpp b/tests/test_phv.cpp index b12c9ec29..7916cb330 100644 --- a/tests/test_phv.cpp +++ b/tests/test_phv.cpp @@ -101,15 +101,18 @@ TEST_F(PHVTest, CopyHeadersWithStack) { auto &stack = phv->get_header_stack(testHeaderStack); ASSERT_EQ(1u, stack.push_back()); - ASSERT_TRUE(phv->get_header(testHeader1).is_valid()); - - ASSERT_FALSE(phv_2->get_header(testHeader1).is_valid()); + // needs to work for both legacy stacks and P4_16 stacks, so we explicitly + // mark the header valid + phv->get_header(testHeader1).mark_valid(); + EXPECT_TRUE(phv->get_header(testHeader1).is_valid()); + EXPECT_FALSE(phv_2->get_header(testHeader1).is_valid()); + EXPECT_EQ(stack.get_count(), 1u); phv_2->copy_headers(*phv); const auto &stack_2 = phv_2->get_header_stack(testHeaderStack); - ASSERT_EQ(stack_2.get_count(), stack.get_count()); - ASSERT_TRUE(phv_2->get_header(testHeader1).is_valid()); + EXPECT_EQ(stack_2.get_count(), stack.get_count()); + EXPECT_TRUE(phv_2->get_header(testHeader1).is_valid()); } TEST_F(PHVTest, CopyHeadersWithUnion) {