diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/constraints/DeltaInvariantCheckerExec.scala b/spark/src/main/scala/org/apache/spark/sql/delta/constraints/DeltaInvariantCheckerExec.scala index b3c9985692..0a1e063eff 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/constraints/DeltaInvariantCheckerExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/constraints/DeltaInvariantCheckerExec.scala @@ -27,10 +27,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.delta.util.AnalysisHelper import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy, UnaryExecNode} import org.apache.spark.sql.types.StructType @@ -71,9 +73,17 @@ case class DeltaInvariantCheckerExec( if (constraints.isEmpty) return child.execute() val invariantChecks = DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, session) - val boundRefs = invariantChecks.map(_.withBoundReferences(child.output)) + + // Resolve current_date()/current_time() expressions. + // We resolve currentTime for all invariants together to make sure we use the same timestamp. + val invariantsFakePlan = AnalysisHelper.FakeLogicalPlan(invariantChecks, Nil) + val newInvariantsPlan = optimizer.ComputeCurrentTime(invariantsFakePlan) + val localOutput = child.output child.execute().mapPartitionsInternal { rows => + val boundRefs = newInvariantsPlan.expressions + .asInstanceOf[Seq[CheckDeltaInvariant]] + .map(_.withBoundReferences(localOutput)) val assertions = UnsafeProjection.create(boundRefs) rows.map { row => assertions(row) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/CheckConstraintsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/CheckConstraintsSuite.scala index ef67d70138..8799ed020c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/CheckConstraintsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/CheckConstraintsSuite.scala @@ -277,18 +277,20 @@ class CheckConstraintsSuite extends QueryTest } } - testQuietly("constraint with analyzer-evaluated expressions") { + for (expression <- Seq("year(current_date())", "unix_timestamp()")) + testQuietly(s"constraint with analyzer-evaluated expressions. Expression: $expression") { withTestTable { table => - // We use current_timestamp() as the most convenient analyzer-evaluated expression - of course - // in a realistic use case it'd probably not be right to add a constraint on a + // We use current_timestamp()/current_date() as the most convenient + // analyzer-evaluated expressions - of course in a realistic use case + // it'd probably not be right to add a constraint on a // nondeterministic expression. sql(s"ALTER TABLE $table ADD CONSTRAINT maxWithAnalyzerEval " + - s"CHECK (num < unix_timestamp())") + s"CHECK (num < $expression)") val e = intercept[InvariantViolationException] { sql(s"INSERT INTO $table VALUES (${Int.MaxValue}, 'data')") } errorContains(e.getMessage, - "maxwithanalyzereval (num < unix_timestamp()) violated by row") + s"maxwithanalyzereval (num < $expression) violated by row") } }