初学Scala(1): Scala实现Hash Join
最近看了些Scala相关的内容,写了个简单的hash join。jion过程从数据源读取两个List[List[Any]](),我把所有的操作都放到List容器里进行将两份数据集,hash到自己写的简单的SimpleHashTable里,每次put进去的时候会返回一个Int值,用于记录两份数据占据的bucket number集合由于两份数据都是基于同一个hash方法进行hash的,join具体发生
最近看了些Scala相关的内容,写了个简单的hash join。
初步实现
jion过程
- 从数据源读取两个List[List[Any]](),我把所有的操作都放到List容器里进行
- 将两份数据集,hash到自己写的简单的SimpleHashTable里,每次put进去的时候会返回一个Int值,用于记录两份数据占据的bucket number集合
- 由于两份数据都是基于同一个hash方法进行hash的,join具体发生在两个hashTable对应的bucket之间
- 遍历需要进行join的buckets,每对bucket之间的join回归到最基本的二层遍历
- 整个过程一共两个文件,SimpleHashTable.scala和HashJoin.scala
- 输入是两个二维的List,输出是一个二维List,支持的是单个键的inner join
- 测试速度:两个10000大小的20个字段的宽表进行hash join,大约0.4s
- HashTable的M值可以针对数据集大小自己定制,尽量让数据集在buckets里打散
可以改进的点有很多,这个hash join还是相当简单的,我比较依赖于foldLeft和map方法,体会到Scala编程代码量很少,用起来很舒服,很强大。
class SimpleHashTable {
val M = 991
val container = new Array[List[Any]](M)
for (i <- 0 to M-1) {
container(i) = List[Any]()
}
def hash(key: String): Int = (key.hashCode() & 0x7fffffff) % M
def put(key: String, value: List[Any]): Int = { // return the hash number
val indice = hash(key)
container(indice) = value :: container(indice)
indice
}
def get(indice: Int): List[Any] = container(indice)
def get(key: String): List[Any] = get(hash(key))
def dataset() = container
}
class HashJoin(list1: List[List[Any]], list2: List[List[Any]]) {
val _list1 = list1
val _list2 = list2
def innerHashJion(col: Int): List[Any] = {
val start = System.currentTimeMillis()
var keys1 = Set[Int]()
var keys2 = Set[Int]()
val sht1 = _list1.par.foldLeft(new SimpleHashTable) { (sht, list) =>
val i = sht.put(list(col).toString, list)
keys1 = keys1 + i
sht
}
val sht2 = _list2.par.foldLeft(new SimpleHashTable) { (sht, list) =>
val i = sht.put(list(col).toString, list)
keys2 = keys2 + i
sht
}
val end = System.currentTimeMillis()
println("Hash took: " + (end-start) + "ms")
getJointRecords((keys1&keys2).toArray, sht1, sht2, col)
}
def getJointRecords(inds: Array[Int], sht1: SimpleHashTable, sht2: SimpleHashTable, col: Int): List[Any] = {
println("joint-keys: " + inds.size)
var ret = scala.collection.immutable.List[Any]()
inds.par.foreach(ind => {
println(Thread.currentThread)
sht1.get(ind).map(record1 => {
sht2.get(ind).map(record2 => {
val r1 = record1.asInstanceOf[List[Any]]
val r2 = record2.asInstanceOf[List[Any]]
if (r1(col) == r2(col)) ret = (r1 ::: r2) :: ret
})
})
})
ret
}
}
测试可以使用下面单例:
object HashJoinTest {
def main(args: Array[String]): Unit = {
test()
}
def test(): Unit = {
val c1 = List(111, "asfd", 23)
val c11 = List(111, "asf", 231)
val c2 = List(333, "e", 1)
val c3 = List(222, "ewr", 80)
val t1 = List(111, "e", 40)
val t11 = List(111, "fge", 30)
val t2 = List(333, "asfd", 80)
val t3 = List(444, "e", 1)
val list1 = List(c1, c11, c2, c3)
val list2 = List(t1, t11, t2, t3)
val hj = new HashJoin(list1, list2)
val ret = hj.innerHashJion(2)
for (i <- (0 to 1)) println(ret(i))
}
}
优化
上面的这种实现,在join结果集并发往同一个List()容器里写的时候会出现性能瓶颈,写的速度会达到10W-100W行/s,而且需要在写的时候加上synchronized实现同步。虽然scala.collection.immutable.List类是不可变的,也是线程安全的,但是在1W join 1W的测试中,0.4s内写入10W行出现了数据丢失,加上synchronized字段可以简单避免这个问题,但同时带来了额外开销。
下面新的HashJoin.scala类,为每个需要join的bucket申请了一个数组空间,让每个线程返回的单个bucket join结果集保存在统一的数组中,最后对结果集进行merge,同时保留了并发求join的特性。
优化HashJoin.scala类之后,测试速度 1W join 1W 只要 0.1s,2W join 2W 时间是 0.2s-0.4s,(M=991的情况下,M可以调整)
class HashJoin(list1: List[List[Any]], list2: List[List[Any]]) {
val _list1 = list1
val _list2 = list2
val M = 991
val retContainer = new Array[List[Any]](M)
for (i <- 0 to M-1) retContainer(i) = List[Any]()
var ret = List[Any]()
def innerHashJion(col: Int): Unit = {
val start = System.currentTimeMillis()
var keys1 = Set[Int]()
var keys2 = Set[Int]()
val sht1 = _list1.par.foldLeft(new SimpleHashTable) { (sht, list) =>
val i = sht.put(list(col).toString, list)
keys1 = keys1 + i
sht
}
val sht2 = _list2.par.foldLeft(new SimpleHashTable) { (sht, list) =>
val i = sht.put(list(col).toString, list)
keys2 = keys2 + i
sht
}
val end = System.currentTimeMillis()
println("Hash took: " + (end-start) + "ms")
val jointKeys = (keys1&keys2).toArray
println("JointKeys Size: " + jointKeys.size)
jointKeys.par.foreach(ind => retContainer(ind) = getBucketRecords(ind, sht1, sht2, col))
def getBucketRecords(ind: Int, sht1: SimpleHashTable, sht2: SimpleHashTable, col: Int): List[Any] = {
var bucketRet = List[Any]()
sht1.get(ind).map(record1 => {
sht2.get(ind).map(record2 => {
val r1 = record1.asInstanceOf[List[Any]]
val r2 = record2.asInstanceOf[List[Any]]
if (r1(col) == r2(col)) bucketRet = (r1 ::: r2) :: bucketRet
})
})
bucketRet
}
}
def getRet: List[Any] = {
mergeRets
ret
}
def mergeRets = {
val t1 = System.currentTimeMillis()
retContainer.foreach({r =>
ret = r ::: ret
})
val t2 = System.currentTimeMillis()
println("Merge Rets took: " + (t2-t1) + " ms")
}
}
我的测试单例如下,数据来自mongodb,进行了一次BSON to List的转换,可以替换掉传入的list1和list2,传入自己想要的测试数据:
object HashJoinTest {
def main(args: Array[String]): Unit = {
mongo()
}
def mongo(): Unit = {
val loadS = System.currentTimeMillis()
val list1 = BsonToList.getMongoList(0, 10000)
val list2 = BsonToList.getMongoList(100000, 10000)
val loadE = System.currentTimeMillis()
println("Load Data took: " + (loadE-loadS) + "ms")
val hj = new HashJoin(list1, list2)
hj.innerHashJion(8)
val ret = hj.getRet
val joinE = System.currentTimeMillis()
println("HashJoin Totally took: " + (joinE-loadE) + "ms")
println("Result size: " + ret.size)
for (i <- (0 to 1)) println(ret(i))
}
}
(全文完)
更多推荐
所有评论(0)