From d4fd63abf048b7855ce33751dc4aaa77fc16b6c8 Mon Sep 17 00:00:00 2001 From: Marcelo Shima Date: Tue, 6 Feb 2024 12:06:34 -0300 Subject: [PATCH] improve create junit integration test --- .../web/rest/_entityClass_ResourceIT.java.ejs | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/generators/spring-boot/templates/src/test/java/_package_/_entityPackage_/web/rest/_entityClass_ResourceIT.java.ejs b/generators/spring-boot/templates/src/test/java/_package_/_entityPackage_/web/rest/_entityClass_ResourceIT.java.ejs index 5ea63c213dff..081da9630489 100644 --- a/generators/spring-boot/templates/src/test/java/_package_/_entityPackage_/web/rest/_entityClass_ResourceIT.java.ejs +++ b/generators/spring-boot/templates/src/test/java/_package_/_entityPackage_/web/rest/_entityClass_ResourceIT.java.ejs @@ -652,7 +652,7 @@ _%> @Test<%= transactionalAnnotation %> void create<%= entityClass %>() throws Exception { - int databaseSizeBeforeCreate = <%= entityInstance %>Repository.findAll()<%= callListBlock %>.size(); + long databaseSizeBeforeCreate = getRepositoryCount(); <%_ if (searchEngineElasticsearch) { _%> int searchDatabaseSizeBefore = IterableUtil.sizeOf(<%= entityInstance %>SearchRepository.findAll()<%= callListBlock %>); <%_ } _%> @@ -668,31 +668,46 @@ _%> <%= dtoClass %> <%= dtoInstance %> = <%= entityInstance %>Mapper.toDto(<%= persistInstance %>); <%_ } _%> <%_ if (reactive) { _%> - webTestClient.post().uri(ENTITY_API_URL) + var returned<%- restClass %> = webTestClient.post().uri(ENTITY_API_URL) .contentType(MediaType.APPLICATION_JSON) .bodyValue(om.writeValueAsBytes(<%= restInstance %>)) .exchange() - .expectStatus().isCreated(); + .expectStatus() + .isCreated() + .expectBody(<%- persistClass %>.class) + .returnResult() + .getResponseBody(); <%_ } else { _%> - rest<%= entityClass %>MockMvc.perform(post(ENTITY_API_URL)<% if (authenticationUsesCsrf) { %>.with(csrf())<% }%> - .contentType(MediaType.APPLICATION_JSON) - .content(om.writeValueAsBytes(<%= restInstance %>))) - .andExpect(status().isCreated()); + var returned<%- restClass %> = om.readValue( + rest<%= entityClass %>MockMvc + .perform( + post(ENTITY_API_URL) + <%_ if (authenticationUsesCsrf) { _%> + .with(csrf()) + <%_ } _%> + .contentType(MediaType.APPLICATION_JSON) + .content(om.writeValueAsBytes(<%= restInstance %>)) + ) + .andExpect(status().isCreated()) + .andReturn() + .getResponse() + .getContentAsString(), + <%- restClass %>.class + ); <%_ } _%> // Validate the <%= entityClass %> in the database <%_ if (databaseTypeCouchbase) { _%> SecurityContextHolder.setContext(TestSecurityContextHolder.getContext()); <%_ } _%> - List<<%= persistClass %>> <%= entityInstance %>List = <%= entityInstance %>Repository.findAll()<%= callListBlock %>; - assertThat(<%= entityInstance %>List).hasSize(databaseSizeBeforeCreate + 1); + assertIncrementedRepositoryCount(databaseSizeBeforeCreate); <%_ if (searchEngineElasticsearch) { _%> await().atMost(5, TimeUnit.SECONDS).untilAsserted(() -> { int searchDatabaseSizeAfter = IterableUtil.sizeOf(<%= entityInstance %>SearchRepository.findAll()<%= callListBlock %>); assertThat(searchDatabaseSizeAfter).isEqualTo(searchDatabaseSizeBefore + 1); }); <%_ } _%> - <%= persistClass %> test<%= entityClass %> = <%= entityInstance %>List.get(<%= entityInstance %>List.size() - 1); + <%= persistClass %> test<%= entityClass %> = <%= entityInstance %>Repository.findById(returned<%- restClass %>.get<%- primaryKey.nameCapitalized %>())<%= reactive ? callBlock : '.orElseThrow()' %>; <%_ for (const field of fieldsToTest) { if (field.fieldTypeZonedDateTime) { _%> assertThat(test<%= entityClass %>.get<%= field.fieldInJavaBeanMethod %>()).isEqualTo(<%= 'DEFAULT_' + field.fieldNameUnderscored.toUpperCase() %>); @@ -724,7 +739,7 @@ _%> <%= dtoClass %> <%= dtoInstance %> = <%= entityInstance %>Mapper.toDto(<%= persistInstance %>); <%_ } _%> - int databaseSizeBeforeCreate = <%= entityInstance %>Repository.findAll()<%= callListBlock %>.size(); + long databaseSizeBeforeCreate = getRepositoryCount(); <%_ if (searchEngineElasticsearch) { _%> int searchDatabaseSizeBefore = IterableUtil.sizeOf(<%= entityInstance %>SearchRepository.findAll()<%= callListBlock %>); <%_ } _%> @@ -748,7 +763,7 @@ _%> SecurityContextHolder.setContext(TestSecurityContextHolder.getContext()); <%_ } _%> List<<%= persistClass %>> <%= entityInstance %>List = <%= entityInstance %>Repository.findAll()<%= callListBlock %>; - assertThat(<%= entityInstance %>List).hasSize(databaseSizeBeforeCreate); + assertSameRepositoryCount(databaseSizeBeforeCreate); <%_ if (searchEngineElasticsearch) { _%> int searchDatabaseSizeAfter = IterableUtil.sizeOf(<%= entityInstance %>SearchRepository.findAll()<%= callListBlock %>); assertThat(searchDatabaseSizeAfter).isEqualTo(searchDatabaseSizeBefore); @@ -761,7 +776,7 @@ _%> // Initialize the database <%= entityInstance %>Repository.<%= saveMethod %>(<%= persistInstance %>)<%= callBlock %>; <%_ const alreadyGeneratedEntities = []; _%> - int databaseSizeBeforeCreate = <%= entityInstance %>Repository.findAll()<%= callListBlock %>.size(); + long databaseSizeBeforeCreate = getRepositoryCount(); <%_ if (searchEngineElasticsearch) { _%> int searchDatabaseSizeBefore = IterableUtil.sizeOf(<%= entityInstance %>SearchRepository.findAll()<%= callListBlock %>); <%_ } _%> @@ -816,7 +831,7 @@ _%> SecurityContextHolder.setContext(TestSecurityContextHolder.getContext()); <%_ } _%> List<<%= persistClass %>> <%= entityInstance %>List = <%= entityInstance %>Repository.findAll()<%= callListBlock %>; - assertThat(<%= entityInstance %>List).hasSize(databaseSizeBeforeCreate); + assertSameRepositoryCount(databaseSizeBeforeCreate); <%= persistClass %> test<%= entityClass %> = <%= entityInstance %>List.get(<%= entityInstance %>List.size() - 1); // Validate the id for MapsId, the ids must be same @@ -1865,4 +1880,16 @@ _%> if (!field.fieldTypeString) { %>.toString()<% } %>))<%= !reactive ? ')' : '' %><%_ } _%>; } <%_ } _%> + + protected long getRepositoryCount() { + return <%= entityInstance %>Repository.count()<%= callBlock %>; + } + + protected void assertIncrementedRepositoryCount(long countBefore) { + assertThat(countBefore + 1).isEqualTo(getRepositoryCount()); + } + + protected void assertSameRepositoryCount(long countBefore) { + assertThat(countBefore).isEqualTo(getRepositoryCount()); + } }