本文旨在简析 Spark 读取数据库的一些关键源码
Spark如何读取数据库数据
像其他的数据映射框架一样(如hibernate,mybatis等),spark如果想读取数据库数据,也绕不开JDBC链接,毕竟这是代码与数据库“交流”的官方途径。spark如果想快速读取数据库中的数据,需要解决的事情包括但不限于:
分布式读取
原始数据到RDD/DataFrame的映射
所以这篇小文主要围绕这两个方面做下源码的简析
关于spark操作数据库API,可以参考这篇文档:Spark JDBC系列--取数的四种方式
源码简析
1.JDBC API公共入口
入口源码:
org.apache.spark.sql.DataFrameReader ... private def jdbc( url: String, table: String, parts: Array[Partition], connectionProperties: Properties): DataFrame = { val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) //关键点 val relation = JDBCRelation(url, table, parts, props)(sparkSession) //逻辑分区的创建,action后会触发读取 sparkSession.baseRelationToDataFrame(relation) }
通过观察源码可知,四种取数API的参数虽然略有不同,但最终都转换成了一个Array[Partition]
,即分区条件数组。
2.指定column的取数API分区原理简析
此处列举提供long型column的分区模式的API
的分区原理,先看源码:
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { //单分区模式会进入此条件 return Array[Partition](JDBCPartition(null, 0)) } //合法性校验 val lowerBound = partitioning.lowerBound val upperBound = partitioning.upperBound .... //分区调整 val numPartitions = if ((upperBound - lowerBound) >= partitioning.numPartitions) { partitioning.numPartitions } else { upperBound - lowerBound } //计算步长 val stride: Long = upperBound / numPartitions - lowerBound / numPartitions val column = partitioning.column var i: Int = 0 var currentValue: Long = lowerBound var ans = new ArrayBuffer[Partition]() //根据步长,根据提供的最大、最小值做步长累计,确定边界后组装where查询条件 while (i < numPartitions) { //注意此处,会存在单边限制条件的情况,如:JDBCPartition(id >= 901,9) val lBound = if (i != 0) s"$column >= $currentValue" else null currentValue += stride val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null val whereClause = if (uBound == null) { lBound } else if (lBound == null) { s"$uBound or $column is null" } else { s"$lBound AND $uBound" } ans += JDBCPartition(whereClause, i) i = i + 1 } ans.toArray }
测试代码与分区结果如下:
入参为: lowerBound=1, upperBound=1000, numPartitions=10对应分区数组为: JDBCPartition(id < 101 or id is null,0), JDBCPartition(id >= 101 AND id < 201,1), JDBCPartition(id >= 201 AND id < 301,2), JDBCPartition(id >= 301 AND id < 401,3), JDBCPartition(id >= 401 AND id < 501,4), JDBCPartition(id >= 501 AND id < 601,5), JDBCPartition(id >= 601 AND id < 701,6), JDBCPartition(id >= 701 AND id < 801,7), JDBCPartition(id >= 801 AND id < 901,8), JDBCPartition(id >= 901,9)
这种使用方式存在误用场景
,即通过指定一段ID的最大最小值
(而非整张表真正的最大最小值去取数据),则依然会取出全表数据
,且发生数据倾斜
,原因就在于第一个分区和最后一个分区的where条件处理,所以如果需要指定范围或更多条件,建议使用支持自定义分区条件
的API。
3.数据结果映射
函数:
org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation//获取dataframe的schema,即对数据库的字段类型和spark的数据类型做映射override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)//具体实现org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD def resolveTable(url: String, table: String, properties: Properties): StructType = { //url中识别出需要使用的方言 val dialect = JdbcDialects.get(url) val ncols = rsmd.getColumnCount val fields = new Array[StructField](ncols) var i = 0 .... while (i < ncols) { val columnName = rsmd.getColumnLabel(i + 1) val dataType = rsmd.getColumnType(i + 1) val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) .... //根据不同方言的约定做映射,未找到时使用默认映射规则 val columnType =dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } return new StructType(fields) 字段映射的默认配置例举: val answer = sqlType match { .... case java.sql.Types.BLOB => BinaryType case java.sql.Types.BOOLEAN => BooleanType case java.sql.Types.CHAR => StringType case java.sql.Types.CLOB => StringType case java.sql.Types.DATALINK => null case java.sql.Types.DATE => DateType case java.sql.Types.DECIMAL if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType .... }
此处例举MySQL的方言实现:
所有的方言实现都此包下:org.apache.spark.sql.jdbc.*,实现请自行参考。 MySQL方言:private case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { //关键实现 if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. md.putLong("binarylong", 1) Option(LongType) } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { Option(BooleanType) } else None } .... }
从源码可以看出,MySQL只对bit和tinyint类型进行了约束,其他类型使用了spark的默认配置,所以在读取数据时,需要考虑spark中的方言映射,是否对已存在的数据造成影响,避免数据失真。
此时 JDBCRelation
对象已经完成构造。
4.RDD构造与逻辑分区生成
根据之前生成的 JDBCRelation,sparkSession会把任务加入逻辑执行计划。当遇到action操作时,会转为物理执行计划,
org.apache.spark.sql.SparkSession//逻辑执行计划构建,细节不写了,源码我也没怎么研究过def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { Dataset.ofRows(self, LogicalRelation(baseRelation)) } org.apache.spark.sql.execution.datasources.DataSourceStrategy//物理执行计划object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation..... //JDBCRelation继承了PrunedFilteredScan,进入此case分支,并调用buildScan方法 case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) => pruneFilterProject( l, projects, filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil case PhysicalOperation..... }
JDBCRelation
的 buildScan 方法执行时,会调用JDBCRDD
的 scanTable 方法新建 RDD,其中计算前加入的 filter 条件,会合并到JDBC查询where条件中,使用AND
连接:
private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, fqTable: String, columns: Array[String], filters: Array[Filter], partitions: Array[Partition], url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { override def getPartitions: Array[Partition] = partitions ..... private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" } else if (part.whereClause != null) { "WHERE " + part.whereClause } else if (filterWhereClause.length > 0) { "WHERE " + filterWhereClause } else { "" } } //compute方法为action触发时,执行的SQL语句,并对结果按之前的约定做数据映射 override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = new Iterator[InternalRow] { 。。。。 //实现细节不再展开,主要是JDBC查询操作和数据类型映射}
filter条件使用示例:
val url = "jdbc:mysql://mysqlHost:3306/database"val tableName = "table"val columnName = "id"val lowerBound = getMinId() val upperBound = getMaxId() val numPartitions = 200// 设置连接用户&密码val prop = new java.util.Properties prop.setProperty("user","username") prop.setProperty("password","pwd")// 对mysql数据进行过滤val jdbcDF = sqlContext.read.jdbc(url,tableName, columnName, lowerBound, upperBound,prop).where("date='2017-11-30'").filter("name is not null")
where 和 filter 是等价的,过滤条件将在 where 语句中生效,多个条件会用And
进行拼接。
结语
读取数据库数据时,可以到对应的源码中,debug分析。
作者:wuli_小博
链接:https://www.jianshu.com/p/429e64663b0e
共同學習,寫下你的評論
評論加載中...
作者其他優質文章