diff --git a/okio/api/okio.api b/okio/api/okio.api index 05c4ab6191..ca3bb3816e 100644 --- a/okio/api/okio.api +++ b/okio/api/okio.api @@ -310,6 +310,9 @@ public class okio/ByteString : java/io/Serializable, java/lang/Comparable { public final fun endsWith (Lokio/ByteString;)Z public final fun endsWith ([B)Z public fun equals (Ljava/lang/Object;)Z + public final fun equals (Lokio/ByteString;)Z + public final fun equals (Lokio/ByteString;Z)Z + public static synthetic fun equals$default (Lokio/ByteString;Lokio/ByteString;ZILjava/lang/Object;)Z public final fun getByte (I)B public fun hashCode ()I public fun hex ()Ljava/lang/String; diff --git a/okio/src/appleMain/kotlin/okio/ByteString.kt b/okio/src/appleMain/kotlin/okio/ByteString.kt index 7d49ebd693..81259c0204 100644 --- a/okio/src/appleMain/kotlin/okio/ByteString.kt +++ b/okio/src/appleMain/kotlin/okio/ByteString.kt @@ -168,6 +168,8 @@ internal actual constructor( actual override fun equals(other: Any?) = commonEquals(other) + actual fun equals(other: ByteString, constantTime: Boolean) = commonEquals(other, constantTime) + actual override fun hashCode() = commonHashCode() actual override fun compareTo(other: ByteString) = commonCompareTo(other) diff --git a/okio/src/commonMain/kotlin/okio/ByteString.kt b/okio/src/commonMain/kotlin/okio/ByteString.kt index 3f2f4cf51a..1ed7fb2783 100644 --- a/okio/src/commonMain/kotlin/okio/ByteString.kt +++ b/okio/src/commonMain/kotlin/okio/ByteString.kt @@ -170,6 +170,9 @@ internal constructor(data: ByteArray) : Comparable { override fun equals(other: Any?): Boolean + @JvmOverloads + fun equals(other: ByteString, constantTime: Boolean = false): Boolean + override fun hashCode(): Int override fun compareTo(other: ByteString): Int diff --git a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt index 60dac84e4b..045178b66b 100644 --- a/okio/src/commonMain/kotlin/okio/internal/ByteString.kt +++ b/okio/src/commonMain/kotlin/okio/internal/ByteString.kt @@ -229,11 +229,23 @@ internal inline fun ByteString.commonLastIndexOf(other: ByteArray, fromIndex: In internal inline fun ByteString.commonEquals(other: Any?): Boolean { return when { other === this -> true - other is ByteString -> other.size == data.size && other.rangeEquals(0, data, 0, data.size) + other is ByteString -> equals(other, constantTime = false) else -> false } } +@Suppress("NOTHING_TO_INLINE") +internal inline fun ByteString.commonEquals(other: ByteString, constantTime: Boolean): Boolean { + if (other.size != data.size) return false + if (!constantTime) return other.rangeEquals(0, data, 0, data.size) + + var result = true + for (i in data.indices) { + result = result and ((data[i].toInt() xor other.data[i].toInt()) == 0) + } + return result +} + @Suppress("NOTHING_TO_INLINE") internal inline fun ByteString.commonHashCode(): Int { val result = hashCode diff --git a/okio/src/commonTest/kotlin/okio/ByteStringTest.kt b/okio/src/commonTest/kotlin/okio/ByteStringTest.kt index da1a89888b..eb3089042a 100644 --- a/okio/src/commonTest/kotlin/okio/ByteStringTest.kt +++ b/okio/src/commonTest/kotlin/okio/ByteStringTest.kt @@ -17,6 +17,9 @@ package okio import app.cash.burst.Burst +import assertk.assertThat +import assertk.assertions.isLessThanOrEqualTo +import kotlin.math.absoluteValue import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals @@ -26,6 +29,7 @@ import kotlin.test.assertNotEquals import kotlin.test.assertSame import kotlin.test.assertTrue import kotlin.test.fail +import kotlin.time.measureTime import okio.ByteString.Companion.decodeBase64 import okio.ByteString.Companion.decodeHex import okio.ByteString.Companion.encodeUtf8 @@ -201,6 +205,40 @@ class ByteStringTest( assertEquals(ByteString.of(), factory.decodeHex("")) } + @Test fun equalsConstantTimeTest() { + val bytes = Random(1234).nextBytes(1024 * 1024) + val subject = bytes.toByteString() + val subjectSame = bytes.toByteString() + + // This instance is the same as the subject except with the first byte changed. + bytes[0] = (bytes[0] + 1).toByte() + val subjectDifferent = bytes.toByteString() + + val iterations = 1000 + var result = true + + val equalsTime = measureTime { + repeat(iterations) { + result = result and subject.equals(subjectSame, constantTime = true) + } + } + val notEqualsTime = measureTime { + repeat(iterations) { + result = result and subject.equals(subjectDifferent, constantTime = true) + } + } + assertFalse(result) + + // If the equals method was short-circuiting then the difference percentage will be huge >99%. + // We'll just check if it's under 50% to account for variance, especially on CI machines. + val equalsMean = equalsTime.inWholeNanoseconds + val notEqualsMean = notEqualsTime.inWholeNanoseconds + val maxMean = maxOf(equalsMean, notEqualsMean) + val difference = (equalsMean - notEqualsMean).absoluteValue + val differencePercentage = 100 * difference / maxMean + assertThat(differencePercentage).isLessThanOrEqualTo(50) + } + private val bronzeHorseman = "На берегу пустынных волн" @Test fun utf8() { diff --git a/okio/src/jvmMain/kotlin/okio/ByteString.kt b/okio/src/jvmMain/kotlin/okio/ByteString.kt index 37922adfed..c63b861148 100644 --- a/okio/src/jvmMain/kotlin/okio/ByteString.kt +++ b/okio/src/jvmMain/kotlin/okio/ByteString.kt @@ -187,6 +187,9 @@ internal actual constructor( actual override fun equals(other: Any?) = commonEquals(other) + @JvmOverloads + actual fun equals(other: ByteString, constantTime: Boolean) = commonEquals(other, constantTime) + actual override fun hashCode() = commonHashCode() actual override fun compareTo(other: ByteString) = commonCompareTo(other) diff --git a/okio/src/nonAppleMain/kotlin/okio/ByteString.kt b/okio/src/nonAppleMain/kotlin/okio/ByteString.kt index e1b8cc5168..01129fd8a9 100644 --- a/okio/src/nonAppleMain/kotlin/okio/ByteString.kt +++ b/okio/src/nonAppleMain/kotlin/okio/ByteString.kt @@ -162,6 +162,8 @@ internal actual constructor( actual override fun equals(other: Any?) = commonEquals(other) + actual fun equals(other: ByteString, constantTime: Boolean) = commonEquals(other, constantTime) + actual override fun hashCode() = commonHashCode() actual override fun compareTo(other: ByteString) = commonCompareTo(other)