if [ -z"${SPARK_HOME}" ]; thensource"$(dirname "$0")"/find-spark-homefi# disable randomized hash for string in Python 3.3+export PYTHONHASHSEED=0exec"${SPARK_HOME}"/bin/spark-classorg.apache.spark.deploy.SparkSubmit"$@"
#!/usr/bin/env bash## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#if [ -z"${SPARK_HOME}" ]; thensource"$(dirname "$0")"/find-spark-homefi."${SPARK_HOME}"/bin/load-spark-env.sh# Find the java binaryif [ -n"${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java"elseif [ "$(command-vjava)" ]; then RUNNER="java"elseecho"JAVA_HOME is not set">&2exit1fifi# Find Spark jars.if [ -d"${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars"else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars"fiif [ !-d"$SPARK_JARS_DIR" ] && [ -z"$SPARK_TESTING$SPARK_SQL_TESTING" ]; thenecho"Failed to find Spark jars directory ($SPARK_JARS_DIR)."1>&2echo"You need to build Spark with the target \"package\" before running this program."1>&2exit1else LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*"fi# Add the launcher build dir to the classpath if requested.if [ -n"$SPARK_PREPEND_CLASSES" ]; then LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH"fi# For testsif [[ -n"$SPARK_TESTING" ]]; thenunsetYARN_CONF_DIRunsetHADOOP_CONF_DIRfi# The launcher library will print arguments separated by a NULL character, to allow arguments with# characters that would be otherwise interpreted by the shell. Read that in a while loop, populating# an array that will be used to exec the final command.## The exit code of the launcher is appended to the output, so the parent shell removes it from the# command array and checks the value to see if the launcher succeeded.build_command() {"$RUNNER"-Xmx128m $SPARK_LAUNCHER_OPTS -cp"$LAUNCH_CLASSPATH"org.apache.spark.launcher.Main"$@"printf"%d\0" $?}# Turn off posix mode since it does not allow process substitutionset+oposixCMD=()DELIM=$'\n'CMD_START_FLAG="false"while IFS=read-d"$DELIM"-rARG; doif [ "$CMD_START_FLAG"=="true" ]; then CMD+=("$ARG")elseif [ "$ARG"==$'\0' ]; then# After NULL character is consumed, change the delimiter and consume command string. DELIM='' CMD_START_FLAG="true"elif [ "$ARG"!="" ]; thenecho"$ARG"fifidone<<(build_command "$@")COUNT=${#CMD[@]}LAST=$((COUNT -1))LAUNCHER_EXIT_CODE=${CMD[$LAST]}# Certain JVM failures result in errors being printed to stdout (instead of stderr), which causes# the code that parses the output of the launcher to get confused. In those cases, check if the# exit code is an integer, and if it's not, handle it as a special error case.if! [[ $LAUNCHER_EXIT_CODE =~ ^[0-9]+$ ]]; thenecho"${CMD[@]}"|head-n-11>&2exit1fiif [ $LAUNCHER_EXIT_CODE !=0 ]; thenexit $LAUNCHER_EXIT_CODEfiCMD=("${CMD[@]:0:$LAST}")exec"${CMD[@]}"
# 真正创建TransportClient的地方TransportChannelHandler channelHandler =createChannelHandler(channel, channelRpcHandler);# 设置channel管道的编解码器channel.pipeline().addLast("encoder", ENCODER).addLast(TransportFrameDecoder.HANDLER_NAME,NettyUtils.createFrameDecoder()).addLast("decoder", DECODER).addLast("idleStateHandler",newIdleStateHandler(0,0,conf.connectionTimeoutMs() /1000))// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this// would require more logic to guarantee if this were not part of the same event loop..addLast("handler", channelHandler);
Dispatcher
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */package org.apache.spark.rpc.nettyimport java.util.concurrent._import javax.annotation.concurrent.GuardedByimport org.apache.spark.SparkExceptionimport org.apache.spark.internal.Loggingimport org.apache.spark.network.client.RpcResponseCallbackimport org.apache.spark.rpc._import org.apache.spark.util.ThreadUtilsimport scala.collection.JavaConverters._import scala.concurrent.Promiseimport scala.util.control.NonFatal/** * 相当于对生产者的包装,使用endpoint包装一个inbox用于发送相应的RPC消息,形成endpoint和inbox的绑定关系 * * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). * 一个消息转发器,负责将RPC消息路由到适当的端点。 * * @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool. * If 0, will consider the available CPUs on the host. */private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { /** * 初始化inbox对象 * * @param name endpont名字 * @param endpoint endpoint * @param ref nettyEndpoint应用 */privateclass EndpointData(val name: String,val endpoint: RpcEndpoint,val ref: NettyRpcEndpointRef) {// 收件箱,接收数据的地方val inbox =new Inbox(ref, endpoint) }// 维护全部的endpointprivateval endpoints: ConcurrentMap[String, EndpointData] = {new ConcurrentHashMap[String, EndpointData] }// 维护全部的endpoint引用privateval endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]// Track the receivers whose inboxes may contain messages.privateval receivers =new LinkedBlockingQueue[EndpointData] /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced * immediately. */@GuardedBy("this")privatevar stopped =false /** * 注册RpcEndpont * * @param name * @param endpoint * @return */defregisterRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {// 创建RpcEndpoint地址val addr = RpcEndpointAddress(nettyEnv.address, name)// 创建一个服务端引用,根据地址创建一个endpointRefval endpointRef =new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)synchronized {// 如果rpc已经被关闭直接抛出rpc关闭异常if (stopped) {thrownew IllegalStateException("RpcEnv has been stopped") }// 如果本地缓存中已经存在一个相同名字的rpcEndpoint则抛出异常if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) !=null) {thrownew IllegalArgumentException(s"There is already an RpcEndpoint called $name") }// 从缓存中拿到EndpointData数据(包含name,endpoint,endpointRef,inbox)val data = endpoints.get(name)// endpoint和ref放入endpointRefs endpointRefs.put(data.endpoint, data.ref)//将数据放入放入receivers中 receivers.offer(data) // for the OnStart message } endpointRef }defgetRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)defremoveRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)// Should be idempotentprivatedefunregisterRpcEndpoint(name: String): Unit = {val data = endpoints.remove(name)if (data !=null) { data.inbox.stop()// 数据放入receivers receivers.offer(data) // for the OnStop message }// Don't clean `endpointRefs` here because it's possible that some messages are being processed// now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via// `removeRpcEndpointRef`. }defstop(rpcEndpointRef: RpcEndpointRef): Unit = {synchronized {if (stopped) {// This endpoint will be stopped by Dispatcher.stop() method.return } unregisterRpcEndpoint(rpcEndpointRef.name) } } /** * Send a message to all registered [[RpcEndpoint]]s in this process. * 发送消息给全部已经注册的rpcEndpoint * This can be used to make network events known to all end points (e.g. "a new node connected"). */defpostToAll(message: InboxMessage): Unit = {// 拿到全部的endpointval iter = endpoints.keySet().iterator()while (iter.hasNext) {val name = iter.next// 发送消息 postMessage(name, message, (e) => { e match {case e: RpcEnvStoppedException => logDebug(s"Message $message dropped. ${e.getMessage}")case e: Throwable => logWarning(s"Message $message dropped. ${e.getMessage}") } } ) } } /** * 发送远程消息 */ /** Posts a message sent by a remote endpoint. */defpostRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {// 创建远程RPC回调上下文val rpcCallContext = {new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) }// 创建RPC消息val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)// 发送消息 postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } /** Posts a message sent by a local endpoint. */// 发送本地消息defpostLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {val rpcCallContext =new LocalNettyRpcCallContext(message.senderAddress, p)val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) } /** Posts a one-way message. */defpostOneWayMessage(message: RequestMessage): Unit = { postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), (e) =>throw e) } /** * Posts a message to a specific endpoint. * * @param endpointName name of the endpoint. * @param message the message to post * @param callbackIfStopped callback function if the endpoint is stopped. */privatedefpostMessage( endpointName: String, message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = {val error =synchronized {val data = endpoints.get(endpointName)if (stopped) { Some(new RpcEnvStoppedException()) } elseif (data ==null) { Some(new SparkException(s"Could not find $endpointName.")) } else {// 发送消息 data.inbox.post(message)// 插入在recives receivers.offer(data) None } }// We don't need to call `onStop` in the `synchronized` block error.foreach(callbackIfStopped) }defstop(): Unit = {synchronized {if (stopped) {return } stopped =true }// Stop all endpoints. This will queue all endpoints for processing by the message loops. endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)// Enqueue a message that tells the message loops to stop.// 发送一个毒药消息,作为stop标示 receivers.offer(PoisonPill) threadpool.shutdown() }defawaitTermination(): Unit = { threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) } /** * Return if the endpoint exists */defverify(name: String): Boolean = { endpoints.containsKey(name) } /** Thread pool used for dispatching messages. */privateval threadpool: ThreadPoolExecutor = {val availableCores =if (numUsableCores >0) numUsableCores else Runtime.getRuntime.availableProcessors()val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", math.max(2, availableCores))val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")// 多线程执行,开启dispather线程for (i <-0 until numThreads) { pool.execute(new MessageLoop) } pool } /** Message loop used for dispatching messages. */privateclass MessageLoop extends Runnable {overridedefrun(): Unit = {try {while (true) {try {val data = receivers.take()// 如果消息为毒药消息,跳过并将该消息放入其他messageLoop中if (data == PoisonPill) {// Put PoisonPill back so that other MessageLoops can see it. receivers.offer(PoisonPill)return }// 处理存储的消息,传递Dispatcher主要是为了移除Endpoint的引用 data.inbox.process(Dispatcher.this) } catch {case NonFatal(e) => logError(e.getMessage, e) } } } catch {case _: InterruptedException =>// exitcase t: Throwable =>try {// Re-submit a MessageLoop so that Dispatcher will still work if// UncaughtExceptionHandler decides to not kill JVM. threadpool.execute(new MessageLoop) } finally {throw t } } } } /** A poison endpoint that indicates MessageLoop should exit its message loop. */privateval PoisonPill =new EndpointData(null, null, null)}
privateclassAsyncEventQueue( val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics, bus: LiveListenerBus)extendsSparkListenerBus with Logging { import AsyncEventQueue._// Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if// it's perpetually being added to more quickly than it's being drained.//private val eventQueue =newLinkedBlockingQueue[SparkListenerEvent]( conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY))// Keep the event count separately, so that waitUntilEmpty() can be implemented properly;// this allows that method to return only when the events in the queue have been fully// processed (instead of just dequeued).private val eventCount =newAtomicLong()/** A counter for dropped events. It will be reset every time we log it. */ private val droppedEventsCounter =newAtomicLong(0L)/** When `droppedEventsCounter` was logged last time in milliseconds. */ @volatile private var lastReportTimestamp =0Lprivate val logDroppedEvent =newAtomicBoolean(false) private var sc: SparkContext =nullprivate val started =newAtomicBoolean(false) private val stopped =newAtomicBoolean(false) private val droppedEvents =metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents")private val processingTime =metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime")// Remove the queue size gauge first, in case it was created by a previous incarnation of// this queue that was removed from the listener bus.metrics.metricRegistry.remove(s"queue.$name.size")metrics.metricRegistry.register(s"queue.$name.size",newGauge[Int] { override def getValue: Int =eventQueue.size() })// 后台线程,转发器private val dispatchThread =newThread(s"spark-listener-group-$name") {setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sc) {// 发送消息dispatch() } }// 转发转系 private def dispatch(): Unit =LiveListenerBus.withinListenerThread.withValue(true) {var next: SparkListenerEvent =eventQueue.take()// 不是毒药消息则一致调用postToAllwhile (next != POISON_PILL) { val ctx =processingTime.time()try {// 遍历消息调用doPostEvent super.postToAll(next) } finally {ctx.stop() }eventCount.decrementAndGet() next =eventQueue.take() }eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface):Option[Timer] = {metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) }/** * Start an asynchronous thread to dispatch events to the underlying listeners. * * @param sc Used to stop the SparkContext in case the async dispatcher fails. */private[scheduler] def start(sc: SparkContext): Unit = {if (started.compareAndSet(false,true)) {this.sc= sc// 开始dispatchThread线程,发送监听器事件dispatchThread.start() } else {thrownewIllegalStateException(s"$name already started!") } }/** * Stop the listener bus. It will wait until the queued events have been processed, but new * events will be dropped. */private[scheduler] def stop(): Unit = {if (!started.get()) {thrownewIllegalStateException(s"Attempted to stop $name that has not yet started!") }if (stopped.compareAndSet(false,true)) {eventCount.incrementAndGet()// 发送毒药消息eventQueue.put(POISON_PILL) }// this thread might be trying to stop itself as part of error handling -- we can't join// in that case.if (Thread.currentThread() != dispatchThread) {// 当先线程不是dispatchThread,就会尝试让dispatchThread diedispatchThread.join() } } def post(event: SparkListenerEvent): Unit = {if (stopped.get()) {return }eventCount.incrementAndGet()if (eventQueue.offer(event)) {return }eventCount.decrementAndGet()droppedEvents.inc()droppedEventsCounter.incrementAndGet()if (logDroppedEvent.compareAndSet(false,true)) {// Only log the following message once to avoid duplicated annoying logs.logError(s"Dropping event from queue $name. "+"This likely means one of the listeners is too slow and cannot keep up with "+"the rate at which tasks are being started by the scheduler.") }logTrace(s"Dropping event $event") val droppedCount =droppedEventsCounter.getif (droppedCount >0) {// Don't log too frequentlyif (System.currentTimeMillis() - lastReportTimestamp >=60*1000) {// There may be multiple threads trying to decrease droppedEventsCounter.// Use "compareAndSet" to make sure only one thread can win.// And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and// then that thread will update it.if (droppedEventsCounter.compareAndSet(droppedCount,0)) { val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp =System.currentTimeMillis() val previous =new java.util.Date(prevLastReportTimestamp)logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } }/** * For testing only. Wait until there are no more events in the queue. * * @return true if the queue is empty. */ def waitUntilEmpty(deadline: Long): Boolean = {while (eventCount.get() !=0) {if (System.currentTimeMillis> deadline) {returnfalse }Thread.sleep(10) }true }/** * LiveListenerBus的委托类,移除监听器 * @param listener */ override def removeListenerOnError(listener: SparkListenerInterface): Unit = {// the listener failed in an unrecoverably way, we want to remove it from the entire// LiveListenerBus (potentially stopping a queue if it is empty)bus.removeListener(listener) }}private object AsyncEventQueue { val POISON_PILL =newSparkListenerEvent() { }}
// allow all users/groups to have view/modify permissionsprivateval WILDCARD_ACL ="*"// 是否开启认证。可以通过spark.authenticate属性配置,默认为false。privateval authOn = sparkConf.get(NETWORK_AUTH_ENABLED)// keep spark.ui.acls.enable for backwards compatibility with 1.0// 是否对账号进行授权。privatevar aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false))// admin acls should be set before view or modify acls// 管理员账号集合,通过spark.admin.acls配置,默认为空privatevar adminAcls: Set[String] = stringToSet(sparkConf.get("spark.admin.acls", ""))// admin group acls should be set before view or modify group aclsprivatevar adminAclsGroups : Set[String] = stringToSet(sparkConf.get("spark.admin.acls.groups", ""))// 有查看权限的账号的集合通过spark.ui.view.acls属性配置privatevar viewAcls: Set[String] = _// 拥有查看权限的账号所在的组的集合 spark.ui.view.acls.groups配置privatevar viewAclsGroups: Set[String] = _// list of users who have permission to modify the application. This should// apply to both UI and CLI for things like killing the application.//有修改权限的账号的集合。包括adminAcls、defaultAclUsers及spark. modify.acls属性配置的用户。privatevar modifyAcls: Set[String] = _//拥有修改权限的账号所在组的集合。包括adminAclsGroups和spark.modify.acls.groups属性配置的用户。privatevar modifyAclsGroups: Set[String] = _// always add the current user and SPARK_USER to the viewAcls//默认用户。包括系统属性user.name指定的用户或系统登录用户或者通过系统环境变量SPARK_USER进行设置的用户。privateval defaultAclUsers = Set[String](System.getProperty("user.name", ""), Utils.getCurrentUserName()) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", ""));//密钥。在YARN模式下,首先使用sparkCookie从HadoopUGI中获取密钥。如果Hadoop UGI没有保存密钥,则生成新的密钥(密钥长度可以通过spark. // authenticate.secretBitLength属性指定)并存入Hadoop UGI。其他模式下,则需要设置环境变量_SPARK_AUTH_SECRET(优先级更高)或spark.authenticate.secret属性指定。
privatevar secretKey: String = _ logInfo("SecurityManager: authentication "+ (if (authOn) "enabled"else"disabled") +"; ui acls "+ (if (aclsOn) "enabled"else"disabled") +"; users with view permissions: "+ viewAcls.toString() +"; groups with view permissions: "+ viewAclsGroups.toString() +"; users with modify permissions: "+ modifyAcls.toString() +"; groups with modify permissions: "+ modifyAclsGroups.toString())// Set our own authenticator to properly negotiate user/password for HTTP connections.// This is needed by the HTTP client fetching from the HttpServer. Put here so its// only set once.// 如果权限开启if (authOn) {// 使用内部匿名类是指权限认证器 Authenticator.setDefault(new Authenticator() {overridedefgetPasswordAuthentication(): PasswordAuthentication = {var passAuth: PasswordAuthentication =null// 获取用户信息val userInfo = getRequestingURL().getUserInfo()if (userInfo !=null) {val parts = userInfo.split(":", 2)// 解密password passAuth =new PasswordAuthentication(parts(0), parts(1).toCharArray()) }return passAuth } } ) }
privatedefwriteBlocks(value: T): Int = {import StorageLevel._// Store a copy of the broadcast variable in the driver so that tasks run on the driver// do not create a duplicate copy of the broadcast variable's value.// 获取当前SparkEnv的blockmanager组件val blockManager = SparkEnv.get.blockManager// 将广播变量写成一个单独对象的blockif (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster =false)) {thrownew SparkException(s"Failed to store $broadcastId in BlockManager") }val blocks = {// 将对象块化 TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) }if (checksumEnabled) {// 初始化校验和数组的长度 checksums =new Array[Int](blocks.length) } blocks.zipWithIndex.foreach { case (block, i) =>if (checksumEnabled) {// 根据单独块计算校验和,存入数组中 checksums(i) = calcChecksum(block) }// 创建对应块的片idval pieceId = BroadcastBlockId(id, "piece"+ i)val bytes =new ChunkedByteBuffer(block.duplicate())// 将block的分片写入driver或executor的本地存储if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster =true)) {thrownew SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } } blocks.length }
readBroadcastBlock
privatedefreadBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized {val broadcastCache: ReferenceMap = SparkEnv.get.broadcastManager.cachedValues Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { setConf(SparkEnv.get.conf)// 获取当前blockManagerval blockManager = SparkEnv.get.blockManager// 首先从本地获取广播对象,即通过BlockManager putSingle写入存储体系的广播。 blockManager.getLocalValues(broadcastId) match {case Some(blockResult) =>if (blockResult.data.hasNext) {val x = blockResult.data.next().asInstanceOf[T]// 释放block的锁 releaseLock(broadcastId)// 将block放入ubroadcastCache中if (x !=null) { broadcastCache.put(broadcastId, x) } x } else {thrownew SparkException(s"Failed to get locally stored broadcast data: $broadcastId") }// 广播变量不是通过putSingle写入BlockManager中的,则通过readBlocks读取存储在driver或executor的广播快case None => logInfo("Started reading broadcast variable "+ id)val startTimeMs = System.currentTimeMillis()val blocks = readBlocks() logInfo("Reading broadcast variable "+ id +" took"+ Utils.getUsedTimeMs(startTimeMs))try {val obj = TorrentBroadcast.unBlockifyObject[T]( blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)// Store the merged copy in BlockManager so other tasks on this executor don't// need to re-fetch it.val storageLevel = StorageLevel.MEMORY_AND_DISKif (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster =false)) {thrownew SparkException(s"Failed to store $broadcastId in BlockManager") }if (obj !=null) { broadcastCache.put(broadcastId, obj) } obj } finally { blocks.foreach(_.dispose()) } } }
readBlocks
privatedefreadBlocks(): Array[BlockData] = {// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported// to the driver, so other executors can pull these chunks from this executor as well.// blocks个数val blocks =new Array[BlockData](numBlocks)// blockManagerval bm = SparkEnv.get.blockManager// 随机获取分片blockfor (pid <- Random.shuffle(Seq.range(0, numBlocks))) {val pieceId = BroadcastBlockId(id, "piece"+ pid) logDebug(s"Reading piece $pieceId of $broadcastId")// First try getLocalBytes because there is a chance that previous attempts to fetch the// broadcast blocks have already fetched some of the blocks. In that case, some blocks// would be available locally (on this executor).// 本地获取,找不到去远程拉去block bm.getLocalBytes(pieceId) match {case Some(block) => blocks(pid) = block releaseLock(pieceId)case None => bm.getRemoteBytes(pieceId) match {case Some(b) =>if (checksumEnabled) {val sum = calcChecksum(b.chunks(0))if (sum != checksums(pid)) {thrownew SparkException(s"corrupt remote block $pieceId of $broadcastId:"+s" $sum != ${checksums(pid)}") } }// We found the block from remote executors/driver's BlockManager, so put the block// in this executor's BlockManager.if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster =true)) {thrownew SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } blocks(pid) =new ByteBufferBlockData(b, true)case None =>thrownew SparkException(s"Failed to get $pieceId of $broadcastId") } } } blocks }
private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging {// Initialized by SparkEnv rpc引用相当于actorRefvar coordinatorRef: Option[RpcEndpointRef] = None// Class used to identify a committer. The task ID for a committer is implicitly defined by// the partition being processed, but the coordinator needs to keep track of both the stage// attempt and the task attempt, because in some situations the same task may be running// concurrently in two different attempts of the same stage.privatecaseclass TaskIdentifier(stageAttempt: Int, taskAttempt: Int)privatecaseclass StageState(numPartitions: Int) {val authorizedCommitters: Array[TaskIdentifier] = Array.fill[TaskIdentifier](numPartitions)(null)val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** * stage的存储,key为stage id * Map from active stages's id => authorized task attempts for each partition id, which hold an * exclusive lock on committing task output for that partition, as well as any known failed * attempts in the stage. * * Entries are added to the top-level map when stages start and are removed they finish * (either successfully or unsuccessfully). * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */privateval stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. */defisEmpty: Boolean = { stageStates.isEmpty } /** * 能否将output写入hdfs * Called by tasks to ask whether they can commit their output to HDFS. * * If a task attempt has been authorized to commit, then all other attempts to commit the same * task will be denied. If the authorized task attempt fails (e.g. due to its executor being * lost), then a subsequent task attempt may be authorized to commit its output. * * @param stage the stage number * @param partition the partition number * @param attemptNumber how many times this task has been attempted * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */defcanCommit( stage: Int, stageAttempt: Int, partition: Int, attemptNumber: Int): Boolean = {// 封装AskPermissionToCommitOutput消息通过OutpuitCommitCoordinatorEndpointRef发送到Endpointval msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match {case Some(endpointRef: RpcEndpointRef) =>// 等到回应是否又写入权限 ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), RpcUtils.askRpcTimeout(conf).duration)case None => logError("canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")false } } /** * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */private[scheduler] defstageStart(stage: Int, maxPartitionId: Int): Unit =synchronized {// 获取stage状态 stageStates.get(stage) match {case Some(state) => require(state.authorizedCommitters.length == maxPartitionId +1) logInfo(s"Reusing state from previous attempt of stage $stage.")case _ => stageStates(stage) =new StageState(maxPartitionId +1) } }// Called by DAGSchedulerprivate[scheduler] defstageEnd(stage: Int): Unit =synchronized { stageStates.remove(stage) }// Called by DAGSchedulerprivate[scheduler] deftaskCompleted( stage: Int, stageAttempt: Int, partition: Int, attemptNumber: Int, reason: TaskEndReason): Unit =synchronized {val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage")return }) reason match {// 执行成功case Success =>// The task output has been committed successfully// 任务提交被拒绝case _: TaskCommitDenied => logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, "+s"partition: $partition, attempt: $attemptNumber")// 其他原因case _ =>// Mark the attempt as failed to blacklist from future commit protocolval taskId = TaskIdentifier(stageAttempt, attemptNumber) stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskIdif (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, "+s"partition=$partition) failed; clearing lock") stageState.authorizedCommitters(partition) =null } } }defstop(): Unit =synchronized {// driver通过ref发送stop指令if (isDriver) { coordinatorRef.foreach(_ send StopCoordinator) coordinatorRef = None stageStates.clear() } }// Marked private[scheduler] instead of private so this can be mocked in testsprivate[scheduler] defhandleAskPermissionToCommit( stage: Int, stageAttempt: Int, partition: Int, attemptNumber: Int): Boolean =synchronized { stageStates.get(stage) match {case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: "+s"task attempt $attemptNumber already marked as failed.")falsecase Some(state) =>val existing = state.authorizedCommitters(partition)if (existing ==null) { logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, "+s"task attempt $attemptNumber") state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber)true } else { logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: "+s"already committed by $existing")false }case None => logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: "+"stage already marked as completed.")false } }privatedefattemptFailed( stageState: StageState, stageAttempt: Int, partition: Int, attempt: Int): Boolean =synchronized {val failInfo = TaskIdentifier(stageAttempt, attempt) stageState.failures.get(partition).exists(_.contains(failInfo)) }}