diff --git a/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAO.java b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAO.java index 2e686f308..cc7dd302e 100644 --- a/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAO.java +++ b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAO.java @@ -2,30 +2,15 @@ import com.linkedin.common.urn.Urn; import com.linkedin.data.template.RecordTemplate; -import com.linkedin.metadata.dao.exception.RetryLimitReached; -import com.linkedin.metadata.dao.utils.Statement; -import com.linkedin.metadata.validator.EntityValidator; -import com.linkedin.metadata.validator.RelationshipValidator; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; import javax.annotation.Nonnull; - -import lombok.AllArgsConstructor; -import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang.time.StopWatch; import org.neo4j.driver.Driver; -import org.neo4j.driver.Record; -import org.neo4j.driver.Session; +import org.neo4j.driver.Query; import org.neo4j.driver.SessionConfig; -import org.neo4j.driver.exceptions.Neo4jException; import static com.linkedin.metadata.dao.Neo4jUtil.*; import static com.linkedin.metadata.dao.utils.ModelUtils.*; @@ -123,14 +108,23 @@ public void onRelationshipsRemoved(int relationshipCount, long updateTimeMs, int } } - private static final int MAX_TRANSACTION_RETRY = 3; - private final Driver _driver; - private SessionConfig _sessionConfig; - private static Map _urnToEntityMap = null; private DelegateMetricListener _metricListener = new DelegateMetricListener(); + private final Neo4jQueriesTransformer _queriesTransformer; + private final Neo4jQueryExecutor _queryExecutor; + + private Neo4jGraphWriterDAO(@Nonnull Neo4jQueriesTransformer queriesTransformer, + @Nonnull Neo4jQueryExecutor queryExecutor) { + _queriesTransformer = queriesTransformer; + _queryExecutor = queryExecutor; + } public Neo4jGraphWriterDAO(@Nonnull Driver driver) { - this(driver, SessionConfig.defaultConfig()); + this(new Neo4jQueriesTransformer(), new Neo4jQueryExecutor(driver)); + } + + /* Should only be used for testing */ + public Neo4jGraphWriterDAO(@Nonnull Driver driver, @Nonnull Set> allEntities) { + this(new Neo4jQueriesTransformer(allEntities), new Neo4jQueryExecutor(driver)); } /** @@ -140,19 +134,7 @@ public Neo4jGraphWriterDAO(@Nonnull Driver driver) { * And Java 11 build is blocked by ES7 migration. */ public Neo4jGraphWriterDAO(@Nonnull Driver driver, @Nonnull String databaseName) { - this(driver, SessionConfig.forDatabase(databaseName)); - } - - public Neo4jGraphWriterDAO(@Nonnull Driver driver, @Nonnull SessionConfig sessionConfig) { - this(driver, sessionConfig, getAllEntities()); - } - - /* Should only be used for testing */ - public Neo4jGraphWriterDAO(@Nonnull Driver driver, @Nonnull SessionConfig sessionConfig, - @Nonnull Set> allEntities) { - this._driver = driver; - this._sessionConfig = sessionConfig; - buildUrnToEntityMap(allEntities); + this(new Neo4jQueriesTransformer(), new Neo4jQueryExecutor(driver, SessionConfig.forDatabase(databaseName))); } public void addMetricListener(@Nonnull MetricListener metricListener) { @@ -161,381 +143,92 @@ public void addMetricListener(@Nonnull MetricListener metricListener) { @Override public void addEntities(@Nonnull List entities) { + final List list = new ArrayList<>(); - for (ENTITY entity1 : entities) { - EntityValidator.validateEntitySchema(entity1.getClass()); - } - List list = new ArrayList<>(); for (ENTITY entity : entities) { - Statement statement = addNode(entity); - list.add(statement); + list.add(_queriesTransformer.addEntityQuery(entity)); } - final ExecutionResult e = executeStatements(list); - log.trace("Added {} entities over {} retries, which took {} millis", entities.size(), e.getTookMs(), - e.getRetries()); - _metricListener.onEntitiesAdded(entities.size(), e.getTookMs(), e.getRetries()); + final Neo4jQueryResult result = _queryExecutor.execute(list); + log.trace("Added {} entities over {} retries, which took {} millis", entities.size(), result.getTookMs(), + result.getRetries()); + _metricListener.onEntitiesAdded(entities.size(), result.getTookMs(), result.getRetries()); } @Override public void removeEntities(@Nonnull List urns) { - List list = new ArrayList<>(); + final List list = new ArrayList<>(); for (URN urn : urns) { - Statement statement = removeNode(urn); - list.add(statement); + list.add(_queriesTransformer.removeEntityQuery(urn)); } - final ExecutionResult e = executeStatements(list); - log.trace("Removed {} entities over {} retries, which took {} millis", urns.size(), e.getTookMs(), - e.getRetries()); - _metricListener.onEntitiesRemoved(urns.size(), e.getTookMs(), e.getRetries()); + final Neo4jQueryResult result = _queryExecutor.execute(list); + log.trace("Removed {} entities over {} retries, which took {} millis", urns.size(), result.getTookMs(), + result.getRetries()); + _metricListener.onEntitiesRemoved(urns.size(), result.getTookMs(), result.getRetries()); } @Override public void addRelationships(@Nonnull List relationships, @Nonnull RemovalOption removalOption) { - - for (RELATIONSHIP relationship : relationships) { - RelationshipValidator.validateRelationshipSchema(relationship.getClass()); - } - - final ExecutionResult e = executeStatements(addEdges(relationships, removalOption)); - log.trace("Added {} relationships over {} retries, which took {} millis", relationships.size(), e.getTookMs(), - e.getRetries()); - _metricListener.onRelationshipsAdded(relationships.size(), e.getTookMs(), e.getRetries()); - } - - @Override - public void removeRelationships(@Nonnull List relationships) { - - for (RELATIONSHIP relationship : relationships) { - RelationshipValidator.validateRelationshipSchema(relationship.getClass()); - } - List list = new ArrayList<>(); - for (RELATIONSHIP relationship : relationships) { - Statement statement = removeEdge(relationship); - list.add(statement); + if (relationships.isEmpty()) { + return; } - final ExecutionResult e = executeStatements(list); - log.trace("Removed {} relationships over {} retries, which took {} millis", relationships.size(), e.getTookMs(), - e.getRetries()); - _metricListener.onRelationshipsRemoved(relationships.size(), e.getTookMs(), e.getRetries()); - } + final List list = new ArrayList<>(); - @AllArgsConstructor - @Data - private static final class ExecutionResult { - private long tookMs; - private int retries; - } + _queriesTransformer.relationshipRemovalOptionQuery(relationships.get(0), removalOption).ifPresent(list::add); - /** - * Executes a list of statements with parameters in one transaction. - * - * @param statements List of statements with parameters to be executed in order - */ - private ExecutionResult executeStatements(@Nonnull List statements) { - int retry = 0; - final StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - Exception lastException; - try (final Session session = _driver.session(_sessionConfig)) { - do { - try { - session.writeTransaction(tx -> { - for (Statement statement : statements) { - tx.run(statement.getCommandText(), statement.getParams()); - } - return 0; - }); - lastException = null; - break; - } catch (Neo4jException e) { - lastException = e; - } - } while (++retry <= MAX_TRANSACTION_RETRY); - } + checkSameUrn(relationships, removalOption); - if (lastException != null) { - throw new RetryLimitReached("Failed to execute Neo4j write transaction after " - + MAX_TRANSACTION_RETRY + " retries", lastException); + for (RELATIONSHIP relationship : relationships) { + list.add(_queriesTransformer.addRelationshipQuery(relationship)); } - stopWatch.stop(); - return new ExecutionResult(stopWatch.getTime(), retry); + final Neo4jQueryResult result = _queryExecutor.execute(list); + log.trace("Added {} relationships over {} retries, which took {} millis", relationships.size(), result.getTookMs(), + result.getRetries()); + _metricListener.onRelationshipsAdded(relationships.size(), result.getTookMs(), result.getRetries()); } - /** - * Run a query statement with parameters and return StatementResult. - * - * @param statement a statement with parameters to be executed - */ - @Nonnull - private List runQuery(@Nonnull Statement statement) { - try (final Session session = _driver.session(_sessionConfig)) { - return session.run(statement.getCommandText(), statement.getParams()).list(); + @Override + public void removeRelationships(@Nonnull List relationships) { + if (relationships.isEmpty()) { + return; } - } - // used in testing - @Nonnull - Optional> getNode(@Nonnull Urn urn) { - List> nodes = getAllNodes(urn); - if (nodes.isEmpty()) { - return Optional.empty(); + final List list = new ArrayList<>(); + for (RELATIONSHIP relationship : relationships) { + list.add(_queriesTransformer.removeEdge(relationship)); } - return Optional.of(nodes.get(0)); - } - - // used in testing - @Nonnull - List> getAllNodes(@Nonnull Urn urn) { - final String matchTemplate = "MATCH (node%s {urn: $urn}) RETURN node"; - - final String sourceType = getNodeType(urn); - final String statement = String.format(matchTemplate, sourceType); - - final Map params = new HashMap<>(); - params.put("urn", urn.toString()); - - final List result = runQuery(buildStatement(statement, params)); - return result.stream().map(record -> record.values().get(0).asMap()).collect(Collectors.toList()); - } - - // used in testing - @Nonnull - List> getEdges(@Nonnull RELATIONSHIP relationship) { - final Urn sourceUrn = getSourceUrnFromRelationship(relationship); - final Urn destinationUrn = getDestinationUrnFromRelationship(relationship); - final String relationshipType = getType(relationship); - - final String sourceType = getNodeType(sourceUrn); - final String destinationType = getNodeType(destinationUrn); - - final String matchTemplate = - "MATCH (source%s {urn: $sourceUrn})-[r:%s]->(destination%s {urn: $destinationUrn}) RETURN r"; - final String statement = String.format(matchTemplate, sourceType, relationshipType, destinationType); - - final Map params = new HashMap<>(); - params.put("sourceUrn", sourceUrn.toString()); - params.put("destinationUrn", destinationUrn.toString()); - - final List result = runQuery(buildStatement(statement, params)); - return result.stream().map(record -> record.values().get(0).asMap()).collect(Collectors.toList()); - } - - // used in testing - @Nonnull - List> getEdgesFromSource( - @Nonnull Urn sourceUrn, @Nonnull Class relationshipClass) { - final String relationshipType = getType(relationshipClass); - final String sourceType = getNodeType(sourceUrn); - - final String matchTemplate = "MATCH (source%s {urn: $sourceUrn})-[r:%s]->() RETURN r"; - final String statement = String.format(matchTemplate, sourceType, relationshipType); - - final Map params = new HashMap<>(); - params.put("sourceUrn", sourceUrn.toString()); - - final List result = runQuery(buildStatement(statement, params)); - return result.stream().map(record -> record.values().get(0).asMap()).collect(Collectors.toList()); - } - - @Nonnull - private Statement addNode(@Nonnull ENTITY entity) { - final Urn urn = getUrnFromEntity(entity); - final String nodeType = getNodeType(urn); - - // Use += to ensure this doesn't override the node but merges in the new properties to allow for partial updates. - final String mergeTemplate = "MERGE (node%s {urn: $urn}) SET node += $properties RETURN node"; - final String statement = String.format(mergeTemplate, nodeType); - - final Map params = new HashMap<>(); - params.put("urn", urn.toString()); - final Map props = entityToNode(entity); - props.remove("urn"); // no need to set twice (this is implied by MERGE), and they can be quite long. - params.put("properties", props); - - return buildStatement(statement, params); - } - - @Nonnull - private Statement removeNode(@Nonnull URN urn) { - // also delete any relationship going to or from it - final String nodeType = getNodeType(urn); - - final String matchTemplate = "MATCH (node%s {urn: $urn}) DETACH DELETE node"; - final String statement = String.format(matchTemplate, nodeType); - - final Map params = new HashMap<>(); - params.put("urn", urn.toString()); - - return buildStatement(statement, params); - } - - /** - * Gets Node based on Urn, if not exist, creates placeholder node. - */ - @Nonnull - private Statement getOrInsertNode(@Nonnull Urn urn) { - final String nodeType = getNodeType(urn); - final String mergeTemplate = "MERGE (node%s {urn: $urn}) RETURN node"; - final String statement = String.format(mergeTemplate, nodeType); - - final Map params = new HashMap<>(); - params.put("urn", urn.toString()); - - return buildStatement(statement, params); + final Neo4jQueryResult result = _queryExecutor.execute(list); + log.trace("Removed {} relationships over {} retries, which took {} millis", relationships.size(), + result.getTookMs(), result.getRetries()); + _metricListener.onRelationshipsRemoved(relationships.size(), result.getTookMs(), result.getRetries()); } - @Nonnull - private - List addEdges(@Nonnull List relationships, @Nonnull RemovalOption removalOption) { - - // if no relationships, return - if (relationships.isEmpty()) { - return Collections.emptyList(); - } - - final List statements = new ArrayList<>(); - - // remove existing edges according to RemovalOption + private void checkSameUrn(@Nonnull List relationships, + @Nonnull RemovalOption removalOption) { final Urn source0Urn = getSourceUrnFromRelationship(relationships.get(0)); final Urn destination0Urn = getDestinationUrnFromRelationship(relationships.get(0)); - final String relationType = getType(relationships.get(0)); - - final String sourceType = getNodeType(source0Urn); - final String destinationType = getNodeType(destination0Urn); - - final Map params = new HashMap<>(); if (removalOption == RemovalOption.REMOVE_ALL_EDGES_FROM_SOURCE) { checkSameUrn(relationships, SOURCE_FIELD, source0Urn); - - final String removeTemplate = "MATCH (source%s {urn: $urn})-[relation:%s]->() DELETE relation"; - final String statement = String.format(removeTemplate, sourceType, relationType); - - params.put("urn", source0Urn.toString()); - - statements.add(buildStatement(statement, params)); } else if (removalOption == RemovalOption.REMOVE_ALL_EDGES_TO_DESTINATION) { checkSameUrn(relationships, DESTINATION_FIELD, destination0Urn); - - final String removeTemplate = "MATCH ()-[relation:%s]->(destination%s {urn: $urn}) DELETE relation"; - final String statement = String.format(removeTemplate, relationType, destinationType); - - params.put("urn", destination0Urn.toString()); - - statements.add(buildStatement(statement, params)); } else if (removalOption == RemovalOption.REMOVE_ALL_EDGES_FROM_SOURCE_TO_DESTINATION) { checkSameUrn(relationships, SOURCE_FIELD, source0Urn); checkSameUrn(relationships, DESTINATION_FIELD, destination0Urn); - - final String removeTemplate = - "MATCH (source%s {urn: $sourceUrn})-[relation:%s]->(destination%s {urn: $destinationUrn}) DELETE relation"; - final String statement = String.format(removeTemplate, sourceType, relationType, destinationType); - - params.put("sourceUrn", source0Urn.toString()); - params.put("destinationUrn", destination0Urn.toString()); - - statements.add(buildStatement(statement, params)); - } - - for (RELATIONSHIP relationship : relationships) { - final Urn srcUrn = getSourceUrnFromRelationship(relationship); - final Urn destUrn = getDestinationUrnFromRelationship(relationship); - final String sourceNodeType = getNodeType(srcUrn); - final String destinationNodeType = getNodeType(destUrn); - - // Add/Update source & destination node first - statements.add(getOrInsertNode(srcUrn)); - statements.add(getOrInsertNode(destUrn)); - - // Add/Update relationship - final String mergeRelationshipTemplate = - "MATCH (source%s {urn: $sourceUrn}),(destination%s {urn: $destinationUrn}) MERGE (source)-[r:%s]->(destination) SET r += $properties"; - final String statement = - String.format(mergeRelationshipTemplate, sourceNodeType, destinationNodeType, getType(relationship)); - - final Map paramsMerge = new HashMap<>(); - paramsMerge.put("sourceUrn", srcUrn.toString()); - paramsMerge.put("destinationUrn", destUrn.toString()); - paramsMerge.put("properties", relationshipToEdge(relationship)); - - statements.add(buildStatement(statement, paramsMerge)); } - - return statements; } - private void checkSameUrn(@Nonnull List records, @Nonnull String field, + private void checkSameUrn(@Nonnull List records, @Nonnull String field, @Nonnull Urn compare) { - for (T relation : records) { + for (RecordTemplate relation : records) { if (!compare.equals(getRecordTemplateField(relation, field, Urn.class))) { throw new IllegalArgumentException("Records have different " + field + " urn"); } } } - - @Nonnull - private Statement removeEdge(@Nonnull RELATIONSHIP relationship) { - - final Urn sourceUrn = getSourceUrnFromRelationship(relationship); - final Urn destinationUrn = getDestinationUrnFromRelationship(relationship); - - final String sourceType = getNodeType(sourceUrn); - final String destinationType = getNodeType(destinationUrn); - - final String removeMatchTemplate = - "MATCH (source%s {urn: $sourceUrn})-[relation:%s %s]->(destination%s {urn: $destinationUrn}) DELETE relation"; - final String criteria = relationshipToCriteria(relationship); - final String statement = - String.format(removeMatchTemplate, sourceType, getType(relationship), criteria, destinationType); - - final Map params = new HashMap<>(); - params.put("sourceUrn", sourceUrn.toString()); - params.put("destinationUrn", destinationUrn.toString()); - - return buildStatement(statement, params); - } - - // visible for testing - @Nonnull - Statement buildStatement(@Nonnull String queryTemplate, @Nonnull Map params) { - for (Map.Entry entry : params.entrySet()) { - String k = entry.getKey(); - Object v = entry.getValue(); - params.put(k, toPropertyValue(v)); - } - return new Statement(queryTemplate, params); - } - - @Nonnull - private Object toPropertyValue(@Nonnull Object obj) { - if (obj instanceof Urn) { - return obj.toString(); - } - return obj; - } - - @Nonnull - public String getNodeType(@Nonnull Urn urn) { - return ":" + _urnToEntityMap.getOrDefault(urn.getEntityType(), "UNKNOWN"); - } - - @Nonnull - private Map buildUrnToEntityMap(@Nonnull Set> entitiesSet) { - if (_urnToEntityMap == null) { - Map map = new HashMap<>(); - for (Class entity : entitiesSet) { - if (map.put(getEntityTypeFromUrnClass(urnClassForEntity(entity)), getType(entity)) != null) { - throw new IllegalStateException("Duplicate key"); - } - } - _urnToEntityMap = map; - } - return _urnToEntityMap; - } } diff --git a/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformer.java b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformer.java new file mode 100644 index 000000000..cb0f9eaa0 --- /dev/null +++ b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformer.java @@ -0,0 +1,195 @@ +package com.linkedin.metadata.dao.internal; + +import com.linkedin.common.urn.Urn; +import com.linkedin.data.template.RecordTemplate; +import com.linkedin.metadata.validator.EntityValidator; +import com.linkedin.metadata.validator.RelationshipValidator; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import javax.annotation.Nonnull; +import org.neo4j.driver.Query; + +import static com.linkedin.metadata.dao.Neo4jUtil.*; +import static com.linkedin.metadata.dao.utils.ModelUtils.*; + + +/** + * Can transform GMA entities and relationships into Neo4j queries for upserting. + * + *

This separates out transformation logic from query execution logic ({@link Neo4jGraphWriterDAO}). + */ +public final class Neo4jQueriesTransformer { + private static final Map DEFAULT_URN_TO_ENTITY_MAP = buildUrnToEntityMap(getAllEntities()); + private final Map _urnToEntityMap; + + public Neo4jQueriesTransformer() { + this(DEFAULT_URN_TO_ENTITY_MAP); + } + + /** + * For use in unit testing. + */ + public Neo4jQueriesTransformer(@Nonnull Set> entitiesSet) { + this(buildUrnToEntityMap(entitiesSet)); + } + + private Neo4jQueriesTransformer(@Nonnull Map urnToEntityMap) { + _urnToEntityMap = urnToEntityMap; + } + + @Nonnull + private static Map buildUrnToEntityMap(@Nonnull Set> entitiesSet) { + Map map = new HashMap<>(); + for (Class entity : entitiesSet) { + if (map.put(getEntityTypeFromUrnClass(urnClassForEntity(entity)), getType(entity)) != null) { + throw new IllegalStateException("Duplicate key"); + } + } + return map; + } + + @Nonnull + private Object toPropertyValue(@Nonnull Object obj) { + if (obj instanceof Urn) { + return obj.toString(); + } + return obj; + } + + // visible for testing + @Nonnull + String getNodeType(@Nonnull Urn urn) { + return ":" + _urnToEntityMap.getOrDefault(urn.getEntityType(), "UNKNOWN"); + } + + @Nonnull + private Query buildQuery(@Nonnull String queryTemplate, @Nonnull Map params) { + for (Map.Entry entry : params.entrySet()) { + String k = entry.getKey(); + Object v = entry.getValue(); + params.put(k, toPropertyValue(v)); + } + return new Query(queryTemplate, params); + } + + @Nonnull + public Query addEntityQuery(@Nonnull RecordTemplate entity) { + EntityValidator.validateEntitySchema(entity.getClass()); + + final Urn urn = getUrnFromEntity(entity); + final String nodeType = getNodeType(urn); + + // Use += to ensure this doesn't override the node but merges in the new properties to allow for partial updates. + final String mergeTemplate = "MERGE (node%s {urn: $urn}) SET node += $properties RETURN node"; + final String statement = String.format(mergeTemplate, nodeType); + + final Map params = new HashMap<>(); + params.put("urn", urn.toString()); + final Map props = entityToNode(entity); + props.remove("urn"); // no need to set twice (this is implied by MERGE), and they can be quite long. + params.put("properties", props); + + return buildQuery(statement, params); + } + + @Nonnull + public Query removeEntityQuery(@Nonnull Urn urn) { + // also delete any relationship going to or from it + final String nodeType = getNodeType(urn); + + final String matchTemplate = "MATCH (node%s {urn: $urn}) DETACH DELETE node"; + final String statement = String.format(matchTemplate, nodeType); + + final Map params = new HashMap<>(); + params.put("urn", urn.toString()); + + return buildQuery(statement, params); + } + + @Nonnull + public Optional relationshipRemovalOptionQuery(@Nonnull RecordTemplate relationship, + BaseGraphWriterDAO.RemovalOption removalOption) { + // remove existing edges according to RemovalOption + final Urn source0Urn = getSourceUrnFromRelationship(relationship); + final Urn destination0Urn = getDestinationUrnFromRelationship(relationship); + final String relationType = getType(relationship); + + final String sourceType = getNodeType(source0Urn); + final String destinationType = getNodeType(destination0Urn); + + final Map params = new HashMap<>(); + + if (removalOption == BaseGraphWriterDAO.RemovalOption.REMOVE_ALL_EDGES_FROM_SOURCE) { + final String removeTemplate = "MATCH (source%s {urn: $urn})-[relation:%s]->() DELETE relation"; + final String statement = String.format(removeTemplate, sourceType, relationType); + + params.put("urn", source0Urn.toString()); + + return Optional.of(buildQuery(statement, params)); + } else if (removalOption == BaseGraphWriterDAO.RemovalOption.REMOVE_ALL_EDGES_TO_DESTINATION) { + final String removeTemplate = "MATCH ()-[relation:%s]->(destination%s {urn: $urn}) DELETE relation"; + final String statement = String.format(removeTemplate, relationType, destinationType); + + params.put("urn", destination0Urn.toString()); + + return Optional.of(buildQuery(statement, params)); + } else if (removalOption == BaseGraphWriterDAO.RemovalOption.REMOVE_ALL_EDGES_FROM_SOURCE_TO_DESTINATION) { + final String removeTemplate = + "MATCH (source%s {urn: $sourceUrn})-[relation:%s]->(destination%s {urn: $destinationUrn}) DELETE relation"; + final String statement = String.format(removeTemplate, sourceType, relationType, destinationType); + + params.put("sourceUrn", source0Urn.toString()); + params.put("destinationUrn", destination0Urn.toString()); + + return Optional.of(buildQuery(statement, params)); + } + + return Optional.empty(); + } + + @Nonnull + public Query addRelationshipQuery(@Nonnull RecordTemplate relationship) { + RelationshipValidator.validateRelationshipSchema(relationship.getClass()); + final Urn srcUrn = getSourceUrnFromRelationship(relationship); + final Urn destUrn = getDestinationUrnFromRelationship(relationship); + final String sourceNodeType = getNodeType(srcUrn); + final String destinationNodeType = getNodeType(destUrn); + + // Add/Update relationship. Use MERGE on nodes to prevent needing to have separate queries to create them. + final String mergeRelationshipTemplate = + "MERGE (source%s {urn: $sourceUrn}) " + "MERGE (destination%s {urn: $destinationUrn}) " + + "MERGE (source)-[r:%s]->(destination) SET r += $properties"; + final String statement = + String.format(mergeRelationshipTemplate, sourceNodeType, destinationNodeType, getType(relationship)); + + final Map paramsMerge = new HashMap<>(); + paramsMerge.put("sourceUrn", srcUrn.toString()); + paramsMerge.put("destinationUrn", destUrn.toString()); + paramsMerge.put("properties", relationshipToEdge(relationship)); + + return new Query(statement, paramsMerge); + } + + @Nonnull + public Query removeEdge(@Nonnull RecordTemplate relationship) { + final Urn sourceUrn = getSourceUrnFromRelationship(relationship); + final Urn destinationUrn = getDestinationUrnFromRelationship(relationship); + + final String sourceType = getNodeType(sourceUrn); + final String destinationType = getNodeType(destinationUrn); + + final String removeMatchTemplate = + "MATCH (source%s {urn: $sourceUrn})-[relation:%s %s]->(destination%s {urn: $destinationUrn}) DELETE relation"; + final String criteria = relationshipToCriteria(relationship); + final String statement = + String.format(removeMatchTemplate, sourceType, getType(relationship), criteria, destinationType); + + final Map params = new HashMap<>(); + params.put("sourceUrn", sourceUrn.toString()); + params.put("destinationUrn", destinationUrn.toString()); + + return buildQuery(statement, params); + } +} diff --git a/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryExecutor.java b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryExecutor.java new file mode 100644 index 000000000..a4ebbb726 --- /dev/null +++ b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryExecutor.java @@ -0,0 +1,64 @@ +package com.linkedin.metadata.dao.internal; + +import com.linkedin.metadata.dao.exception.RetryLimitReached; +import java.util.List; +import javax.annotation.Nonnull; +import org.apache.commons.lang.time.StopWatch; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Query; +import org.neo4j.driver.Session; +import org.neo4j.driver.SessionConfig; +import org.neo4j.driver.exceptions.Neo4jException; + + +public final class Neo4jQueryExecutor { + private static final int MAX_TRANSACTION_RETRY = 3; + private final Driver _driver; + private final SessionConfig _sessionConfig; + + public Neo4jQueryExecutor(@Nonnull Driver driver, @Nonnull SessionConfig sessionConfig) { + _driver = driver; + _sessionConfig = sessionConfig; + } + + public Neo4jQueryExecutor(@Nonnull Driver driver) { + this(driver, SessionConfig.defaultConfig()); + } + + /** + * Executes a list of queries with parameters in one transaction. + * + * @param queries List of queries with parameters to be executed in order + */ + @Nonnull + public Neo4jQueryResult execute(@Nonnull List queries) { + int retry = 0; + final StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + Exception lastException; + try (final Session session = _driver.session(_sessionConfig)) { + do { + try { + session.writeTransaction(tx -> { + for (Query query : queries) { + tx.run(query); + } + return null; + }); + lastException = null; + break; + } catch (Neo4jException e) { + lastException = e; + } + } while (++retry <= MAX_TRANSACTION_RETRY); + } + + if (lastException != null) { + throw new RetryLimitReached( + "Failed to execute Neo4j write transaction after " + MAX_TRANSACTION_RETRY + " retries", lastException); + } + + stopWatch.stop(); + return Neo4jQueryResult.builder().tookMs(stopWatch.getTime()).retries(retry).build(); + } +} diff --git a/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryResult.java b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryResult.java new file mode 100644 index 000000000..ffe058ec1 --- /dev/null +++ b/dao-impl/neo4j-dao/src/main/java/com/linkedin/metadata/dao/internal/Neo4jQueryResult.java @@ -0,0 +1,12 @@ +package com.linkedin.metadata.dao.internal; + +import lombok.Builder; +import lombok.Data; + + +@Builder +@Data +public final class Neo4jQueryResult { + private final long tookMs; + private final int retries; +} diff --git a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/Neo4jQueryDAOTest.java b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/Neo4jQueryDAOTest.java index c2c8712b0..087c25c13 100644 --- a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/Neo4jQueryDAOTest.java +++ b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/Neo4jQueryDAOTest.java @@ -6,11 +6,11 @@ import com.linkedin.metadata.query.Filter; import com.linkedin.metadata.query.RelationshipDirection; import com.linkedin.metadata.query.RelationshipFilter; +import com.linkedin.testing.EntityBar; import com.linkedin.testing.EntityBaz; +import com.linkedin.testing.EntityFoo; import com.linkedin.testing.RelationshipBar; import com.linkedin.testing.RelationshipFoo; -import com.linkedin.testing.EntityFoo; -import com.linkedin.testing.EntityBar; import com.linkedin.testing.TestUtils; import com.linkedin.testing.urn.BarUrn; import com.linkedin.testing.urn.BazUrn; @@ -26,7 +26,6 @@ import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; import org.neo4j.driver.Record; -import org.neo4j.driver.SessionConfig; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -50,7 +49,7 @@ public void init() { final Driver driver = GraphDatabase.driver(_serverBuilder.boltURI()); _dao = new Neo4jQueryDAO(driver); - _writer = new Neo4jGraphWriterDAO(driver, SessionConfig.defaultConfig(), TestUtils.getAllTestEntities()); + _writer = new Neo4jGraphWriterDAO(driver, TestUtils.getAllTestEntities()); } @AfterMethod diff --git a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAOTest.java b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAOTest.java index 16ceaaf1f..a0eabaf58 100644 --- a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAOTest.java +++ b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jGraphWriterDAOTest.java @@ -4,7 +4,6 @@ import com.linkedin.metadata.dao.BaseQueryDAO; import com.linkedin.metadata.dao.Neo4jQueryDAO; import com.linkedin.metadata.dao.Neo4jTestServerBuilder; -import com.linkedin.metadata.dao.utils.Statement; import com.linkedin.metadata.query.Criterion; import com.linkedin.metadata.query.CriterionArray; import com.linkedin.metadata.query.Filter; @@ -14,17 +13,14 @@ import com.linkedin.testing.TestUtils; import com.linkedin.testing.urn.BarUrn; import com.linkedin.testing.urn.FooUrn; - import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import javax.annotation.Nonnull; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; -import org.neo4j.driver.SessionConfig; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -38,8 +34,8 @@ public class Neo4jGraphWriterDAOTest { private Neo4jTestServerBuilder _serverBuilder; - private Driver _driver; private Neo4jGraphWriterDAO _dao; + private Neo4jTestHelper _helper; private BaseQueryDAO _queryDao; private TestMetricListener _testMetricListener; @@ -83,9 +79,10 @@ public void init() { _serverBuilder = new Neo4jTestServerBuilder(); _serverBuilder.newServer(); _testMetricListener = new TestMetricListener(); - _driver = GraphDatabase.driver(_serverBuilder.boltURI()); - _dao = new Neo4jGraphWriterDAO(_driver, SessionConfig.defaultConfig(), TestUtils.getAllTestEntities()); - _queryDao = new Neo4jQueryDAO(_driver); + final Driver driver = GraphDatabase.driver(_serverBuilder.boltURI()); + _dao = new Neo4jGraphWriterDAO(driver, TestUtils.getAllTestEntities()); + _helper = new Neo4jTestHelper(driver, TestUtils.getAllTestEntities()); + _queryDao = new Neo4jQueryDAO(driver); _dao.addMetricListener(_testMetricListener); } @@ -100,13 +97,13 @@ public void testAddRemoveEntity() throws Exception { EntityFoo entity = new EntityFoo().setUrn(urn).setValue("foo"); _dao.addEntity(entity); - Optional> node = _dao.getNode(urn); + Optional> node = _helper.getNode(urn); assertEntityFoo(node.get(), entity); assertEquals(_testMetricListener.entitiesAdded, 1); assertEquals(_testMetricListener.entityAddedEvents, 1); _dao.removeEntity(urn); - node = _dao.getNode(urn); + node = _helper.getNode(urn); assertFalse(node.isPresent()); assertEquals(_testMetricListener.entitiesRemoved, 1); assertEquals(_testMetricListener.entityRemovedEvents, 1); @@ -118,21 +115,21 @@ public void testPartialUpdateEntity() throws Exception { EntityFoo entity = new EntityFoo().setUrn(urn); _dao.addEntity(entity); - Optional> node = _dao.getNode(urn); + Optional> node = _helper.getNode(urn); assertEntityFoo(node.get(), entity); // add value for optional field EntityFoo entity2 = new EntityFoo().setUrn(urn).setValue("IamTheSameEntity"); _dao.addEntity(entity2); - node = _dao.getNode(urn); - assertEquals(_dao.getAllNodes(urn).size(), 1); + node = _helper.getNode(urn); + assertEquals(_helper.getAllNodes(urn).size(), 1); assertEntityFoo(node.get(), entity2); // change value for optional field EntityFoo entity3 = new EntityFoo().setUrn(urn).setValue("ChangeValue"); _dao.addEntity(entity3); - node = _dao.getNode(urn); - assertEquals(_dao.getAllNodes(urn).size(), 1); + node = _helper.getNode(urn); + assertEquals(_helper.getAllNodes(urn).size(), 1); assertEntityFoo(node.get(), entity3); } @@ -144,16 +141,16 @@ public void testAddRemoveEntities() throws Exception { List entities = Arrays.asList(entity1, entity2, entity3); _dao.addEntities(entities); - assertEntityFoo(_dao.getNode(entity1.getUrn()).get(), entity1); - assertEntityFoo(_dao.getNode(entity2.getUrn()).get(), entity2); - assertEntityFoo(_dao.getNode(entity3.getUrn()).get(), entity3); + assertEntityFoo(_helper.getNode(entity1.getUrn()).get(), entity1); + assertEntityFoo(_helper.getNode(entity2.getUrn()).get(), entity2); + assertEntityFoo(_helper.getNode(entity3.getUrn()).get(), entity3); assertEquals(_testMetricListener.entitiesAdded, 3); assertEquals(_testMetricListener.entityAddedEvents, 1); _dao.removeEntities(Arrays.asList(entity1.getUrn(), entity3.getUrn())); - assertFalse(_dao.getNode(entity1.getUrn()).isPresent()); - assertTrue(_dao.getNode(entity2.getUrn()).isPresent()); - assertFalse(_dao.getNode(entity3.getUrn()).isPresent()); + assertFalse(_helper.getNode(entity1.getUrn()).isPresent()); + assertTrue(_helper.getNode(entity2.getUrn()).isPresent()); + assertFalse(_helper.getNode(entity3.getUrn()).isPresent()); assertEquals(_testMetricListener.entitiesRemoved, 2); assertEquals(_testMetricListener.entityRemovedEvents, 1); } @@ -166,9 +163,9 @@ public void testAddRelationshipNodeNonExist() throws Exception { _dao.addRelationship(relationship, REMOVE_NONE); - assertRelationshipFoo(_dao.getEdges(relationship), 1); - assertEntityFoo(_dao.getNode(urn1).get(), new EntityFoo().setUrn(urn1)); - assertEntityBar(_dao.getNode(urn2).get(), new EntityBar().setUrn(urn2)); + assertRelationshipFoo(_helper.getEdges(relationship), 1); + assertEntityFoo(_helper.getNode(urn1).get(), new EntityFoo().setUrn(urn1)); + assertEntityBar(_helper.getNode(urn2).get(), new EntityBar().setUrn(urn2)); assertEquals(_testMetricListener.relationshipsAdded, 1); assertEquals(_testMetricListener.relationshipAddedEvents, 1); } @@ -183,7 +180,7 @@ public void testPartialUpdateEntityCreatedByRelationship() throws Exception { // Check if adding an entity with same urn and with label creates a new node _dao.addEntity(new EntityFoo().setUrn(urn1)); - assertEquals(_dao.getAllNodes(urn1).size(), 1); + assertEquals(_helper.getAllNodes(urn1).size(), 1); } @Test @@ -192,36 +189,36 @@ public void testAddRemoveRelationships() throws Exception { FooUrn urn1 = makeFooUrn(1); EntityFoo entity1 = new EntityFoo().setUrn(urn1).setValue("foo"); _dao.addEntity(entity1); - assertEntityFoo(_dao.getNode(urn1).get(), entity1); + assertEntityFoo(_helper.getNode(urn1).get(), entity1); // Add entity2 BarUrn urn2 = makeBarUrn(2); EntityBar entity2 = new EntityBar().setUrn(urn2).setValue("bar"); _dao.addEntity(entity2); - assertEntityBar(_dao.getNode(urn2).get(), entity2); + assertEntityBar(_helper.getNode(urn2).get(), entity2); // add relationship1 (urn1 -> urn2) RelationshipFoo relationship1 = new RelationshipFoo().setSource(urn1).setDestination(urn2); _dao.addRelationship(relationship1, REMOVE_NONE); - assertRelationshipFoo(_dao.getEdges(relationship1), 1); + assertRelationshipFoo(_helper.getEdges(relationship1), 1); // add relationship1 again _dao.addRelationship(relationship1); - assertRelationshipFoo(_dao.getEdges(relationship1), 1); + assertRelationshipFoo(_helper.getEdges(relationship1), 1); // add relationship2 (urn1 -> urn3) Urn urn3 = makeUrn(3); RelationshipFoo relationship2 = new RelationshipFoo().setSource(urn1).setDestination(urn3); _dao.addRelationship(relationship2); - assertRelationshipFoo(_dao.getEdgesFromSource(urn1, RelationshipFoo.class), 2); + assertRelationshipFoo(_helper.getEdgesFromSource(urn1, RelationshipFoo.class), 2); // remove relationship1 _dao.removeRelationship(relationship1); - assertRelationshipFoo(_dao.getEdges(relationship1), 0); + assertRelationshipFoo(_helper.getEdges(relationship1), 0); // remove relationship1 & relationship2 _dao.removeRelationships(Arrays.asList(relationship1, relationship2)); - assertRelationshipFoo(_dao.getEdgesFromSource(urn1, RelationshipFoo.class), 0); + assertRelationshipFoo(_helper.getEdgesFromSource(urn1, RelationshipFoo.class), 0); assertEquals(_testMetricListener.relationshipsAdded, 3); @@ -237,40 +234,40 @@ public void testAddRelationshipRemoveAll() throws Exception { FooUrn urn1 = makeFooUrn(1); EntityFoo entity1 = new EntityFoo().setUrn(urn1).setValue("foo"); _dao.addEntity(entity1); - assertEntityFoo(_dao.getNode(urn1).get(), entity1); + assertEntityFoo(_helper.getNode(urn1).get(), entity1); // Add entity2 BarUrn urn2 = makeBarUrn(2); EntityBar entity2 = new EntityBar().setUrn(urn2).setValue("bar"); _dao.addEntity(entity2); - assertEntityBar(_dao.getNode(urn2).get(), entity2); + assertEntityBar(_helper.getNode(urn2).get(), entity2); // add relationship1 (urn1 -> urn2) RelationshipFoo relationship1 = new RelationshipFoo().setSource(urn1).setDestination(urn2); _dao.addRelationship(relationship1, REMOVE_NONE); - assertRelationshipFoo(_dao.getEdges(relationship1), 1); + assertRelationshipFoo(_helper.getEdges(relationship1), 1); // add relationship2 (urn1 -> urn3), removeAll from source Urn urn3 = makeUrn(3); RelationshipFoo relationship2 = new RelationshipFoo().setSource(urn1).setDestination(urn3); _dao.addRelationship(relationship2, REMOVE_ALL_EDGES_FROM_SOURCE); - assertRelationshipFoo(_dao.getEdgesFromSource(urn1, RelationshipFoo.class), 1); + assertRelationshipFoo(_helper.getEdgesFromSource(urn1, RelationshipFoo.class), 1); // add relationship3 (urn4 -> urn3), removeAll from destination Urn urn4 = makeUrn(4); RelationshipFoo relationship3 = new RelationshipFoo().setSource(urn4).setDestination(urn3); _dao.addRelationship(relationship3, REMOVE_ALL_EDGES_TO_DESTINATION); - assertRelationshipFoo(_dao.getEdgesFromSource(urn1, RelationshipFoo.class), 0); - assertRelationshipFoo(_dao.getEdgesFromSource(urn4, RelationshipFoo.class), 1); + assertRelationshipFoo(_helper.getEdgesFromSource(urn1, RelationshipFoo.class), 0); + assertRelationshipFoo(_helper.getEdgesFromSource(urn4, RelationshipFoo.class), 1); // add relationship3 again without removal _dao.addRelationship(relationship3); - assertRelationshipFoo(_dao.getEdgesFromSource(urn4, RelationshipFoo.class), 1); + assertRelationshipFoo(_helper.getEdgesFromSource(urn4, RelationshipFoo.class), 1); // add relationship3 again, removeAll from source & destination _dao.addRelationship(relationship3, REMOVE_ALL_EDGES_FROM_SOURCE_TO_DESTINATION); - assertRelationshipFoo(_dao.getEdgesFromSource(urn1, RelationshipFoo.class), 0); - assertRelationshipFoo(_dao.getEdgesFromSource(urn4, RelationshipFoo.class), 1); + assertRelationshipFoo(_helper.getEdgesFromSource(urn1, RelationshipFoo.class), 0); + assertRelationshipFoo(_helper.getEdgesFromSource(urn4, RelationshipFoo.class), 1); } @Test @@ -285,7 +282,7 @@ public void upsertNodeAddNewProperty() throws Exception { _dao.addEntity(updatedEntity); // then - assertEntityFoo(_dao.getNode(urn).get(), updatedEntity); + assertEntityFoo(_helper.getNode(urn).get(), updatedEntity); } @Test @@ -326,7 +323,7 @@ public void upsertNodeChangeProperty() throws Exception { _dao.addEntity(updatedEntity); // then - assertEntityFoo(_dao.getNode(urn).get(), updatedEntity); + assertEntityFoo(_helper.getNode(urn).get(), updatedEntity); } @Test @@ -368,7 +365,7 @@ public void upsertNodeRemovedProperty() throws Exception { // then // Upsert won't ever delete properties. - assertEntityFoo(_dao.getNode(urn).get(), initialEntity); + assertEntityFoo(_helper.getNode(urn).get(), initialEntity); } @Test @@ -398,31 +395,6 @@ RelationshipFoo.class, new Filter().setCriteria(new CriterionArray()), 0, 10), Collections.singletonList(initialRelationship)); } - @Test - public void testGetNodeTypeFromUrn() { - assertEquals(_dao.getNodeType(makeBarUrn(1)), ":`com.linkedin.testing.EntityBar`"); - assertEquals(_dao.getNodeType(makeFooUrn(1)), ":`com.linkedin.testing.EntityFoo`"); - assertEquals(_dao.getNodeType(makeUrn(1, "foo")), ":`com.linkedin.testing.EntityFoo`"); - assertEquals(_dao.getNodeType(makeUrn("1")), ":UNKNOWN"); - - // test consistency !! - assertEquals(_dao.getNodeType(makeBarUrn(1)), getTypeOrEmptyString(EntityBar.class)); - assertEquals(_dao.getNodeType(makeFooUrn(1)), getTypeOrEmptyString(EntityFoo.class)); - } - - @Test - public void testBuildStatement() { - final FooUrn urn = makeFooUrn(0); - final String queryTemplate = "dummy query template"; - final Map queryParams = new HashMap<>(); - queryParams.put("urn", urn); - - final Statement queryStatement = _dao.buildStatement(queryTemplate, queryParams); - - assertEquals(queryStatement.getCommandText(), queryTemplate); - assertEquals(queryStatement.getParams().get("urn"), urn.toString()); - } - private void assertEntityFoo(@Nonnull Map node, @Nonnull EntityFoo entity) { assertEquals(node.get("urn"), entity.getUrn().toString()); assertEquals(node.get("value"), entity.getValue()); diff --git a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformerTest.java b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformerTest.java new file mode 100644 index 000000000..d31268866 --- /dev/null +++ b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jQueriesTransformerTest.java @@ -0,0 +1,27 @@ +package com.linkedin.metadata.dao.internal; + +import com.linkedin.testing.EntityBar; +import com.linkedin.testing.EntityFoo; +import org.testng.annotations.Test; + +import static com.linkedin.metadata.dao.Neo4jUtil.*; +import static com.linkedin.testing.TestUtils.*; +import static org.testng.Assert.*; + + +public class Neo4jQueriesTransformerTest { + + @Test + public void testGetNodeTypeFromUrn() { + final Neo4jQueriesTransformer transformer = new Neo4jQueriesTransformer(getAllTestEntities()); + + assertEquals(transformer.getNodeType(makeBarUrn(1)), ":`com.linkedin.testing.EntityBar`"); + assertEquals(transformer.getNodeType(makeFooUrn(1)), ":`com.linkedin.testing.EntityFoo`"); + assertEquals(transformer.getNodeType(makeUrn(1, "foo")), ":`com.linkedin.testing.EntityFoo`"); + assertEquals(transformer.getNodeType(makeUrn("1")), ":UNKNOWN"); + + // test consistency !! + assertEquals(transformer.getNodeType(makeBarUrn(1)), getTypeOrEmptyString(EntityBar.class)); + assertEquals(transformer.getNodeType(makeFooUrn(1)), getTypeOrEmptyString(EntityFoo.class)); + } +} \ No newline at end of file diff --git a/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jTestHelper.java b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jTestHelper.java new file mode 100644 index 000000000..809b1b076 --- /dev/null +++ b/dao-impl/neo4j-dao/src/test/java/com/linkedin/metadata/dao/internal/Neo4jTestHelper.java @@ -0,0 +1,106 @@ +package com.linkedin.metadata.dao.internal; + +import com.linkedin.common.urn.Urn; +import com.linkedin.data.template.RecordTemplate; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Query; +import org.neo4j.driver.Session; + +import static com.linkedin.metadata.dao.Neo4jUtil.*; +import static com.linkedin.metadata.dao.utils.ModelUtils.*; + + +/** + * Helper for making queries in unit tests. + */ +public final class Neo4jTestHelper { + private final Driver _driver; + private final Neo4jQueriesTransformer _neo4jQueriesTransformer; + + private Neo4jTestHelper(@Nonnull Driver driver, @Nonnull Neo4jQueriesTransformer neo4jQueriesTransformer) { + _driver = driver; + _neo4jQueriesTransformer = neo4jQueriesTransformer; + } + + public Neo4jTestHelper(@Nonnull Driver driver) { + this(driver, new Neo4jQueriesTransformer()); + } + + public Neo4jTestHelper(@Nonnull Driver driver, @Nonnull Set> entitiesSet) { + this(driver, new Neo4jQueriesTransformer(entitiesSet)); + } + + private List> execute(@Nonnull Query query) { + try (Session session = _driver.session()) { + return session.run(query) + .list() + .stream() + .map(record -> record.values().get(0).asMap()) + .collect(Collectors.toList()); + } + } + + @Nonnull + public Optional> getNode(@Nonnull Urn urn) { + List> nodes = getAllNodes(urn); + if (nodes.isEmpty()) { + return Optional.empty(); + } + return Optional.of(nodes.get(0)); + } + + @Nonnull + public List> getAllNodes(@Nonnull Urn urn) { + final String matchTemplate = "MATCH (node%s {urn: $urn}) RETURN node"; + + final String sourceType = _neo4jQueriesTransformer.getNodeType(urn); + final String statement = String.format(matchTemplate, sourceType); + + final Map params = new HashMap<>(); + params.put("urn", urn.toString()); + + return execute(new Query(statement, params)); + } + + @Nonnull + public List> getEdges(@Nonnull RecordTemplate relationship) { + final Urn sourceUrn = getSourceUrnFromRelationship(relationship); + final Urn destinationUrn = getDestinationUrnFromRelationship(relationship); + final String relationshipType = getType(relationship); + + final String sourceType = _neo4jQueriesTransformer.getNodeType(sourceUrn); + final String destinationType = _neo4jQueriesTransformer.getNodeType(destinationUrn); + + final String matchTemplate = + "MATCH (source%s {urn: $sourceUrn})-[r:%s]->(destination%s {urn: $destinationUrn}) RETURN r"; + final String statement = String.format(matchTemplate, sourceType, relationshipType, destinationType); + + final Map params = new HashMap<>(); + params.put("sourceUrn", sourceUrn.toString()); + params.put("destinationUrn", destinationUrn.toString()); + + return execute(new Query(statement, params)); + } + + @Nonnull + public List> getEdgesFromSource(@Nonnull Urn sourceUrn, + @Nonnull Class relationshipClass) { + final String relationshipType = getType(relationshipClass); + final String sourceType = _neo4jQueriesTransformer.getNodeType(sourceUrn); + + final String matchTemplate = "MATCH (source%s {urn: $sourceUrn})-[r:%s]->() RETURN r"; + final String statement = String.format(matchTemplate, sourceType, relationshipType); + + final Map params = new HashMap<>(); + params.put("sourceUrn", sourceUrn.toString()); + + return execute(new Query(statement, params)); + } +}