Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.cluster.k8s

import io.fabric8.kubernetes.api.model.Pod
package org.apache.spark.shuffle.sort.io.external;

sealed trait ExecutorPodState {
def pod: Pod
}

case class PodRunning(pod: Pod) extends ExecutorPodState

case class PodPending(pod: Pod) extends ExecutorPodState
import org.apache.spark.SparkConf;
import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
import org.apache.spark.api.shuffle.ShuffleDataIO;

sealed trait FinalPodState extends ExecutorPodState
public class ExternalShuffleDataIO implements ShuffleDataIO {

case class PodSucceeded(pod: Pod) extends FinalPodState
private final SparkConf sparkConf;

case class PodFailed(pod: Pod) extends FinalPodState
public ExternalShuffleDataIO(SparkConf sparkConf) {
this.sparkConf = sparkConf;
}

case class PodDeleted(pod: Pod) extends FinalPodState

case class PodUnknown(pod: Pod) extends ExecutorPodState
@Override
public ShuffleExecutorComponents executor() {
return new ExternalShuffleExecutorComponents(sparkConf);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.sort.io.external;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
import org.apache.spark.api.shuffle.ShuffleWriteSupport;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleServiceAddressProvider;
import org.apache.spark.storage.BlockManager;

public class ExternalShuffleExecutorComponents implements ShuffleExecutorComponents {

private final SparkConf sparkConf;
private BlockManager blockManager;
private IndexShuffleBlockResolver blockResolver;
private ShuffleServiceAddressProvider shuffleServiceAddressProvider;

public ExternalShuffleExecutorComponents(SparkConf sparkConf) {
this.sparkConf = sparkConf;
}

@Override
public void initializeExecutor(String appId, String execId) {
blockManager = SparkEnv.get().blockManager();
blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
shuffleServiceAddressProvider = SparkEnv.get().shuffleServiceAddressProvider();
}

@Override
public ShuffleWriteSupport writes() {
if (blockResolver == null) {
throw new IllegalStateException(
"Executor components must be initialized before getting writers.");
}
return null;
}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ class SparkContext(config: SparkConf) extends Logging {
None
}

// Start the ShuffleServiceAddressProvider
_env.shuffleServiceAddressProvider.start()

// Optionally scale number of executors dynamically based on workload. Exposed for testing.
val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf)
_executorAllocationManager =
Expand Down
19 changes: 18 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory}
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}

Expand All @@ -67,6 +67,7 @@ class SparkEnv (
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val shuffleServiceAddressProvider: ShuffleServiceAddressProvider,
val conf: SparkConf) extends Logging {

private[spark] var isStopped = false
Expand All @@ -90,6 +91,7 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
shuffleServiceAddressProvider.stop()
rpcEnv.shutdown()
rpcEnv.awaitTermination()

Expand Down Expand Up @@ -365,6 +367,20 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)

// ShuffleServiceAddressProvider initialization
val master = conf.get("spark.master")
val shuffleProvider = conf.get(SHUFFLE_SERVICE_PROVIDER_CLASS)
.map(clazz => Utils.loadExtensions(classOf[ShuffleServiceAddressProviderFactory],
Seq(clazz), conf)).getOrElse(Seq())
val serviceLoaders = shuffleProvider.filter(_.canCreate(master))
if (serviceLoaders.size > 1) {
throw new SparkException(
s"Multiple external cluster managers registered for the url $master: $serviceLoaders")
}
val shuffleServiceAddressProvider = serviceLoaders.headOption
.map(_.create(conf))
.getOrElse(DefaultShuffleServiceAddressProvider)

val envInstance = new SparkEnv(
executorId,
rpcEnv,
Expand All @@ -379,6 +395,7 @@ object SparkEnv extends Logging {
metricsSystem,
memoryManager,
outputCommitCoordinator,
shuffleServiceAddressProvider,
conf)

// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ private[spark] class Executor(
env.metricsSystem.registerSource(executorSource)
env.metricsSystem.registerSource(new JVMCPUSource())
env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource)
env.shuffleServiceAddressProvider.start()
}

// Whether to load classes in user jars before those in Spark jars
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ package object config {
private[spark] val SHUFFLE_SERVICE_ENABLED =
ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false)

private[spark] val K8S_SHUFFLE_SERVICE_ENABLED =
ConfigBuilder("spark.k8s.shuffle.service.enabled").booleanConf.createWithDefault(false)

