diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilder.scala b/spark/src/main/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilder.scala index 0f92e15f870..f9b752ba6e0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilder.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import io.unitycatalog.client.auth.TokenProvider import org.apache.spark.internal.MDC import org.apache.spark.sql.SparkSession +import org.apache.spark.util.Utils /** * Builder for Unity Catalog Commit Coordinator Clients. @@ -53,16 +54,13 @@ object UCCommitCoordinatorBuilder final private[delta] val UNITY_CATALOG_CONNECTOR_CLASS: String = "io.unitycatalog.spark.UCSingleCatalog" - /** Suffix for the URI configuration of a catalog. */ - final private val URI_SUFFIX = "uri" - /** Cache for UCCommitCoordinatorClient instances. */ private val commitCoordinatorClientCache = new ConcurrentHashMap[String, UCCommitCoordinatorClient]() - // Helper cache for (uri, authConfig) to metastoreId to avoid redundant calls to getMetastoreId - private val uriAuthConfigToMetastoreIdCache = - new ConcurrentHashMap[(String, Map[String, String]), String]() + // Cache for ucConfig to metastoreId to avoid redundant calls to getMetastoreId. + private val ucConfigToMetastoreIdCache = + new ConcurrentHashMap[Map[String, String], String]() // Use a var instead of val for ease of testing by injecting different UCClientFactory. private[delta] var ucClientFactory: UCClientFactory = UCTokenBasedRestClientFactory @@ -84,7 +82,7 @@ object UCCommitCoordinatorBuilder spark: SparkSession, catalogName: String): CommitCoordinatorClient = { val client = getCatalogConfigs(spark).find(_._1 == catalogName) match { - case Some((_, uri, authConfig)) => ucClientFactory.createUCClient(uri, authConfig) + case Some((_, ucConfig)) => ucClientFactory.createUCClient(ucConfig) case None => throw new IllegalArgumentException( s"Catalog $catalogName not found in the provided SparkSession configurations.") @@ -102,41 +100,45 @@ object UCCommitCoordinatorBuilder * appropriate exception. */ private def getMatchingUCClient(spark: SparkSession, metastoreId: String): UCClient = { - val matchingClients: List[(String, Map[String, String])] = getCatalogConfigs(spark) - .map { case (name, uri, authConfig) => (uri, authConfig) } - .distinct // Remove duplicates since multiple catalogs can have the same uri and config - .filter { case (uri, authConfig) => getMetastoreId(uri, authConfig).contains(metastoreId) } + val matchingConfigs: List[Map[String, String]] = getCatalogConfigs(spark) + .map(_._2) + .distinct + .filter { ucConfig => getMetastoreId(ucConfig).contains(metastoreId) } - matchingClients match { + matchingConfigs match { case Nil => throw noMatchingCatalogException(metastoreId) - case (uri, authConfig) :: Nil => ucClientFactory.createUCClient(uri, authConfig) - case multiple => throw multipleMatchingCatalogs(metastoreId, multiple.map(_._1)) + case ucConfig :: Nil => ucClientFactory.createUCClient(ucConfig) + case multiple => + throw multipleMatchingCatalogs(metastoreId, multiple.map(_.getOrElse("uri", ""))) } } /** - * Retrieves the metastore ID for a given URI and auth configuration map. + * Retrieves the metastore ID for a given UC configuration map. * - * This method creates a UCClient using the provided URI and auth configuration map, then - * retrieves its metastore ID. The result is cached to avoid unnecessary getMetastoreId requests - * in future calls. If there's an error, it returns None and logs a warning. + * This method creates a UCClient using the provided config, then retrieves its metastore ID. + * The result is cached to avoid unnecessary getMetastoreId requests in future calls. If there's + * an error, it returns None and logs a warning. */ - private def getMetastoreId(uri: String, authConfig: Map[String, String]): Option[String] = { + private def getMetastoreId(ucConfig: Map[String, String]): Option[String] = { try { - val metastoreId = uriAuthConfigToMetastoreIdCache.computeIfAbsent( - (uri, authConfig), + val metastoreId = ucConfigToMetastoreIdCache.computeIfAbsent( + ucConfig, _ => { - val ucClient = ucClientFactory.createUCClient(uri, authConfig) + val ucClient = ucClientFactory.createUCClient(ucConfig) try { ucClient.getMetastoreId } finally { - safeClose(ucClient, uri) + safeClose(ucClient, ucConfig.getOrElse("uri", "")) } }) Some(metastoreId) } catch { case NonFatal(e) => - logWarning(log"Failed to getMetastoreSummary with ${MDC(DeltaLogKeys.URI, uri)}", e) + logWarning( + log"Failed to getMetastoreSummary with " + + log"${MDC(DeltaLogKeys.URI, ucConfig.getOrElse("uri", ""))}", + e) None } } @@ -164,98 +166,50 @@ object UCCommitCoordinatorBuilder /** * Retrieves the catalog configurations from the SparkSession. * - * This method supports both the new auth.* format and the legacy token format for backward - * compatibility: - * - * New format: - * spark.sql.catalog.catalog1.uri = "https://dbc-123abc.databricks.com" - * spark.sql.catalog.catalog1.auth.type = "static" - * spark.sql.catalog.catalog1.auth.token = "dapi1234567890" - * - * Legacy format (for backward compatibility): - * spark.sql.catalog.catalog1.uri = "https://dbc-123abc.databricks.com" - * spark.sql.catalog.catalog1.token = "dapi1234567890" + * For each Unity Catalog catalog, collects all sub-keys under + * `spark.sql.catalog..*` (stripping the prefix) into a flat config map. + * This includes `uri`, `auth.*`, `token` (legacy), `deltaRestApi.enabled`, and any + * other catalog-specific settings. * - * When the legacy format is detected (token without auth. prefix), it is automatically - * converted to the new format (type=static, token=value) for TokenProvider. - * - * @return - * A list of tuples containing (catalogName, uri, authConfigMap) for each properly configured - * catalog. The authConfigMap contains authentication configurations ready to be passed to - * TokenProvider.create(). + * @return A list of (catalogName, ucConfig) for each properly configured UC catalog. */ private[delta] def getCatalogConfigs( - spark: SparkSession): List[(String, String, Map[String, String])] = { - val catalogConfigs = spark.conf.getAll.filterKeys(_.startsWith(SPARK_SQL_CATALOG_PREFIX)) + spark: SparkSession): List[(String, Map[String, String])] = { + val allConfigs = spark.conf.getAll.filterKeys(_.startsWith(SPARK_SQL_CATALOG_PREFIX)) // First, identify all Unity Catalog catalogs - val ucCatalogNames = catalogConfigs + val ucCatalogNames = allConfigs .keys .map(_.split("\\.")) .filter(_.length == 4) .map(_(3)) .filter { catalogName: String => - val connector = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName") + val connector = allConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName") connector.contains(UNITY_CATALOG_CONNECTOR_CLASS) } - // For each UC catalog, extract its URI and auth configurations + // For each UC catalog, collect all sub-keys into a flat config map ucCatalogNames .flatMap { catalogName: String => - val catalogPrefix = s"$SPARK_SQL_CATALOG_PREFIX$catalogName." - val authPrefix = s"${catalogPrefix}auth." - val uriOpt = catalogConfigs.get(s"$catalogPrefix$URI_SUFFIX") - - uriOpt match { - case Some(uri) => - try { - new URI(uri) // Validate the URI - - // Extract all auth.* configuration keys for this catalog - // and strip the "spark.sql.catalog..auth." prefix - var authConfigMap = catalogConfigs - .filterKeys(_.startsWith(authPrefix)) - .map { case (fullKey, value) => - // Remove the auth prefix to get just the auth config key - // e.g., "spark.sql.catalog.catalog1.auth.type" -> "type" - // e.g., "spark.sql.catalog.catalog1.auth.oauth.uri" -> "oauth.uri" - val authKey = fullKey.stripPrefix(authPrefix) - (authKey, value) - } - .toMap - - // Support legacy format: if no auth.* configs but token exists, - // convert to new format (type=static, token=value) - if (authConfigMap.isEmpty) { - val legacyTokenOpt = catalogConfigs.get(s"${catalogPrefix}token") - legacyTokenOpt match { - case Some(token) => - authConfigMap = Map("type" -> "static", "token" -> token) - case None => - // No auth configs found - } - } - - if (authConfigMap.isEmpty) { - logWarning( - log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it " + - "does not have any authentication configurations in Spark Session.") - None - } else { - Some((catalogName, uri, authConfigMap)) - } - } catch { - case _: URISyntaxException => - logWarning( - log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it " + - log"does not have a valid URI ${MDC(DeltaLogKeys.URI, uri)}.") - None - } + val prefix = s"$SPARK_SQL_CATALOG_PREFIX$catalogName." + val ucConfig = allConfigs + .filterKeys(_.startsWith(prefix)) + .map { case (k, v) => (k.stripPrefix(prefix), v) } + .toMap + + ucConfig.get("uri") match { case None => logWarning( log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it does " + "not have uri configured in Spark Session.") None + case Some(uri) if !isValidUri(uri) => + logWarning( + log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it has" + + log" an invalid uri ${MDC(DeltaLogKeys.URI, uri)}.") + None + case _ => + Some((catalogName, ucConfig)) } } .toList @@ -267,10 +221,14 @@ object UCCommitCoordinatorBuilder */ private[delta] def getCatalogConfigMap(spark: SparkSession): Map[String, UCCatalogConfig] = { getCatalogConfigs(spark).map { - case (name, uri, authConfig) => name -> UCCatalogConfig(name, uri, authConfig) + case (name, ucConfig) => name -> UCCatalogConfig(name, ucConfig) }.toMap } + private def isValidUri(uri: String): Boolean = { + try { new URI(uri); true } catch { case _: URISyntaxException => false } + } + private def safeClose(ucClient: UCClient, uri: String): Unit = { try { ucClient.close() @@ -282,33 +240,120 @@ object UCCommitCoordinatorBuilder def clearCache(): Unit = { commitCoordinatorClientCache.clear() - uriAuthConfigToMetastoreIdCache.clear() + ucConfigToMetastoreIdCache.clear() } } +/** Factory trait for creating [[UCClient]] instances from a unified configuration map. */ trait UCClientFactory { - def createUCClient(uri: String, authConfig: Map[String, String]): UCClient + def createUCClient(ucConfig: Map[String, String]): UCClient } +/** + * Default [[UCClientFactory]] that uses reflection to instantiate the [[UCClient]] + * implementation so this module has no compile-time dependency on specific implementations + * (e.g. UCDeltaTokenBasedRestClient). + * + * The `ucConfig` map is typically built from Spark catalog configuration. For example, + * given these Spark configs: + * + * {{{ + * spark.sql.catalog.my_catalog = io.unitycatalog.spark.UCSingleCatalog + * spark.sql.catalog.my_catalog.uri = https://my-uc-server.com + * spark.sql.catalog.my_catalog.auth.type = static + * spark.sql.catalog.my_catalog.auth.token = dapi1234567890 + * }}} + * + * The resulting `ucConfig` map (with the `spark.sql.catalog.my_catalog.` prefix stripped) + * would be: `Map("uri" -> "...", "auth.type" -> "static", "auth.token" -> "...")`. + * + * Legacy format (token without auth. prefix) is also supported for backward compatibility: + * + * {{{ + * spark.sql.catalog.my_catalog.uri = https://my-uc-server.com + * spark.sql.catalog.my_catalog.token = dapi1234567890 + * }}} + * + * Recognised ucConfig keys: + * - `uri` (required) -- the UC server endpoint. + * - `auth.*` / `token` (legacy) -- authentication parameters for [[TokenProvider]]. + * - `deltaRestApi.enabled` -- if `"true"`, uses [[UCDeltaTokenBasedRestClient]]; + * otherwise uses [[UCTokenBasedRestClient]]. + * - `appVersions.*` -- caller-supplied version entries merged with defaults; e.g. + * `appVersions.Kernel -> "0.7.0"` adds a `"Kernel"` entry to the version map. + */ object UCTokenBasedRestClientFactory extends UCClientFactory { - override def createUCClient(uri: String, authConfig: Map[String, String]): UCClient = { - createUCClientWithVersions(uri, authConfig, defaultAppVersions) + + final val URI_KEY = "uri" + final val AUTH_PREFIX = "auth." + final val DELTA_REST_API_ENABLED_KEY = "deltaRestApi.enabled" + final val APP_VERSIONS_PREFIX = "appVersions." + + private val DEFAULT_UC_CLIENT_CLASS: String = classOf[UCTokenBasedRestClient].getName + + private val DELTA_UC_CLIENT_CLASS: String = + "io.delta.storage.commit.uccommitcoordinator.UCDeltaTokenBasedRestClient" + + override def createUCClient(ucConfig: Map[String, String]): UCClient = { + val uri = ucConfig.getOrElse(URI_KEY, + throw new IllegalArgumentException(s"UC config must contain '$URI_KEY'")) + + val authConfig = extractAuthConfig(ucConfig) + val tokenProvider = TokenProvider.create(authConfig.asJava) + + val className = + if (ucConfig.get(DELTA_REST_API_ENABLED_KEY).exists(_.equalsIgnoreCase("true"))) { + DELTA_UC_CLIENT_CLASS + } else { + DEFAULT_UC_CLIENT_CLASS + } + + val cls = Utils.classForName(className) + require(classOf[UCClient].isAssignableFrom(cls), + s"$className does not implement ${classOf[UCClient].getName}") + val appVersions = extractAppVersions(ucConfig) + val ctor = cls.getConstructor( + classOf[String], classOf[TokenProvider], classOf[java.util.Map[_, _]]) + ctor.newInstance(uri, tokenProvider, appVersions.asJava).asInstanceOf[UCClient] + } + + /** Java-friendly overload that accepts a java.util.Map. */ + def createUCClient(ucConfig: java.util.Map[String, String]): UCClient = { + createUCClient(ucConfig.asScala.toMap) } /** - * Creates a UC client with the given application versions for telemetry. - * The provided `appVersions` map is used as-is; callers are responsible for - * including all desired version entries. + * Extracts authentication configuration from ucConfig. + * Prefers `auth.*` keys; falls back to legacy `token` key. */ - def createUCClientWithVersions( - uri: String, - authConfig: Map[String, String], - appVersions: Map[String, String]): UCClient = { - // Create TokenProvider from the authentication configuration map - // We pass the configuration through without interpreting any specific keys, - // as those are managed by the Unity Catalog client library - val tokenProvider = TokenProvider.create(authConfig.asJava) - new UCTokenBasedRestClient(uri, tokenProvider, appVersions.asJava) + private[coordinatedcommits] def extractAuthConfig( + ucConfig: Map[String, String]): Map[String, String] = { + val authConfig = ucConfig + .filterKeys(_.startsWith(AUTH_PREFIX)) + .map { case (k, v) => (k.stripPrefix(AUTH_PREFIX), v) } + .toMap + + if (authConfig.nonEmpty) { + authConfig + } else { + ucConfig.get("token") match { + case Some(token) => Map("type" -> "static", "token" -> token) + case None => Map.empty + } + } + } + + /** + * Merges default app versions with any `appVersions.*` entries from ucConfig. + * Caller-supplied entries override defaults with the same key. + */ + private[coordinatedcommits] def extractAppVersions( + ucConfig: Map[String, String]): Map[String, String] = { + val extra = ucConfig + .filterKeys(_.startsWith(APP_VERSIONS_PREFIX)) + .map { case (k, v) => (k.stripPrefix(APP_VERSIONS_PREFIX), v) } + .toMap + defaultAppVersions ++ extra } private[coordinatedcommits] def defaultAppVersions: Map[String, String] = { @@ -319,28 +364,23 @@ object UCTokenBasedRestClientFactory extends UCClientFactory { "Java" -> System.getProperty("java.version") ) } - - /** Returns the default app versions as a mutable Java map for easy extension. */ - def defaultAppVersionsAsJava: java.util.Map[String, String] = { - new java.util.HashMap(defaultAppVersions.asJava) - } - - /** Java-friendly overload that accepts a java.util.Map */ - def createUCClient(uri: String, authConfig: java.util.Map[String, String]): UCClient = { - createUCClient(uri, authConfig.asScala.toMap) - } - - /** Java-friendly overload that accepts application versions for telemetry. */ - def createUCClientWithVersions( - uri: String, - authConfig: java.util.Map[String, String], - appVersions: java.util.Map[String, String]): UCClient = { - createUCClientWithVersions(uri, authConfig.asScala.toMap, appVersions.asScala.toMap) - } } /** * Holder for Unity Catalog configuration extracted from Spark configs. + * The `ucConfig` map contains all sub-keys under `spark.sql.catalog..*` + * with the prefix stripped. * Used by [[UCCommitCoordinatorBuilder.getCatalogConfigMap]]. */ -case class UCCatalogConfig(catalogName: String, uri: String, authConfig: Map[String, String]) +case class UCCatalogConfig(catalogName: String, ucConfig: Map[String, String]) { + + def uri: String = ucConfig.getOrElse("uri", + throw new NoSuchElementException(s"No URI in config for catalog $catalogName")) + + /** + * Returns the authentication config suitable for [[TokenProvider.create]]. + * Prefers `auth.*` keys; falls back to legacy `token` key. + */ + def authConfig: Map[String, String] = + UCTokenBasedRestClientFactory.extractAuthConfig(ucConfig) +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilderSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilderSuite.scala index e7926e73e8b..a073a70a625 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilderSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorBuilderSuite.scala @@ -44,7 +44,17 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess configMap: Map[String, String] = Map.empty, metastoreId: Option[String] = None, path: Option[String] = Some("io.unitycatalog.spark.UCSingleCatalog") - ) + ) { + /** + * The ucConfig as it would appear after getCatalogConfigs + * parsing: all sub-keys under spark.sql.catalog..* + * with the prefix stripped. Includes `uri` when present. + */ + def expectedUcConfig: Map[String, String] = { + val base = configMap + uri.map(u => base + ("uri" -> u)).getOrElse(base) + } + } def setupCatalogs(configs: CatalogTestConfig*)(testCode: => Unit): Unit = { val allConfigs = configs.flatMap { config => @@ -53,7 +63,6 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess config.uri.map(uri => s"spark.sql.catalog.${config.name}.uri" -> uri) ).flatten - // Add all additional configurations from configMap (without any prefix) val additionalConfigs = config.configMap.map { case (key, value) => s"spark.sql.catalog.${config.name}.$key" -> value } @@ -64,14 +73,13 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess withSQLConf(allConfigs: _*) { configs.foreach { config => (config.uri, config.configMap.isEmpty, config.metastoreId) match { - case (Some(uri), false, Some(id)) => - registerMetastoreId(uri, config.configMap, id) - case (Some(uri), false, None) => + case (Some(_), false, Some(id)) => + registerMetastoreId(config.expectedUcConfig, id) + case (Some(_), false, None) => registerMetastoreIdException( - uri, - config.configMap, + config.expectedUcConfig, new RuntimeException("Invalid metastore ID")) - case _ => // Do nothing for incomplete configs + case _ => } } testCode @@ -97,14 +105,14 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val result = getCommitCoordinatorClient(expectedMetastoreId) assert(result.isInstanceOf[UCCommitCoordinatorClient]) - verify(mockFactory, times(2)).createUCClient(catalog1.uri.get, catalog1.configMap) - verify(mockFactory).createUCClient(catalog2.uri.get, catalog2.configMap) - verify(mockFactory.createUCClient(catalog1.uri.get, catalog1.configMap)) + verify(mockFactory, times(2)).createUCClient(catalog1.expectedUcConfig) + verify(mockFactory).createUCClient(catalog2.expectedUcConfig) + verify(mockFactory.createUCClient(catalog1.expectedUcConfig)) .getMetastoreId - verify(mockFactory.createUCClient(catalog2.uri.get, catalog2.configMap)) + verify(mockFactory.createUCClient(catalog2.expectedUcConfig)) .getMetastoreId - verify(mockFactory.createUCClient(catalog2.uri.get, catalog2.configMap)).close() - verify(mockFactory.createUCClient(catalog1.uri.get, catalog1.configMap)).close() + verify(mockFactory.createUCClient(catalog2.expectedUcConfig)).close() + verify(mockFactory.createUCClient(catalog1.expectedUcConfig)).close() } } @@ -116,20 +124,6 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess assert(defaults("Java") === System.getProperty("java.version")) } - test("createUCClientWithVersions passes custom app versions to UCClient") { - val customVersions = Map( - "Delta" -> io.delta.VERSION, - "Kernel" -> "4.0.0", - "Delta V2 connector" -> "true" - ) - val defaults = UCTokenBasedRestClientFactory.defaultAppVersions - val merged = defaults ++ customVersions - assert(merged("Kernel") === "4.0.0") - assert(merged("Delta V2 connector") === "true") - assert(merged("Delta") === io.delta.VERSION) - assert(merged("Spark") === org.apache.spark.SPARK_VERSION) - } - test("build with missing metastore ID") { val exception = intercept[IllegalArgumentException] { CommitCoordinatorProvider.getCommitCoordinatorClient( @@ -154,9 +148,9 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess getCommitCoordinatorClient(metastoreId) } assert(exception.getMessage.contains("No matching catalog found")) - verify(mockFactory).createUCClient(catalog.uri.get, catalog.configMap) - verify(mockFactory.createUCClient(catalog.uri.get, catalog.configMap)).getMetastoreId - verify(mockFactory.createUCClient(catalog.uri.get, catalog.configMap)).close() + verify(mockFactory).createUCClient(catalog.expectedUcConfig) + verify(mockFactory.createUCClient(catalog.expectedUcConfig)).getMetastoreId + verify(mockFactory.createUCClient(catalog.expectedUcConfig)).close() } } @@ -180,14 +174,14 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess getCommitCoordinatorClient(metastoreId) } assert(exception.getMessage.contains("Found multiple catalogs")) - verify(mockFactory).createUCClient(catalog1.uri.get, catalog1.configMap) - verify(mockFactory).createUCClient(catalog2.uri.get, catalog2.configMap) - verify(mockFactory.createUCClient(catalog1.uri.get, catalog1.configMap)) + verify(mockFactory).createUCClient(catalog1.expectedUcConfig) + verify(mockFactory).createUCClient(catalog2.expectedUcConfig) + verify(mockFactory.createUCClient(catalog1.expectedUcConfig)) .getMetastoreId - verify(mockFactory.createUCClient(catalog2.uri.get, catalog2.configMap)) + verify(mockFactory.createUCClient(catalog2.expectedUcConfig)) .getMetastoreId - verify(mockFactory.createUCClient(catalog1.uri.get, catalog1.configMap)).close() - verify(mockFactory.createUCClient(catalog2.uri.get, catalog2.configMap)).close() + verify(mockFactory.createUCClient(catalog1.expectedUcConfig)).close() + verify(mockFactory.createUCClient(catalog2.expectedUcConfig)).close() } } @@ -220,10 +214,9 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess assert(result.isInstanceOf[UCCommitCoordinatorClient]) verify(mockFactory, times(2)).createUCClient( - validCatalog.uri.get, - validCatalog.configMap + validCatalog.expectedUcConfig ) - verify(mockFactory.createUCClient(validCatalog.uri.get, validCatalog.configMap), + verify(mockFactory.createUCClient(validCatalog.expectedUcConfig), times(1)).close() } } @@ -248,6 +241,7 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val metastoreId = "shared-metastore-id" val sharedUri = "https://shared-test-uri.com" val sharedConfigMap = Map("type" -> "static", "token" -> "shared-test-token") + val sharedUcConfig = sharedConfigMap + ("uri" -> sharedUri) val catalog1 = CatalogTestConfig( name = "catalog1", uri = Some(sharedUri), @@ -271,9 +265,9 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val result = getCommitCoordinatorClient(metastoreId) assert(result.isInstanceOf[UCCommitCoordinatorClient]) - verify(mockFactory, times(2)).createUCClient(sharedUri, sharedConfigMap) - verify(mockFactory.createUCClient(sharedUri, sharedConfigMap)).getMetastoreId - verify(mockFactory.createUCClient(sharedUri, sharedConfigMap)).close() + verify(mockFactory, times(2)).createUCClient(sharedUcConfig) + verify(mockFactory.createUCClient(sharedUcConfig)).getMetastoreId + verify(mockFactory.createUCClient(sharedUcConfig)).close() } } @@ -293,30 +287,28 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess getCommitCoordinatorClient(metastoreId) } assert(e.getMessage.contains("No matching catalog found")) - verify(mockFactory, never()).createUCClient(catalog.uri.get, catalog.configMap) + verify(mockFactory, never()).createUCClient(catalog.expectedUcConfig) } } private def registerMetastoreId( - uri: String, - configMap: Map[String, String], + ucConfig: Map[String, String], metastoreId: String): Unit = { val mockClient = org.mockito.Mockito.mock(classOf[UCClient]) when(mockClient.getMetastoreId).thenReturn(metastoreId) - when(mockFactory.createUCClient(meq(uri), meq(configMap))).thenReturn(mockClient) + when(mockFactory.createUCClient(meq(ucConfig))).thenReturn(mockClient) } private def registerMetastoreIdException( - uri: String, - configMap: Map[String, String], + ucConfig: Map[String, String], exception: Throwable): Unit = { val mockClient = org.mockito.Mockito.mock(classOf[UCClient]) when(mockClient.getMetastoreId).thenThrow(exception) - when(mockFactory.createUCClient(meq(uri), meq(configMap))).thenReturn(mockClient) + when(mockFactory.createUCClient(meq(ucConfig))).thenReturn(mockClient) } - test("getCatalogConfigs with legacy token format") { - val catalogName = "legacy_catalog" + test("getCatalogConfigs returns all sub-keys") { + val catalogName = "test_catalog" val uri = "https://test-uri.com" val token = "test-token" @@ -328,15 +320,10 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) assert(configs.length == 1) - val (name, catalogUri, authConfigMap) = configs.head + val (name, ucConfig) = configs.head assert(name == catalogName) - assert(catalogUri == uri) - - // Legacy token should be converted to new format - assert(authConfigMap.contains("type")) - assert(authConfigMap("type") == "static") - assert(authConfigMap.contains("token")) - assert(authConfigMap("token") == token) + assert(ucConfig("uri") == uri) + assert(ucConfig("token") == token) } } @@ -354,11 +341,11 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) assert(configs.length == 1) - val (name, catalogUri, authConfigMap) = configs.head + val (name, ucConfig) = configs.head assert(name == catalogName) - assert(catalogUri == uri) - assert(authConfigMap("type") == "static") - assert(authConfigMap("token") == token) + assert(ucConfig("uri") == uri) + assert(ucConfig("auth.type") == "static") + assert(ucConfig("auth.token") == token) } } @@ -377,30 +364,42 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) assert(configs.length == 1) - val (name, catalogUri, authConfigMap) = configs.head + val (name, ucConfig) = configs.head assert(name == catalogName) - assert(catalogUri == uri) - assert(authConfigMap("type") == "oauth") - assert(authConfigMap("oauth.uri") == "https://oauth.example.com") - assert(authConfigMap("oauth.client_id") == "client123") - assert(authConfigMap("oauth.client_secret") == "secret456") + assert(ucConfig("uri") == uri) + assert(ucConfig("auth.type") == "oauth") + assert(ucConfig("auth.oauth.uri") == "https://oauth.example.com") + assert(ucConfig("auth.oauth.client_id") == "client123") + assert(ucConfig("auth.oauth.client_secret") == "secret456") } } - test("getCatalogConfigs skips catalog with no auth configurations") { - val catalogName = "no_auth_catalog" - val uri = "https://test-uri.com" + test("getCatalogConfigs skips catalog without uri") { + val catalogName = "no_uri_catalog" + + withSQLConf( + s"spark.sql.catalog.$catalogName" -> "io.unitycatalog.spark.UCSingleCatalog", + s"spark.sql.catalog.$catalogName.token" -> "some-token" + ) { + val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) + assert(configs.isEmpty, "Catalog without uri should be skipped") + } + } + + test("getCatalogConfigs skips catalog with invalid uri") { + val catalogName = "bad_uri_catalog" withSQLConf( s"spark.sql.catalog.$catalogName" -> "io.unitycatalog.spark.UCSingleCatalog", - s"spark.sql.catalog.$catalogName.uri" -> uri + s"spark.sql.catalog.$catalogName.uri" -> "://missing scheme", + s"spark.sql.catalog.$catalogName.token" -> "some-token" ) { val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) - assert(configs.isEmpty, "Catalog without auth config should be skipped") + assert(configs.isEmpty, "Catalog with invalid uri should be skipped") } } - test("getCatalogConfigs prefers new auth.* format over legacy token") { + test("getCatalogConfigs prefers auth.* keys (both present)") { val catalogName = "mixed_catalog" val uri = "https://test-uri.com" val legacyToken = "legacy-token" @@ -416,13 +415,15 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) assert(configs.length == 1) - val (name, catalogUri, authConfigMap) = configs.head - assert(name == catalogName) - assert(catalogUri == uri) - // New format should take precedence - assert(authConfigMap("type") == "static") - assert(authConfigMap("token") == newToken) - assert(!authConfigMap.contains(legacyToken)) + val (_, ucConfig) = configs.head + assert(ucConfig("uri") == uri) + assert(ucConfig("auth.type") == "static") + assert(ucConfig("auth.token") == newToken) + assert(ucConfig("token") == legacyToken) + + val catalogConfig = UCCatalogConfig(catalogName, ucConfig) + assert(catalogConfig.authConfig("type") == "static") + assert(catalogConfig.authConfig("token") == newToken) } } @@ -435,25 +436,46 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess "spark.sql.catalog.catalog2.uri" -> "https://uri2.com", "spark.sql.catalog.catalog2.auth.type" -> "static", "spark.sql.catalog.catalog2.auth.token" -> "token2", - "spark.sql.catalog.catalog3" -> "io.unitycatalog.spark.UCSingleCatalog", - "spark.sql.catalog.catalog3.uri" -> "https://uri3.com" + "spark.sql.catalog.catalog3" -> "io.unitycatalog.spark.UCSingleCatalog" ) { val configs = UCCommitCoordinatorBuilder.getCatalogConfigs(spark) - // Only catalog1 and catalog2 should be included (catalog3 has no auth) assert(configs.length == 2) val catalog1 = configs.find(_._1 == "catalog1") assert(catalog1.isDefined) - assert(catalog1.get._3("type") == "static") - assert(catalog1.get._3("token") == "token1") + assert(catalog1.get._2("uri") == "https://uri1.com") + assert(catalog1.get._2("token") == "token1") val catalog2 = configs.find(_._1 == "catalog2") assert(catalog2.isDefined) - assert(catalog2.get._3("type") == "static") - assert(catalog2.get._3("token") == "token2") + assert(catalog2.get._2("uri") == "https://uri2.com") + assert(catalog2.get._2("auth.type") == "static") + assert(catalog2.get._2("auth.token") == "token2") } } + test("extractAuthConfig prefers auth.* over legacy token") { + val ucConfig = Map( + "uri" -> "https://test.com", + "token" -> "legacy-token", + "auth.type" -> "static", + "auth.token" -> "new-token" + ) + val auth = UCTokenBasedRestClientFactory.extractAuthConfig(ucConfig) + assert(auth("type") == "static") + assert(auth("token") == "new-token") + } + + test("extractAuthConfig falls back to legacy token") { + val ucConfig = Map( + "uri" -> "https://test.com", + "token" -> "legacy-token" + ) + val auth = UCTokenBasedRestClientFactory.extractAuthConfig(ucConfig) + assert(auth("type") == "static") + assert(auth("token") == "legacy-token") + } + test("buildForCatalog with legacy token format") { val catalogName = "test_catalog" val uri = "https://test-uri.com" @@ -467,9 +489,7 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess val result = UCCommitCoordinatorBuilder.buildForCatalog(spark, catalogName) assert(result.isInstanceOf[UCCommitCoordinatorClient]) - // Verify that createUCClient was called with the converted auth config verify(mockFactory).createUCClient( - meq(uri), any[Map[String, String]]() ) } @@ -490,12 +510,24 @@ class UCCommitCoordinatorBuilderSuite extends SparkFunSuite with SharedSparkSess assert(result.isInstanceOf[UCCommitCoordinatorClient]) verify(mockFactory).createUCClient( - meq(uri), any[Map[String, String]]() ) } } + test("extractAppVersions merges defaults with ucConfig entries") { + val ucConfig = Map( + "uri" -> "https://test.com", + "appVersions.Kernel" -> "0.7.0", + "appVersions.Delta V2 connector" -> "true" + ) + val versions = UCTokenBasedRestClientFactory.extractAppVersions(ucConfig) + assert(versions("Delta") === io.delta.VERSION) + assert(versions("Spark") === org.apache.spark.SPARK_VERSION) + assert(versions("Kernel") === "0.7.0") + assert(versions("Delta V2 connector") === "true") + } + test("buildForCatalog with non-existent catalog") { val exception = intercept[IllegalArgumentException] { UCCommitCoordinatorBuilder.buildForCatalog(spark, "non_existent_catalog") diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorClientSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorClientSuiteBase.scala index ab12be57c6f..1653fdac1e4 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorClientSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/coordinatedcommits/UCCommitCoordinatorClientSuiteBase.scala @@ -34,7 +34,7 @@ import io.delta.storage.commit.{ } import io.delta.storage.commit.uccommitcoordinator.{UCClient, UCCommitCoordinatorClient} import org.apache.hadoop.fs.Path -import org.mockito.ArgumentMatchers.{any, anyString} +import org.mockito.ArgumentMatchers.any import org.mockito.Mock import org.mockito.Mockito import org.mockito.Mockito.{mock, when} @@ -84,7 +84,7 @@ trait UCCommitCoordinatorClientSuiteBase extends CommitCoordinatorClientImplSuit CommitCoordinatorProvider.registerBuilder(UCCommitCoordinatorBuilder) ucCommitCoordinator = new InMemoryUCCommitCoordinator() ucClient = new InMemoryUCClient(metastoreId.toString, ucCommitCoordinator) - when(mockFactory.createUCClient(anyString(), any[Map[String, String]]())).thenReturn(ucClient) + when(mockFactory.createUCClient(any[Map[String, String]]())).thenReturn(ucClient) } override protected def createTableCommitCoordinatorClient( deltaLog: DeltaLog): TableCommitCoordinatorClient = { diff --git a/spark/v2/src/main/java/io/delta/spark/internal/v2/ddl/CreateTableBuilder.java b/spark/v2/src/main/java/io/delta/spark/internal/v2/ddl/CreateTableBuilder.java index b9cdb3dd34a..ef720c0b466 100644 --- a/spark/v2/src/main/java/io/delta/spark/internal/v2/ddl/CreateTableBuilder.java +++ b/spark/v2/src/main/java/io/delta/spark/internal/v2/ddl/CreateTableBuilder.java @@ -123,8 +123,7 @@ private static CreateTableTransactionBuilder buildUCTransactionBuilder( String tableName, io.delta.kernel.types.StructType kernelSchema) { UCClient ucClient = - UCTokenBasedRestClientFactory$.MODULE$.createUCClient( - ucTableInfo.getUcUri(), ucTableInfo.getAuthConfig()); + UCTokenBasedRestClientFactory$.MODULE$.createUCClient(ucTableInfo.toUcConfig()); return new UCCatalogManagedClient(ucClient) .buildCreateTableTransaction( ucTableInfo.getTableId(), diff --git a/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/SnapshotManagerFactory.java b/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/SnapshotManagerFactory.java index 7cb61d1eafb..0cc4860d3e3 100644 --- a/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/SnapshotManagerFactory.java +++ b/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/SnapshotManagerFactory.java @@ -22,6 +22,7 @@ import io.delta.spark.internal.v2.snapshot.unitycatalog.UCTableInfo; import io.delta.spark.internal.v2.snapshot.unitycatalog.UCUtils; import io.delta.storage.commit.uccommitcoordinator.UCClient; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import org.apache.spark.annotation.Experimental; @@ -71,14 +72,10 @@ public static DeltaSnapshotManager create( private static UCManagedTableSnapshotManager createUCManagedSnapshotManager( UCTableInfo tableInfo, Engine kernelEngine) { - // Start from defaults (Delta, Spark, Scala, Java) and add connector-specific entries - Map appVersions = - UCTokenBasedRestClientFactory$.MODULE$.defaultAppVersionsAsJava(); - appVersions.put("Kernel", Meta.KERNEL_VERSION); - appVersions.put("Delta V2 connector", "true"); - UCClient ucClient = - UCTokenBasedRestClientFactory$.MODULE$.createUCClientWithVersions( - tableInfo.getUcUri(), tableInfo.getAuthConfig(), appVersions); + Map ucConfig = new HashMap<>(tableInfo.toUcConfig()); + ucConfig.put("appVersions.Kernel", Meta.KERNEL_VERSION); + ucConfig.put("appVersions.Delta V2 connector", "true"); + UCClient ucClient = UCTokenBasedRestClientFactory$.MODULE$.createUCClient(ucConfig); UCCatalogManagedClient ucCatalogClient = new UCCatalogManagedClient(ucClient); return new UCManagedTableSnapshotManager(ucCatalogClient, tableInfo, kernelEngine); } diff --git a/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/unitycatalog/UCTableInfo.java b/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/unitycatalog/UCTableInfo.java index 7a9195b2c93..5563097f828 100644 --- a/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/unitycatalog/UCTableInfo.java +++ b/spark/v2/src/main/java/io/delta/spark/internal/v2/snapshot/unitycatalog/UCTableInfo.java @@ -18,6 +18,7 @@ import static java.util.Objects.requireNonNull; import java.util.Collections; +import java.util.HashMap; import java.util.Map; /** @@ -56,4 +57,15 @@ public String getUcUri() { public Map getAuthConfig() { return authConfig; } + + /** + * Builds a flat config map suitable for {@code UCTokenBasedRestClientFactory.createUCClient}. + * Re-adds the {@code auth.} prefix to auth config keys and includes {@code uri}. + */ + public Map toUcConfig() { + Map ucConfig = new HashMap<>(); + ucConfig.put("uri", ucUri); + authConfig.forEach((k, v) -> ucConfig.put("auth." + k, v)); + return ucConfig; + } }