Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RD-14980: CRUD support #19

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
RD-14980: Support for INSERT, UPDATE, DELETE
bgaidioz committed Dec 30, 2024
commit a11a9268e8be1306e2800a78e4fba87bc9387d37
12 changes: 10 additions & 2 deletions src/main/scala/com/rawlabs/das/databricks/DASDatabricks.scala
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ package com.rawlabs.das.databricks

import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.DatabricksConfig
import com.databricks.sdk.service.catalog.ListTablesRequest
import com.databricks.sdk.service.catalog.{GetTableRequest, ListTablesRequest}
import com.databricks.sdk.service.sql.ListWarehousesRequest
import com.rawlabs.das.sdk.{DASFunction, DASSdk, DASTable}
import com.rawlabs.protocol.das.{FunctionDefinition, TableDefinition}
@@ -41,7 +41,15 @@ class DASDatabricks(options: Map[String, String]) extends DASSdk {
val databricksTables = databricksClient.tables().list(req)
val tables = mutable.Map.empty[String, DASDatabricksTable]
databricksTables.forEach { databricksTable =>
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, databricksTable))
// `databricksTable` is a `TableInfo` and its `getTableConstraints` permits us to know
// if it has a primary key column, which we could use for UPDATE calls. But it's not populated.
// We have to issue an individual `GetTableRequest` call (the single table one, that returns the same
// object but with constraints provided).
val tableDetails = {
val tableReq = new GetTableRequest().setFullName(catalog + '.' + schema + '.' + databricksTable.getName)
databricksClient.tables().get(tableReq)
}
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, tableDetails))
}
tables.toMap
}
150 changes: 136 additions & 14 deletions src/main/scala/com/rawlabs/das/databricks/DASDatabricksTable.scala
Original file line number Diff line number Diff line change
@@ -15,18 +15,21 @@ package com.rawlabs.das.databricks
import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.service.catalog.{ColumnInfo, ColumnTypeName, TableInfo}
import com.databricks.sdk.service.sql._
import com.rawlabs.das.sdk.{DASExecuteResult, DASTable}
import com.rawlabs.das.sdk.{DASExecuteResult, DASSdkException, DASTable}
import com.rawlabs.protocol.das._
import com.rawlabs.protocol.raw.{Type, Value}
import com.typesafe.scalalogging.StrictLogging

import scala.annotation.tailrec
import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import scala.collection.mutable

