Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the TPCDS schema based on the Spark codebase #201

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/main/scala/com/databricks/spark/sql/perf/Tables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.slf4j.LoggerFactory

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.ColumnName
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext, SaveMode}
Expand Down Expand Up @@ -95,7 +96,8 @@ trait DataGenerator extends Serializable {


abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false)
useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false,
useStringForCharVarchar: Boolean = true)
extends Serializable {

def dataGenerator: DataGenerator
Expand All @@ -105,11 +107,21 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,

def sparkContext = sqlContext.sparkContext

case class Table(name: String, partitionColumns: Seq[String], fields: StructField*) {
val schema = StructType(fields)
object Table {

def apply(name: String, partitionColumns: Seq[String], fields: StructField*): Table = {
Table(name, partitionColumns, StructType(fields))
}

def apply(name: String, partitionColumns: Seq[String], schemaString: String): Table = {
Table(name, partitionColumns, StructType.fromDDL(schemaString))
}
}

case class Table(name: String, partitionColumns: Seq[String], schema: StructType) {

def nonPartitioned: Table = {
Table(name, Nil, fields : _*)
Table(name, Nil, schema)
}

/**
Expand Down Expand Up @@ -144,7 +156,12 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,

val convertedData = {
val columns = schema.fields.map { f =>
col(f.name).cast(f.dataType).as(f.name)
val expr = f.dataType match {
// Needs right-side padding for char types
case CharType(n) => rpad(new ColumnName(f.name), n, " ")
case _ => new ColumnName(f.name).cast(f.dataType)
}
expr.as(f.name)
}
stringData.select(columns: _*)
}
Expand All @@ -156,16 +173,17 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
}

def convertTypes(): Table = {
val newFields = fields.map { field =>
val newFields = schema.fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType if useDoubleForDecimal => DoubleType
case date: DateType if useStringForDate => StringType
case _: CharType | _: VarcharType if useStringForCharVarchar => StringType
case other => other
}
field.copy(dataType = newDataType)
}

Table(name, partitionColumns, newFields:_*)
Table(name, partitionColumns, StructType(newFields))
}

def genData(
Expand Down Expand Up @@ -274,7 +292,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
log.info(s"Analyzing table $name.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS")
if (analyzeColumns) {
val allColumns = fields.map(_.name).mkString(", ")
val allColumns = schema.fields.map(_.name).mkString(", ")
println(s"Analyzing table $name columns $allColumns.")
log.info(s"Analyzing table $name columns $allColumns.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ case class GenTPCDSDataConfig(
format: String = null,
useDoubleForDecimal: Boolean = false,
useStringForDate: Boolean = false,
useStringForCharVarchar: Boolean = true,
overwrite: Boolean = false,
partitionTables: Boolean = true,
clusterByPartitionColumns: Boolean = true,
Expand Down Expand Up @@ -65,6 +66,9 @@ object GenTPCDSData {
opt[Boolean]('e', "useStringForDate")
.action((x, c) => c.copy(useStringForDate = x))
.text("true to replace DateType with StringType")
opt[Boolean]('r', "useStringForCharVarchar")
.action((x, c) => c.copy(useStringForCharVarchar = x))
.text("true to replace CharType/VarcharType with StringType")
opt[Boolean]('o', "overwrite")
.action((x, c) => c.copy(overwrite = x))
.text("overwrite the data that is already there")
Expand Down Expand Up @@ -106,7 +110,8 @@ object GenTPCDSData {
dsdgenDir = config.dsdgenDir,
scaleFactor = config.scaleFactor,
useDoubleForDecimal = config.useDoubleForDecimal,
useStringForDate = config.useStringForDate)
useStringForDate = config.useStringForDate,
useStringForCharVarchar = config.useStringForCharVarchar)

tables.genData(
location = config.location,
Expand Down
Loading