Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,16 @@ private fun checkOrigin(
LOGGER.trace { "${request.id()}: Skip CORS handler because Origin $origin is malformed" }
OriginCheckResult.SkipCORS
}

allowSameOrigin && isSameOrigin(origin, request.origin) -> {
if (request.isCorsPreflightRequest()) {
LOGGER.trace { "${request.id()}: Handle same-origin CORS preflight" }
return OriginCheckResult.OK
}
LOGGER.trace { "${request.id()}: Skip CORS handler because Origin $origin matches the server origin exactly" }
OriginCheckResult.SkipCORS
}

!corsCheckOrigins(
request,
origin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ public class CORSConfig {
/**
* Allows requests from the same origin.
*
* When `true` (default), same-origin requests skip CORS processing: no `Access-Control-*` headers are added
* and the request is passed to route handlers as-is.
*
* The only exception is CORS preflight requests — `OPTIONS` with an `Access-Control-Request-Method` header.
* The plugin still evaluates those (not skipped), even when the origin matches the server.
* The plugin checks the requested method and headers against the configured allowlists; if allowed, it responds
* with the appropriate CORS headers (typically `200 OK`), otherwise it responds with `403 Forbidden`.
*
* [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.server.plugins.cors.CORSConfig.allowSameOrigin)
*/
public var allowSameOrigin: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ internal fun corsCheckOrigins(
}
} else {
when {
allowsAnyHost ->
LOGGER.trace { "${request.id()}: Any * host is allowed" }
allowsAnyHost -> LOGGER.trace { "${request.id()}: Any * host is allowed" }
normalizedOrigin in hostsNormalized ->
LOGGER.trace { "${request.id()}: Origin $normalizedOrigin is allowed from $hostsNormalized" }

matchWildcardHosts ->
LOGGER.trace {
val (prefix, suffix) = hostsWithWildcard
Expand All @@ -85,6 +85,7 @@ internal fun corsCheckOrigins(
}!!
"${request.id()}: Origin $normalizedOrigin matches wildcard host $prefix*$suffix"
}

originPredicates.any { it(origin) } -> {
LOGGER.trace {
"${request.id()}: Origin $normalizedOrigin fulfills " +
Expand All @@ -108,6 +109,9 @@ internal fun corsCheckRequestHeaders(
internal fun headerMatchesAPredicate(header: String, headerPredicates: List<(String) -> Boolean>): Boolean =
headerPredicates.any { it(header) }

internal fun ApplicationRequest.isCorsPreflightRequest(): Boolean =
httpMethod == HttpMethod.Options && header(HttpHeaders.AccessControlRequestMethod) != null

internal fun ApplicationCall.corsCheckCurrentMethod(methods: Set<HttpMethod>): Boolean = request.httpMethod in methods

internal fun ApplicationCall.corsCheckRequestMethod(methods: Set<HttpMethod>): Boolean {
Expand All @@ -131,6 +135,27 @@ internal suspend fun ApplicationCall.respondCorsFailed() {
respond(HttpStatusCode.Forbidden)
}

private fun findPortDigitStartIndex(origin: String, hostStartIndex: Int): Int {
val isIpv6 = hostStartIndex < origin.length && origin[hostStartIndex] == '['
if (isIpv6) {
val ipv6LiteralEndIndex = origin.indexOf(']', hostStartIndex)
if (ipv6LiteralEndIndex == -1) {
return -1
}
val portSeparatorIndex = origin.indexOf(':', ipv6LiteralEndIndex)
return if (portSeparatorIndex != -1) portSeparatorIndex + 1 else origin.length
}

for (index in hostStartIndex until origin.length) {
when (origin[index]) {
':' -> return index + 1
'/' -> return origin.length
'?' -> return -1
}
}
return origin.length
}

internal fun isValidOrigin(origin: String): Boolean {
if (origin.isEmpty()) return false
if (origin == "null") return true
Expand All @@ -145,17 +170,11 @@ internal fun isValidOrigin(origin: String): Boolean {

if (!protoValid) return false

var portIndex = origin.length
for (index in protoDelimiter + 3 until origin.length) {
val ch = origin[index]
if (ch == ':' || ch == '/') {
portIndex = index + 1
break
}
if (ch == '?') return false
}
val hostStartIndex = protoDelimiter + 3
val portDigitStartIndex = findPortDigitStartIndex(origin, hostStartIndex)
if (portDigitStartIndex == -1) return false

for (index in portIndex until origin.length) {
for (index in portDigitStartIndex until origin.length) {
val isTrailingSlash = index == origin.length - 1 && origin[index] == '/'
if (!origin[index].isDigit() && !isTrailingSlash) return false
}
Expand All @@ -172,8 +191,13 @@ internal fun normalizeOrigin(origin: String): String {
} else {
builder.append(origin)
}
if (!builder.toString().substringAfterLast(":", "").matches(NUMBER_REGEX)) {
val port = when (builder.toString().substringBefore(':')) {
val originWithoutTrailingSlash = builder.toString()
val hostStartIndex = originWithoutTrailingSlash.indexOf("://") + 3
val portDigitStartIndex = findPortDigitStartIndex(originWithoutTrailingSlash, hostStartIndex)
val hasExplicitPort = portDigitStartIndex in originWithoutTrailingSlash.indices &&
originWithoutTrailingSlash.substring(portDigitStartIndex).matches(NUMBER_REGEX)
if (!hasExplicitPort) {
val port = when (originWithoutTrailingSlash.substringBefore(':')) {
"http" -> "80"
"https" -> "443"
else -> null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import io.ktor.server.routing.options
* A plugin that allows you to configure handling cross-origin requests.
* This plugin allows you to configure allowed hosts, HTTP methods, headers set by the client, and so on.
*
* CORS preflight requests (`OPTIONS` with an `Access-Control-Request-Method` header) are intercepted by the plugin
* and answered before routing, including when the request origin matches the server origin.
*
* The configuration below allows requests from the specified address and allows sending the `Content-Type` header:
* ```kotlin
* install(CORS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,21 @@ class CORSTest {
}.let { call ->
assertEquals(HttpStatusCode.Forbidden, call.status)
}

client.options("/") {
header(HttpHeaders.Origin, "http://localhost")
header(HttpHeaders.AccessControlRequestMethod, "GET")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertEquals("http://localhost", call.headers[HttpHeaders.AccessControlAllowOrigin])
}

// options without Access-Control-Request-Method just fall though, but there is no handler
client.options("/") {
header(HttpHeaders.Origin, "http://localhost")
}.let { call ->
assertEquals(HttpStatusCode.MethodNotAllowed, call.status)
}
}

@Test
Expand Down Expand Up @@ -981,6 +996,54 @@ class CORSTest {
}
}

@Test
fun ipv6LiteralOriginIsAccepted() = testApplication {
install(CORS) {
anyHost()
}

routing {
get("/") {
call.respond("OK")
}
}

client.get("/") {
header(HttpHeaders.Origin, "http://[::1]:22222")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertEquals("*", call.headers[HttpHeaders.AccessControlAllowOrigin])
}

client.get("/") {
header(HttpHeaders.Origin, "http://[2001:db8::1]:8080")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertEquals("*", call.headers[HttpHeaders.AccessControlAllowOrigin])
}

client.get("/") {
header(HttpHeaders.Origin, "http://[::1]:22222/")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertEquals("*", call.headers[HttpHeaders.AccessControlAllowOrigin])
}

client.get("/") {
header(HttpHeaders.Origin, "http://[::1]:notaport")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertNull(call.headers[HttpHeaders.AccessControlAllowOrigin])
}

client.get("/") {
header(HttpHeaders.Origin, "http://[::1")
}.let { call ->
assertEquals(HttpStatusCode.OK, call.status)
assertNull(call.headers[HttpHeaders.AccessControlAllowOrigin])
}
}

@Test
fun originValidation() = testApplication {
install(CORS) {
Expand Down