class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databricksTable: TableInfo)
extends DASTable
with StrictLogging {

private val tableFullName = databricksTable.getSchemaName + '.' + databricksTable.getName

override def getRelSize(quals: Seq[Qual], columns: Seq[String]): (Int, Int) = REL_SIZE

override def execute(
@@ -36,8 +39,7 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
maybeLimit: Option[Long]
): DASExecuteResult = {
val databricksColumns = if (columns.isEmpty) Seq("NULL") else columns.map(databricksColumnName)
var query =
s"SELECT ${databricksColumns.mkString(",")} FROM " + databricksTable.getSchemaName + '.' + databricksTable.getName
var query = s"SELECT ${databricksColumns.mkString(",")} FROM " + tableFullName
val stmt = new ExecuteStatementRequest()
val parameters = new java.util.LinkedList[StatementParameterListItem]
if (quals.nonEmpty) {
@@ -93,9 +95,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick

stmt.setStatement(query).setWarehouseId(warehouseID).setDisposition(Disposition.INLINE).setFormat(Format.JSON_ARRAY)
val executeAPI = client.statementExecution()
val response1 = executeAPI.executeStatement(stmt)
val response = getResult(response1)
new DASDatabricksExecuteResult(executeAPI, response)
val response = executeAPI.executeStatement(stmt)
getResult(response) match {
case Left(error) => throw new DASSdkException(error)
case Right(result) => new DASDatabricksExecuteResult(executeAPI, result)
}
}

private def databricksColumnName(name: String): String = {
@@ -119,21 +123,19 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
override def canSort(sortKeys: Seq[SortKey]): Seq[SortKey] = sortKeys

@tailrec
private def getResult(response: StatementResponse): StatementResponse = {
private def getResult(response: StatementResponse): Either[String, StatementResponse] = {
val state = response.getStatus.getState
logger.info(s"Query ${response.getStatementId} state: $state")
state match {
case StatementState.PENDING | StatementState.RUNNING =>
logger.info(s"Query is still running, polling again in $POLLING_TIME ms")
Thread.sleep(POLLING_TIME)
val response2 = client.statementExecution().getStatement(response.getStatementId)
getResult(response2)
case StatementState.SUCCEEDED => response
case StatementState.FAILED =>
throw new RuntimeException(s"Query failed: ${response.getStatus.getError.getMessage}")
case StatementState.CLOSED =>
throw new RuntimeException(s"Query closed: ${response.getStatus.getError.getMessage}")
case StatementState.CANCELED =>
throw new RuntimeException(s"Query canceled: ${response.getStatus.getError.getMessage}")
case StatementState.SUCCEEDED => Right(response)
case StatementState.FAILED => Left(s"Query failed: ${response.getStatus.getError.getMessage}")
case StatementState.CLOSED => Left(s"Query closed: ${response.getStatus.getError.getMessage}")
case StatementState.CANCELED => Left(s"Query canceled: ${response.getStatus.getError.getMessage}")
}
}

@@ -161,6 +163,26 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
definition.build()
}

// Potential primary key column name found in constraints table metadata.
private var primaryKeyColumn: Option[String] = None

// Try to find a primary key constraint over one column.
if (databricksTable.getTableConstraints == null) {
logger.warn(s"No constraints found for table $tableFullName")
} else {
databricksTable.getTableConstraints.forEach { constraint =>
val primaryKeyConstraint = constraint.getPrimaryKeyConstraint
if (primaryKeyConstraint != null) {
if (primaryKeyConstraint.getChildColumns.size != 1) {
logger.warn("Ignoring composite primary key")
} else {
primaryKeyColumn = Some(primaryKeyConstraint.getChildColumns.iterator().next())
logger.info(s"Found primary key ($primaryKeyColumn)")
}
}
}
}

private def columnType(info: ColumnInfo): Option[Type] = {
val builder = Type.newBuilder()
val columnType = info.getTypeName
@@ -230,6 +252,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
}
}

override def uniqueColumn: String = {
// Return the first column if none.
primaryKeyColumn.getOrElse(databricksTable.getColumns.asScala.head.getName)
}

private def rawValueToParameter(v: Value): StatementParameterListItem = {
logger.debug(s"Converting value to parameter: $v")
val parameter = new StatementParameterListItem()
@@ -286,4 +313,99 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
}
}

override def insert(row: Row): Row = {
bulkInsert(Seq(row)).head
}

// INSERTs can be done in batches, but by inlining values in the query string.
// We don't want to send gigantic query strings accidentally. We try to keep
// queries around that size.
private val MAX_INSERT_CODE_SIZE = 2048

override def bulkInsert(rows: Seq[Row]): Seq[Row] = {
// There's no bulk call in Databricks, we inline values. We build a
// batches of query strings that are at most of MAX_INSERT_CODE_SIZE and
// loop until all rows are consumed.
val columnNames = databricksTable.getColumns.asScala.map(_.getName)
val values = rows.map { row =>
val data = row.getDataMap
columnNames
.map { name =>
val value = data.get(name)
if (value == null) {
"DEFAULT"
} else {
rawValueToDatabricksQueryString(value)
}
}
.mkString("(", ",", ")")
}
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)

val items = values.iterator
while (items.nonEmpty) {
val item = items.next()
val code = StringBuilder.newBuilder
code.append(s"INSERT INTO ${databricksTable.getName} VALUES $item")
while (code.size < MAX_INSERT_CODE_SIZE && items.hasNext) {
code.append(s",${items.next()}")
}
stmt.setStatement(code.toString())
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
}
rows
}

override def delete(rowId: Value): Unit = {
if (primaryKeyColumn.isEmpty) {
throw new IllegalArgumentException(s"Table $tableFullName has no primary key column")
}
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)
stmt.setStatement(
s"DELETE FROM ${databricksTable.getName} WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
)
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
}

// How many rows are accepted in a batch update. Technically we're unlimited
// since updates are sent one by one.
private val MODIFY_BATCH_SIZE = 1000

override def modifyBatchSize: Int = {
MODIFY_BATCH_SIZE
}

override def update(rowId: Value, newValues: Row): Row = {
if (primaryKeyColumn.isEmpty) {
throw new IllegalArgumentException(s"Table $tableFullName has no primary key column")
}
val buffer = mutable.Buffer.empty[String]
newValues.getDataMap
.forEach {
case (name, value) =>
buffer.append(s"${databricksColumnName(name)} = ${rawValueToDatabricksQueryString(value)}")
}
val setValues = buffer.mkString(", ")
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)
stmt.setStatement(
s"UPDATE ${databricksTable.getName} SET $setValues WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
)
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
newValues
}
}
3 changes: 0 additions & 3 deletions src/main/scala/com/rawlabs/das/databricks/id.kt

This file was deleted.