

 /** Called when stage's parents are available and we can now do its task. */
  private def submitMissingTasks(stage: Stage, jobId: Int) {
    logDebug("submitMissingTasks(" + stage + ")")

    // First figure out the indexes of partition ids to compute.
    val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()

    // Use the scheduling pool, job group, description, etc. from an ActiveJob associated
    // with this Stage
    val properties = jobIdToActiveJob(jobId).properties

    runningStages += stage
    // SparkListenerStageSubmitted should be posted before testing whether tasks are
    // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
    // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
    // event.
    stage match {
      case s: ShuffleMapStage =>
        outputCommitCoordinator.stageStart(stage =, maxPartitionId = s.numPartitions - 1)
      case s: ResultStage =>
          stage =, maxPartitionId = s.rdd.partitions.length - 1)
    // 每个partition 的数据本地性结果
    val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
      stage match {
        case s: ShuffleMapStage =>
 { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
        case s: ResultStage =>
 { id =>
            val p = s.partitions(id)
            (id, getPreferredLocs(stage.rdd, p))
    } catch {
      case NonFatal(e) =>
        stage.makeNewStageAttempt(partitionsToCompute.size), properties))
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage

    stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)

    // If there are tasks to execute, record the submission time of the stage. Otherwise,
    // post the even without the submission time, which indicates that this stage was
    // skipped.
    if (partitionsToCompute.nonEmpty) {
      stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
    }, properties))

    // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
    // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
    // the serialized copy of the RDD and for each task we will deserialize it, which means each
    // task gets a different copy of the RDD. This provides stronger isolation between tasks that
    // might modify state of objects referenced in their closures. This is necessary in Hadoop
    // where the JobConf/Configuration object is not thread-safe.
    var taskBinary: Broadcast[Array[Byte]] = null
    var partitions: Array[Partition] = null
    try {
      // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
      // For ResultTask, serialize and broadcast (rdd, func).
      var taskBinaryBytes: Array[Byte] = null
      // taskBinaryBytes and partitions are both effected by the checkpoint status. We need
      // this synchronization in case another concurrent job is checkpointing this RDD, so we get a
      // consistent view of both variables.
      RDDCheckpointData.synchronized {
        taskBinaryBytes = stage match {
          case stage: ShuffleMapStage =>
              closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
          case stage: ResultStage =>
            JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))

        partitions = stage.rdd.partitions

      taskBinary = sc.broadcast(taskBinaryBytes)
    } catch {
      // In the case of a failure during serialization, abort the stage.
      case e: NotSerializableException =>
        abortStage(stage, "Task not serializable: " + e.toString, Some(e))
        runningStages -= stage

        // Abort execution
      case e: Throwable =>
        abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage

        // Abort execution

    val tasks: Seq[Task[_]] = try {
      val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
      stage match {
        case stage: ShuffleMapStage =>
 { id =>
            val locs = taskIdToLocations(id)
            val part = partitions(id)
            stage.pendingPartitions += id
            new ShuffleMapTask(, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
              Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())

        case stage: ResultStage =>
 { id =>
            val p: Int = stage.partitions(id)
            val part = partitions(p)
            val locs = taskIdToLocations(id)
            new ResultTask(, stage.latestInfo.attemptNumber,
              taskBinary, part, locs, id, properties, serializedTaskMetrics,
              Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
    } catch {
      case NonFatal(e) =>
        abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
        runningStages -= stage

    if (tasks.size > 0) {
      logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
        s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
      taskScheduler.submitTasks(new TaskSet(
        tasks.toArray,, stage.latestInfo.attemptNumber, jobId, properties))
    } else {
      // Because we posted SparkListenerStageSubmitted earlier, we should mark
      // the stage as completed here in case there are no tasks to run
      markStageAsFinished(stage, None)

      stage match {
        case stage: ShuffleMapStage =>
          logDebug(s"Stage ${stage} is actually done; " +
              s"(available: ${stage.isAvailable}," +
              s"available outputs: ${stage.numAvailableOutputs}," +
              s"partitions: ${stage.numPartitions})")
        case stage : ResultStage =>
          logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})")


   * Gets the locality information associated with a partition of a particular RDD.
   * This method is thread-safe and is called from both DAGScheduler and SparkContext.
   * @param rdd whose partitions are to be looked at
   * @param partition to lookup locality information for
   * @return list of machines that are preferred by the partition
  def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
    getPreferredLocsInternal(rdd, partition, new HashSet)
   * Recursive implementation for getPreferredLocs.
   * This method is thread-safe because it only accesses DAGScheduler state through thread-safe
   * methods (getCacheLocs()); please be careful when modifying this method, because any new
   * DAGScheduler state accessed by it may require additional synchronization.
   * 这里返回block的数据存储位置,在TaskSetManager【addPendingTask】中,会根据数据位置确定task的数据本地性级别
   * partition计算阶段 <===> block存储阶段
  private def getPreferredLocsInternal(
      rdd: RDD[_],
      partition: Int,
      visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
    // If the partition has already been visited, no need to re-visit.
    // This avoids exponential path exploration.  SPARK-695
    if (!visited.add((rdd, partition))) {
      // Nil has already been returned for previously visited partitions.
      return Nil
    // If the partition is cached, return the cache locations
    // 这里会调用到BlockManager中,通过blockId获取blcokManagerId,即block存储位置
    val cached = getCacheLocs(rdd)(partition)
    if (cached.nonEmpty) {
      return cached
    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    // 这里是读取的checkpointRDD中的数据信息
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {

    // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
    // that has any placement preferences. Ideally we would choose based on transfer sizes,
    // but this will do for now.
    rdd.dependencies.foreach {
      case n: NarrowDependency[_] =>
        for (inPart <- n.getParents(partition)) {
          val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
          if (locs != Nil) {
            return locs

      case _ =>


接下来,如果需要从其他位置拉取数据到task执行位置,然后暂存在task执行位置【详见源码BlockTransferService,这一部分暂时还没深入看源码】,这样,第二个stage在执行时,会发现数据已经在task执行位置存在了,默认的数据本地性级别为NODE_LOCAL,但如果之前有cache操作的话,数据本地性级别就为PROCESS_LOCAL。这就可以解释了,在k8s中,Spark读取Ceph上的数据时,第一个stage的datalocality为NO_PREF、ANY,第二个stage的datalocality为NODE_LOCAL、ANY,符合我们的预期。另外我们发现了,datalocality都包含了ANY,我的理解是,Spark为了确保所有的task都有数据本地性级别,在计算数据本地性级别时,会单独添加ANY 级别【详见源码TaskSetManager#computeValidLocalityLevels】

   * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
   * added to queues using addPendingTask.
  private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
    val levels = new ArrayBuffer[TaskLocality.TaskLocality]
    if (!pendingTasksForExecutor.isEmpty &&
        pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
      levels += PROCESS_LOCAL
    if (!pendingTasksForHost.isEmpty &&
        pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
        在k8s中,executor的pod 都有自己mock的 IP 地址,并不是主机地址
        executor 上开启 hostnetwork ,让pod和宿主机的使用相同的ip
        pendingTasksForHost 存放的是节点和分配到节点上待执行的 task 数据。host <===> ArrayBuffer(task1,task2,task3)
        sched.hasExecutorsAliveOnHost(_) 校验的是 executor 所在的节点【服务器】ip 或者 pod【k8s中】 的 ip 中,是否包含上述的 host

      levels += NODE_LOCAL
    if (!pendingTasksWithNoPrefs.isEmpty) {
      levels += NO_PREF
    if (!pendingTasksForRack.isEmpty &&
        pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
      levels += RACK_LOCAL
    levels += ANY
    logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))



由于Spark在k8s中执行宿主是pod【非k8s中是服务器】,k8s的特性决定了pod是用完即销毁,所以,现在面对的最主要的问题是怎么将之前计算app拉取到pod中的数据保存下来,供下一个app来使用。在我最开始的思考中,是在Spark的executor中,挂载一个hostPath类型的pv,让数据持久化到执行节点 ,供下一个app使用。这需要修改Spark的源码,Spark读取数据时,之前会把数据缓存在执行节点,现在把数据缓存的位置改为pv的数据路径。


  1. 对Spark源码理解没有很深入、透彻,尤其是涉及到数据读写这部分,更是不熟悉。修改源码技术难度大,后期Spark维护、版本升级困难,而且,也担心过修改后会影响整个Spark的计算稳定性。

  2. 发现社区里有一个组件Alluxio,提供了Spark在k8s中的数据本地性保障的技术特性,而且还提供了远端缓存的功能。原理上和上述说的差不多,通过数据落本地+挂载pv的方式,实现了Spark在k8s中的较好的数据本地性保障。详见github项目Alluxio



