Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a parameter that controls the number of StreamLoad tasks committed per partition #92 #99

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
446 changes: 271 additions & 175 deletions spark-doris-connector/pom.xml

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,16 @@ public interface ConfigurationOptions {

int DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT = 50;




/*
Set only one StreamLoad task to be submitted per partition
to ensure that task retries do not result in repeated submission
of StreamLoad tasks on the same batch of data if the task fails.
*/
String DORIS_SINK_PER_PARTITION_TASK_ATOMICITY = "doris.sink.per.partition.task.atomicity";

boolean DORIS_SINK_PER_PARTITION_TASK_ATOMICITY_DEFAULT = false;

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark;
package org.apache.doris.spark.load;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.cache.RemovalListener;
import com.google.common.cache.RemovalNotification;
import org.apache.doris.spark.cfg.SparkSettings;
import org.apache.doris.spark.exception.DorisException;

import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.spark;
package org.apache.doris.spark.load;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.doris.spark.cfg.ConfigurationOptions;
import org.apache.doris.spark.cfg.SparkSettings;
import org.apache.doris.spark.exception.StreamLoadException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.BackendV2;
import org.apache.doris.spark.rest.models.RespContent;
import org.apache.doris.spark.util.ListUtils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
Expand All @@ -45,10 +46,17 @@
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.sql.Timestamp;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
Expand All @@ -73,13 +81,11 @@ public class DorisStreamLoad implements Serializable {
private String tbl;
private String authEncoded;
private String columns;
private String[] dfColumns;
private String maxFilterRatio;
private Map<String, String> streamLoadProp;
private static final long cacheExpireTimeout = 4 * 60;
private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
private final String fileType;
private final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSSSS");

public DorisStreamLoad(SparkSettings settings) {
String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
Expand All @@ -101,11 +107,6 @@ public DorisStreamLoad(SparkSettings settings) {
}
}

public DorisStreamLoad(SparkSettings settings, String[] dfColumns) {
this(settings);
this.dfColumns = dfColumns;
}

public String getLoadUrlStr() {
if (StringUtils.isEmpty(loadUrlStr)) {
return "";
Expand Down Expand Up @@ -168,7 +169,7 @@ public String listToString(List<List<Object>> rows) {
}


public void loadV2(List<List<Object>> rows) throws StreamLoadException, JsonProcessingException {
public void loadV2(List<List<Object>> rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException {
if (fileType.equals("csv")) {
load(listToString(rows));
} else if(fileType.equals("json")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.doris.spark.sql

import org.apache.doris.spark.DorisStreamLoad
import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.cfg.SparkSettings
import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME
import org.apache.doris.spark.writer.DorisWriter
import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
Expand All @@ -28,12 +28,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.slf4j.{Logger, LoggerFactory}

import java.io.IOException
import java.time.Duration
import java.util
import java.util.Objects
import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.util.{Failure, Success}

private[sql] class DorisSourceProvider extends DataSourceRegister
with RelationProvider
Expand All @@ -60,58 +55,9 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
val sparkSettings = new SparkSettings(sqlContext.sparkContext.getConf)
sparkSettings.merge(Utils.params(parameters, logger).asJava)
// init stream loader
val dorisStreamLoader = new DorisStreamLoad(sparkSettings, data.columns)
val writer = new DorisWriter(sparkSettings)
writer.write(data)

val maxRowCount = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
val maxRetryTimes = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
val sinkTaskPartitionSize = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
val sinkTaskUseRepartition = sparkSettings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION, ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
val batchInterValMs = sparkSettings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS, ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)

logger.info(s"maxRowCount ${maxRowCount}")
logger.info(s"maxRetryTimes ${maxRetryTimes}")
logger.info(s"batchInterVarMs ${batchInterValMs}")

var resultRdd = data.rdd
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}

resultRdd.foreachPartition(partition => {
val rowsBuffer: util.List[util.List[Object]] = new util.ArrayList[util.List[Object]](maxRowCount)
partition.foreach(row => {
val line: util.List[Object] = new util.ArrayList[Object]()
for (i <- 0 until row.size) {
val field = row.get(i)
line.add(field.asInstanceOf[AnyRef])
}
rowsBuffer.add(line)
if (rowsBuffer.size > maxRowCount - 1 ) {
flush()
}
})
// flush buffer
if (!rowsBuffer.isEmpty) {
flush()
}

/**
* flush data to Doris and do retry when flush error
*
*/
def flush(): Unit = {
Utils.retry[Unit, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadV2(rowsBuffer)
rowsBuffer.clear()
} match {
case Success(_) =>
case Failure(e) =>
throw new IOException(
s"Failed to load $maxRowCount batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
}

})
new BaseRelation {
override def sqlContext: SQLContext = unsupportedException

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,27 @@

package org.apache.doris.spark.sql

import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.spark.rdd.RDD
import org.apache.doris.spark.cfg.SparkSettings
import org.apache.doris.spark.writer.DorisWriter
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.slf4j.{Logger, LoggerFactory}

import java.io.IOException
import java.time.Duration
import java.util
import java.util.Objects
import scala.collection.JavaConverters._
import scala.util.{Failure, Success}

private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSettings) extends Sink with Serializable {

private val logger: Logger = LoggerFactory.getLogger(classOf[DorisStreamLoadSink].getName)
@volatile private var latestBatchId = -1L
val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
val sinkTaskPartitionSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
val sinkTaskUseRepartition = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION, ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
val batchInterValMs = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS, ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)

val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)
private val writer = new DorisWriter(settings)

override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= latestBatchId) {
logger.info(s"Skipping already committed batch $batchId")
} else {
write(data.rdd)
writer.write(data)
latestBatchId = batchId
}
}

def write(rdd: RDD[Row]): Unit = {
var resultRdd = rdd
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd
.map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
.foreachPartition(partition => {
partition
.grouped(batchSize)
.foreach(batch => flush(batch))
})

/**
* flush data to Doris and do retry when flush error
*
*/
def flush(batch: Iterable[util.List[Object]]): Unit = {
Utils.retry[Unit, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadV2(batch.toList.asJava)
} match {
case Success(_) =>
case Failure(e) =>
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max $maxRetryTimes retry times.", e)
}
}
}

override def toString: String = "DorisStreamLoadSink"
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import scala.annotation.tailrec
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

private[sql] object Utils {
private[spark] object Utils {
/**
* quote column name
* @param colName column name
Expand Down Expand Up @@ -169,7 +169,9 @@ private[sql] object Utils {
assert(retryTimes >= 0)
val result = Try(f)
result match {
case Success(result) => Success(result)
case Success(result) =>
LockSupport.parkNanos(interval.toNanos)
Success(result)
case Failure(exception: T) if retryTimes > 0 =>
logger.warn(s"Execution failed caused by: ", exception)
logger.warn(s"$retryTimes times retry remaining, the next will be in ${interval.toMillis}ms")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.writer

import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.sql.DataFrame
import org.slf4j.{Logger, LoggerFactory}

import java.io.IOException
import java.time.Duration
import java.util
import java.util.Objects
import scala.collection.JavaConverters._
import scala.util.{Failure, Success}

class DorisWriter(settings: SparkSettings) extends Serializable {

private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])

val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean
private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)
val partitionTaskAtomicity = settings.getProperty(ConfigurationOptions.DORIS_SINK_PER_PARTITION_TASK_ATOMICITY,
ConfigurationOptions.DORIS_SINK_PER_PARTITION_TASK_ATOMICITY_DEFAULT.toString).toBoolean
private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)

def write(dataFrame: DataFrame): Unit = {
var resultRdd = dataFrame.rdd
val dfColumns = dataFrame.columns
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
if (partitionTaskAtomicity) {
resultRdd
.map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
.foreachPartition(partition => {
flush(partition.toIterable, dfColumns)
})
} else {
resultRdd
.map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
.foreachPartition(partition => {
partition
.grouped(batchSize)
.foreach(batch => flush(batch, dfColumns))
})
}

/**
* flush data to Doris and do retry when flush error
*
*/
def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = {
Utils.retry[Unit, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns)
} match {
case Success(_) =>
case Failure(e) =>
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
}

}


}