private[spark] val SHUFFLE_SERVICE_DB_ENABLED =
ConfigBuilder("spark.shuffle.service.db.enabled")
.doc("Whether to use db in ExternalShuffleService. Note that this only affects " +
Expand Down Expand Up @@ -773,6 +776,12 @@ package object config {
.stringConf
.createWithDefault(classOf[DefaultShuffleDataIO].getName)

private[spark] val SHUFFLE_SERVICE_PROVIDER_CLASS =
ConfigBuilder("spark.shuffle.provider.plugin.class")
.doc("Experimental. Specify a class that can handle detecting shuffle service pods.")
.stringConf
.createOptional

private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
ConfigBuilder("spark.shuffle.file.buffer")
.doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

trait ShuffleServiceAddressProvider {
def start(): Unit = {}
def getShuffleServiceAddresses(): List[(String, Int)]
def stop(): Unit = {}
}

private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider {
override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import org.apache.spark.SparkConf

trait ShuffleServiceAddressProviderFactory {
def canCreate(masterUrl: String): Boolean
def create(conf: SparkConf): ShuffleServiceAddressProvider
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase {
null,
null,
null,
null,
defaultConf
))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
null,
null,
null,
null,
defaultConf
))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,26 @@ private[spark] object Config extends Logging {
.booleanConf
.createWithDefault(true)

val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE =
ConfigBuilder("spark.kubernetes.shuffle.service.remote.pods.namespace")
.doc("Namespace of the pods that are running the shuffle service instances for remote" +
" pushing of shuffle data.")
.stringConf
.createOptional

val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT =
ConfigBuilder("spark.kubernetes.shuffle.service.remote.port")
.doc("Port of the external k8s shuffle service pods")
.intConf
.createWithDefault(7337)

val KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL =
ConfigBuilder("spark.kubernetes.shuffle.service.cleanup.interval")
.doc("Cleanup interval for the shuffle service to take down an app id")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("30s")


val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label."
val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation."
val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets."
Expand All @@ -349,4 +369,7 @@ private[spark] object Config extends Logging {
val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit"

val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv."

val KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS =
"spark.kubernetes.shuffle.service.remote.label."
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import java.io.File
import com.google.common.base.Charsets
import com.google.common.io.Files
import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient}
import io.fabric8.kubernetes.client.Config.autoConfigure
import io.fabric8.kubernetes.client.Config._
import io.fabric8.kubernetes.client.utils.HttpClientUtils
import okhttp3.Dispatcher

import org.apache.spark.SparkConf
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.util.ThreadUtils
Expand All @@ -37,6 +38,36 @@ import org.apache.spark.util.ThreadUtils
* options for different components.
*/
private[spark] object SparkKubernetesClientFactory extends Logging {
def getDriverKubernetesClient(conf: SparkConf, masterURL: String): KubernetesClient = {
val wasSparkSubmittedInClusterMode = conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK)
val (authConfPrefix,
apiServerUri,
defaultServiceAccountToken,
defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) {
require(conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined,
"If the application is deployed using spark-submit in cluster mode, the driver pod name " +
"must be provided.")
(KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX,
KUBERNETES_MASTER_INTERNAL_URL,
Some(new File(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)),
Some(new File(KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH)))
} else {
(KUBERNETES_AUTH_CLIENT_MODE_PREFIX,
KubernetesUtils.parseMasterUrl(masterURL),
None,
None)
}

val kubernetesClient = createKubernetesClient(
apiServerUri,
Some(conf.get(KUBERNETES_NAMESPACE)),
authConfPrefix,
SparkKubernetesClientFactory.ClientType.Driver,
conf,
defaultServiceAccountToken,
defaultServiceAccountCaCrt)
kubernetesClient
}

def createKubernetesClient(
master: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging
/**
* An immutable view of the current executor pods that are running in the cluster.
*/
private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) {
private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, SparkPodState]) {

import ExecutorPodsSnapshot._

Expand All @@ -42,15 +42,15 @@ object ExecutorPodsSnapshot extends Logging {
ExecutorPodsSnapshot(toStatesByExecutorId(executorPods))
}

def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState])
def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, SparkPodState])

private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = {
private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, SparkPodState] = {
executorPods.map { pod =>
(pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod))
}.toMap
}

private def toState(pod: Pod): ExecutorPodState = {
private def toState(pod: Pod): SparkPodState = {
if (isDeleted(pod)) {
PodDeleted(pod)
} else {
Expand Down
Loading