diff --git a/control-plane/common/controller/validator.go b/control-plane/common/controller/validator.go index 0d708fccf..9ea32ae76 100644 --- a/control-plane/common/controller/validator.go +++ b/control-plane/common/controller/validator.go @@ -7,7 +7,7 @@ import ( "github.com/xeipuuv/gojsonschema" ) -//go:generate cp ../../../data-plane/core/config/src/grpc/schema/client-config.schema.json ./schema.json +//go:generate cp ../../../data-plane/core/config/src/schema/client-config.schema.json ./schema.json //go:embed schema.json var schemaData []byte diff --git a/data-plane/.cargo/config.toml b/data-plane/.cargo/config.toml index 182595223..57b230e63 100644 --- a/data-plane/.cargo/config.toml +++ b/data-plane/.cargo/config.toml @@ -18,6 +18,11 @@ rustflags = ["-C", "link-arg=-mios-version-min=13.4"] linker = "clang-ios-linker-wrapper.sh" rustflags = ["-C", "link-arg=-mios-simulator-version-min=13.4"] +# Enable mls_build_async for WASM targets so that mls-rs traits and methods +# become async (required by mls-rs-crypto-webcrypto which uses SubtleCrypto). +[target.wasm32-unknown-unknown] +rustflags = ["--cfg", "mls_build_async"] + [env] # CFLAGS suppresses macro redefinition warnings that occur when system headers # are mixed with Zig's headers (common with Windows cross-compilation). diff --git a/data-plane/Cargo.lock b/data-plane/Cargo.lock index d415ff994..0c69af99b 100644 --- a/data-plane/Cargo.lock +++ b/data-plane/Cargo.lock @@ -31,7 +31,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tracing", "tracing-test", @@ -46,11 +46,14 @@ dependencies = [ "agntcy-slim-version", "async-trait", "aws-lc-rs", - "base64", + "base64 0.22.1", "display-error-chain", "futures", + "getrandom 0.3.4", "headers", + "hmac", "http", + "js-sys", "jsonwebtoken", "mls-rs-core", "mls-rs-crypto-awslc", @@ -61,11 +64,12 @@ dependencies = [ "prost-types", "rand 0.9.4", "reqwest", - "schemars 1.2.1", + "schemars 1.2.0", "serde", "serde_json", + "sha2", "spiffe", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-test", "tokio-util", @@ -105,7 +109,7 @@ dependencies = [ "serde", "serde_json", "test-fork", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-stream", "tokio-test", @@ -123,12 +127,18 @@ dependencies = [ "agntcy-slim-testing", "agntcy-slim-version", "async-trait", - "base64", + "base64 0.22.1", + "bytes", "display-error-chain", "drain", "duration-string", + "fastwebsockets", "futures", + "getrandom 0.3.4", + "gloo-net", "http", + "http-body-util", + "hyper", "hyper-rustls", "hyper-util", "lazy_static", @@ -140,13 +150,14 @@ dependencies = [ "rustls", "rustls-native-certs", "rustls-pki-types", - "schemars 1.2.1", + "schemars 1.2.0", "serde", "serde_json", "serde_yaml", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-retry", + "tokio-rustls", "tokio-stream", "tokio-util", "tonic", @@ -185,7 +196,7 @@ dependencies = [ "rand 0.9.4", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-stream", "tokio-util", @@ -207,7 +218,12 @@ dependencies = [ "criterion", "display-error-chain", "drain", + "fastwebsockets", + "futures", + "getrandom 0.3.4", + "gloo-net", "h2", + "http", "opentelemetry", "parking_lot", "prost", @@ -216,10 +232,11 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-stream", "tokio-util", + "tokio_with_wasm", "tonic", "tonic-prost", "tonic-prost-build", @@ -228,6 +245,7 @@ dependencies = [ "tracing-test", "twox-hash", "uuid", + "wasm-bindgen-futures", ] [[package]] @@ -237,14 +255,18 @@ dependencies = [ "agntcy-slim-auth", "agntcy-slim-datapath", "agntcy-slim-version", - "base64", + "async-trait", + "base64 0.22.1", + "getrandom 0.3.4", "hex", + "maybe-async", "mls-rs", "mls-rs-core", "mls-rs-crypto-awslc", + "mls-rs-crypto-webcrypto", "serde_json", "tempfile", - "thiserror 2.0.18", + "thiserror 2.0.17", "tracing", ] @@ -265,7 +287,7 @@ dependencies = [ "futures", "parking_lot", "serde", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-test", "tokio-util", @@ -286,15 +308,18 @@ dependencies = [ "display-error-chain", "futures", "futures-timer", + "getrandom 0.3.4", + "maybe-async", "parking_lot", "rand 0.9.4", "serde", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-util", - "tonic", + "tokio_with_wasm", "tracing", "tracing-test", + "web-time", ] [[package]] @@ -318,7 +343,7 @@ dependencies = [ "agntcy-slim-session", "agntcy-slim-tracing", "aws-lc-rs", - "base64", + "base64 0.22.1", "bollard", "clap", "futures", @@ -328,7 +353,7 @@ dependencies = [ "rand 0.9.4", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tracing", "uuid", @@ -341,6 +366,7 @@ version = "0.3.9" dependencies = [ "agntcy-slim-config", "agntcy-slim-version", + "getrandom 0.3.4", "once_cell", "opentelemetry", "opentelemetry-otlp", @@ -348,17 +374,39 @@ dependencies = [ "opentelemetry-stdout", "opentelemetry_sdk", "serde", - "thiserror 2.0.18", + "thiserror 2.0.17", "tracing", "tracing-opentelemetry", "tracing-subscriber", "uuid", + "web-sys", ] [[package]] name = "agntcy-slim-version" version = "1.4.0-rc.0" +[[package]] +name = "agntcy-slim-wasm" +version = "0.1.0" +dependencies = [ + "agntcy-slim-auth", + "agntcy-slim-config", + "agntcy-slim-datapath", + "agntcy-slim-mls", + "agntcy-slim-session", + "agntcy-slim-tracing", + "console_error_panic_hook", + "gloo-net", + "js-sys", + "parking_lot", + "tokio_with_wasm", + "tracing", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", +] + [[package]] name = "agntcy-slimctl" version = "1.4.0-rc.0" @@ -416,9 +464,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "1.0.0" +version = "0.6.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" dependencies = [ "anstyle", "anstyle-parse", @@ -431,15 +479,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.14" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" [[package]] name = "anstyle-parse" -version = "1.0.0" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] @@ -466,9 +514,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.102" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arc-swap" @@ -518,7 +566,7 @@ dependencies = [ "memchr", "serde", "serde_derive", - "winnow 0.7.15", + "winnow 0.7.14", ] [[package]] @@ -533,7 +581,7 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror 2.0.18", + "thiserror 2.0.17", "time", ] @@ -681,6 +729,12 @@ dependencies = [ "tower-service", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -743,7 +797,7 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d41711ad46fda47cd701f6908e59d1bd6b9a2b7464c0d0aeab95c6d37096ff8a" dependencies = [ - "base64", + "base64 0.22.1", "bollard-stubs", "bytes", "futures-core", @@ -783,9 +837,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "bytes" @@ -822,7 +876,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror 2.0.17", ] [[package]] @@ -862,9 +916,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ "iana-time-zone", "js-sys", @@ -903,9 +957,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.6.1" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" dependencies = [ "clap_builder", "clap_derive", @@ -913,9 +967,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.6.0" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" dependencies = [ "anstream", "anstyle", @@ -925,9 +979,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.6.1" +version = "4.5.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" dependencies = [ "heck", "proc-macro2", @@ -952,9 +1006,19 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.5" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] [[package]] name = "const-oid" @@ -962,6 +1026,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "const-oid" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" + [[package]] name = "core-foundation" version = "0.9.4" @@ -1163,7 +1233,18 @@ version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ - "const-oid", + "const-oid 0.9.6", +] + +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "const-oid 0.10.2", + "der_derive", + "zeroize", ] [[package]] @@ -1180,11 +1261,22 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "der_derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59600e2c2d636fde9b65e99cc6445ac770c63d3628195ff39932b8d6d7409903" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "deranged" -version = "0.5.8" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" dependencies = [ "powerfmt", "serde_core", @@ -1198,6 +1290,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -1286,6 +1379,26 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +[[package]] +name = "fastwebsockets" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "305d3ba574508e27190906d11707dad683e0494e6b85eae9b044cb2734a5e422" +dependencies = [ + "base64 0.21.7", + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "pin-project", + "rand 0.8.6", + "sha1", + "simdutf8", + "thiserror 1.0.69", + "tokio", + "utf-8", +] + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -1345,9 +1458,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1360,9 +1473,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1370,15 +1483,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1387,15 +1500,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -1404,15 +1517,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -1422,9 +1535,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.32" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1434,6 +1547,7 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", + "pin-utils", "slab", ] @@ -1467,9 +1581,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi 5.3.0", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1492,6 +1608,40 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "gloo-net" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06f627b1a58ca3d42b45d6104bf1e1a03799df472df00988b6ba21accc10580" +dependencies = [ + "futures-channel", + "futures-core", + "futures-sink", + "gloo-utils", + "http", + "js-sys", + "pin-project", + "serde", + "serde_json", + "thiserror 1.0.69", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "gloo-utils" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5555354113b18c547c1d3a98fbf7fb32a9ff4f6fa112ce823a21641a0ba3aa" +dependencies = [ + "js-sys", + "serde", + "serde_json", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "goblin" version = "0.8.2" @@ -1560,7 +1710,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "headers-core", "http", @@ -1599,6 +1749,15 @@ dependencies = [ "serde", ] +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "1.4.0" @@ -1711,13 +1870,14 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.20" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-channel", + "futures-core", "futures-util", "http", "http-body", @@ -1751,9 +1911,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.65" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1913,9 +2073,9 @@ dependencies = [ [[package]] name = "inotify" -version = "0.11.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5b3eaf1a28b758ac0faa5a4254e8ab2705605496f1b1f3fbbc3988ad73d199" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" dependencies = [ "bitflags", "inotify-sys", @@ -1933,9 +2093,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.12.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" @@ -1974,9 +2134,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.18" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jemalloc-sys" @@ -2027,7 +2187,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" dependencies = [ "aws-lc-rs", - "base64", + "base64 0.22.1", "getrandom 0.2.17", "js-sys", "pem", @@ -2077,9 +2237,9 @@ checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "linux-raw-sys" -version = "0.12.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" [[package]] name = "litemap" @@ -2130,9 +2290,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.8.0" +version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "mime" @@ -2182,7 +2342,7 @@ dependencies = [ "serde", "spin", "subtle", - "thiserror 2.0.18", + "thiserror 2.0.17", "wasm-bindgen", "zeroize", ] @@ -2195,7 +2355,7 @@ checksum = "45bd834f164dc06c1fed805540ae307a460b7ed7c2769a35a376f1de577a0dc1" dependencies = [ "itertools 0.14.0", "mls-rs-codec-derive", - "thiserror 2.0.18", + "thiserror 2.0.17", "wasm-bindgen", ] @@ -2222,7 +2382,7 @@ dependencies = [ "maybe-async", "mls-rs-codec", "serde", - "thiserror 2.0.18", + "thiserror 2.0.17", "wasm-bindgen", "zeroize", ] @@ -2241,7 +2401,7 @@ dependencies = [ "mls-rs-crypto-hpke", "mls-rs-crypto-traits", "mls-rs-identity-x509", - "thiserror 2.0.18", + "thiserror 2.0.17", "zeroize", ] @@ -2256,7 +2416,7 @@ dependencies = [ "maybe-async", "mls-rs-core", "mls-rs-crypto-traits", - "thiserror 2.0.18", + "thiserror 2.0.17", "zeroize", ] @@ -2272,6 +2432,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "mls-rs-crypto-webcrypto" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "292a8e8a85689c029a6f6b0c7a7eef0f390a044ef176b80d0442b38948c66e35" +dependencies = [ + "async-trait", + "const-oid 0.10.2", + "der 0.8.0", + "js-sys", + "maybe-async", + "mls-rs-core", + "mls-rs-crypto-hpke", + "mls-rs-crypto-traits", + "serde", + "serde-wasm-bindgen", + "thiserror 2.0.17", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "zeroize", +] + [[package]] name = "mls-rs-identity-x509" version = "0.20.0" @@ -2281,7 +2464,7 @@ dependencies = [ "async-trait", "maybe-async", "mls-rs-core", - "thiserror 2.0.18", + "thiserror 2.0.17", "wasm-bindgen", ] @@ -2387,7 +2570,7 @@ version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "getrandom 0.2.17", "http", @@ -2412,9 +2595,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.4" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "once_cell_polyfill" @@ -2430,9 +2613,9 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" -version = "0.2.1" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +checksum = "9f50d9b3dabb09ecd771ad0aa242ca6894994c130308ca3d7684634df8037391" [[package]] name = "opentelemetry" @@ -2444,7 +2627,7 @@ dependencies = [ "futures-sink", "js-sys", "pin-project-lite", - "thiserror 2.0.18", + "thiserror 2.0.17", "tracing", ] @@ -2463,9 +2646,9 @@ dependencies = [ [[package]] name = "opentelemetry-otlp" -version = "0.31.1" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f69cd6acbb9af919df949cd1ec9e5e7fdc2ef15d234b6b795aaa525cc02f71f" +checksum = "7a2366db2dca4d2ad033cad11e6ee42844fd727007af5ad04a1730f4cb8163bf" dependencies = [ "http", "opentelemetry", @@ -2474,7 +2657,7 @@ dependencies = [ "opentelemetry_sdk", "prost", "reqwest", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tonic", "tracing", @@ -2522,7 +2705,7 @@ dependencies = [ "opentelemetry", "percent-encoding", "rand 0.9.4", - "thiserror 2.0.18", + "thiserror 2.0.17", "tokio", "tokio-stream", ] @@ -2556,7 +2739,7 @@ version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" dependencies = [ - "base64", + "base64 0.22.1", "serde_core", ] @@ -2599,9 +2782,15 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.17" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs8" @@ -2609,7 +2798,7 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der", + "der 0.7.10", "spki", ] @@ -2649,9 +2838,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.13.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" dependencies = [ "critical-section", ] @@ -2701,9 +2890,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.106" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" dependencies = [ "unicode-ident", ] @@ -2827,9 +3016,9 @@ checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" [[package]] name = "pulldown-cmark" -version = "0.13.3" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" +checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" dependencies = [ "bitflags", "memchr", @@ -2993,9 +3182,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" dependencies = [ "aho-corasick", "memchr", @@ -3005,9 +3194,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.14" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" dependencies = [ "aho-corasick", "memchr", @@ -3016,9 +3205,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" @@ -3026,7 +3215,7 @@ version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-channel", @@ -3092,9 +3281,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.4" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags", "errno", @@ -3159,9 +3348,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.23" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "same-file" @@ -3174,9 +3363,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.29" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" dependencies = [ "windows-sys 0.61.2", ] @@ -3195,9 +3384,9 @@ dependencies = [ [[package]] name = "schemars" -version = "1.2.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +checksum = "54e910108742c57a770f492731f99be216a52fadd361b06c8fb59d74ccc267d2" dependencies = [ "dyn-clone", "ref-cast", @@ -3208,9 +3397,9 @@ dependencies = [ [[package]] name = "schemars_derive" -version = "1.2.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +checksum = "4908ad288c5035a8eb12cfdf0d49270def0a268ee162b75eeee0f85d155a7c45" dependencies = [ "proc-macro2", "quote", @@ -3246,9 +3435,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "3.7.0" +version = "3.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" dependencies = [ "bitflags", "core-foundation 0.10.1", @@ -3259,9 +3448,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.17.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", @@ -3287,6 +3476,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -3380,13 +3580,13 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", "indexmap 2.14.0", "schemars 0.9.0", - "schemars 1.2.1", + "schemars 1.2.0", "serde_core", "serde_json", "time", @@ -3461,15 +3661,21 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "simple_asn1" -version = "0.6.4" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" dependencies = [ "num-bigint", "num-traits", - "thiserror 2.0.18", + "thiserror 2.0.17", "time", ] @@ -3481,9 +3687,9 @@ checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" -version = "0.4.12" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" [[package]] name = "slim-examples" @@ -3546,7 +3752,7 @@ dependencies = [ "prost-types", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror 2.0.17", "time", "tokio", "tokio-util", @@ -3573,7 +3779,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ - "der", + "der 0.7.10", ] [[package]] @@ -3602,9 +3808,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.117" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -3633,9 +3839,9 @@ dependencies = [ [[package]] name = "system-configuration" -version = "0.7.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags", "core-foundation 0.9.4", @@ -3654,12 +3860,12 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.27.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom 0.4.2", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys 0.61.2", @@ -3712,11 +3918,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.18" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 2.0.18", + "thiserror-impl 2.0.17", ] [[package]] @@ -3732,9 +3938,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.18" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", @@ -3884,6 +4090,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio_with_wasm" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef3ce6a8f5b5190dfe4851db6c969e8360a262759e16a0b75dfc43af19d97a86" +dependencies = [ + "js-sys", + "tokio", + "tokio_with_wasm_proc", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "tokio_with_wasm_proc" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d8aa1d26c1550eef93cfb2dafadc145b3220432dae8d156b5ba485880594ffe" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "toml" version = "0.9.12+spec-1.1.0" @@ -3896,7 +4126,7 @@ dependencies = [ "toml_datetime", "toml_parser", "toml_writer", - "winnow 0.7.15", + "winnow 0.7.14", ] [[package]] @@ -3925,13 +4155,13 @@ checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "tonic" -version = "0.14.5" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" +checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" dependencies = [ "async-trait", "axum", - "base64", + "base64 0.22.1", "bytes", "h2", "http", @@ -3954,9 +4184,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.14.5" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1882ac3bf5ef12877d7ed57aad87e75154c11931c2ba7e6cde5e22d63522c734" +checksum = "4c40aaccc9f9eccf2cd82ebc111adc13030d23e887244bc9cfa5d1d636049de3" dependencies = [ "prettyplease", "proc-macro2", @@ -3966,9 +4196,9 @@ dependencies = [ [[package]] name = "tonic-prost" -version = "0.14.5" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" dependencies = [ "bytes", "prost", @@ -3977,9 +4207,9 @@ dependencies = [ [[package]] name = "tonic-prost-build" -version = "0.14.5" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3144df636917574672e93d0f56d7edec49f90305749c668df5101751bb8f95a" +checksum = "b4a16cba4043dc3ff43fcb3f96b4c5c154c64cbd18ca8dce2ab2c6a451d058a2" dependencies = [ "prettyplease", "proc-macro2", @@ -4034,7 +4264,7 @@ version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28f0d049ccfaa566e14e9663d304d8577427b368cb4710a20528690287a738b" dependencies = [ - "base64", + "base64 0.22.1", "bitflags", "bytes", "futures-util", @@ -4121,9 +4351,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.23" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -4139,9 +4369,9 @@ dependencies = [ [[package]] name = "tracing-test" -version = "0.2.6" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19a4c448db514d4f24c5ddb9f73f2ee71bfb24c526cf0c570ba142d1119e0051" +checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" dependencies = [ "tracing-core", "tracing-subscriber", @@ -4150,9 +4380,9 @@ dependencies = [ [[package]] name = "tracing-test-macro" -version = "0.2.6" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad06847b7afb65c7866a36664b75c40b895e318cea4f71299f013fb22965329d" +checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" dependencies = [ "quote", "syn", @@ -4187,9 +4417,9 @@ checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" [[package]] name = "unicode-ident" -version = "1.0.24" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "unicode-xid" @@ -4367,6 +4597,12 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4828,9 +5064,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.15" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" dependencies = [ "memchr", ] @@ -4848,7 +5084,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" dependencies = [ "assert-json-diff", - "base64", + "base64 0.22.1", "deadpool", "futures", "http", @@ -4977,7 +5213,7 @@ dependencies = [ "nom", "oid-registry", "rusticata-macros", - "thiserror 2.0.18", + "thiserror 2.0.17", "time", ] @@ -5100,6 +5336,6 @@ dependencies = [ [[package]] name = "zmij" -version = "1.0.21" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/data-plane/Cargo.toml b/data-plane/Cargo.toml index d557ca1ce..119227715 100644 --- a/data-plane/Cargo.toml +++ b/data-plane/Cargo.toml @@ -11,6 +11,7 @@ members = [ "core/session", "core/signal", "core/slim", + "core/slim-wasm", "core/tracing", "core/version", "examples", @@ -39,17 +40,17 @@ resolver = "2" [workspace.dependencies] # Local dependencies agntcy-slim = { path = "core/slim", version = "1.4.0-rc.0" } -agntcy-slim-auth = { path = "core/auth", version = "0.7.0" } +agntcy-slim-auth = { path = "core/auth", version = "0.7.0", default-features = false } agntcy-slim-bindings = { path = "bindings/rust", version = "1.4.0-rc.0" } -agntcy-slim-config = { path = "core/config", version = "0.9.0" } -agntcy-slim-controller = { path = "core/controller", version = "0.5.0" } -agntcy-slim-datapath = { path = "core/datapath", version = "0.12.3" } -agntcy-slim-mls = { path = "core/mls", version = "0.1.15" } +agntcy-slim-config = { path = "core/config", version = "0.9.0", default-features = false } +agntcy-slim-controller = { path = "core/controller", version = "0.5.0", default-features = false } +agntcy-slim-datapath = { path = "core/datapath", version = "0.12.3", default-features = false } +agntcy-slim-mls = { path = "core/mls", version = "0.1.15", default-features = false } agntcy-slim-service = { path = "core/service", version = "0.8.12", default-features = false } -agntcy-slim-session = { path = "core/session", version = "0.1.12" } -agntcy-slim-signal = { path = "core/signal", version = "0.1.9" } +agntcy-slim-session = { path = "core/session", version = "0.1.12", default-features = false } +agntcy-slim-signal = { path = "core/signal", version = "0.1.9", default-features = false } agntcy-slim-testing = { path = "testing" } -agntcy-slim-tracing = { path = "core/tracing", version = "0.3.9" } +agntcy-slim-tracing = { path = "core/tracing", version = "0.3.9", default-features = false } agntcy-slim-version = { path = "core/version", version = "1.4.0-rc.0" } agntcy-slimctl = { path = "slimctl", version = "1.4.0-rc.0" } @@ -68,12 +69,15 @@ criterion = { version = "0.5", features = ["html_reports"] } display-error-chain = { version = "0.2" } drain = { version = "0.2", features = ["retain"] } duration-string = { version = "0.5.3", features = ["serde"] } +fastwebsockets = { version = "0.10.0", features = ["upgrade", "unstable-split"] } futures = "0.3.31" futures-timer = "3.0.3" h2 = "0.4.7" headers = "0.4.1" hex = "0.4.3" http = "1.2.0" +http-body-util = "0.1.3" +hyper = { version = "1.8.1", features = ["http1", "server", "client"] } hyper-rustls = { version = "0.27", features = ["http2", "aws-lc-rs"], default-features = false } hyper-util = "0.1.10" indexmap = "2" @@ -104,7 +108,7 @@ protoc-bin-vendored = "3.1.0" pyo3 = { version = "0.26" } pyo3-async-runtimes = { version = "0.26", features = ["tokio-runtime"] } pyo3-stub-gen = "0.14" -rand = "0.9.0" +rand = "0.9.3" regex = "1.11.1" reqwest = { version = "0.12", features = ["json", "rustls-tls-no-provider", "charset", "http2", "system-proxy", "blocking"], default-features = false } @@ -123,6 +127,7 @@ test-fork = "0.1.3" thiserror = "2.0.9" tokio = "1.52.2" tokio-retry = "0.3" +tokio-rustls = "0.26.4" tokio-stream = "0.1" tokio-test = "0.4.4" tokio-util = "0.7.14" diff --git a/data-plane/Taskfile.yaml b/data-plane/Taskfile.yaml index 7d4de4f50..02e21ba4f 100644 --- a/data-plane/Taskfile.yaml +++ b/data-plane/Taskfile.yaml @@ -126,6 +126,29 @@ tasks: cmds: - cargo run --bin generate-schema + # Feature-specific build tasks for agntcy-slim-config + data-plane:check:no-default-features: + desc: "Check agntcy-slim-config compiles with no default features" + cmds: + - cargo check --package agntcy-slim-config --no-default-features --locked + + data-plane:check:config-wasm: + desc: "Check agntcy-slim-config compiles with wasm feature for wasm target" + cmds: + - cargo check --package agntcy-slim-config --no-default-features --features wasm --target wasm32-unknown-unknown --locked + + data-plane:check:features: + desc: "Check all feature combinations compile" + cmds: + - task: data-plane:check:no-default-features + - task: data-plane:check:config-wasm + + # Build tasks for specific features + data-plane:build:config-wasm: + desc: "Build with wasm feature for wasm target" + cmds: + - cargo build --package agntcy-slim-config --no-default-features --features wasm --target wasm32-unknown-unknown --locked + # Enhanced coverage task that includes Python bindings coverage data-plane:coverage-full: desc: "Run full tests coverage including Python bindings" @@ -134,3 +157,17 @@ tasks: vars: ARGS: "{{.GLOBAL_ARGS}} --all-targets --locked --all-features" - task: data-plane:core:test + + # WASM (wasm32-unknown-unknown) build checks + data-plane:check:wasm: + desc: "Check all WASM-ready crates compile for wasm32-unknown-unknown" + cmds: + - task: data-plane:check:config-wasm + - cargo check -p agntcy-slim-auth --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-datapath --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-tracing --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-signal --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-mls --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-session --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-controller --target wasm32-unknown-unknown --no-default-features --features wasm --locked + - cargo check -p agntcy-slim-service --target wasm32-unknown-unknown --no-default-features --features "wasm,session" --locked diff --git a/data-plane/bindings/rust/Cargo.toml b/data-plane/bindings/rust/Cargo.toml index a7d47608a..4705c1cf7 100644 --- a/data-plane/bindings/rust/Cargo.toml +++ b/data-plane/bindings/rust/Cargo.toml @@ -12,14 +12,14 @@ crate-type = ["lib", "staticlib", "cdylib"] [dependencies] # Local workspace dependencies agntcy-slim = { workspace = true } -agntcy-slim-auth = { workspace = true } -agntcy-slim-config = { workspace = true } -agntcy-slim-controller = { workspace = true } -agntcy-slim-datapath = { workspace = true } -agntcy-slim-service = { workspace = true, features = ["session"] } -agntcy-slim-session = { workspace = true } -agntcy-slim-signal = { workspace = true } -agntcy-slim-tracing = { workspace = true } +agntcy-slim-auth = { workspace = true, features = ["native"] } +agntcy-slim-config = { workspace = true, features = ["native"] } +agntcy-slim-controller = { workspace = true, features = ["native"] } +agntcy-slim-datapath = { workspace = true, features = ["native"] } +agntcy-slim-service = { workspace = true, features = ["native", "session"] } +agntcy-slim-session = { workspace = true, features = ["native"] } +agntcy-slim-signal = { workspace = true, features = ["native"] } +agntcy-slim-tracing = { workspace = true, features = ["native"] } agntcy-slim-version = { workspace = true } # External dependencies diff --git a/data-plane/config/websocket/client-config-debug.yaml b/data-plane/config/websocket/client-config-debug.yaml new file mode 100644 index 000000000..5d5d80026 --- /dev/null +++ b/data-plane/config/websocket/client-config-debug.yaml @@ -0,0 +1,21 @@ +# Debug config for client +tracing: + log_level: debug + display_thread_names: true + display_thread_ids: true + +runtime: + n_cores: 0 + thread_name: "slim-data-plane" + drain_timeout: 10s + +services: + slim/0: + dataplane: + clients: + - endpoint: "ws://localhost:46357" + transport: websocket + websocket_auth_query_param: token + tls: + insecure: true + servers: [] diff --git a/data-plane/config/websocket/client-config-wss.yaml b/data-plane/config/websocket/client-config-wss.yaml new file mode 100644 index 000000000..a3b8eaac9 --- /dev/null +++ b/data-plane/config/websocket/client-config-wss.yaml @@ -0,0 +1,25 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +tracing: + log_level: info + display_thread_names: true + display_thread_ids: true + +runtime: + n_cores: 0 + thread_name: "slim-data-plane" + drain_timeout: 10s + +services: + slim/0: + dataplane: + clients: + - endpoint: "wss://localhost:46357" + transport: websocket + server_name: "localhost" + tls: + ca_source: + type: file + path: "./config/crypto/server-ca-cert.pem" + servers: [] diff --git a/data-plane/config/websocket/server-config-debug.yaml b/data-plane/config/websocket/server-config-debug.yaml new file mode 100644 index 000000000..fb01e3d70 --- /dev/null +++ b/data-plane/config/websocket/server-config-debug.yaml @@ -0,0 +1,20 @@ +# Debug config with detailed logging +tracing: + log_level: debug + display_thread_names: true + display_thread_ids: true + +runtime: + n_cores: 0 + thread_name: "slim-data-plane" + drain_timeout: 10s + +services: + slim/0: + dataplane: + servers: + - endpoint: "ws://0.0.0.0:46357" + transport: websocket + tls: + insecure: true + clients: [] diff --git a/data-plane/config/websocket/server-config-wss.yaml b/data-plane/config/websocket/server-config-wss.yaml new file mode 100644 index 000000000..58d8dcfab --- /dev/null +++ b/data-plane/config/websocket/server-config-wss.yaml @@ -0,0 +1,25 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +tracing: + log_level: info + display_thread_names: true + display_thread_ids: true + +runtime: + n_cores: 0 + thread_name: "slim-data-plane" + drain_timeout: 10s + +services: + slim/0: + dataplane: + servers: + - endpoint: "wss://0.0.0.0:46357" + transport: websocket + tls: + source: + type: file + cert: ./config/crypto/server-cert.pem + key: ./config/crypto/server-key.pem + clients: [] diff --git a/data-plane/core/auth/Cargo.toml b/data-plane/core/auth/Cargo.toml index d19de6990..260b888c3 100644 --- a/data-plane/core/auth/Cargo.toml +++ b/data-plane/core/auth/Cargo.toml @@ -5,43 +5,82 @@ license = { workspace = true } edition = { workspace = true } description = "Authentication utilities for the Agntcy Slim framework" +[package.metadata.cargo-machete] +ignored = ["getrandom"] + [lib] name = "slim_auth" +[features] +default = ["native"] +native = [ + "dep:aws-lc-rs", + "dep:display-error-chain", + "dep:futures", + "dep:headers", + "dep:jsonwebtoken", + "dep:mls-rs-core", + "dep:mls-rs-crypto-awslc", + "dep:notify", + "dep:oauth2", + "dep:parking_lot", + "dep:pin-project", + "dep:reqwest", + "dep:tokio", + "dep:tokio-util", + "dep:tower", + "dep:tower-layer", + "dep:tower-service", + "dep:url", + "dep:wiremock", + "native-spiffe", +] +native-spiffe = ["dep:spiffe"] +wasm = ["dep:hmac", "dep:js-sys", "dep:parking_lot", "dep:sha2", "dep:getrandom"] + [dependencies] +# Always available (core types, traits, metadata) agntcy-slim-version = { workspace = true } async-trait = { workspace = true } -aws-lc-rs = { workspace = true } + +# Native-only dependencies +aws-lc-rs = { workspace = true, optional = true } base64 = { workspace = true } -display-error-chain = { workspace = true } -futures = { workspace = true } -headers = { workspace = true } +display-error-chain = { workspace = true, optional = true } +futures = { workspace = true, optional = true } + +# WASM support +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } +headers = { workspace = true, optional = true } +hmac = { version = "0.12", optional = true } http = { workspace = true } -jsonwebtoken = { workspace = true } -mls-rs-core = { workspace = true } -mls-rs-crypto-awslc = { workspace = true } -notify = { workspace = true } -oauth2 = { workspace = true, features = ["reqwest"] } -parking_lot = { workspace = true } -pin-project = { workspace = true } +js-sys = { version = "0.3", optional = true } +jsonwebtoken = { workspace = true, optional = true } +mls-rs-core = { workspace = true, optional = true } +mls-rs-crypto-awslc = { workspace = true, optional = true } +notify = { workspace = true, optional = true } +oauth2 = { workspace = true, features = ["reqwest"], optional = true } +parking_lot = { workspace = true, optional = true } +pin-project = { workspace = true, optional = true } prost-types = { workspace = true } rand = { workspace = true } -reqwest = { workspace = true } +reqwest = { workspace = true, optional = true } schemars = { workspace = true } -serde = { workspace = true } +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +sha2 = { version = "0.10", optional = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } -tokio-util = { workspace = true } -tower = { workspace = true } -tower-layer = { workspace = true } -tower-service = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"], optional = true } +tokio-util = { workspace = true, optional = true } +tower = { workspace = true, optional = true } +tower-layer = { workspace = true, optional = true } +tower-service = { workspace = true, optional = true } tracing = { workspace = true } -url = { workspace = true } -wiremock = { workspace = true } +url = { workspace = true, optional = true } +wiremock = { workspace = true, optional = true } [target.'cfg(not(target_family = "windows"))'.dependencies] -spiffe = { workspace = true } +spiffe = { workspace = true, optional = true } [dev-dependencies] agntcy-slim-config = { path = "../config" } @@ -49,4 +88,3 @@ agntcy-slim-testing = { path = "../../testing" } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] } tokio-test = { workspace = true } tracing-test = { workspace = true } -# uuid moved to optional dependency under [dependencies] for feature = "testutils" diff --git a/data-plane/core/auth/src/auth_provider.rs b/data-plane/core/auth/src/auth_provider.rs index 8e05fcfc9..e5fe2c3dd 100644 --- a/data-plane/core/auth/src/auth_provider.rs +++ b/data-plane/core/auth/src/auth_provider.rs @@ -29,6 +29,7 @@ //! // Create a shared secret provider //! let provider = AuthProvider::shared_secret_from_str("service-id", "shared-secret-value-0123456789abcdef")?; //! let token = provider.get_token()?; +//! # Ok::<(), Box>(()) //! ``` //! //! ## Creating an authentication verifier @@ -39,9 +40,9 @@ //! //! # tokio_test::block_on(async { //! // Create a shared secret provider and verifier -//! let provider = AuthProvider::shared_secret_from_str("service-id", "shared-secret-value-0123456789abcdef")?; -//! let verifier = AuthVerifier::shared_secret_from_str("service-id", "shared-secret-value-0123456789abcdef")?; -//! let token = provider.get_token()?; +//! let provider = AuthProvider::shared_secret_from_str("service-id", "shared-secret-value-0123456789abcdef").unwrap(); +//! let verifier = AuthVerifier::shared_secret_from_str("service-id", "shared-secret-value-0123456789abcdef").unwrap(); +//! let token = provider.get_token().unwrap(); //! //! // Verify the token (must be generated by a provider with same secret) //! let result = verifier.verify(&token).await; @@ -61,6 +62,7 @@ //! let verifier = Arc::new(AuthVerifier::shared_secret_from_str("id", "shared-secret-value-0123456789abcdef")?); //! //! // These can now be safely shared across threads +//! # Ok::<(), Box>(()) //! ``` use std::sync::Arc; @@ -98,10 +100,11 @@ use crate::traits::{TokenProvider, Verifier}; /// // Create from shared secret /// let provider = AuthProvider::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef")?; /// let token = provider.get_token()?; -/// // Token format: id:timestamp:nonce:mac (e.g., "my-service_ABC12345:1640995200:NONCEBASE64:MACBASE64") +/// // Token format: id:timestamp:nonce:claims:mac /// let parts: Vec<&str> = token.split(':').collect(); -/// assert_eq!(parts.len(), 4); +/// assert_eq!(parts.len(), 5); /// assert!(parts[0].starts_with("my-service_")); +/// # Ok::<(), Box>(()) /// ``` #[derive(Clone)] pub enum AuthProvider { @@ -142,8 +145,8 @@ pub enum AuthProvider { /// /// # tokio_test::block_on(async { /// // Create provider and verifier with same secret -/// let provider = AuthProvider::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef")?; -/// let verifier = AuthVerifier::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef")?; +/// let provider = AuthProvider::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef").unwrap(); +/// let verifier = AuthVerifier::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef").unwrap(); /// /// // Generate and verify a valid token /// let token = provider.get_token().unwrap(); @@ -269,6 +272,24 @@ impl TokenProvider for AuthProvider { AuthProvider::Spire(spire) => spire.rotate_signature_keys(), } } + + fn set_signature_keys( + &mut self, + private_key: Vec, + public_key: Vec, + ) -> Result<(), AuthError> { + match self { + AuthProvider::JwtSigner(signer) => signer.set_signature_keys(private_key, public_key), + AuthProvider::StaticToken(provider) => { + provider.set_signature_keys(private_key, public_key) + } + AuthProvider::SharedSecret(secret) => { + secret.set_signature_keys(private_key, public_key) + } + #[cfg(not(target_family = "windows"))] + AuthProvider::Spire(spire) => spire.set_signature_keys(private_key, public_key), + } + } } #[async_trait] @@ -369,9 +390,9 @@ impl AuthProvider { /// use slim_auth::auth_provider::AuthProvider; /// use slim_auth::shared_secret::SharedSecret; /// - /// let shared_secret = SharedSecret::generated(); - /// let secret = SharedSecret::new("service-id", shared_secret); + /// let secret = SharedSecret::new("service-id", "shared-secret-value-0123456789abcdef")?; /// let provider = AuthProvider::shared_secret(secret); + /// # Ok::<(), slim_auth::errors::AuthError>(()) /// ``` pub fn shared_secret(secret: SharedSecret) -> Self { AuthProvider::SharedSecret(secret) @@ -393,10 +414,11 @@ impl AuthProvider { /// /// let provider = AuthProvider::shared_secret_from_str("my-service", "shared-secret-value-0123456789abcdef")?; /// let token = provider.get_token()?; - /// // Token format: id:timestamp:nonce:mac (e.g., "my-service_ABC12345:1640995200:NONCEBASE64:MACBASE64") + /// // Token format: id:timestamp:nonce:claims:mac /// let parts: Vec<&str> = token.split(':').collect(); - /// assert_eq!(parts.len(), 4); + /// assert_eq!(parts.len(), 5); /// assert!(parts[0].starts_with("my-service_")); + /// # Ok::<(), Box>(()) /// ``` pub fn shared_secret_from_str( id: &str, @@ -485,6 +507,7 @@ impl AuthVerifier { /// let result = verifier.verify(&token).await; /// assert!(result.is_ok()); /// # }); + /// # Ok::<(), Box>(()) /// ``` pub fn shared_secret_from_str( id: &str, diff --git a/data-plane/core/auth/src/errors.rs b/data-plane/core/auth/src/errors.rs index 10946593a..1c146338f 100644 --- a/data-plane/core/auth/src/errors.rs +++ b/data-plane/core/auth/src/errors.rs @@ -2,9 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use http::StatusCode; +#[cfg(feature = "native")] use jsonwebtoken::jwk::KeyAlgorithm; -#[cfg(not(target_family = "windows"))] +#[cfg(all(feature = "native", not(target_family = "windows")))] use spiffe::{ JwtSourceError, JwtSvidError, SpiffeIdError, TrustDomain, WorkloadApiError, X509SourceError, }; @@ -14,6 +15,7 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum AuthError { // JWT errors + #[cfg(feature = "native")] #[error("unsupported key algorithm: {0}")] JwtUnsupportedKeyAlgorithm(KeyAlgorithm), #[error("JWK does not contain the key algorithm (alg) field")] @@ -62,6 +64,7 @@ pub enum AuthError { TimeError(#[from] std::time::SystemTimeError), // URL parsing + #[cfg(feature = "native")] #[error("URL parse error")] UrlParseError(#[from] url::ParseError), @@ -72,6 +75,7 @@ pub enum AuthError { HeaderValueError(#[from] http::header::InvalidHeaderValue), // File watcher + #[cfg(feature = "native")] #[error("file watcher error")] FileWatcherError(#[from] crate::file_watcher::FileWatcherError), @@ -86,12 +90,14 @@ pub enum AuthError { TokenInvalidMissingSub, #[error("token invalid: replay")] TokenInvalidReplay, + #[cfg(feature = "native")] #[error("token invalid")] JwtTokenInvalid(#[from] jsonwebtoken::errors::Error), #[error("token invalid - missing or invalid exp claim")] TokenInvalidMissingExp, // HTTP / networking + #[cfg(feature = "native")] #[error("HTTP request error")] HttpError(#[from] reqwest::Error), @@ -110,48 +116,65 @@ pub enum AuthError { // SPIFFE / SPIRE integration #[error("spire integration is not supported on Windows")] SpireUnsupportedOnWindows, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("serde error while encoding audience: {source}")] SpiffeCustomClaimsSerialize { source: serde_json::Error }, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spiffe error")] SpiffeError(#[from] SpiffeIdError), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spiffe grpc error")] SpiffeGrpcError(#[from] WorkloadApiError), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spiffe workload api unavailable")] SpiffeWorkloadApiUnavailable, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spiffe x509 source error")] SpiffeX509SourceError(#[from] X509SourceError), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spiffe jwt source error")] SpiffeJwtSourceError(#[from] JwtSourceError), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("jwt source not initialized")] SpiffeJwtSourceNotInitialized, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("missing jwt svid")] SpiffeJwtSvidMissing, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("missing jwt bundle")] SpiffeJwtBundleMissing, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("invalid JWT svid")] SpiffeInvalidJwtSvid(#[from] JwtSvidError), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("failed to fetch x509 SVID")] SpiffeX509SvidMissing, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("x509 source not initialized")] SpiffeX509SourceNotInitialized, - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("x509 trust bundle not available: {0}")] SpiffeX509BundleMissing(TrustDomain), - #[cfg(not(target_family = "windows"))] + #[cfg(all(feature = "native", not(target_family = "windows")))] + #[error("error fetching x509 SVID: {source}")] + SpiffeX509SvidFetch { + source: Box, + }, + #[cfg(all(feature = "native", not(target_family = "windows")))] + #[error("error fetching x509 trust bundle: {source}")] + SpiffeX509BundleFetch { + source: Box, + }, + #[cfg(all(feature = "native", not(target_family = "windows")))] #[error("spire x509 empty certificate chain")] SpiffeX509EmptyCertChain, + #[cfg(all(feature = "native", not(target_family = "windows")))] + #[error("jwt source closed")] + SpiffeCustomAudiencesJwtSourceClosed, + #[cfg(all(feature = "native", not(target_family = "windows")))] + #[error("error fetching jwt svid with custom audiences")] + SpiffeCustomAudiencesError, + // Serialization #[error("JSON serialization error")] JsonError(#[from] serde_json::Error), diff --git a/data-plane/core/auth/src/jwt.rs b/data-plane/core/auth/src/jwt.rs index c2b3cf9ff..eca672558 100644 --- a/data-plane/core/auth/src/jwt.rs +++ b/data-plane/core/auth/src/jwt.rs @@ -194,7 +194,9 @@ impl Jwt { /// Internal constructor used by the builder. /// /// This should not be called directly. Use the builder pattern instead: - /// ``` + /// ```rust,ignore + /// use slim_auth::jwt::Jwt; + /// /// let jwt = Jwt::builder() /// .issuer("my-issuer") /// .audience("my-audience") diff --git a/data-plane/core/auth/src/lib.rs b/data-plane/core/auth/src/lib.rs index 414416cee..81ba6740d 100644 --- a/data-plane/core/auth/src/lib.rs +++ b/data-plane/core/auth/src/lib.rs @@ -1,18 +1,26 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +#[cfg(feature = "native")] pub mod auth_provider; +#[cfg(feature = "native")] pub mod builder; pub mod errors; +#[cfg(feature = "native")] pub mod file_watcher; pub mod identity_claims; +#[cfg(feature = "native")] pub mod jwt; +#[cfg(feature = "native")] pub mod jwt_middleware; pub mod metadata; +#[cfg(feature = "native")] pub mod oidc; +#[cfg(feature = "native")] pub mod resolver; pub mod shared_secret; -#[cfg(not(target_family = "windows"))] +#[cfg(all(feature = "native", not(target_family = "windows")))] pub mod spire; pub mod traits; +#[cfg(feature = "native")] pub mod utils; diff --git a/data-plane/core/auth/src/resolver.rs b/data-plane/core/auth/src/resolver.rs index af8880d43..8bac15251 100644 --- a/data-plane/core/auth/src/resolver.rs +++ b/data-plane/core/auth/src/resolver.rs @@ -34,11 +34,14 @@ pub struct JwksCache { /// /// Example usage: /// -/// ``` +/// ```rust,ignore +/// use slim_auth::resolver::KeyResolver; +/// use std::time::Duration; +/// /// let resolver = KeyResolver::new() /// .with_jwks_ttl(Duration::from_secs(1800)); // 30 minute cache TTL /// -/// let jwt = Jwt::builder() +/// let jwt = slim_auth::builder::JwtBuilder::new() /// .issuer("https://your-oidc-provider.com") /// .key_resolver(resolver) /// .build()?; diff --git a/data-plane/core/auth/src/shared_secret.rs b/data-plane/core/auth/src/shared_secret.rs index 2d6127bf7..56cce6c43 100644 --- a/data-plane/core/auth/src/shared_secret.rs +++ b/data-plane/core/auth/src/shared_secret.rs @@ -22,7 +22,7 @@ SPDX-License-Identifier: Apache-2.0 //! Design notes: //! * `id` is randomized per construction (`_`). //! * Replay cache stores only (nonce, timestamp) for memory efficiency. -//! * HMAC via `aws-lc-rs` for constant-time primitives. +//! * HMAC-SHA256: on native, `aws-lc-rs`; on wasm, the `hmac` + `sha2` crates. //! * `SharedSecret` is cheap to clone (Arc increment) and cloning preserves //! replay cache state (when enabled). //! @@ -51,7 +51,6 @@ SPDX-License-Identifier: Apache-2.0 //! * All other fields are immutable after construction. use async_trait::async_trait; -use aws_lc_rs::hmac; use base64::Engine; use base64::engine::general_purpose::STANDARD as STANDARD_BASE64; use base64::engine::general_purpose::URL_SAFE_NO_PAD; @@ -60,15 +59,39 @@ use rand::{Rng, distr::Alphanumeric}; use std::{ collections::{HashSet, VecDeque}, sync::Arc, - time::{SystemTime, UNIX_EPOCH}, }; +#[cfg(feature = "native")] +use std::time::{SystemTime, UNIX_EPOCH}; + +#[cfg(feature = "native")] +use aws_lc_rs::hmac as aws_hmac; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +use hmac::{Hmac, Mac}; +#[cfg(all(feature = "wasm", not(feature = "native")))] +use sha2::Sha256; + use crate::{ errors::AuthError, traits::{TokenProvider, Verifier}, - utils::generate_mls_signature_keys, }; +#[cfg(feature = "native")] +use crate::utils::generate_mls_signature_keys; + +/// WASM fallback: generate random bytes as MLS signature key material. +/// The keys are used as opaque blobs embedded in shared-secret tokens, +/// not for actual MLS crypto operations in this context. +#[cfg(all(feature = "wasm", not(feature = "native")))] +fn generate_mls_signature_keys() -> Result<(Vec, Vec), AuthError> { + let mut secret = vec![0u8; 32]; + let mut public = vec![0u8; 32]; + rand::Fill::fill(&mut secret[..], &mut rand::rng()); + rand::Fill::fill(&mut public[..], &mut rand::rng()); + Ok((secret, public)) +} + /// Minimum length (in bytes) required for the shared secret (baseline 256 bits). const MIN_SECRET_LEN: usize = 32; /// Raw nonce byte length before base64url encoding. @@ -365,6 +388,7 @@ impl SharedSecret { Ok(()) } + #[cfg(feature = "native")] fn get_current_timestamp(&self) -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -372,24 +396,52 @@ impl SharedSecret { .as_secs() } + #[cfg(all(feature = "wasm", not(feature = "native")))] + fn get_current_timestamp(&self) -> u64 { + (js_sys::Date::now() / 1000.0) as u64 + } + + #[cfg(feature = "native")] fn create_hmac_raw(&self, message: &[u8]) -> Result, AuthError> { - let key = hmac::Key::new(hmac::HMAC_SHA256, self.inner.shared_secret.as_bytes()); - let tag = hmac::sign(&key, message); + let key = aws_hmac::Key::new(aws_hmac::HMAC_SHA256, self.inner.shared_secret.as_bytes()); + let tag = aws_hmac::sign(&key, message); Ok(tag.as_ref().to_vec()) } + #[cfg(all(feature = "wasm", not(feature = "native")))] + fn create_hmac_raw(&self, message: &[u8]) -> Result, AuthError> { + let mut mac = Hmac::::new_from_slice(self.inner.shared_secret.as_bytes()) + .map_err(|_e| AuthError::HmacKeyMissing)?; + mac.update(message); + Ok(mac.finalize().into_bytes().to_vec()) + } + fn create_hmac_b64(&self, message: &str) -> Result { let raw = self.create_hmac_raw(message.as_bytes())?; Ok(URL_SAFE_NO_PAD.encode(raw)) } + #[cfg(feature = "native")] fn verify_hmac(&self, message: &str, expected_b64: &str) -> Result<(), AuthError> { let expected = URL_SAFE_NO_PAD.decode(expected_b64.as_bytes())?; if expected.len() != 32 { return Err(AuthError::TokenMalformed); } - let key = hmac::Key::new(hmac::HMAC_SHA256, self.inner.shared_secret.as_bytes()); - hmac::verify(&key, message.as_bytes(), &expected).map_err(|_e| AuthError::TokenInvalid) + let key = aws_hmac::Key::new(aws_hmac::HMAC_SHA256, self.inner.shared_secret.as_bytes()); + aws_hmac::verify(&key, message.as_bytes(), &expected).map_err(|_e| AuthError::TokenInvalid) + } + + #[cfg(all(feature = "wasm", not(feature = "native")))] + fn verify_hmac(&self, message: &str, expected_b64: &str) -> Result<(), AuthError> { + let expected = URL_SAFE_NO_PAD.decode(expected_b64.as_bytes())?; + if expected.len() != 32 { + return Err(AuthError::TokenMalformed); + } + let mut mac = Hmac::::new_from_slice(self.inner.shared_secret.as_bytes()) + .map_err(|_e| AuthError::HmacKeyMissing)?; + mac.update(message.as_bytes()); + mac.verify_slice(&expected) + .map_err(|_e| AuthError::TokenInvalid) } fn build_message(&self, id: &str, timestamp: u64, nonce: &str, claims_b64: &str) -> String { @@ -447,7 +499,8 @@ impl SharedSecret { } } -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl TokenProvider for SharedSecret { async fn initialize(&mut self) -> Result<(), AuthError> { // SharedSecret has no async initialization steps. @@ -491,9 +544,19 @@ impl TokenProvider for SharedSecret { self.signature_keys = generate_mls_signature_keys()?; Ok(()) } + + fn set_signature_keys( + &mut self, + private_key: Vec, + public_key: Vec, + ) -> Result<(), AuthError> { + self.signature_keys = (private_key, public_key); + Ok(()) + } } -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl Verifier for SharedSecret { async fn initialize(&mut self) -> Result<(), AuthError> { Ok(()) diff --git a/data-plane/core/auth/src/spire.rs b/data-plane/core/auth/src/spire.rs index 569c28e9a..6cc3d5cb4 100644 --- a/data-plane/core/auth/src/spire.rs +++ b/data-plane/core/auth/src/spire.rs @@ -25,14 +25,13 @@ //! //! Basic usage: //! ```rust,no_run -//! use slim_auth::spire::{SpireIdentityManager, SpireConfig}; +//! use slim_auth::spire::SpireIdentityManager; +//! use slim_auth::traits::{TokenProvider, Verifier}; //! //! # async fn example() -> Result<(), Box> { -//! let mut mgr = SpireIdentityManager::new(SpireConfig { -//! socket_path: None, // Use SPIFFE_ENDPOINT_SOCKET env var -//! target_spiffe_id: None, // Optional: specify a target for JWT SVID -//! jwt_audiences: vec!["my-app".into()], -//! }); +//! let mut mgr = SpireIdentityManager::builder() +//! .with_jwt_audiences(vec!["my-app".into()]) +//! .build(); //! mgr.initialize().await?; //! //! // Obtain JWT token diff --git a/data-plane/core/auth/src/traits.rs b/data-plane/core/auth/src/traits.rs index 584a465bd..4e097c67e 100644 --- a/data-plane/core/auth/src/traits.rs +++ b/data-plane/core/auth/src/traits.rs @@ -63,7 +63,8 @@ impl StandardClaims { } /// Trait for verifying JWT tokens -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] pub trait Verifier { /// Initializes the verifier asynchronously. async fn initialize(&mut self) -> Result<(), AuthError>; @@ -104,7 +105,8 @@ pub trait Signer { } /// Trait for providing JWT claims -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] pub trait TokenProvider { /// Initializes the token provider asynchronously. /// Usage notes: @@ -141,4 +143,15 @@ pub trait TokenProvider { fn rotate_signature_keys(&mut self) -> Result<(), AuthError> { Err(AuthError::MlsNotSupported) } + + /// Replace the MLS signature key pair with externally-generated keys. + /// Used by WASM builds where keys must be generated by the MLS crypto + /// provider (WebCrypto) rather than the identity provider. + fn set_signature_keys( + &mut self, + _private_key: Vec, + _public_key: Vec, + ) -> Result<(), AuthError> { + Err(AuthError::MlsNotSupported) + } } diff --git a/data-plane/core/auth/tests/spiffe_integration_test.rs b/data-plane/core/auth/tests/spiffe_integration_test.rs index ffcf8c7b3..a4eebfff0 100644 --- a/data-plane/core/auth/tests/spiffe_integration_test.rs +++ b/data-plane/core/auth/tests/spiffe_integration_test.rs @@ -545,7 +545,7 @@ async fn test_spiffe_grpc_jwt_auth() { client_cfg.tls_setting.insecure = true; client_cfg.auth = slim_config::grpc::client::AuthenticationConfig::Spire(spire_cfg.clone()); - let channel = match client_cfg.to_channel().await { + let channel = match client_cfg.to_grpc_channel().await { Ok(c) => c, Err(e) => { tracing::error!(error = %e, "Failed to create authenticated channel"); @@ -582,7 +582,7 @@ async fn test_spiffe_grpc_jwt_auth() { ); unauthenticated_cfg.tls_setting.insecure = true; - let unauth_channel = match unauthenticated_cfg.to_channel().await { + let unauth_channel = match unauthenticated_cfg.to_grpc_channel().await { Ok(c) => c, Err(e) => { tracing::error!(error = %e, "Failed to create unauthenticated channel"); diff --git a/data-plane/core/config/Cargo.toml b/data-plane/core/config/Cargo.toml index 3772f81d2..d166edaad 100644 --- a/data-plane/core/config/Cargo.toml +++ b/data-plane/core/config/Cargo.toml @@ -5,48 +5,98 @@ edition = { workspace = true } license = { workspace = true } description = "Configuration utilities" +[package.metadata.cargo-machete] +ignored = ["getrandom"] + [lib] name = "slim_config" [[bin]] name = "generate-schema" -path = "src/grpc/schema/generate_schema.rs" +path = "src/schema/generate_schema.rs" + +[features] +default = ["native"] +native = [ + "dep:agntcy-slim-auth", + "dep:bytes", + "dep:display-error-chain", + "dep:drain", + "dep:fastwebsockets", + "dep:futures", + "dep:http-body-util", + "dep:hyper", + "hyper-util/client-proxy-system", + "dep:hyper-rustls", + "dep:hyper-util", + "dep:parking_lot", + "dep:prost", + "dep:rustls", + "dep:rustls-native-certs", + "dep:rustls-pki-types", + "dep:tokio", + "tokio/io-util", + "tokio/macros", + "tokio/net", + "tokio/rt", + "tokio/rt-multi-thread", + "tokio/sync", + "tokio/time", + "dep:tokio-retry", + "dep:tokio-rustls", + "dep:tokio-stream", + "dep:tokio-util", + "dep:tonic", + "dep:tonic-prost", + "dep:tonic-tls", + "dep:tower-layer", + "dep:tower-service", +] +wasm = ["dep:getrandom", "dep:gloo-net", "uuid/js"] [dependencies] -agntcy-slim-auth = { workspace = true } +agntcy-slim-auth = { workspace = true, optional = true, features = ["native"] } agntcy-slim-version = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } -display-error-chain = { workspace = true } -drain = { workspace = true } +bytes = { workspace = true, optional = true } +display-error-chain = { workspace = true, optional = true } +drain = { workspace = true, optional = true } duration-string = { workspace = true } -futures = { workspace = true } +fastwebsockets = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } +gloo-net = { version = "0.6", optional = true } http = { workspace = true } -hyper-rustls = { workspace = true } -hyper-util = { workspace = true, features = ["client-proxy-system"] } +http-body-util = { workspace = true, optional = true } +hyper = { workspace = true, optional = true } +hyper-rustls = { workspace = true, optional = true } +hyper-util = { workspace = true, optional = true } lazy_static = { workspace = true } -parking_lot = { workspace = true } -prost = { workspace = true } +parking_lot = { workspace = true, optional = true } +prost = { workspace = true, optional = true } +rand = { workspace = true } regex = { workspace = true } -rustls = { workspace = true } -rustls-native-certs = { workspace = true } -rustls-pki-types = { workspace = true } +rustls = { workspace = true, optional = true } +rustls-native-certs = { workspace = true, optional = true } +rustls-pki-types = { workspace = true, optional = true } schemars = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yaml = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } -tokio-retry = { workspace = true } -tokio-stream = { workspace = true } -tokio-util = { workspace = true } -tonic = { workspace = true } -tonic-prost = { workspace = true } -tonic-tls = { workspace = true } +tokio = { workspace = true, optional = true } +tokio-retry = { workspace = true, optional = true } +tokio-rustls = { workspace = true, optional = true } +tokio-stream = { workspace = true, optional = true } +tokio-util = { workspace = true, optional = true } +tonic = { workspace = true, optional = true } +tonic-prost = { workspace = true, optional = true } +tonic-tls = { workspace = true, optional = true } tower = { workspace = true } tower-http = { workspace = true } -tower-layer = { workspace = true } -tower-service = { workspace = true } +tower-layer = { workspace = true, optional = true } +tower-service = { workspace = true, optional = true } tracing = { workspace = true } tracing-test = { workspace = true } uuid = { workspace = true, features = ["v4"] } @@ -57,7 +107,5 @@ tonic-prost-build = { workspace = true } [dev-dependencies] agntcy-slim-testing = { workspace = true } -rand = { workspace = true } serde_json = { workspace = true } tracing-test = { workspace = true } - diff --git a/data-plane/core/config/src/auth.rs b/data-plane/core/config/src/auth.rs index ea2c09ba5..5510fe01e 100644 --- a/data-plane/core/config/src/auth.rs +++ b/data-plane/core/config/src/auth.rs @@ -2,13 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 pub mod basic; +#[cfg(feature = "native")] pub mod identity; +#[cfg(feature = "native")] pub mod jwt; +#[cfg(feature = "native")] pub mod oidc; -#[cfg(not(target_family = "windows"))] +#[cfg(all(not(target_family = "windows"), feature = "native"))] pub mod spire; +#[cfg(feature = "native")] pub mod static_jwt; +#[cfg(feature = "native")] use slim_auth::errors::AuthError as SlimAuthError; use thiserror::Error; @@ -27,6 +32,7 @@ pub enum ConfigAuthError { AuthOidcEmptyClientSecret, // Propagated auth library errors + #[cfg(feature = "native")] #[error("internal auth error")] AuthInternalError(#[from] SlimAuthError), diff --git a/data-plane/core/config/src/backoff/exponential.rs b/data-plane/core/config/src/backoff/exponential.rs index eae2513de..f3fd12f2a 100644 --- a/data-plane/core/config/src/backoff/exponential.rs +++ b/data-plane/core/config/src/backoff/exponential.rs @@ -5,7 +5,6 @@ use duration_string::DurationString; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::time::Duration; -use tokio_retry::strategy::{ExponentialBackoff, jitter}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] #[serde(default)] @@ -52,12 +51,33 @@ impl Default for Config { impl Strategy for Config { fn get_strategy(&self) -> Box + Send> { - let ret = ExponentialBackoff::from_millis(self.base) - .factor(self.factor) - .max_delay(self.max_delay.into()) - .take(self.max_attempts); - let jitter_flag = self.jitter; + let base = self.base; + let factor = self.factor; + let max_delay: Duration = self.max_delay.into(); + let max_attempts = self.max_attempts; + let jitter = self.jitter; - Box::new(ret.map(move |d| if jitter_flag { jitter(d) } else { d })) + Box::new((0..max_attempts).scan(base, move |current, _| { + let delay_ms = (*current).min(max_delay.as_millis() as u64); + let next = if factor == 0 { + 0 + } else { + current.saturating_mul(factor) + }; + *current = next.max(base); + + let delay = Duration::from_millis(delay_ms); + Some(if jitter { apply_jitter(delay) } else { delay }) + })) } } + +fn apply_jitter(delay: Duration) -> Duration { + if delay.is_zero() { + return delay; + } + + let cap = delay.as_millis() as u64; + let jitter: u64 = rand::random::() % (cap + 1); + Duration::from_millis(jitter) +} diff --git a/data-plane/core/config/src/backoff/fixedinterval.rs b/data-plane/core/config/src/backoff/fixedinterval.rs index 7f73836c6..34f9b3616 100644 --- a/data-plane/core/config/src/backoff/fixedinterval.rs +++ b/data-plane/core/config/src/backoff/fixedinterval.rs @@ -4,7 +4,6 @@ use duration_string::DurationString; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::time::Duration; -use tokio_retry::strategy::FixedInterval; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] #[serde(default)] @@ -35,6 +34,7 @@ impl Default for Config { impl Strategy for Config { fn get_strategy(&self) -> Box + Send> { - Box::new(FixedInterval::new(self.interval.into()).take(self.max_attempts)) + let interval: Duration = self.interval.into(); + Box::new(std::iter::repeat_n(interval, self.max_attempts)) } } diff --git a/data-plane/core/config/src/client.rs b/data-plane/core/config/src/client.rs new file mode 100644 index 000000000..a3c55f0a3 --- /dev/null +++ b/data-plane/core/config/src/client.rs @@ -0,0 +1,1683 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use duration_string::DurationString; +#[cfg(feature = "native")] +use rustls_pki_types::ServerName; +#[cfg(feature = "native")] +use tokio_retry::RetryIf; + +#[cfg(feature = "native")] +use display_error_chain::ErrorChainExt; +#[cfg(any(feature = "native", feature = "wasm"))] +use std::str::FromStr; +use std::{collections::HashMap, time::Duration}; +#[cfg(feature = "native")] +use tower::ServiceExt; +#[cfg(all(feature = "native", target_family = "unix"))] +use { + hyper_util::rt::TokioIo, + std::{error::Error as StdErrorTrait, path::PathBuf, sync::Arc}, + tokio::net::UnixStream, + tower::service_fn, +}; + +#[cfg(feature = "native")] +use base64::prelude::*; +#[cfg(feature = "native")] +use http::header::{HeaderMap, HeaderName, HeaderValue}; +#[cfg(feature = "native")] +use hyper_rustls; +#[cfg(feature = "native")] +use hyper_util::client::legacy::connect::HttpConnector; +#[cfg(feature = "native")] +use hyper_util::client::legacy::connect::proxy::Tunnel; +#[cfg(feature = "native")] +use hyper_util::client::proxy::matcher::Intercept; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "native")] +use tonic::codegen::{Body, Bytes, StdError}; +#[cfg(feature = "native")] +use tonic::transport::{Channel, Uri}; +#[cfg(feature = "native")] +use tracing::warn; + +#[cfg(feature = "native")] +use slim_auth::metadata::MetadataMap; +#[cfg(not(feature = "native"))] +type MetadataMap = HashMap; + +#[cfg(feature = "native")] +use crate::auth::ClientAuthenticator; +use crate::auth::basic::Config as BasicAuthenticationConfig; +#[cfg(feature = "native")] +use crate::auth::jwt::Config as JwtAuthenticationConfig; +#[cfg(all(feature = "native", not(target_family = "windows")))] +use crate::auth::spire::SpireConfig as SpireAuthConfig; +#[cfg(feature = "native")] +use crate::auth::static_jwt::Config as BearerAuthenticationConfig; +use crate::backoff::Strategy; +use crate::backoff::exponential::Config as ExponentialBackoff; +use crate::backoff::fixedinterval::Config as FixedIntervalBackoff; +use crate::component::configuration::Configuration; +use crate::grpc::compression::CompressionType; +use crate::grpc::errors::ConfigError; +#[cfg(feature = "native")] +use crate::grpc::headers_middleware::SetRequestHeaderLayer; +use crate::grpc::proxy::ProxyConfig; +use crate::tls::client::TlsClientConfig as TLSSetting; +#[cfg(feature = "native")] +use crate::tls::common::RustlsConfigLoader; +use crate::transport::TransportProtocol; +#[cfg(any(feature = "native", feature = "wasm"))] +use crate::websocket::client::WebSocketClientChannel; + +/// Creates an HTTPS connector with optional SNI based on the origin +#[cfg(feature = "native")] +fn https_connector( + s: S, + tls: &rustls::ClientConfig, + server_name: Option, +) -> hyper_rustls::HttpsConnector +where + S: tower::Service, +{ + let tls = tls.clone(); + let mut builder = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(tls) + .https_or_http(); + + if let Some(origin_str) = server_name { + builder = + builder.with_server_name_resolver(move |_: &_| ServerName::try_from(origin_str.clone())) + } + + builder.enable_http2().wrap_connector(s) +} + +/// Macro to create TLS-enabled or plain connectors based on TLS configuration, +/// applying the optional origin (for SNI) when TLS is enabled. +/// Supports both lazy and eager connection modes. +#[cfg(feature = "native")] +macro_rules! create_connector { + ($builder:expr, $base_connector:expr, $tls_config:expr, $server_name:expr, $lazy:expr) => { + match ($tls_config, $lazy) { + (Some(tls), true) => { + let connector = tower::ServiceBuilder::new() + .layer_fn(move |s| { + https_connector(s, &tls, $server_name.map(|s| s.to_string())) + }) + .service($base_connector); + Ok($builder.connect_with_connector_lazy(connector)) + } + (Some(tls), false) => { + let connector = tower::ServiceBuilder::new() + .layer_fn(move |s| { + https_connector(s, &tls, $server_name.map(|s| s.to_string())) + }) + .service($base_connector); + let ret = $builder.connect_with_connector(connector).await?; + Ok(ret) + } + (None, true) => Ok($builder.connect_with_connector_lazy($base_connector)), + (None, false) => { + let ret = $builder.connect_with_connector($base_connector).await?; + Ok(ret) + } + } + }; +} + +/// Macro to create authenticated service layers for auth types that don't need initialization +#[cfg(feature = "native")] +macro_rules! create_auth_service_no_init { + ($self:expr, $auth_config:expr, $header_map:expr, $channel:expr) => {{ + let auth_layer = $auth_config.get_client_layer()?; + + $self.warn_insecure_auth(); + + Ok(tower::ServiceBuilder::new() + .layer(SetRequestHeaderLayer::new($header_map)) + .layer(auth_layer) + .service($channel) + .boxed_clone()) + }}; +} + +/// Macro to create authenticated service layers for auth types that need initialization +#[cfg(feature = "native")] +macro_rules! create_auth_service_with_init { + ($self:expr, $auth_config:expr, $header_map:expr, $channel:expr) => {{ + let mut auth_layer = $auth_config.get_client_layer()?; + + // Initialize the auth layer + auth_layer.initialize().await?; + + $self.warn_insecure_auth(); + + Ok(tower::ServiceBuilder::new() + .layer(SetRequestHeaderLayer::new($header_map)) + .layer(auth_layer) + .service($channel) + .boxed_clone()) + }}; +} + +/// Enum to handle all connection types: direct connections and proxy tunnels +#[cfg(feature = "native")] +enum ConnectionType { + /// Direct HTTP connection without proxy + Direct(HttpConnector), + /// HTTP proxy tunnel connection + ProxyHttp(Tunnel), + /// HTTPS proxy tunnel connection + ProxyHttps(Tunnel>), +} + +#[cfg(feature = "native")] +pub enum TransportChannel { + #[cfg(feature = "native")] + Grpc(G), + #[cfg(any(feature = "native", feature = "wasm"))] + Websocket(Box), +} + +#[cfg(not(feature = "native"))] +pub enum TransportChannel { + #[cfg(any(feature = "native", feature = "wasm"))] + Websocket(Box), +} + +/// Keepalive configuration for the client. +/// This struct contains the keepalive time for TCP and HTTP2, +/// the timeout duration for the keepalive, and whether to permit +/// keepalive without an active stream. +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] +pub struct KeepaliveConfig { + /// The duration of the keepalive time for TCP + #[serde(default = "default_tcp_keepalive")] + #[schemars(with = "String")] + pub tcp_keepalive: DurationString, + + /// The duration of the keepalive time for HTTP2 + #[serde(default = "default_http2_keepalive")] + #[schemars(with = "String")] + pub http2_keepalive: DurationString, + + /// The timeout duration for the keepalive + #[serde(default = "default_timeout")] + #[schemars(with = "String")] + pub timeout: DurationString, + + /// Whether to permit keepalive without an active stream + #[serde(default = "default_keep_alive_while_idle")] + pub keep_alive_while_idle: bool, +} + +/// Defaults for KeepaliveConfig +impl Default for KeepaliveConfig { + fn default() -> Self { + KeepaliveConfig { + tcp_keepalive: default_tcp_keepalive(), + http2_keepalive: default_http2_keepalive(), + timeout: default_timeout(), + keep_alive_while_idle: default_keep_alive_while_idle(), + } + } +} + +fn default_tcp_keepalive() -> DurationString { + Duration::from_secs(60).into() +} + +fn default_http2_keepalive() -> DurationString { + Duration::from_secs(60).into() +} + +fn default_timeout() -> DurationString { + Duration::from_secs(10).into() +} + +fn default_keep_alive_while_idle() -> bool { + false +} + +/// Enum holding one configuration for the client. +#[derive(Debug, Serialize, Default, Deserialize, Clone, PartialEq, JsonSchema)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum AuthenticationConfig { + /// Basic authentication configuration. + Basic(BasicAuthenticationConfig), + /// Bearer authentication configuration. + #[cfg(feature = "native")] + StaticJwt(BearerAuthenticationConfig), + /// JWT authentication configuration. + #[cfg(feature = "native")] + Jwt(JwtAuthenticationConfig), + /// SPIRE/SPIFFE authentication configuration. + #[cfg(all(feature = "native", not(target_family = "windows")))] + Spire(SpireAuthConfig), + /// None + #[default] + None, +} + +/// Enum holding one configuration for the client. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum BackoffConfig { + // Exponential backoff retry config. + Exponential(ExponentialBackoff), + /// FixedInterval backoff retry config. + FixedInterval(FixedIntervalBackoff), +} + +impl BackoffConfig { + /// Creates a new Exponential backoff configuration + pub fn new_exponential( + base: u64, + factor: u64, + max_delay: Duration, + max_attempts: usize, + jitter: bool, + ) -> Self { + BackoffConfig::Exponential(ExponentialBackoff::new( + base, + factor, + max_delay, + max_attempts, + jitter, + )) + } + + /// Creates a new FixedInterval backoff configuration + pub fn new_fixed_interval(interval: Duration, max_attempts: usize) -> Self { + BackoffConfig::FixedInterval(FixedIntervalBackoff::new(interval, max_attempts)) + } +} + +impl Default for BackoffConfig { + fn default() -> Self { + BackoffConfig::Exponential(ExponentialBackoff::default()) + } +} + +impl Strategy for BackoffConfig { + fn get_strategy(&self) -> Box + Send> { + match self { + BackoffConfig::Exponential(b) => b.get_strategy(), + BackoffConfig::FixedInterval(b) => b.get_strategy(), + } + } +} + +/// Struct for the client configuration. +/// This struct contains the endpoint, origin, compression type, rate limit, +/// TLS settings, keepalive settings, proxy settings, timeout settings, buffer size settings, +/// headers, and auth settings. +/// The client configuration can be converted to a tonic channel. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] +pub struct ClientConfig { + /// The target the client will connect to. + pub endpoint: String, + + /// Transport protocol to use for dataplane communication. + #[serde(default)] + pub transport: TransportProtocol, + + /// Optional websocket authentication query parameter key. + /// This is only used when `transport=websocket`. + pub websocket_auth_query_param: Option, + + /// Origin (HTTP Host authority override) for the client. + pub origin: Option, + + /// Optional TLS SNI server name override. If set, this value is used for TLS + /// server name verification (SNI) instead of the host extracted from endpoint/origin. + pub server_name: Option, + + /// Compression type - TODO(msardara): not implemented yet. + pub compression: Option, + + /// Rate Limits + pub rate_limit: Option, + + /// TLS client configuration. + #[serde(default, rename = "tls")] + pub tls_setting: TLSSetting, + + /// Keepalive parameters. + pub keepalive: Option, + + /// HTTP Proxy configuration. + #[serde(default)] + pub proxy: ProxyConfig, + + /// Timeout for the connection. + #[serde(default = "default_connect_timeout")] + #[schemars(with = "String")] + pub connect_timeout: DurationString, + + /// Timeout per request. + #[serde(default = "default_request_timeout")] + #[schemars(with = "String")] + pub request_timeout: DurationString, + + /// ReadBufferSize. + pub buffer_size: Option, + + /// The headers associated with gRPC requests. + #[serde(default)] + pub headers: HashMap, + + /// Auth configuration for outgoing RPCs. + #[serde(default)] + pub auth: AuthenticationConfig, + + /// Backoff retry configuration. + #[serde(default)] + pub backoff: BackoffConfig, + + /// Arbitrary user-provided metadata. + pub metadata: Option, + + /// Link identifier for this connection, used during link negotiation. + /// Must be a valid UUID v4. Defaults to a randomly generated UUID v4. + #[serde(default = "default_link_id")] + pub link_id: String, +} + +/// Defaults for ClientConfig +impl Default for ClientConfig { + fn default() -> Self { + ClientConfig { + endpoint: String::new(), + transport: TransportProtocol::default(), + websocket_auth_query_param: None, + origin: None, + server_name: None, + compression: None, + rate_limit: None, + tls_setting: TLSSetting::default(), + keepalive: None, + proxy: ProxyConfig::default(), + connect_timeout: default_connect_timeout(), + request_timeout: default_request_timeout(), + buffer_size: None, + headers: HashMap::new(), + auth: AuthenticationConfig::None, + backoff: BackoffConfig::default(), + metadata: None, + link_id: default_link_id(), + } + } +} + +fn default_link_id() -> String { + uuid::Uuid::new_v4().to_string() +} + +fn default_connect_timeout() -> DurationString { + Duration::from_secs(0).into() +} + +fn default_request_timeout() -> DurationString { + Duration::from_secs(0).into() +} + +// Display for ClientConfig +impl std::fmt::Display for ClientConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ClientConfig {{ endpoint: {}, transport: {:?}, websocket_auth_query_param: {:?}, origin: {:?}, server_name: {:?}, compression: {:?}, rate_limit: {:?}, tls_setting: {:?}, keepalive: {:?}, proxy: {:?}, connect_timeout: {:?}, request_timeout: {:?}, buffer_size: {:?}, headers: {:?}, auth: {:?}, backoff: {:?}, metadata: {:?}, link_id: {:?} }}", + self.endpoint, + self.transport, + self.websocket_auth_query_param, + self.origin, + self.server_name, + self.compression, + self.rate_limit, + self.tls_setting, + self.keepalive, + self.proxy, + self.connect_timeout, + self.request_timeout, + self.buffer_size, + self.headers, + self.auth, + self.backoff, + self.metadata, + self.link_id + ) + } +} + +pub fn is_valid_uuid_v4(s: &str) -> bool { + match uuid::Uuid::parse_str(s) { + Ok(id) => id.get_version() == Some(uuid::Version::Random), + Err(_) => false, + } +} + +impl Configuration for ClientConfig { + type Error = ConfigError; + + fn validate(&self) -> Result<(), Self::Error> { + if self.endpoint.is_empty() { + return Err(ConfigError::MissingEndpoint); + } + + // Validate link_id is a UUID v4 + if !is_valid_uuid_v4(&self.link_id) { + return Err(ConfigError::InvalidLinkId); + } + + // Validate the client configuration + self.tls_setting.validate()?; + self.validate_websocket_endpoint()?; + + Ok(()) + } +} + +/// Implements configuration methods for creating and connecting gRPC/WebSocket clients. +/// +/// This impl block provides a builder pattern API for `ClientConfig`, allowing fluent +/// configuration of client connections. It handles multiple transport protocols (gRPC over +/// TCP/Unix sockets, WebSocket), TLS/security, proxies, authentication, and advanced +/// networking options like keepalive, rate limiting, and custom headers. +/// +/// # Builder Methods +/// +/// Methods prefixed with `with_` follow the builder pattern: +/// - `with_endpoint()`: Sets the target server endpoint (required) +/// - `with_transport()`: Selects gRPC or WebSocket protocol +/// - `with_origin()`, `with_server_name()`: Sets HTTP/TLS metadata +/// - `with_tls_setting()`, `with_auth()`: Configures security +/// - `with_keepalive()`, `with_connect_timeout()`: Network tuning +/// - `with_proxy()`, `with_rate_limit()`: Advanced networking options +/// +/// # Channel Creation +/// +/// - `to_channel()`: Main entry point for creating a transport channel +/// - `to_channel_internal()`: Internal implementation supporting lazy connection +/// - Methods like `connect_tcp_channel()`, `connect_unix_channel()`: Protocol-specific connections +/// - Automatic retry with exponential backoff on transport errors +/// +/// # Connection Handling +/// +/// Supports direct connections, HTTP proxies, and HTTPS proxies with configurable: +/// - TLS/mTLS for both client-to-server and client-to-proxy paths +/// - Proxy authentication and custom proxy headers +/// - Unix domain socket connections (Unix-only systems) +/// - Lazy vs eager channel initialization +/// +/// # Validation & Parsing +/// +/// - `validate_endpoint()`: Ensures endpoint is set and non-empty +/// - `parse_endpoint_uri()`: Converts endpoint strings to URIs (handles special unix:// scheme) +/// - `parse_headers()`: Converts custom headers to HTTP HeaderMap +/// +/// # Internal Helpers +/// +/// - `create_channel_builder()`: Applies all channel-level settings (buffers, keepalive, rate limits) +/// - `create_http_connector()`: Configures TCP/TLS connector with timeouts +/// - `apply_tunnel_config()`: Sets up proxy authentication and headers +/// - `apply_auth_and_headers()`: Wraps channel with authentication layers +impl ClientConfig { + /// Creates a new client configuration with the given endpoint. + /// This function will return a ClientConfig with the endpoint set + /// and all other fields set to default. + pub fn with_endpoint(endpoint: &str) -> Self { + Self { + endpoint: endpoint.to_string(), + ..Self::default() + } + } + + pub fn with_origin(self, origin: &str) -> Self { + Self { + origin: Some(origin.to_string()), + ..self + } + } + + pub fn with_transport(self, transport: TransportProtocol) -> Self { + Self { transport, ..self } + } + + pub fn with_websocket_auth_query_param(self, query_param: &str) -> Self { + Self { + websocket_auth_query_param: Some(query_param.to_string()), + ..self + } + } + + pub fn with_server_name(self, server_name: &str) -> Self { + Self { + server_name: Some(server_name.to_string()), + ..self + } + } + + pub fn with_compression(self, compression: CompressionType) -> Self { + Self { + compression: Some(compression), + ..self + } + } + + pub fn with_rate_limit(self, rate_limit: &str) -> Self { + Self { + rate_limit: Some(rate_limit.to_string()), + ..self + } + } + + pub fn with_tls_setting(self, tls_setting: TLSSetting) -> Self { + Self { + tls_setting, + ..self + } + } + + pub fn with_keepalive(self, keepalive: KeepaliveConfig) -> Self { + Self { + keepalive: Some(keepalive), + ..self + } + } + + pub fn with_proxy(self, proxy: ProxyConfig) -> Self { + Self { proxy, ..self } + } + + pub fn with_connect_timeout(self, connect_timeout: Duration) -> Self { + Self { + connect_timeout: connect_timeout.into(), + ..self + } + } + + pub fn with_request_timeout(self, request_timeout: Duration) -> Self { + Self { + request_timeout: request_timeout.into(), + ..self + } + } + + pub fn with_buffer_size(self, buffer_size: usize) -> Self { + Self { + buffer_size: Some(buffer_size), + ..self + } + } + + pub fn with_headers(self, headers: HashMap) -> Self { + Self { headers, ..self } + } + + pub fn with_auth(self, auth: AuthenticationConfig) -> Self { + Self { auth, ..self } + } + + pub fn with_backoff(self, backoff: BackoffConfig) -> Self { + Self { backoff, ..self } + } + + pub fn with_metadata(self, metadata: MetadataMap) -> Self { + Self { + metadata: Some(metadata), + ..self + } + } + + /// Converts the client configuration to a transport channel. + /// Returns either a gRPC channel or websocket channel based on `transport`. + #[cfg(feature = "native")] + pub async fn to_channel( + &self, + ) -> Result< + TransportChannel< + impl tonic::client::GrpcService< + tonic::body::Body, + Error: Into + Send, + ResponseBody: Body + std::marker::Send> + + Send + + 'static, + Future: Send, + > + + Send + + Clone + + 'static + + use<>, + >, + ConfigError, + > { + match self.transport { + TransportProtocol::Grpc => Ok(TransportChannel::Grpc(self.to_grpc_channel().await?)), + TransportProtocol::Websocket => { + #[cfg(any(feature = "native", feature = "wasm"))] + { + return Ok(TransportChannel::Websocket(Box::new( + self.to_websocket_channel().await?, + ))); + } + + #[cfg(not(any(feature = "native", feature = "wasm")))] + { + return Err(ConfigError::WebSocketFeatureDisabled); + } + } + } + } + + /// Converts the client configuration to a transport channel. + /// When `grpc` feature is disabled only websocket transport is available. + #[cfg(not(feature = "native"))] + pub async fn to_channel(&self) -> Result { + match self.transport { + TransportProtocol::Grpc => Err(ConfigError::GrpcFeatureDisabled), + TransportProtocol::Websocket => { + #[cfg(any(feature = "native", feature = "wasm"))] + { + return Ok(TransportChannel::Websocket(Box::new( + self.to_websocket_channel().await?, + ))); + } + + #[cfg(not(any(feature = "native", feature = "wasm")))] + { + return Err(ConfigError::WebSocketFeatureDisabled); + } + } + } + } + + /// Internal implementation for channel creation with optional lazy flag. + #[cfg(feature = "native")] + pub(crate) async fn to_channel_internal( + &self, + lazy: bool, + ) -> Result< + impl tonic::client::GrpcService< + tonic::body::Body, + Error: Into + Send, + ResponseBody: Body + std::marker::Send> + + Send + + 'static, + Future: Send, + > + + Send + + Clone + + 'static + + use<>, + ConfigError, + > { + if self.transport == TransportProtocol::Websocket { + return Err(ConfigError::GrpcChannelUnsupportedTransport); + } + + // Validate endpoint + self.validate_endpoint()?; + + // Parse headers + let header_map = self.parse_headers()?; + + let uri = self.parse_endpoint_uri()?; + + let channel = if uri.scheme_str() == Some("unix") { + self.connect_unix_channel(uri, lazy).await? + } else if uri.scheme_str() == Some("http") || uri.scheme_str() == Some("https") { + self.connect_tcp_channel(uri, lazy).await? + } else { + return Err(ConfigError::InvalidEndpointScheme); + }; + + // Apply authentication and headers + self.apply_auth_and_headers(channel, header_map).await + } + + /// Validates that the endpoint is set and not empty + #[cfg(feature = "native")] + fn validate_endpoint(&self) -> Result<(), ConfigError> { + if self.endpoint.is_empty() { + return Err(ConfigError::MissingEndpoint); + } + Ok(()) + } + + fn validate_websocket_endpoint(&self) -> Result<(), ConfigError> { + if self.transport != TransportProtocol::Websocket { + return Ok(()); + } + + #[cfg(any(feature = "native", feature = "wasm"))] + { + let endpoint = http::Uri::from_str(self.endpoint.as_str())?; + match endpoint.scheme_str() { + Some("ws") | Some("wss") => Ok(()), + _ => Err(ConfigError::InvalidWebSocketEndpointScheme), + } + } + + #[cfg(not(any(feature = "native", feature = "wasm")))] + { + Err(ConfigError::WebSocketFeatureDisabled) + } + } + + /// Parses the endpoint string into a URI for TCP/HTTP, Unix domain socket endpoints. + #[cfg(feature = "native")] + fn parse_endpoint_uri(&self) -> Result { + // Special case for the unix scheme because it doesn't have an + // authority in the URI and the Uri parser doesn't like this today, + // so we build our own URI with a fake localhost authority. + if self.endpoint.starts_with("unix://") { + let path = &self.endpoint[7..]; + if path.is_empty() { + return Err(ConfigError::UnixSocketMissingPath); + } + + let uri = Uri::builder() + .scheme("unix") + .authority("localhost") + .path_and_query(path) + .build() + .map_err(ConfigError::UnixSocketInvalidPath)?; + return Ok(uri); + } + Ok(Uri::from_str(&self.endpoint)?) + } + + /// Creates and configures the HTTP connector + #[cfg(feature = "native")] + fn create_http_connector(&self) -> Result { + let mut http = HttpConnector::new(); + + // NOTE(msardara): we might want to make these configurable as well. + http.enforce_http(false); + http.set_nodelay(false); + + // set the connection timeout + match self.connect_timeout.as_secs() { + 0 => http.set_connect_timeout(None), + _ => http.set_connect_timeout(Some(self.connect_timeout.into())), + } + + // set keepalive settings + if let Some(keepalive) = &self.keepalive { + http.set_keepalive(Some(keepalive.tcp_keepalive.into())); + } + + Ok(http) + } + + /// Creates the channel builder with all configuration settings + #[cfg(feature = "native")] + fn create_channel_builder(&self, uri: Uri) -> Result { + let mut builder = Channel::builder(uri); + + // set the buffer size + if let Some(size) = self.buffer_size { + builder = builder.buffer_size(size); + } + + // set keepalive settings + if let Some(keepalive) = &self.keepalive { + builder = builder + .keep_alive_timeout(keepalive.timeout.into()) + .keep_alive_while_idle(keepalive.keep_alive_while_idle) + // HTTP level keepalive + .http2_keep_alive_interval(keepalive.http2_keepalive.into()); + } + + // set origin settings + if let Some(origin) = &self.origin { + let origin_uri = Uri::from_str(origin.as_str())?; + builder = builder.origin(origin_uri); + } + + // set rate limit settings + if let Some(rate_limit) = &self.rate_limit { + let (limit, duration) = parse_rate_limit(rate_limit)?; + builder = builder.rate_limit(limit, duration); + } + + // set the request timeout + if self.request_timeout.as_secs() > 0 { + builder = builder.timeout(self.request_timeout.into()); + } + + if self.connect_timeout.as_secs() > 0 { + builder = builder.connect_timeout(self.connect_timeout.into()); + } + + Ok(builder) + } + + /// Parses headers from the configuration + #[cfg(feature = "native")] + fn parse_headers(&self) -> Result { + Self::parse_header_map(&self.headers) + } + + /// Generic helper to parse a HashMap into HeaderMap + #[cfg(feature = "native")] + fn parse_header_map(headers: &HashMap) -> Result { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + let header_name = HeaderName::from_str(key)?; + let header_value = HeaderValue::from_str(value)?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } + + #[cfg(all(feature = "native", target_family = "unix"))] + fn map_transport_error(err: tonic::transport::Error) -> ConfigError { + #[cfg(target_family = "unix")] + { + let mut source: Option<&(dyn StdErrorTrait + 'static)> = Some(&err); + while let Some(err_ref) = source { + if let Some(io_err) = err_ref.downcast_ref::() { + let cloned = std::io::Error::new(io_err.kind(), io_err.to_string()); + return ConfigError::UnixSocketConnect(cloned); + } + source = err_ref.source(); + } + } + + ConfigError::from(err) + } + + /// Helper to create basic auth header for proxy authentication + #[cfg(feature = "native")] + fn create_proxy_auth_header( + username: &str, + password: &str, + ) -> Result { + let auth_value = BASE64_STANDARD.encode(format!("{}:{}", username, password)); + Ok(HeaderValue::from_str(&format!("Basic {}", auth_value))?) + } + + /// Helper to apply authentication and headers to a tunnel + #[cfg(feature = "native")] + fn apply_tunnel_config( + &self, + mut tunnel: Tunnel, + proxy_config: &ProxyConfig, + warn_insecure: bool, + ) -> Result, ConfigError> { + // Set proxy authentication if provided + if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) { + if warn_insecure { + self.warn_insecure_auth(); + } + + let auth_header = Self::create_proxy_auth_header(username, password)?; + tunnel = tunnel.with_auth(auth_header); + } + + // Set custom headers for proxy requests + if !proxy_config.headers.is_empty() { + let proxy_headers = self.parse_proxy_headers(&proxy_config.headers)?; + tunnel = tunnel.with_headers(proxy_headers); + } + + Ok(tunnel) + } + + /// Loads TLS configuration + #[cfg(feature = "native")] + async fn load_tls_config(&self) -> Result, ConfigError> { + let tls = self.tls_setting.load_rustls_config().await?; + Ok(tls) + } + + #[cfg(all(feature = "native", target_family = "unix"))] + async fn connect_unix_channel(&self, uri: Uri, lazy: bool) -> Result { + if !self.tls_setting.insecure { + // TLS handshakes are unnecessary over local UDS and currently unsupported + return Err(ConfigError::UnixSocketTlsUnsupported); + } + + let path = uri.path(); + let socket_path = Arc::new(PathBuf::from(path)); + let builder = self.create_channel_builder(uri)?; + + let make_connector = || { + let path = socket_path.clone(); + service_fn(move |_uri: Uri| { + let path = path.clone(); + async move { UnixStream::connect(path.as_path()).await.map(TokioIo::new) } + }) + }; + + if lazy { + Ok(builder.connect_with_connector_lazy(make_connector())) + } else { + let backoff_strategy = self.backoff.get_strategy(); + RetryIf::spawn( + backoff_strategy, + || { + let builder = builder.clone(); + let connector = make_connector(); + let path = socket_path.clone(); + async move { + tracing::debug!( + socket_path = %path.display(), + "Attempting to create gRPC channel over Unix domain socket" + ); + builder + .connect_with_connector(connector) + .await + .map_err(Self::map_transport_error) + } + }, + |e: &ConfigError| match e { + ConfigError::TransportError(err) => { + tracing::warn!(error = %err.chain(), "Transport error encountered. Retrying..."); + true + } + ConfigError::UnixSocketConnect(err) => { + tracing::warn!(error = %err, "Unix socket connect error encountered. Retrying..."); + true + } + _ => { + tracing::error!(error = %e.chain(), "non-retryable error encountered"); + false + } + }, + ) + .await + } + } + + #[cfg(all(feature = "native", not(target_family = "unix")))] + async fn connect_unix_channel(&self, _uri: Uri, _lazy: bool) -> Result { + Err(ConfigError::UnixSocketUnsupported) + } + + #[cfg(feature = "native")] + async fn connect_tcp_channel(&self, uri: Uri, lazy: bool) -> Result { + let http_connector = self.create_http_connector()?; + let builder = self.create_channel_builder(uri.clone())?; + let tls_config = self.load_tls_config().await?; + + if lazy { + let connection = self.create_connection(uri, http_connector).await?; + self.create_channel_from_connection(builder, connection, tls_config, true) + .await + } else { + let backoff_strategy = self.backoff.get_strategy(); + RetryIf::spawn( + backoff_strategy, + || { + let uri = uri.clone(); + let builder = builder.clone(); + let http_connector = http_connector.clone(); + let tls_config = tls_config.clone(); + async move { + tracing::debug!(%uri, "Attempting to create gRPC channel"); + self.create_channel_with_connector(uri, builder, http_connector, tls_config) + .await + } + }, + |e: &ConfigError| { + match e { + ConfigError::TransportError(err) => { + tracing::warn!(error = %err.chain(), "Transport error encountered. Retrying..."); + true + } + _ => { + tracing::error!(error = %e.chain(), "non-retryable error encountered"); + false + } + } + }, + ) + .await + } + } + + /// Creates the channel with the appropriate connector (proxy or direct) + /// Creates a channel with the provided connector and TLS configuration. + #[cfg(feature = "native")] + async fn create_channel_with_connector( + &self, + uri: Uri, + builder: tonic::transport::Endpoint, + http_connector: HttpConnector, + tls_config: Option, + ) -> Result { + let connection = self.create_connection(uri, http_connector).await?; + self.create_channel_from_connection(builder, connection, tls_config, false) + .await + } + + /// Creates the appropriate connection type based on proxy configuration + #[cfg(feature = "native")] + async fn create_connection( + &self, + uri: Uri, + http_connector: HttpConnector, + ) -> Result { + // Check if this host should bypass the proxy + if let Some(intercept) = self.proxy.should_use_proxy(uri.to_string()) { + // Use proxy for this host + self.create_proxy_connection(intercept, http_connector) + .await + } else { + // Skip proxy for this host, use direct connection + Ok(ConnectionType::Direct(http_connector)) + } + } + + /// Creates a proxy connection + #[cfg(feature = "native")] + async fn create_proxy_connection( + &self, + intercept: Intercept, + http_connector: HttpConnector, + ) -> Result { + let proxy_uri = intercept.uri(); + + tracing::info!(%proxy_uri, "Creating proxy tunnel"); + + // Check if the proxy URL uses HTTPS + if proxy_uri.scheme_str() == Some("https") { + let proxy_tls_config = self.proxy.tls_setting.load_rustls_config().await?.unwrap(); + + // Create HTTPS connector for the proxy itself + let https_connector = hyper_rustls::HttpsConnectorBuilder::new() + .with_tls_config(proxy_tls_config) + .https_or_http() + .enable_http2() + .wrap_connector(http_connector); + + let tunnel = Tunnel::new(proxy_uri.clone(), https_connector); + let configured_tunnel = self.apply_tunnel_config(tunnel, &self.proxy, false)?; + + Ok(ConnectionType::ProxyHttps(configured_tunnel)) + } else { + // Use HTTP connector for the proxy + let tunnel = Tunnel::new(proxy_uri.clone(), http_connector); + let configured_tunnel = self.apply_tunnel_config(tunnel, &self.proxy, true)?; + + Ok(ConnectionType::ProxyHttp(configured_tunnel)) + } + } + + /// Creates a channel from any connection type with TLS support + #[cfg(feature = "native")] + async fn create_channel_from_connection( + &self, + builder: tonic::transport::Endpoint, + connection: ConnectionType, + tls_config: Option, + lazy: bool, + ) -> Result { + match connection { + ConnectionType::Direct(connector) => { + create_connector!( + builder, + connector, + tls_config, + self.server_name.as_deref(), + lazy + ) + } + ConnectionType::ProxyHttp(tunnel) => { + create_connector!( + builder, + tunnel, + tls_config, + self.server_name.as_deref(), + lazy + ) + } + ConnectionType::ProxyHttps(tunnel) => { + create_connector!( + builder, + tunnel, + tls_config, + self.server_name.as_deref(), + lazy + ) + } + } + } + + /// Parses proxy headers + #[cfg(feature = "native")] + fn parse_proxy_headers( + &self, + headers: &HashMap, + ) -> Result { + Self::parse_header_map(headers) + } + + /// Applies authentication and headers to the channel + #[cfg(feature = "native")] + async fn apply_auth_and_headers( + &self, + channel: Channel, + header_map: HeaderMap, + ) -> Result< + impl tonic::client::GrpcService< + tonic::body::Body, + Error: Into + Send, + ResponseBody: Body + std::marker::Send> + + Send + + 'static, + Future: Send, + > + + Send + + Clone + + 'static + + use<>, + ConfigError, + > { + match &self.auth { + AuthenticationConfig::Basic(basic) => { + create_auth_service_no_init!(self, basic, header_map, channel) + } + AuthenticationConfig::StaticJwt(jwt) => { + create_auth_service_with_init!(self, jwt, header_map, channel) + } + AuthenticationConfig::Jwt(jwt) => { + create_auth_service_with_init!(self, jwt, header_map, channel) + } + #[cfg(all(feature = "native", not(target_family = "windows")))] + AuthenticationConfig::Spire(spire) => { + create_auth_service_with_init!(self, spire, header_map, channel) + } + AuthenticationConfig::None => Ok(tower::ServiceBuilder::new() + .layer(SetRequestHeaderLayer::new(header_map)) + .service(channel) + .boxed_clone()), + } + } + + /// Warns if authentication is enabled without TLS + #[cfg(feature = "native")] + fn warn_insecure_auth(&self) { + if self.tls_setting.insecure { + warn!("Auth is enabled without TLS. This is not recommended."); + } + } +} + +#[cfg(test)] +mod metadata_tests { + use super::*; + + #[test] + fn client_config_with_metadata_roundtrip_json() { + let mut md = MetadataMap::default(); + md.insert("feature", "alpha"); + md.insert("level", 2u64); + + let cfg = ClientConfig::with_endpoint("http://localhost:1234").with_metadata(md.clone()); + let s = serde_json::to_string(&cfg).expect("serialize"); + let deser: ClientConfig = serde_json::from_str(&s).expect("deserialize"); + assert_eq!(deser.metadata, Some(md)); + } +} + +/// Parse the rate limit string into a limit and a duration. +/// The rate limit string should be in the format of /, +/// with duration expressed in seconds. +/// This function will return a Result with the limit and duration if the +/// rate limit is valid. +#[cfg(feature = "native")] +fn parse_rate_limit(rate_limit: &str) -> Result<(u64, Duration), ConfigError> { + let parts: Vec<&str> = rate_limit.split('/').collect(); + + if parts.len() != 2 { + // Invalid format: expected / + return Err(ConfigError::Unknown); + } + + let limit = parts[0].parse::()?; + let duration = Duration::from_secs(parts[1].parse::()?); + + Ok((limit, duration)) +} + +#[cfg(all(test, feature = "native"))] +mod test { + #[allow(unused_imports)] + use super::*; + use crate::tls::common::CaSource; + use hyper_util::rt::TokioIo; + use tower::service_fn; + use tracing_test::traced_test; + + #[test] + fn test_default_keepalive_config() { + let keepalive = KeepaliveConfig::default(); + assert_eq!(keepalive.tcp_keepalive, Duration::from_secs(60)); + assert_eq!(keepalive.http2_keepalive, Duration::from_secs(60)); + assert_eq!(keepalive.timeout, Duration::from_secs(10)); + assert!(!keepalive.keep_alive_while_idle); + } + + #[test] + fn test_default_client_config() { + let client = ClientConfig::default(); + assert_eq!(client.endpoint, String::new()); + assert_eq!(client.transport, TransportProtocol::Grpc); + assert_eq!(client.websocket_auth_query_param, None); + assert_eq!(client.origin, None); + assert_eq!(client.compression, None); + assert_eq!(client.rate_limit, None); + assert_eq!(client.tls_setting, TLSSetting::default()); + assert_eq!(client.keepalive, None); + assert_eq!(client.connect_timeout, Duration::from_secs(0)); + assert_eq!(client.request_timeout, Duration::from_secs(0)); + assert_eq!(client.buffer_size, None); + assert_eq!(client.headers, HashMap::new()); + assert_eq!(client.auth, AuthenticationConfig::None); + } + + #[test] + fn test_parse_rate_limit() { + let res = parse_rate_limit("100/10"); + assert!(res.is_ok()); + + let (limit, duration) = res.unwrap(); + + assert_eq!(limit, 100); + assert_eq!(duration, Duration::from_secs(10)); + + let res = parse_rate_limit("100"); + assert!(res.is_err()); + } + + #[test] + fn test_parse_endpoint_uri_http() { + let client = ClientConfig::with_endpoint("http://localhost:1234"); + let uri = client.parse_endpoint_uri().expect("valid http uri"); + assert_eq!(uri.scheme_str(), Some("http")); + assert_eq!( + uri.authority().map(|auth| auth.as_str()), + Some("localhost:1234") + ); + } + + #[test] + fn test_parse_endpoint_uri_unix() { + let client = ClientConfig::with_endpoint("unix://tmp/slim.sock"); + let uri = client.parse_endpoint_uri().expect("valid unix uri"); + assert_eq!(uri.scheme_str(), Some("unix")); + assert_eq!(uri.authority().map(|auth| auth.as_str()), Some("localhost")); + assert_eq!(uri.path(), "tmp/slim.sock"); + } + + #[test] + fn test_parse_endpoint_uri_unix_missing_path() { + let client = ClientConfig::with_endpoint("unix://"); + let err = client.parse_endpoint_uri().expect_err("missing unix path"); + assert!(matches!(err, ConfigError::UnixSocketMissingPath)); + } + + #[test] + fn test_websocket_transport_endpoint_validation() { + let ws_config = ClientConfig::with_endpoint("ws://localhost:46357") + .with_transport(TransportProtocol::Websocket); + assert!(ws_config.validate().is_ok()); + + let wss_config = ClientConfig::with_endpoint("wss://localhost:46357") + .with_transport(TransportProtocol::Websocket); + assert!(wss_config.validate().is_ok()); + + let invalid = ClientConfig::with_endpoint("http://localhost:46357") + .with_transport(TransportProtocol::Websocket); + let err = invalid + .validate() + .expect_err("expected invalid websocket scheme"); + assert!(matches!(err, ConfigError::InvalidWebSocketEndpointScheme)); + } + + #[tokio::test] + async fn test_connect_tcp_channel_lazy_ok() { + let client = ClientConfig::with_endpoint("http://127.0.0.1:0"); + let uri = client.parse_endpoint_uri().expect("valid http uri"); + let channel = client.connect_tcp_channel(uri, true).await; + assert!(channel.is_ok()); + } + + #[tokio::test] + async fn test_connect_tcp_channel_non_lazy_error() { + let mut client = ClientConfig::with_endpoint("http://127.0.0.1:0") + .with_connect_timeout(Duration::from_millis(50)); + client.backoff = BackoffConfig::new_fixed_interval(Duration::from_millis(0), 1); + + let uri = client.parse_endpoint_uri().expect("valid http uri"); + let err = client + .connect_tcp_channel(uri, false) + .await + .expect_err("expected connect error"); + assert!(matches!(err, ConfigError::TransportError(_))); + } + + #[cfg(target_family = "unix")] + #[tokio::test] + async fn test_connect_unix_channel_lazy_ok() { + let mut client = ClientConfig::with_endpoint("unix:///tmp/slim-test.sock"); + client.tls_setting.insecure = true; + + let uri = client.parse_endpoint_uri().expect("valid unix uri"); + let channel = client.connect_unix_channel(uri, true).await; + assert!(channel.is_ok()); + } + + #[cfg(target_family = "unix")] + #[tokio::test] + async fn test_connect_unix_channel_non_lazy_error() { + let mut client = ClientConfig::with_endpoint("unix:///tmp/slim-missing.sock"); + client.tls_setting.insecure = true; + client.backoff = BackoffConfig::new_fixed_interval(Duration::from_millis(0), 1); + + let uri = client.parse_endpoint_uri().expect("valid unix uri"); + let err = client + .connect_unix_channel(uri, false) + .await + .expect_err("expected unix socket connect error"); + assert!(matches!(err, ConfigError::UnixSocketConnect(_))); + } + + #[cfg(not(target_family = "unix"))] + #[tokio::test] + async fn test_connect_unix_channel_unsupported() { + let client = ClientConfig::with_endpoint("unix:///tmp/slim.sock"); + let uri = client.parse_endpoint_uri().expect("valid unix uri"); + let err = client + .connect_unix_channel(uri, true) + .await + .expect_err("expected unix socket unsupported"); + assert!(matches!(err, ConfigError::UnixSocketUnsupported)); + } + + #[cfg(target_family = "unix")] + #[tokio::test] + async fn test_map_transport_error_maps_io() { + let endpoint = tonic::transport::Endpoint::from_static("http://localhost"); + let connector = service_fn(|_uri: Uri| async move { + Err::, std::io::Error>(std::io::Error::other("boom")) + }); + let err = endpoint + .connect_with_connector(connector) + .await + .expect_err("expected connect error"); + let mapped = ClientConfig::map_transport_error(err); + assert!(matches!(mapped, ConfigError::UnixSocketConnect(_))); + } + + #[cfg(not(target_family = "unix"))] + #[tokio::test] + async fn test_map_transport_error_transport() { + let endpoint = tonic::transport::Endpoint::from_static("http://localhost"); + let connector = service_fn(|_uri: Uri| async move { + Err::, std::io::Error>(std::io::Error::new( + std::io::ErrorKind::Other, + "boom", + )) + }); + let err = endpoint + .connect_with_connector(connector) + .await + .expect_err("expected connect error"); + let mapped = ClientConfig::map_transport_error(err); + assert!(matches!(mapped, ConfigError::TransportError(_))); + } + + #[tokio::test] + #[traced_test] + async fn test_to_channel() { + let test_path: &str = env!("CARGO_MANIFEST_DIR"); + + // create a new client config + let mut client = ClientConfig::default(); + + // as the endpoint is missing, this should fail + let mut channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_err()); + + // Set the endpoint + client.endpoint = "http://localhost:8080".to_string(); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set the tls settings + client.tls_setting.insecure = true; + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set the tls settings + client.tls_setting = { + let mut tls = TLSSetting::default(); + // Updated for new Config fields: set CA via ca_source and leave source as default (None) + tls.config.ca_source = CaSource::File { + path: format!("{}/testdata/grpc/{}", test_path, "ca.crt"), + }; + tls.insecure = false; + tls + }; + + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set keepalive settings + client.keepalive = Some(KeepaliveConfig::default()); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set rate limit settings + client.rate_limit = Some("100/10".to_string()); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set rate limit settings wrong + client.rate_limit = Some("100".to_string()); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_err()); + + // reset config + client.rate_limit = None; + + // Set timeout settings + client.request_timeout = Duration::from_secs(10).into(); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set buffer size settings + client.buffer_size = Some(1024); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set origin settings + client.origin = Some("http://example.com".to_string()); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // set additional header to add to the request + client + .headers + .insert("X-Test".to_string(), "test".to_string()); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set proxy settings + client.proxy = ProxyConfig::new("http://proxy.example.com:8080"); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set proxy with authentication + client.proxy = ProxyConfig::new("http://proxy.example.com:8080").with_auth("user", "pass"); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set proxy with headers + let mut proxy_headers = HashMap::new(); + proxy_headers.insert("X-Proxy-Header".to_string(), "value".to_string()); + client.proxy = + ProxyConfig::new("http://proxy.example.com:8080").with_headers(proxy_headers); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set HTTPS proxy settings + client.proxy = ProxyConfig::new("https://proxy.example.com:8080"); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set HTTPS proxy with authentication + client.proxy = ProxyConfig::new("https://proxy.example.com:8080").with_auth("user", "pass"); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + + // Set HTTPS proxy with headers + let mut https_proxy_headers = HashMap::new(); + https_proxy_headers.insert("X-Proxy-Header".to_string(), "value".to_string()); + client.proxy = + ProxyConfig::new("https://proxy.example.com:8080").with_headers(https_proxy_headers); + channel = client.to_grpc_channel_lazy().await; + assert!(channel.is_ok()); + } + + #[tokio::test] + async fn test_to_channel_rejects_websocket_transport() { + let client = ClientConfig::with_endpoint("ws://localhost:46357") + .with_transport(TransportProtocol::Websocket); + let channel = client.to_grpc_channel_lazy().await; + assert!(matches!( + channel, + Err(ConfigError::GrpcChannelUnsupportedTransport) + )); + } + + #[test] + fn test_client_config_with_proxy() { + let proxy = ProxyConfig::new("http://proxy.example.com:8080").with_auth("user", "pass"); + let client = ClientConfig::with_endpoint("http://localhost:8080").with_proxy(proxy.clone()); + assert_eq!(client.proxy, proxy); + } + + #[test] + fn test_connect_and_request_timeout_valid_durations_deserialize() { + let json = r#"{ + "endpoint": "http://localhost:1234", + "connect_timeout": "1m30s", + "request_timeout": "250ms" + }"#; + + let cfg: ClientConfig = serde_json::from_str(json).expect("deserialization should succeed"); + assert_eq!(cfg.connect_timeout, Duration::from_secs(90)); + assert_eq!(cfg.request_timeout, Duration::from_millis(250)); + + // More complex duration + let json = r#"{ + "endpoint": "http://localhost:1234", + "connect_timeout": "1h2m3s4ms", + "request_timeout": "1500ms" + }"#; + let cfg: ClientConfig = + serde_json::from_str(json).expect("complex duration should deserialize"); + assert_eq!( + cfg.connect_timeout, + Duration::from_secs(3723) + Duration::from_millis(4) + ); + assert_eq!(cfg.request_timeout, Duration::from_millis(1500)); + } + + #[test] + fn test_invalid_duration_strings_fail_deserialize() { + let invalids = [ + r#"{ "endpoint": "http://localhost:1234", "connect_timeout": "abc" }"#, + r#"{ "endpoint": "http://localhost:1234", "request_timeout": "10x" }"#, + r#"{ "endpoint": "http://localhost:1234", "request_timeout": "--5s" }"#, + ]; + for js in invalids { + let res: Result = serde_json::from_str(js); + assert!(res.is_err(), "expected error for json: {}", js); + } + } + + #[test] + fn test_keepalive_config_duration_parsing() { + let json = r#"{ + "endpoint": "http://localhost:1234", + "keepalive": { + "tcp_keepalive": "30s", + "http2_keepalive": "45s", + "timeout": "5s", + "keep_alive_while_idle": true + } + }"#; + let cfg: ClientConfig = serde_json::from_str(json).expect("keepalive should deserialize"); + let ka = cfg.keepalive.expect("keepalive should be present"); + assert_eq!(ka.tcp_keepalive, Duration::from_secs(30)); + assert_eq!(ka.http2_keepalive, Duration::from_secs(45)); + assert_eq!(ka.timeout, Duration::from_secs(5)); + assert!(ka.keep_alive_while_idle); + + // Invalid keepalive duration + let invalid_json = r#"{ + "endpoint": "http://localhost:1234", + "keepalive": { "tcp_keepalive": "zz", "http2_keepalive": "10s", "timeout": "5s", "keep_alive_while_idle": false } + }"#; + let res: Result = serde_json::from_str(invalid_json); + assert!(res.is_err(), "invalid tcp_keepalive should fail"); + } + + #[test] + fn test_client_config_roundtrip_duration_serialization() { + let mut cfg = ClientConfig::with_endpoint("http://localhost:9999") + .with_connect_timeout(Duration::from_secs(90)) + .with_request_timeout(Duration::from_millis(750)); + + cfg.keepalive = Some(KeepaliveConfig { + tcp_keepalive: Duration::from_secs(11).into(), + http2_keepalive: Duration::from_secs(22).into(), + timeout: Duration::from_secs(3).into(), + keep_alive_while_idle: true, + }); + + let serialized = serde_json::to_string(&cfg).expect("serialize"); + let deserialized: ClientConfig = serde_json::from_str(&serialized).expect("deserialize"); + + assert_eq!(deserialized.connect_timeout, Duration::from_secs(90)); + assert_eq!(deserialized.request_timeout, Duration::from_millis(750)); + let ka = deserialized.keepalive.expect("keepalive present"); + assert_eq!(ka.tcp_keepalive, Duration::from_secs(11)); + assert_eq!(ka.http2_keepalive, Duration::from_secs(22)); + assert_eq!(ka.timeout, Duration::from_secs(3)); + assert!(ka.keep_alive_while_idle); + } + + #[test] + fn test_validate_rejects_non_uuid_link_id() { + let mut config = ClientConfig::with_endpoint("http://localhost:1234"); + config.link_id = "not-a-uuid".to_string(); + assert!(matches!(config.validate(), Err(ConfigError::InvalidLinkId))); + } + + #[test] + fn test_validate_rejects_non_v4_uuid() { + let mut config = ClientConfig::with_endpoint("http://localhost:1234"); + // Version 1 UUID. + config.link_id = "00000000-0000-1000-8000-000000000000".to_string(); + assert!(matches!(config.validate(), Err(ConfigError::InvalidLinkId))); + } + + #[test] + fn test_validate_accepts_default_v4_link_id() { + // default_link_id() generates a v4 UUID; validation must pass. + let config = ClientConfig::with_endpoint("http://localhost:1234"); + assert!(config.validate().is_ok()); + } +} diff --git a/data-plane/core/config/src/grpc.rs b/data-plane/core/config/src/grpc.rs index 8daa2d3b5..ee9aae121 100644 --- a/data-plane/core/config/src/grpc.rs +++ b/data-plane/core/config/src/grpc.rs @@ -1,9 +1,12 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +#[cfg(feature = "native")] pub mod client; pub mod compression; pub mod errors; +#[cfg(feature = "native")] pub mod headers_middleware; pub mod proxy; +#[cfg(feature = "native")] pub mod server; diff --git a/data-plane/core/config/src/grpc/client.rs b/data-plane/core/config/src/grpc/client.rs index a67fb2370..95ce79717 100644 --- a/data-plane/core/config/src/grpc/client.rs +++ b/data-plane/core/config/src/grpc/client.rs @@ -1,554 +1,15 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 -use duration_string::DurationString; -use rustls_pki_types::ServerName; -use tokio_retry::RetryIf; +pub use crate::client::*; -use display_error_chain::ErrorChainExt; -use std::{collections::HashMap, str::FromStr, time::Duration}; -use tower::ServiceExt; -#[cfg(target_family = "unix")] -use { - hyper_util::rt::TokioIo, - std::{error::Error as StdErrorTrait, path::PathBuf, sync::Arc}, - tokio::net::UnixStream, - tower::service_fn, -}; - -use base64::prelude::*; -use http::header::{HeaderMap, HeaderName, HeaderValue}; -use hyper_rustls; -use hyper_util::client::legacy::connect::HttpConnector; -use hyper_util::client::legacy::connect::proxy::Tunnel; -use hyper_util::client::proxy::matcher::Intercept; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; use tonic::codegen::{Body, Bytes, StdError}; -use tonic::transport::{Channel, Uri}; -use tracing::warn; - -use slim_auth::metadata::MetadataMap; -use super::compression::CompressionType; use super::errors::ConfigError; -use super::headers_middleware::SetRequestHeaderLayer; -use crate::auth::ClientAuthenticator; -use crate::auth::basic::Config as BasicAuthenticationConfig; -use crate::auth::jwt::Config as JwtAuthenticationConfig; -#[cfg(not(target_family = "windows"))] -use crate::auth::spire::SpireConfig as SpireAuthConfig; -use crate::auth::static_jwt::Config as BearerAuthenticationConfig; -use crate::backoff::Strategy; -use crate::backoff::exponential::Config as ExponentialBackoff; -use crate::backoff::fixedinterval::Config as FixedIntervalBackoff; -use crate::component::configuration::Configuration; -use crate::grpc::proxy::ProxyConfig; -use crate::tls::{client::TlsClientConfig as TLSSetting, common::RustlsConfigLoader}; -use crate::transport::TransportProtocol; - -/// Creates an HTTPS connector with optional SNI based on the origin -fn https_connector( - s: S, - tls: &rustls::ClientConfig, - server_name: Option, -) -> hyper_rustls::HttpsConnector -where - S: tower::Service, -{ - let tls = tls.clone(); - let mut builder = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(tls) - .https_or_http(); - - if let Some(origin_str) = server_name { - builder = - builder.with_server_name_resolver(move |_: &_| ServerName::try_from(origin_str.clone())) - } - - builder.enable_http2().wrap_connector(s) -} - -/// Macro to create TLS-enabled or plain connectors based on TLS configuration, -/// applying the optional origin (for SNI) when TLS is enabled. -/// Supports both lazy and eager connection modes. -macro_rules! create_connector { - ($builder:expr, $base_connector:expr, $tls_config:expr, $server_name:expr, $lazy:expr) => { - match ($tls_config, $lazy) { - (Some(tls), true) => { - let connector = tower::ServiceBuilder::new() - .layer_fn(move |s| { - https_connector(s, &tls, $server_name.map(|s| s.to_string())) - }) - .service($base_connector); - Ok($builder.connect_with_connector_lazy(connector)) - } - (Some(tls), false) => { - let connector = tower::ServiceBuilder::new() - .layer_fn(move |s| { - https_connector(s, &tls, $server_name.map(|s| s.to_string())) - }) - .service($base_connector); - let ret = $builder.connect_with_connector(connector).await?; - Ok(ret) - } - (None, true) => Ok($builder.connect_with_connector_lazy($base_connector)), - (None, false) => { - let ret = $builder.connect_with_connector($base_connector).await?; - Ok(ret) - } - } - }; -} - -/// Macro to create authenticated service layers for auth types that don't need initialization -macro_rules! create_auth_service_no_init { - ($self:expr, $auth_config:expr, $header_map:expr, $channel:expr) => {{ - let auth_layer = $auth_config.get_client_layer()?; - - $self.warn_insecure_auth(); - - Ok(tower::ServiceBuilder::new() - .layer(SetRequestHeaderLayer::new($header_map)) - .layer(auth_layer) - .service($channel) - .boxed_clone()) - }}; -} - -/// Macro to create authenticated service layers for auth types that need initialization -macro_rules! create_auth_service_with_init { - ($self:expr, $auth_config:expr, $header_map:expr, $channel:expr) => {{ - let mut auth_layer = $auth_config.get_client_layer()?; - - // Initialize the auth layer - auth_layer.initialize().await?; - - $self.warn_insecure_auth(); - - Ok(tower::ServiceBuilder::new() - .layer(SetRequestHeaderLayer::new($header_map)) - .layer(auth_layer) - .service($channel) - .boxed_clone()) - }}; -} - -/// Enum to handle all connection types: direct connections and proxy tunnels -enum ConnectionType { - /// Direct HTTP connection without proxy - Direct(HttpConnector), - /// HTTP proxy tunnel connection - ProxyHttp(Tunnel), - /// HTTPS proxy tunnel connection - ProxyHttps(Tunnel>), -} - -/// Keepalive configuration for the client. -/// This struct contains the keepalive time for TCP and HTTP2, -/// the timeout duration for the keepalive, and whether to permit -/// keepalive without an active stream. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] -pub struct KeepaliveConfig { - /// The duration of the keepalive time for TCP - #[serde(default = "default_tcp_keepalive")] - #[schemars(with = "String")] - pub tcp_keepalive: DurationString, - - /// The duration of the keepalive time for HTTP2 - #[serde(default = "default_http2_keepalive")] - #[schemars(with = "String")] - pub http2_keepalive: DurationString, - - /// The timeout duration for the keepalive - #[serde(default = "default_timeout")] - #[schemars(with = "String")] - pub timeout: DurationString, - - /// Whether to permit keepalive without an active stream - #[serde(default = "default_keep_alive_while_idle")] - pub keep_alive_while_idle: bool, -} - -/// Defaults for KeepaliveConfig -impl Default for KeepaliveConfig { - fn default() -> Self { - KeepaliveConfig { - tcp_keepalive: default_tcp_keepalive(), - http2_keepalive: default_http2_keepalive(), - timeout: default_timeout(), - keep_alive_while_idle: default_keep_alive_while_idle(), - } - } -} - -fn default_tcp_keepalive() -> DurationString { - Duration::from_secs(60).into() -} - -fn default_http2_keepalive() -> DurationString { - Duration::from_secs(60).into() -} - -fn default_timeout() -> DurationString { - Duration::from_secs(10).into() -} - -fn default_keep_alive_while_idle() -> bool { - false -} - -/// Enum holding one configuration for the client. -#[derive(Debug, Serialize, Default, Deserialize, Clone, PartialEq, JsonSchema)] -#[serde(rename_all = "snake_case", tag = "type")] -pub enum AuthenticationConfig { - /// Basic authentication configuration. - Basic(BasicAuthenticationConfig), - /// Bearer authentication configuration. - StaticJwt(BearerAuthenticationConfig), - /// JWT authentication configuration. - Jwt(JwtAuthenticationConfig), - /// SPIRE/SPIFFE authentication configuration. - #[cfg(not(target_family = "windows"))] - Spire(SpireAuthConfig), - /// None - #[default] - None, -} - -/// Enum holding one configuration for the client. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] -#[serde(rename_all = "snake_case", tag = "type")] -pub enum BackoffConfig { - // Exponential backoff retry config. - Exponential(ExponentialBackoff), - /// FixedInterval backoff retry config. - FixedInterval(FixedIntervalBackoff), -} - -impl BackoffConfig { - /// Creates a new Exponential backoff configuration - pub fn new_exponential( - base: u64, - factor: u64, - max_delay: Duration, - max_attempts: usize, - jitter: bool, - ) -> Self { - BackoffConfig::Exponential(ExponentialBackoff::new( - base, - factor, - max_delay, - max_attempts, - jitter, - )) - } - - /// Creates a new FixedInterval backoff configuration - pub fn new_fixed_interval(interval: Duration, max_attempts: usize) -> Self { - BackoffConfig::FixedInterval(FixedIntervalBackoff::new(interval, max_attempts)) - } -} - -impl Default for BackoffConfig { - fn default() -> Self { - BackoffConfig::Exponential(ExponentialBackoff::default()) - } -} - -impl Strategy for BackoffConfig { - fn get_strategy(&self) -> Box + Send> { - match self { - BackoffConfig::Exponential(b) => b.get_strategy(), - BackoffConfig::FixedInterval(b) => b.get_strategy(), - } - } -} - -/// Struct for the client configuration. -/// This struct contains the endpoint, origin, compression type, rate limit, -/// TLS settings, keepalive settings, proxy settings, timeout settings, buffer size settings, -/// headers, and auth settings. -/// The client configuration can be converted to a tonic channel. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] -pub struct ClientConfig { - /// The target the client will connect to. - pub endpoint: String, - - /// Transport protocol to use for dataplane communication. - #[serde(default)] - pub transport: TransportProtocol, - - /// Optional websocket authentication query parameter key. - /// This is only used when `transport=websocket`. - pub websocket_auth_query_param: Option, - - /// Origin (HTTP Host authority override) for the client. - pub origin: Option, - - /// Optional TLS SNI server name override. If set, this value is used for TLS - /// server name verification (SNI) instead of the host extracted from endpoint/origin. - pub server_name: Option, - - /// Compression type - TODO(msardara): not implemented yet. - pub compression: Option, - - /// Rate Limits - pub rate_limit: Option, - - /// TLS client configuration. - #[serde(default, rename = "tls")] - pub tls_setting: TLSSetting, - - /// Keepalive parameters. - pub keepalive: Option, - - /// HTTP Proxy configuration. - #[serde(default)] - pub proxy: ProxyConfig, - - /// Timeout for the connection. - #[serde(default = "default_connect_timeout")] - #[schemars(with = "String")] - pub connect_timeout: DurationString, - - /// Timeout per request. - #[serde(default = "default_request_timeout")] - #[schemars(with = "String")] - pub request_timeout: DurationString, - - /// ReadBufferSize. - pub buffer_size: Option, - - /// The headers associated with gRPC requests. - #[serde(default)] - pub headers: HashMap, - - /// Auth configuration for outgoing RPCs. - #[serde(default)] - pub auth: AuthenticationConfig, - - /// Backoff retry configuration. - #[serde(default)] - pub backoff: BackoffConfig, - - /// Arbitrary user-provided metadata. - pub metadata: Option, - - /// Link identifier for this connection, used during link negotiation. - /// Must be a valid UUID v4. Defaults to a randomly generated UUID v4. - #[serde(default = "default_link_id")] - pub link_id: String, -} - -/// Defaults for ClientConfig -impl Default for ClientConfig { - fn default() -> Self { - ClientConfig { - endpoint: String::new(), - transport: TransportProtocol::default(), - websocket_auth_query_param: None, - origin: None, - server_name: None, - compression: None, - rate_limit: None, - tls_setting: TLSSetting::default(), - keepalive: None, - proxy: ProxyConfig::default(), - connect_timeout: default_connect_timeout(), - request_timeout: default_request_timeout(), - buffer_size: None, - headers: HashMap::new(), - auth: AuthenticationConfig::None, - backoff: BackoffConfig::default(), - metadata: None, - link_id: default_link_id(), - } - } -} - -fn default_link_id() -> String { - uuid::Uuid::new_v4().to_string() -} - -fn default_connect_timeout() -> DurationString { - Duration::from_secs(0).into() -} - -fn default_request_timeout() -> DurationString { - Duration::from_secs(0).into() -} - -// Display for ClientConfig -impl std::fmt::Display for ClientConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "ClientConfig {{ endpoint: {}, transport: {:?}, websocket_auth_query_param: {:?}, origin: {:?}, server_name: {:?}, compression: {:?}, rate_limit: {:?}, tls_setting: {:?}, keepalive: {:?}, proxy: {:?}, connect_timeout: {:?}, request_timeout: {:?}, buffer_size: {:?}, headers: {:?}, auth: {:?}, backoff: {:?}, metadata: {:?}, link_id: {:?} }}", - self.endpoint, - self.transport, - self.websocket_auth_query_param, - self.origin, - self.server_name, - self.compression, - self.rate_limit, - self.tls_setting, - self.keepalive, - self.proxy, - self.connect_timeout, - self.request_timeout, - self.buffer_size, - self.headers, - self.auth, - self.backoff, - self.metadata, - self.link_id - ) - } -} - -pub fn is_valid_uuid_v4(s: &str) -> bool { - match uuid::Uuid::parse_str(s) { - Ok(id) => id.get_version() == Some(uuid::Version::Random), - Err(_) => false, - } -} - -impl Configuration for ClientConfig { - type Error = ConfigError; - - fn validate(&self) -> Result<(), Self::Error> { - if self.endpoint.is_empty() { - return Err(ConfigError::MissingEndpoint); - } - - // Validate link_id is a UUID v4 - if !is_valid_uuid_v4(&self.link_id) { - return Err(ConfigError::InvalidLinkId); - } - - // Validate the client configuration - self.tls_setting.validate()?; - self.validate_websocket_endpoint()?; - - Ok(()) - } -} impl ClientConfig { - /// Creates a new client configuration with the given endpoint. - /// This function will return a ClientConfig with the endpoint set - /// and all other fields set to default. - pub fn with_endpoint(endpoint: &str) -> Self { - Self { - endpoint: endpoint.to_string(), - ..Self::default() - } - } - - pub fn with_origin(self, origin: &str) -> Self { - Self { - origin: Some(origin.to_string()), - ..self - } - } - - pub fn with_transport(self, transport: TransportProtocol) -> Self { - Self { transport, ..self } - } - - pub fn with_websocket_auth_query_param(self, query_param: &str) -> Self { - Self { - websocket_auth_query_param: Some(query_param.to_string()), - ..self - } - } - - pub fn with_server_name(self, server_name: &str) -> Self { - Self { - server_name: Some(server_name.to_string()), - ..self - } - } - - pub fn with_compression(self, compression: CompressionType) -> Self { - Self { - compression: Some(compression), - ..self - } - } - - pub fn with_rate_limit(self, rate_limit: &str) -> Self { - Self { - rate_limit: Some(rate_limit.to_string()), - ..self - } - } - - pub fn with_tls_setting(self, tls_setting: TLSSetting) -> Self { - Self { - tls_setting, - ..self - } - } - - pub fn with_keepalive(self, keepalive: KeepaliveConfig) -> Self { - Self { - keepalive: Some(keepalive), - ..self - } - } - - pub fn with_proxy(self, proxy: ProxyConfig) -> Self { - Self { proxy, ..self } - } - - pub fn with_connect_timeout(self, connect_timeout: Duration) -> Self { - Self { - connect_timeout: connect_timeout.into(), - ..self - } - } - - pub fn with_request_timeout(self, request_timeout: Duration) -> Self { - Self { - request_timeout: request_timeout.into(), - ..self - } - } - - pub fn with_buffer_size(self, buffer_size: usize) -> Self { - Self { - buffer_size: Some(buffer_size), - ..self - } - } - - pub fn with_headers(self, headers: HashMap) -> Self { - Self { headers, ..self } - } - - pub fn with_auth(self, auth: AuthenticationConfig) -> Self { - Self { auth, ..self } - } - - pub fn with_backoff(self, backoff: BackoffConfig) -> Self { - Self { backoff, ..self } - } - - pub fn with_metadata(self, metadata: MetadataMap) -> Self { - Self { - metadata: Some(metadata), - ..self - } - } - - /// Converts the client configuration to a tonic channel. - /// This function will return a Result with the channel if the configuration is valid. - /// If the configuration is invalid, it will return a ConfigError. - /// The function will set the headers, tls settings, keepalive settings, rate limit settings - /// timeout settings, buffer size settings, and origin settings. - pub async fn to_channel( + /// Converts the client configuration to a gRPC-only channel. + pub async fn to_grpc_channel( &self, ) -> Result< impl tonic::client::GrpcService< @@ -565,15 +26,11 @@ impl ClientConfig { + use<>, ConfigError, > { - self.to_channel_internal(false).await + self.to_grpc_channel_internal().await } - /// Converts the client configuration to a tonic channel without retry logic. - /// This is useful for testing where you want to validate configuration without - /// attempting actual connections. The channel is created lazily and won't connect - /// until the first RPC call is made. - #[cfg(test)] - pub async fn to_channel_lazy( + /// Internal gRPC channel builder used by both public gRPC API and transport switching. + pub(crate) async fn to_grpc_channel_internal( &self, ) -> Result< impl tonic::client::GrpcService< @@ -583,18 +40,19 @@ impl ClientConfig { + Send + 'static, Future: Send, - > + Send + > + + Send + Clone - + 'static, + + 'static + + use<>, ConfigError, > { - self.to_channel_internal(true).await + self.to_channel_internal(false).await } - /// Internal implementation for channel creation with optional lazy flag. - async fn to_channel_internal( + /// Converts the client configuration to a gRPC-only channel without retry logic. + pub async fn to_grpc_channel_lazy( &self, - lazy: bool, ) -> Result< impl tonic::client::GrpcService< tonic::body::Body, @@ -603,439 +61,17 @@ impl ClientConfig { + Send + 'static, Future: Send, - > - + Send + > + Send + Clone - + 'static - + use<>, + + 'static, ConfigError, > { - if self.transport == TransportProtocol::Websocket { - return Err(ConfigError::GrpcChannelUnsupportedTransport); - } - - // Validate endpoint - self.validate_endpoint()?; - - // Parse headers - let header_map = self.parse_headers()?; - - let uri = self.parse_endpoint_uri()?; - - let channel = if uri.scheme_str() == Some("unix") { - self.connect_unix_channel(uri, lazy).await? - } else if uri.scheme_str() == Some("http") || uri.scheme_str() == Some("https") { - self.connect_tcp_channel(uri, lazy).await? - } else { - return Err(ConfigError::InvalidEndpointScheme); - }; - - // Apply authentication and headers - self.apply_auth_and_headers(channel, header_map).await - } - - /// Validates that the endpoint is set and not empty - fn validate_endpoint(&self) -> Result<(), ConfigError> { - if self.endpoint.is_empty() { - return Err(ConfigError::MissingEndpoint); - } - Ok(()) - } - - fn validate_websocket_endpoint(&self) -> Result<(), ConfigError> { - if self.transport != TransportProtocol::Websocket { - return Ok(()); - } - - let endpoint = Uri::from_str(self.endpoint.as_str())?; - match endpoint.scheme_str() { - Some("ws") | Some("wss") => Ok(()), - _ => Err(ConfigError::InvalidWebSocketEndpointScheme), - } + self.to_grpc_channel_lazy_internal().await } - /// Parses the endpoint string into a URI for TCP/HTTP, Unix domain socket endpoints. - fn parse_endpoint_uri(&self) -> Result { - // Special case for the unix scheme because it doesn't have an - // authority in the URI and the Uri parser doesn't like this today, - // so we build our own URI with a fake localhost authority. - if self.endpoint.starts_with("unix://") { - let path = &self.endpoint[7..]; - if path.is_empty() { - return Err(ConfigError::UnixSocketMissingPath); - } - - let uri = Uri::builder() - .scheme("unix") - .authority("localhost") - .path_and_query(path) - .build() - .map_err(ConfigError::UnixSocketInvalidPath)?; - return Ok(uri); - } - Ok(Uri::from_str(&self.endpoint)?) - } - - /// Creates and configures the HTTP connector - fn create_http_connector(&self) -> Result { - let mut http = HttpConnector::new(); - - // NOTE(msardara): we might want to make these configurable as well. - http.enforce_http(false); - http.set_nodelay(false); - - // set the connection timeout - match self.connect_timeout.as_secs() { - 0 => http.set_connect_timeout(None), - _ => http.set_connect_timeout(Some(self.connect_timeout.into())), - } - - // set keepalive settings - if let Some(keepalive) = &self.keepalive { - http.set_keepalive(Some(keepalive.tcp_keepalive.into())); - } - - Ok(http) - } - - /// Creates the channel builder with all configuration settings - fn create_channel_builder(&self, uri: Uri) -> Result { - let mut builder = Channel::builder(uri); - - // set the buffer size - if let Some(size) = self.buffer_size { - builder = builder.buffer_size(size); - } - - // set keepalive settings - if let Some(keepalive) = &self.keepalive { - builder = builder - .keep_alive_timeout(keepalive.timeout.into()) - .keep_alive_while_idle(keepalive.keep_alive_while_idle) - // HTTP level keepalive - .http2_keep_alive_interval(keepalive.http2_keepalive.into()); - } - - // set origin settings - if let Some(origin) = &self.origin { - let origin_uri = Uri::from_str(origin.as_str())?; - builder = builder.origin(origin_uri); - } - - // set rate limit settings - if let Some(rate_limit) = &self.rate_limit { - let (limit, duration) = parse_rate_limit(rate_limit)?; - builder = builder.rate_limit(limit, duration); - } - - // set the request timeout - if self.request_timeout.as_secs() > 0 { - builder = builder.timeout(self.request_timeout.into()); - } - - if self.connect_timeout.as_secs() > 0 { - builder = builder.connect_timeout(self.connect_timeout.into()); - } - - Ok(builder) - } - - /// Parses headers from the configuration - fn parse_headers(&self) -> Result { - Self::parse_header_map(&self.headers) - } - - /// Generic helper to parse a HashMap into HeaderMap - fn parse_header_map(headers: &HashMap) -> Result { - let mut header_map = HeaderMap::new(); - for (key, value) in headers { - let header_name = HeaderName::from_str(key)?; - let header_value = HeaderValue::from_str(value)?; - header_map.insert(header_name, header_value); - } - Ok(header_map) - } - - #[cfg(target_family = "unix")] - fn map_transport_error(err: tonic::transport::Error) -> ConfigError { - #[cfg(target_family = "unix")] - { - let mut source: Option<&(dyn StdErrorTrait + 'static)> = Some(&err); - while let Some(err_ref) = source { - if let Some(io_err) = err_ref.downcast_ref::() { - let cloned = std::io::Error::new(io_err.kind(), io_err.to_string()); - return ConfigError::UnixSocketConnect(cloned); - } - source = err_ref.source(); - } - } - - ConfigError::from(err) - } - - /// Helper to create basic auth header for proxy authentication - fn create_proxy_auth_header( - username: &str, - password: &str, - ) -> Result { - let auth_value = BASE64_STANDARD.encode(format!("{}:{}", username, password)); - Ok(HeaderValue::from_str(&format!("Basic {}", auth_value))?) - } - - /// Helper to apply authentication and headers to a tunnel - fn apply_tunnel_config( - &self, - mut tunnel: Tunnel, - proxy_config: &ProxyConfig, - warn_insecure: bool, - ) -> Result, ConfigError> { - // Set proxy authentication if provided - if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) { - if warn_insecure { - self.warn_insecure_auth(); - } - - let auth_header = Self::create_proxy_auth_header(username, password)?; - tunnel = tunnel.with_auth(auth_header); - } - - // Set custom headers for proxy requests - if !proxy_config.headers.is_empty() { - let proxy_headers = self.parse_proxy_headers(&proxy_config.headers)?; - tunnel = tunnel.with_headers(proxy_headers); - } - - Ok(tunnel) - } - - /// Loads TLS configuration - async fn load_tls_config(&self) -> Result, ConfigError> { - let tls = self.tls_setting.load_rustls_config().await?; - Ok(tls) - } - - #[cfg(target_family = "unix")] - async fn connect_unix_channel(&self, uri: Uri, lazy: bool) -> Result { - if !self.tls_setting.insecure { - // TLS handshakes are unnecessary over local UDS and currently unsupported - return Err(ConfigError::UnixSocketTlsUnsupported); - } - - let path = uri.path(); - let socket_path = Arc::new(PathBuf::from(path)); - let builder = self.create_channel_builder(uri)?; - - let make_connector = || { - let path = socket_path.clone(); - service_fn(move |_uri: Uri| { - let path = path.clone(); - async move { UnixStream::connect(path.as_path()).await.map(TokioIo::new) } - }) - }; - - if lazy { - Ok(builder.connect_with_connector_lazy(make_connector())) - } else { - let backoff_strategy = self.backoff.get_strategy(); - RetryIf::spawn( - backoff_strategy, - || { - let builder = builder.clone(); - let connector = make_connector(); - let path = socket_path.clone(); - async move { - tracing::debug!( - socket_path = %path.display(), - "Attempting to create gRPC channel over Unix domain socket" - ); - builder - .connect_with_connector(connector) - .await - .map_err(Self::map_transport_error) - } - }, - |e: &ConfigError| match e { - ConfigError::TransportError(err) => { - tracing::warn!(error = %err.chain(), "Transport error encountered. Retrying..."); - true - } - ConfigError::UnixSocketConnect(err) => { - tracing::warn!(error = %err, "Unix socket connect error encountered. Retrying..."); - true - } - _ => { - tracing::error!(error = %e.chain(), "non-retryable error encountered"); - false - } - }, - ) - .await - } - } - - #[cfg(not(target_family = "unix"))] - async fn connect_unix_channel(&self, _uri: Uri, _lazy: bool) -> Result { - Err(ConfigError::UnixSocketUnsupported) - } - - async fn connect_tcp_channel(&self, uri: Uri, lazy: bool) -> Result { - let http_connector = self.create_http_connector()?; - let builder = self.create_channel_builder(uri.clone())?; - let tls_config = self.load_tls_config().await?; - - if lazy { - let connection = self.create_connection(uri, http_connector).await?; - self.create_channel_from_connection(builder, connection, tls_config, true) - .await - } else { - let backoff_strategy = self.backoff.get_strategy(); - RetryIf::spawn( - backoff_strategy, - || { - let uri = uri.clone(); - let builder = builder.clone(); - let http_connector = http_connector.clone(); - let tls_config = tls_config.clone(); - async move { - tracing::debug!(%uri, "Attempting to create gRPC channel"); - self.create_channel_with_connector(uri, builder, http_connector, tls_config) - .await - } - }, - |e: &ConfigError| { - match e { - ConfigError::TransportError(err) => { - tracing::warn!(error = %err.chain(), "Transport error encountered. Retrying..."); - true - } - _ => { - tracing::error!(error = %e.chain(), "non-retryable error encountered"); - false - } - } - }, - ) - .await - } - } - - /// Creates the channel with the appropriate connector (proxy or direct) - /// Creates a channel with the provided connector and TLS configuration. - async fn create_channel_with_connector( - &self, - uri: Uri, - builder: tonic::transport::Endpoint, - http_connector: HttpConnector, - tls_config: Option, - ) -> Result { - let connection = self.create_connection(uri, http_connector).await?; - self.create_channel_from_connection(builder, connection, tls_config, false) - .await - } - - /// Creates the appropriate connection type based on proxy configuration - async fn create_connection( - &self, - uri: Uri, - http_connector: HttpConnector, - ) -> Result { - // Check if this host should bypass the proxy - if let Some(intercept) = self.proxy.should_use_proxy(uri.to_string()) { - // Use proxy for this host - self.create_proxy_connection(intercept, http_connector) - .await - } else { - // Skip proxy for this host, use direct connection - Ok(ConnectionType::Direct(http_connector)) - } - } - - /// Creates a proxy connection - async fn create_proxy_connection( - &self, - intercept: Intercept, - http_connector: HttpConnector, - ) -> Result { - let proxy_uri = intercept.uri(); - - tracing::info!(%proxy_uri, "Creating proxy tunnel"); - - // Check if the proxy URL uses HTTPS - if proxy_uri.scheme_str() == Some("https") { - let proxy_tls_config = self.proxy.tls_setting.load_rustls_config().await?.unwrap(); - - // Create HTTPS connector for the proxy itself - let https_connector = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(proxy_tls_config) - .https_or_http() - .enable_http2() - .wrap_connector(http_connector); - - let tunnel = Tunnel::new(proxy_uri.clone(), https_connector); - let configured_tunnel = self.apply_tunnel_config(tunnel, &self.proxy, false)?; - - Ok(ConnectionType::ProxyHttps(configured_tunnel)) - } else { - // Use HTTP connector for the proxy - let tunnel = Tunnel::new(proxy_uri.clone(), http_connector); - let configured_tunnel = self.apply_tunnel_config(tunnel, &self.proxy, true)?; - - Ok(ConnectionType::ProxyHttp(configured_tunnel)) - } - } - - /// Creates a channel from any connection type with TLS support - async fn create_channel_from_connection( + /// Internal lazy gRPC channel builder (no retry layer). + pub(crate) async fn to_grpc_channel_lazy_internal( &self, - builder: tonic::transport::Endpoint, - connection: ConnectionType, - tls_config: Option, - lazy: bool, - ) -> Result { - match connection { - ConnectionType::Direct(connector) => { - create_connector!( - builder, - connector, - tls_config, - self.server_name.as_deref(), - lazy - ) - } - ConnectionType::ProxyHttp(tunnel) => { - create_connector!( - builder, - tunnel, - tls_config, - self.server_name.as_deref(), - lazy - ) - } - ConnectionType::ProxyHttps(tunnel) => { - create_connector!( - builder, - tunnel, - tls_config, - self.server_name.as_deref(), - lazy - ) - } - } - } - - /// Parses proxy headers - fn parse_proxy_headers( - &self, - headers: &HashMap, - ) -> Result { - Self::parse_header_map(headers) - } - - /// Applies authentication and headers to the channel - async fn apply_auth_and_headers( - &self, - channel: Channel, - header_map: HeaderMap, ) -> Result< impl tonic::client::GrpcService< tonic::body::Body, @@ -1044,508 +80,11 @@ impl ClientConfig { + Send + 'static, Future: Send, - > - + Send + > + Send + Clone - + 'static - + use<>, + + 'static, ConfigError, > { - match &self.auth { - AuthenticationConfig::Basic(basic) => { - create_auth_service_no_init!(self, basic, header_map, channel) - } - AuthenticationConfig::StaticJwt(jwt) => { - create_auth_service_with_init!(self, jwt, header_map, channel) - } - AuthenticationConfig::Jwt(jwt) => { - create_auth_service_with_init!(self, jwt, header_map, channel) - } - #[cfg(not(target_family = "windows"))] - AuthenticationConfig::Spire(spire) => { - create_auth_service_with_init!(self, spire, header_map, channel) - } - AuthenticationConfig::None => Ok(tower::ServiceBuilder::new() - .layer(SetRequestHeaderLayer::new(header_map)) - .service(channel) - .boxed_clone()), - } - } - - /// Warns if authentication is enabled without TLS - fn warn_insecure_auth(&self) { - if self.tls_setting.insecure { - warn!("Auth is enabled without TLS. This is not recommended."); - } - } -} - -#[cfg(test)] -mod metadata_tests { - use super::*; - - #[test] - fn client_config_with_metadata_roundtrip_json() { - let mut md = MetadataMap::default(); - md.insert("feature", "alpha"); - md.insert("level", 2u64); - - let cfg = ClientConfig::with_endpoint("http://localhost:1234").with_metadata(md.clone()); - let s = serde_json::to_string(&cfg).expect("serialize"); - let deser: ClientConfig = serde_json::from_str(&s).expect("deserialize"); - assert_eq!(deser.metadata, Some(md)); - } -} - -/// Parse the rate limit string into a limit and a duration. -/// The rate limit string should be in the format of /, -/// with duration expressed in seconds. -/// This function will return a Result with the limit and duration if the -/// rate limit is valid. -fn parse_rate_limit(rate_limit: &str) -> Result<(u64, Duration), ConfigError> { - let parts: Vec<&str> = rate_limit.split('/').collect(); - - if parts.len() != 2 { - // Invalid format: expected / - return Err(ConfigError::Unknown); - } - - let limit = parts[0].parse::()?; - let duration = Duration::from_secs(parts[1].parse::()?); - - Ok((limit, duration)) -} - -#[cfg(test)] -mod test { - #[allow(unused_imports)] - use super::*; - use crate::tls::common::CaSource; - use hyper_util::rt::TokioIo; - use tower::service_fn; - use tracing_test::traced_test; - - #[test] - fn test_default_keepalive_config() { - let keepalive = KeepaliveConfig::default(); - assert_eq!(keepalive.tcp_keepalive, Duration::from_secs(60)); - assert_eq!(keepalive.http2_keepalive, Duration::from_secs(60)); - assert_eq!(keepalive.timeout, Duration::from_secs(10)); - assert!(!keepalive.keep_alive_while_idle); - } - - #[test] - fn test_default_client_config() { - let client = ClientConfig::default(); - assert_eq!(client.endpoint, String::new()); - assert_eq!(client.transport, TransportProtocol::Grpc); - assert_eq!(client.websocket_auth_query_param, None); - assert_eq!(client.origin, None); - assert_eq!(client.compression, None); - assert_eq!(client.rate_limit, None); - assert_eq!(client.tls_setting, TLSSetting::default()); - assert_eq!(client.keepalive, None); - assert_eq!(client.connect_timeout, Duration::from_secs(0)); - assert_eq!(client.request_timeout, Duration::from_secs(0)); - assert_eq!(client.buffer_size, None); - assert_eq!(client.headers, HashMap::new()); - assert_eq!(client.auth, AuthenticationConfig::None); - } - - #[test] - fn test_parse_rate_limit() { - let res = parse_rate_limit("100/10"); - assert!(res.is_ok()); - - let (limit, duration) = res.unwrap(); - - assert_eq!(limit, 100); - assert_eq!(duration, Duration::from_secs(10)); - - let res = parse_rate_limit("100"); - assert!(res.is_err()); - } - - #[test] - fn test_parse_endpoint_uri_http() { - let client = ClientConfig::with_endpoint("http://localhost:1234"); - let uri = client.parse_endpoint_uri().expect("valid http uri"); - assert_eq!(uri.scheme_str(), Some("http")); - assert_eq!( - uri.authority().map(|auth| auth.as_str()), - Some("localhost:1234") - ); - } - - #[test] - fn test_parse_endpoint_uri_unix() { - let client = ClientConfig::with_endpoint("unix://tmp/slim.sock"); - let uri = client.parse_endpoint_uri().expect("valid unix uri"); - assert_eq!(uri.scheme_str(), Some("unix")); - assert_eq!(uri.authority().map(|auth| auth.as_str()), Some("localhost")); - assert_eq!(uri.path(), "tmp/slim.sock"); - } - - #[test] - fn test_parse_endpoint_uri_unix_missing_path() { - let client = ClientConfig::with_endpoint("unix://"); - let err = client.parse_endpoint_uri().expect_err("missing unix path"); - assert!(matches!(err, ConfigError::UnixSocketMissingPath)); - } - - #[test] - fn test_websocket_transport_endpoint_validation() { - let ws_config = ClientConfig::with_endpoint("ws://localhost:46357") - .with_transport(TransportProtocol::Websocket); - assert!(ws_config.validate().is_ok()); - - let wss_config = ClientConfig::with_endpoint("wss://localhost:46357") - .with_transport(TransportProtocol::Websocket); - assert!(wss_config.validate().is_ok()); - - let invalid = ClientConfig::with_endpoint("http://localhost:46357") - .with_transport(TransportProtocol::Websocket); - let err = invalid - .validate() - .expect_err("expected invalid websocket scheme"); - assert!(matches!(err, ConfigError::InvalidWebSocketEndpointScheme)); - } - - #[tokio::test] - async fn test_connect_tcp_channel_lazy_ok() { - let client = ClientConfig::with_endpoint("http://127.0.0.1:0"); - let uri = client.parse_endpoint_uri().expect("valid http uri"); - let channel = client.connect_tcp_channel(uri, true).await; - assert!(channel.is_ok()); - } - - #[tokio::test] - async fn test_connect_tcp_channel_non_lazy_error() { - let mut client = ClientConfig::with_endpoint("http://127.0.0.1:0") - .with_connect_timeout(Duration::from_millis(50)); - client.backoff = BackoffConfig::new_fixed_interval(Duration::from_millis(0), 1); - - let uri = client.parse_endpoint_uri().expect("valid http uri"); - let err = client - .connect_tcp_channel(uri, false) - .await - .expect_err("expected connect error"); - assert!(matches!(err, ConfigError::TransportError(_))); - } - - #[cfg(target_family = "unix")] - #[tokio::test] - async fn test_connect_unix_channel_lazy_ok() { - let mut client = ClientConfig::with_endpoint("unix:///tmp/slim-test.sock"); - client.tls_setting.insecure = true; - - let uri = client.parse_endpoint_uri().expect("valid unix uri"); - let channel = client.connect_unix_channel(uri, true).await; - assert!(channel.is_ok()); - } - - #[cfg(target_family = "unix")] - #[tokio::test] - async fn test_connect_unix_channel_non_lazy_error() { - let mut client = ClientConfig::with_endpoint("unix:///tmp/slim-missing.sock"); - client.tls_setting.insecure = true; - client.backoff = BackoffConfig::new_fixed_interval(Duration::from_millis(0), 1); - - let uri = client.parse_endpoint_uri().expect("valid unix uri"); - let err = client - .connect_unix_channel(uri, false) - .await - .expect_err("expected unix socket connect error"); - assert!(matches!(err, ConfigError::UnixSocketConnect(_))); - } - - #[cfg(not(target_family = "unix"))] - #[tokio::test] - async fn test_connect_unix_channel_unsupported() { - let client = ClientConfig::with_endpoint("unix:///tmp/slim.sock"); - let uri = client.parse_endpoint_uri().expect("valid unix uri"); - let err = client - .connect_unix_channel(uri, true) - .await - .expect_err("expected unix socket unsupported"); - assert!(matches!(err, ConfigError::UnixSocketUnsupported)); - } - - #[cfg(target_family = "unix")] - #[tokio::test] - async fn test_map_transport_error_maps_io() { - let endpoint = tonic::transport::Endpoint::from_static("http://localhost"); - let connector = service_fn(|_uri: Uri| async move { - Err::, std::io::Error>(std::io::Error::other("boom")) - }); - let err = endpoint - .connect_with_connector(connector) - .await - .expect_err("expected connect error"); - let mapped = ClientConfig::map_transport_error(err); - assert!(matches!(mapped, ConfigError::UnixSocketConnect(_))); - } - - #[cfg(not(target_family = "unix"))] - #[tokio::test] - async fn test_map_transport_error_transport() { - let endpoint = tonic::transport::Endpoint::from_static("http://localhost"); - let connector = service_fn(|_uri: Uri| async move { - Err::, std::io::Error>(std::io::Error::new( - std::io::ErrorKind::Other, - "boom", - )) - }); - let err = endpoint - .connect_with_connector(connector) - .await - .expect_err("expected connect error"); - let mapped = ClientConfig::map_transport_error(err); - assert!(matches!(mapped, ConfigError::TransportError(_))); - } - - #[tokio::test] - #[traced_test] - async fn test_to_channel() { - let test_path: &str = env!("CARGO_MANIFEST_DIR"); - - // create a new client config - let mut client = ClientConfig::default(); - - // as the endpoint is missing, this should fail - let mut channel = client.to_channel_lazy().await; - assert!(channel.is_err()); - - // Set the endpoint - client.endpoint = "http://localhost:8080".to_string(); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set the tls settings - client.tls_setting.insecure = true; - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set the tls settings - client.tls_setting = { - let mut tls = TLSSetting::default(); - // Updated for new Config fields: set CA via ca_source and leave source as default (None) - tls.config.ca_source = CaSource::File { - path: format!("{}/testdata/grpc/{}", test_path, "ca.crt"), - }; - tls.insecure = false; - tls - }; - - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set keepalive settings - client.keepalive = Some(KeepaliveConfig::default()); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set rate limit settings - client.rate_limit = Some("100/10".to_string()); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set rate limit settings wrong - client.rate_limit = Some("100".to_string()); - channel = client.to_channel_lazy().await; - assert!(channel.is_err()); - - // reset config - client.rate_limit = None; - - // Set timeout settings - client.request_timeout = Duration::from_secs(10).into(); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set buffer size settings - client.buffer_size = Some(1024); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set origin settings - client.origin = Some("http://example.com".to_string()); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // set additional header to add to the request - client - .headers - .insert("X-Test".to_string(), "test".to_string()); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set proxy settings - client.proxy = ProxyConfig::new("http://proxy.example.com:8080"); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set proxy with authentication - client.proxy = ProxyConfig::new("http://proxy.example.com:8080").with_auth("user", "pass"); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set proxy with headers - let mut proxy_headers = HashMap::new(); - proxy_headers.insert("X-Proxy-Header".to_string(), "value".to_string()); - client.proxy = - ProxyConfig::new("http://proxy.example.com:8080").with_headers(proxy_headers); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set HTTPS proxy settings - client.proxy = ProxyConfig::new("https://proxy.example.com:8080"); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set HTTPS proxy with authentication - client.proxy = ProxyConfig::new("https://proxy.example.com:8080").with_auth("user", "pass"); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - - // Set HTTPS proxy with headers - let mut https_proxy_headers = HashMap::new(); - https_proxy_headers.insert("X-Proxy-Header".to_string(), "value".to_string()); - client.proxy = - ProxyConfig::new("https://proxy.example.com:8080").with_headers(https_proxy_headers); - channel = client.to_channel_lazy().await; - assert!(channel.is_ok()); - } - - #[tokio::test] - async fn test_to_channel_rejects_websocket_transport() { - let client = ClientConfig::with_endpoint("ws://localhost:46357") - .with_transport(TransportProtocol::Websocket); - let channel = client.to_channel_lazy().await; - assert!(matches!( - channel, - Err(ConfigError::GrpcChannelUnsupportedTransport) - )); - } - - #[test] - fn test_client_config_with_proxy() { - let proxy = ProxyConfig::new("http://proxy.example.com:8080").with_auth("user", "pass"); - let client = ClientConfig::with_endpoint("http://localhost:8080").with_proxy(proxy.clone()); - assert_eq!(client.proxy, proxy); - } - - #[test] - fn test_connect_and_request_timeout_valid_durations_deserialize() { - let json = r#"{ - "endpoint": "http://localhost:1234", - "connect_timeout": "1m30s", - "request_timeout": "250ms" - }"#; - - let cfg: ClientConfig = serde_json::from_str(json).expect("deserialization should succeed"); - assert_eq!(cfg.connect_timeout, Duration::from_secs(90)); - assert_eq!(cfg.request_timeout, Duration::from_millis(250)); - - // More complex duration - let json = r#"{ - "endpoint": "http://localhost:1234", - "connect_timeout": "1h2m3s4ms", - "request_timeout": "1500ms" - }"#; - let cfg: ClientConfig = - serde_json::from_str(json).expect("complex duration should deserialize"); - assert_eq!( - cfg.connect_timeout, - Duration::from_secs(3723) + Duration::from_millis(4) - ); - assert_eq!(cfg.request_timeout, Duration::from_millis(1500)); - } - - #[test] - fn test_invalid_duration_strings_fail_deserialize() { - let invalids = [ - r#"{ "endpoint": "http://localhost:1234", "connect_timeout": "abc" }"#, - r#"{ "endpoint": "http://localhost:1234", "request_timeout": "10x" }"#, - r#"{ "endpoint": "http://localhost:1234", "request_timeout": "--5s" }"#, - ]; - for js in invalids { - let res: Result = serde_json::from_str(js); - assert!(res.is_err(), "expected error for json: {}", js); - } - } - - #[test] - fn test_keepalive_config_duration_parsing() { - let json = r#"{ - "endpoint": "http://localhost:1234", - "keepalive": { - "tcp_keepalive": "30s", - "http2_keepalive": "45s", - "timeout": "5s", - "keep_alive_while_idle": true - } - }"#; - let cfg: ClientConfig = serde_json::from_str(json).expect("keepalive should deserialize"); - let ka = cfg.keepalive.expect("keepalive should be present"); - assert_eq!(ka.tcp_keepalive, Duration::from_secs(30)); - assert_eq!(ka.http2_keepalive, Duration::from_secs(45)); - assert_eq!(ka.timeout, Duration::from_secs(5)); - assert!(ka.keep_alive_while_idle); - - // Invalid keepalive duration - let invalid_json = r#"{ - "endpoint": "http://localhost:1234", - "keepalive": { "tcp_keepalive": "zz", "http2_keepalive": "10s", "timeout": "5s", "keep_alive_while_idle": false } - }"#; - let res: Result = serde_json::from_str(invalid_json); - assert!(res.is_err(), "invalid tcp_keepalive should fail"); - } - - #[test] - fn test_client_config_roundtrip_duration_serialization() { - let mut cfg = ClientConfig::with_endpoint("http://localhost:9999") - .with_connect_timeout(Duration::from_secs(90)) - .with_request_timeout(Duration::from_millis(750)); - - cfg.keepalive = Some(KeepaliveConfig { - tcp_keepalive: Duration::from_secs(11).into(), - http2_keepalive: Duration::from_secs(22).into(), - timeout: Duration::from_secs(3).into(), - keep_alive_while_idle: true, - }); - - let serialized = serde_json::to_string(&cfg).expect("serialize"); - let deserialized: ClientConfig = serde_json::from_str(&serialized).expect("deserialize"); - - assert_eq!(deserialized.connect_timeout, Duration::from_secs(90)); - assert_eq!(deserialized.request_timeout, Duration::from_millis(750)); - let ka = deserialized.keepalive.expect("keepalive present"); - assert_eq!(ka.tcp_keepalive, Duration::from_secs(11)); - assert_eq!(ka.http2_keepalive, Duration::from_secs(22)); - assert_eq!(ka.timeout, Duration::from_secs(3)); - assert!(ka.keep_alive_while_idle); - } - - #[test] - fn test_validate_rejects_non_uuid_link_id() { - let mut config = ClientConfig::with_endpoint("http://localhost:1234"); - config.link_id = "not-a-uuid".to_string(); - assert!(matches!(config.validate(), Err(ConfigError::InvalidLinkId))); - } - - #[test] - fn test_validate_rejects_non_v4_uuid() { - let mut config = ClientConfig::with_endpoint("http://localhost:1234"); - // Version 1 UUID. - config.link_id = "00000000-0000-1000-8000-000000000000".to_string(); - assert!(matches!(config.validate(), Err(ConfigError::InvalidLinkId))); - } - - #[test] - fn test_validate_accepts_default_v4_link_id() { - // default_link_id() generates a v4 UUID; validation must pass. - let config = ClientConfig::with_endpoint("http://localhost:1234"); - assert!(config.validate().is_ok()); + self.to_channel_internal(true).await } } diff --git a/data-plane/core/config/src/grpc/errors.rs b/data-plane/core/config/src/grpc/errors.rs index 2b5425598..762e817a1 100644 --- a/data-plane/core/config/src/grpc/errors.rs +++ b/data-plane/core/config/src/grpc/errors.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use crate::auth::ConfigAuthError; +#[cfg(feature = "native")] use slim_auth::errors::AuthError; use thiserror::Error; @@ -25,16 +26,48 @@ pub enum ConfigError { InvalidEndpointScheme, #[error("websocket transport requires endpoint scheme ws:// or wss://")] InvalidWebSocketEndpointScheme, + #[error("websocket client builder requires websocket transport")] + WebSocketClientUnsupportedTransport, + #[error("websocket server builder requires websocket transport")] + WebSocketServerUnsupportedTransport, + #[error("websocket transport TLS configuration is invalid")] + WebSocketTlsConfiguration, + #[error("websocket support is disabled at compile time")] + WebSocketFeatureDisabled, + #[error( + "browser websocket cannot set Authorization header; configure websocket_auth_query_param" + )] + WebSocketWasmAuthorizationHeaderUnsupported, + #[error("websocket wasm connection error: {0}")] + WebSocketWasmConnection(String), + #[error("websocket wasm client is only supported on wasm32 target")] + WebSocketWasmUnsupportedTarget, // Network / transport + #[cfg(feature = "native")] #[error("transport error")] TransportError(#[from] tonic::transport::Error), + #[cfg(not(feature = "native"))] + #[error("transport error")] + TransportError, + #[error("gRPC support is disabled at compile time")] + GrpcFeatureDisabled, #[error("gRPC channel builder does not support websocket transport")] GrpcChannelUnsupportedTransport, #[error("gRPC server builder does not support websocket transport")] GrpcServerUnsupportedTransport, #[error("bind error")] Bind(#[from] std::io::Error), + #[error("websocket connection error")] + WebSocketConnection(#[source] std::io::Error), + #[error("websocket handshake error")] + #[cfg(feature = "native")] + WebSocketHandshake(#[source] fastwebsockets::WebSocketError), + #[error("websocket handshake error: {0}")] + #[cfg(not(feature = "native"))] + WebSocketHandshake(String), + #[error("websocket request error")] + WebSocketRequest(#[source] http::Error), // Unix domain sockets #[error("unix domain sockets are unsupported on this platform")] @@ -65,6 +98,7 @@ pub enum ConfigError { TlsConfig(#[from] crate::tls::errors::ConfigError), // Authentication + #[cfg(feature = "native")] #[error("auth error")] AuthError(#[from] AuthError), #[error("auth config error")] diff --git a/data-plane/core/config/src/grpc/proxy.rs b/data-plane/core/config/src/grpc/proxy.rs index 09c58a759..05b475bfb 100644 --- a/data-plane/core/config/src/grpc/proxy.rs +++ b/data-plane/core/config/src/grpc/proxy.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; +#[cfg(feature = "native")] use hyper_util::client::proxy::matcher::{Intercept, Matcher}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -57,6 +58,7 @@ impl ProxyConfig { } /// Checks if the given host should bypass the proxy + #[cfg(feature = "native")] pub fn should_use_proxy(&self, uri: impl Into) -> Option { let uri = uri.into(); @@ -77,7 +79,7 @@ impl ProxyConfig { } } -#[cfg(test)] +#[cfg(all(test, feature = "native"))] mod test { use super::*; diff --git a/data-plane/core/config/src/grpc/server.rs b/data-plane/core/config/src/grpc/server.rs index 1ccd0150d..0b7b61228 100644 --- a/data-plane/core/config/src/grpc/server.rs +++ b/data-plane/core/config/src/grpc/server.rs @@ -1,445 +1,20 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 -use display_error_chain::ErrorChainExt; -use duration_string::DurationString; -use futures::FutureExt; -use futures::Stream; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +pub use crate::server::*; + use std::convert::Infallible; use std::future::Future; -#[cfg(target_family = "unix")] -use std::path::PathBuf; use std::pin::Pin; -use std::sync::Arc; -use std::{net::SocketAddr, str::FromStr, time::Duration}; -use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(target_family = "unix")] -use tokio::net::UnixListener; -#[cfg(target_family = "unix")] -use tokio_stream::wrappers::UnixListenerStream; + use tokio_util::sync::CancellationToken; -use tonic::transport::server::TcpIncoming; -use tower_http::BoxError; -use tracing::debug; use super::errors::ConfigError; -use crate::auth::ServerAuthenticator; -use crate::auth::basic::Config as BasicAuthenticationConfig; -use crate::auth::jwt::Config as JwtAuthenticationConfig; -#[cfg(not(target_family = "windows"))] -use crate::auth::spire::SpireConfig as SpireAuthConfig; -use crate::component::configuration::Configuration; -use crate::transport::TransportProtocol; -use slim_auth::metadata::MetadataMap; - -use crate::tls::{common::RustlsConfigLoader, server::TlsServerConfig as TLSSetting}; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] -pub struct KeepaliveServerParameters { - /// max_connection_idle sets the time after which an idle connection is closed. - #[serde(default = "default_max_connection_idle")] - #[schemars(with = "String")] - pub max_connection_idle: DurationString, - - /// max_connection_age sets the maximum amount of time a connection may exist before it will be closed. - #[serde(default = "default_max_connection_age")] - #[schemars(with = "String")] - pub max_connection_age: DurationString, - - /// max_connection_age_grace is an additional time given after MaxConnectionAge before closing the connection. - #[serde(default = "default_max_connection_age_grace")] - #[schemars(with = "String")] - pub max_connection_age_grace: DurationString, - - /// Time sets the frequency of the keepalive ping. - #[serde(default = "default_time")] - #[schemars(with = "String")] - pub time: DurationString, - - /// Timeout sets the amount of time the server waits for a keepalive ping ack. - #[serde(default = "default_timeout")] - #[schemars(with = "String")] - pub timeout: DurationString, -} - -/// Enum holding one configuration for the client. -#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] -#[serde(rename_all = "snake_case", tag = "type")] -pub enum AuthenticationConfig { - /// Basic authentication configuration. - Basic(BasicAuthenticationConfig), - /// JWT authentication configuration. - Jwt(JwtAuthenticationConfig), - /// SPIRE/SPIFFE authentication configuration. - #[cfg(not(target_family = "windows"))] - Spire(SpireAuthConfig), - /// None - #[default] - None, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] -pub struct ServerConfig { - /// Endpoint is the address to listen on. - pub endpoint: String, - - /// Transport protocol to use for dataplane communication. - #[serde(default)] - pub transport: TransportProtocol, - - /// Configures the protocol to use TLS. - #[serde(default, rename = "tls")] - pub tls_setting: TLSSetting, - - /// Use HTTP 2 only. - #[serde(default = "default_http2_only")] - pub http2_only: bool, - - /// Maximum size (in MiB) of messages accepted by the server. - pub max_frame_size: Option, - - /// MaxConcurrentStreams sets the limit on the number of concurrent streams to each ServerTransport. - pub max_concurrent_streams: Option, - - /// Max header list size - pub max_header_list_size: Option, - - /// ReadBufferSize for gRPC server. - // TODO(msardara): not implemented yet - pub read_buffer_size: Option, - - /// WriteBufferSize for gRPC server. - // TODO(msardara): not implemented yet - pub write_buffer_size: Option, - - /// Keepalive anchor for all the settings related to keepalive. - #[serde(default)] - pub keepalive: KeepaliveServerParameters, - /// Auth for this receiver. - #[serde(default)] - pub auth: AuthenticationConfig, - - /// Arbitrary user-provided metadata. - pub metadata: Option, -} - -/// Default values for KeepaliveServerParameters -impl Default for KeepaliveServerParameters { - fn default() -> Self { - Self { - max_connection_idle: default_max_connection_idle(), - max_connection_age: default_max_connection_age(), - max_connection_age_grace: default_max_connection_age_grace(), - time: default_time(), - timeout: default_timeout(), - } - } -} - -fn default_max_connection_idle() -> DurationString { - Duration::from_secs(3600).into() -} - -fn default_max_connection_age() -> DurationString { - Duration::from_secs(2 * 3600).into() -} - -fn default_max_connection_age_grace() -> DurationString { - Duration::from_secs(5 * 60).into() -} - -fn default_time() -> DurationString { - Duration::from_secs(2 * 60).into() -} - -fn default_timeout() -> DurationString { - Duration::from_secs(20).into() -} - -/// Default values for ServerConfig -impl Default for ServerConfig { - fn default() -> Self { - Self { - endpoint: String::new(), - transport: TransportProtocol::default(), - tls_setting: TLSSetting::default(), - http2_only: default_http2_only(), - max_frame_size: Some(4), - max_concurrent_streams: Some(100), - max_header_list_size: None, - read_buffer_size: Some(1024 * 1024), - write_buffer_size: Some(1024 * 1024), - keepalive: KeepaliveServerParameters::default(), - auth: AuthenticationConfig::default(), - metadata: None, - } - } -} - -fn default_http2_only() -> bool { - true -} - -/// Display implementation for ServerConfig -/// This is used to print the ServerConfig in a human-readable format. -impl std::fmt::Display for ServerConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "ServerConfig {{ endpoint: {}, transport: {:?}, tls_setting: {}, http2_only: {}, max_frame_size: {:?}, max_concurrent_streams: {:?}, max_header_list_size: {:?}, read_buffer_size: {:?}, write_buffer_size: {:?}, keepalive: {:?}, auth: {:?}, metadata: {:?} }}", - self.endpoint, - self.transport, - self.tls_setting, - self.http2_only, - self.max_frame_size, - self.max_concurrent_streams, - self.max_header_list_size, - self.read_buffer_size, - self.write_buffer_size, - self.keepalive, - self.auth, - self.metadata - ) - } -} - -#[cfg(test)] -mod metadata_tests { - use super::*; - - #[test] - fn server_config_with_metadata_roundtrip_yaml() { - let mut md = MetadataMap::default(); - md.insert("role", "ingress"); - md.insert("replicas", 3u64); - let mut nested = MetadataMap::default(); - nested.insert("inner", "v"); - md.insert("nested", nested); - - let cfg = ServerConfig { - endpoint: "127.0.0.1:50051".to_string(), - metadata: Some(md.clone()), - ..Default::default() - }; - - let s = serde_yaml::to_string(&cfg).expect("serialize"); - let deser: ServerConfig = serde_yaml::from_str(&s).expect("deserialize"); - assert_eq!(deser.metadata, Some(md)); - } -} - -impl Configuration for ServerConfig { - type Error = ConfigError; - - fn validate(&self) -> Result<(), Self::Error> { - // Validate the client configuration - self.tls_setting.validate()?; - - Ok(()) - } -} - -/// ServerFuture is a type alias for a boxed future that returns a Result<(), tonic::transport::Error>. +/// ServerFuture is a type alias for a boxed future that returns a tonic server result. type ServerFuture = Pin> + Send>>; -/// Convert ServerConfig to IncomingServerConfig -/// This function takes a ServerConfig and a service and returns a ServerFuture. -/// The ServerFuture is a boxed future that returns a Result<(), tonic::transport::Error>. -/// The ServerFuture is created by creating a new TcpIncoming and then creating a new Server. impl ServerConfig { - pub fn with_endpoint(endpoint: &str) -> Self { - Self { - endpoint: endpoint.to_string(), - ..Default::default() - } - } - - pub fn with_tls_settings(self, tls_setting: TLSSetting) -> Self { - Self { - tls_setting, - ..self - } - } - - pub fn with_transport(self, transport: TransportProtocol) -> Self { - Self { transport, ..self } - } - - pub fn with_http2_only(self, http2_only: bool) -> Self { - Self { http2_only, ..self } - } - - pub fn with_max_frame_size(self, max_frame_size: Option) -> Self { - Self { - max_frame_size, - ..self - } - } - - pub fn with_max_concurrent_streams(self, max_concurrent_streams: Option) -> Self { - Self { - max_concurrent_streams, - ..self - } - } - - pub fn with_max_header_list_size(self, max_header_list_size: Option) -> Self { - Self { - max_header_list_size, - ..self - } - } - - pub fn with_read_buffer_size(self, read_buffer_size: Option) -> Self { - Self { - read_buffer_size, - ..self - } - } - - pub fn with_write_buffer_size(self, write_buffer_size: Option) -> Self { - Self { - write_buffer_size, - ..self - } - } - - pub fn with_keepalive(self, keepalive: KeepaliveServerParameters) -> Self { - Self { keepalive, ..self } - } - - pub fn with_auth(self, auth: AuthenticationConfig) -> Self { - Self { auth, ..self } - } - - #[cfg(target_family = "unix")] - fn parse_unix_socket_path(endpoint: &str) -> Result { - let path = endpoint.strip_prefix("unix://").unwrap_or(endpoint); - - let without_query = path.split_once('?').map(|(p, _)| p).unwrap_or(path); - let path_part = without_query - .split_once('#') - .map(|(p, _)| p) - .unwrap_or(without_query); - - if path_part.is_empty() { - return Err(ConfigError::UnixSocketMissingPath); - } - - Ok(PathBuf::from(path_part)) - } - - fn create_server_builder(&self) -> tonic::transport::Server { - let builder: tonic::transport::Server = - tonic::transport::Server::builder().accept_http1(false); - - let builder = match self.max_concurrent_streams { - Some(max_concurrent_streams) => { - builder.concurrency_limit_per_connection(max_concurrent_streams as usize) - } - None => builder, - }; - - let builder = match self.max_frame_size { - Some(max_frame_size) => builder.max_frame_size(max_frame_size * 1024 * 1024), - None => builder, - }; - - let builder = match self.max_header_list_size { - Some(max_header_list_size) => builder.http2_max_header_list_size(max_header_list_size), - None => builder, - }; - - let builder = builder.http2_keepalive_interval(Some(self.keepalive.time.into())); - let builder = builder.http2_keepalive_timeout(Some(self.keepalive.timeout.into())); - - builder.max_connection_age(self.keepalive.max_connection_age.into()) - } - - async fn serve_with_incoming( - &self, - svc: &[S], - incoming: I, - ) -> Result - where - S: tower_service::Service< - http::Request, - Response = http::Response, - Error = Infallible, - > - + tonic::server::NamedService - + Clone - + Send - + 'static - + Sync, - S::Future: Send + 'static, - I: Stream> + Send + 'static, - IO: AsyncRead + AsyncWrite + tonic::transport::server::Connected + Unpin + Send + 'static, - IE: Into + Send + 'static, - { - let mut builder = self.create_server_builder(); - - match &self.auth { - AuthenticationConfig::Basic(basic) => { - let auth_layer = basic.get_server_layer()?; - - let mut builder = builder.layer(auth_layer); - - let mut router = builder.add_service(svc[0].clone()); - for s in svc.iter().skip(1) { - router = builder.add_service(s.clone()); - } - - Ok(router.serve_with_incoming(incoming).boxed()) - } - AuthenticationConfig::Jwt(jwt) => { - // Build the authentication layer and perform its async initialization - let mut auth_layer = , - >>::get_server_layer(jwt)?; - - auth_layer.initialize().await?; - - let mut builder = builder.layer(auth_layer); - - let mut router = builder.add_service(svc[0].clone()); - for s in svc.iter().skip(1) { - router = builder.add_service(s.clone()); - } - - Ok(router.serve_with_incoming(incoming).boxed()) - } - #[cfg(not(target_family = "windows"))] - AuthenticationConfig::Spire(spire) => { - let mut auth_layer = , - >>::get_server_layer(spire)?; - - auth_layer.initialize().await?; - - let mut builder = builder.layer(auth_layer); - - let mut router = builder.add_service(svc[0].clone()); - for s in svc.iter().skip(1) { - router = builder.add_service(s.clone()); - } - - Ok(router.serve_with_incoming(incoming).boxed()) - } - AuthenticationConfig::None => { - let mut router = builder.add_service(svc[0].clone()); - for s in svc.iter().skip(1) { - router = builder.add_service(s.clone()); - } - - Ok(router.serve_with_incoming(incoming).boxed()) - } - } - } - pub async fn to_server_future(&self, svc: &[S]) -> Result where S: tower_service::Service< @@ -454,61 +29,14 @@ impl ServerConfig { + Sync, S::Future: Send + 'static, { - if svc.is_empty() { - return Err(ConfigError::MissingServices); - } - - if self.transport == TransportProtocol::Websocket { - return Err(ConfigError::GrpcServerUnsupportedTransport); - } - - if self.endpoint.is_empty() { - return Err(ConfigError::MissingEndpoint); - } - - #[cfg(target_family = "unix")] - if self.endpoint.starts_with("unix://") { - if !self.tls_setting.insecure { - // For local Unix domain sockets we currently require insecure=true - return Err(ConfigError::UnixSocketTlsUnsupported); - } - - let socket_path = Self::parse_unix_socket_path(self.endpoint.as_str())?; - - // Best-effort cleanup of any stale socket file - let _ = std::fs::remove_file(&socket_path); - - let listener = UnixListener::bind(&socket_path)?; - let incoming = UnixListenerStream::new(listener); - - return self.serve_with_incoming(svc, incoming).await; - } - - #[cfg(not(target_family = "unix"))] - if self.endpoint.starts_with("unix://") { - return Err(ConfigError::UnixSocketUnsupported); - } - - let addr = SocketAddr::from_str(self.endpoint.as_str())?; - - // Async TLS configuration load (may involve SPIFFE operations) - let tls_config = self.tls_setting.load_rustls_config().await?; - let incoming = TcpIncoming::bind(addr)?; - - match tls_config { - Some(tls_config) => { - let incoming = tonic_tls::rustls::TlsIncoming::new(incoming, Arc::new(tls_config)); - self.serve_with_incoming(svc, incoming).await - } - None => self.serve_with_incoming(svc, incoming).await, - } + self.to_grpc_server_future(svc).await } pub async fn run_server( &self, svc: &[S], drain_rx: drain::Watch, - ) -> Result + ) -> Result where S: tower_service::Service< http::Request, @@ -522,216 +50,6 @@ impl ServerConfig { + Sync, S::Future: Send + 'static, { - debug!(%self, "server configured: setting it up"); - let server_future = self.to_server_future(svc).await?; - - // create a new cancellation token - let token = CancellationToken::new(); - let token_clone = token.clone(); - - // spawn server acceptor in a new task - tokio::spawn(async move { - debug!("starting server main loop"); - let shutdown = drain_rx.signaled(); - - tokio::select! { - res = server_future => { - match res { - Ok(_) => { - debug!("server shutdown"); - } - Err(e) => { - tracing::error!(error = %e.chain(), "server error"); - } - } - } - _ = shutdown => { - debug!("shutting down server"); - } - _ = token.cancelled() => { - debug!("cancellation token triggered: shutting down server"); - } - } - }); - - Ok(token_clone) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::testutils::{Empty, helloworld::greeter_server::GreeterServer}; - use crate::tls::common::TlsSource; - use serde_json; - - static TEST_DATA_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata/grpc"); - - #[test] - fn test_default_keepalive_server_parameters() { - let keepalive = KeepaliveServerParameters::default(); - assert_eq!(keepalive.max_connection_idle, default_max_connection_idle()); - assert_eq!(keepalive.max_connection_age, default_max_connection_age()); - assert_eq!( - keepalive.max_connection_age_grace, - default_max_connection_age_grace() - ); - assert_eq!(keepalive.time, default_time()); - assert_eq!(keepalive.timeout, default_timeout()); - } - - #[test] - fn test_default_server_config() { - let server_config = ServerConfig::default(); - assert_eq!(server_config.endpoint, String::new()); - assert_eq!(server_config.transport, TransportProtocol::Grpc); - assert_eq!(server_config.tls_setting, TLSSetting::default()); - assert_eq!(server_config.http2_only, default_http2_only()); - assert_eq!(server_config.max_frame_size, Some(4)); - assert_eq!(server_config.max_concurrent_streams, Some(100)); - assert_eq!(server_config.max_header_list_size, None); - assert_eq!(server_config.read_buffer_size, Some(1024 * 1024)); - assert_eq!(server_config.write_buffer_size, Some(1024 * 1024)); - assert_eq!( - server_config.keepalive, - KeepaliveServerParameters::default() - ); - assert_eq!(server_config.auth, AuthenticationConfig::None); - } - - #[tokio::test] - async fn test_to_incoming_server_config() { - let mut server_config = ServerConfig::default(); - let empty_service = Arc::new(Empty::new()); - - // no endpoint - should return an error - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) - .await; - // Make sure the error is a ConfigError::MissingEndpoint - assert!(ret.is_err_and(|e| { e.to_string().contains("missing grpc endpoint") })); - - // set the endpoint in the config. Now it shouhld fail because of the invalid endpoint - server_config.endpoint = "0.0.0.0:123456".to_string(); - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) - .await; - // Make sure we got an EndpointParse error - assert!(ret.is_err_and(|e| { matches!(e, ConfigError::EndpointParse(_)) })); - - // set a valid endpoint in the config. Now it should fail because of the missing cert/key files for tls - server_config.endpoint = "0.0.0.0:12345".to_string(); - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) - .await; - assert!(ret.is_err_and(|e| { matches!(e, ConfigError::TlsConfig(_)) })); - - // set the tls setting to insecure. Now it should return a server future - server_config.tls_setting.insecure = true; - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) - .await; - assert!(ret.is_ok()); - - // drop it, as we have a server listening on the port now - drop(ret.unwrap()); - - // Set insecure to false and configure certificate/key via TlsSource::File (updated API) - server_config.tls_setting.insecure = false; - server_config.tls_setting.config.source = TlsSource::File { - cert: format!("{}/server.crt", TEST_DATA_PATH), - key: format!("{}/server.key", TEST_DATA_PATH), - }; - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) - .await; - assert!(ret.is_ok()); - } - - #[tokio::test] - async fn test_to_server_future_rejects_websocket_transport() { - let empty_service = Arc::new(Empty::new()); - let server_config = ServerConfig::with_endpoint("0.0.0.0:12345") - .with_transport(TransportProtocol::Websocket); - let ret = server_config - .to_server_future(&[GreeterServer::from_arc(empty_service)]) - .await; - assert!(matches!( - ret, - Err(ConfigError::GrpcServerUnsupportedTransport) - )); - } - - #[test] - fn test_keepalive_server_parameters_valid_durations_deserialize() { - let json = r#"{ - "endpoint": "0.0.0.0:12345", - "keepalive": { - "max_connection_idle": "30m", - "max_connection_age": "1h30m", - "max_connection_age_grace": "15s", - "time": "5s", - "timeout": "2s" - } - }"#; - - let cfg: ServerConfig = serde_json::from_str(json).expect("deserialization should succeed"); - assert_eq!( - cfg.keepalive.max_connection_idle, - Duration::from_secs(30 * 60) - ); - assert_eq!( - cfg.keepalive.max_connection_age, - Duration::from_secs(90 * 60) - ); - assert_eq!( - cfg.keepalive.max_connection_age_grace, - Duration::from_secs(15) - ); - assert_eq!(cfg.keepalive.time, Duration::from_secs(5)); - assert_eq!(cfg.keepalive.timeout, Duration::from_secs(2)); - } - - #[test] - fn test_invalid_keepalive_duration_strings_fail_deserialize() { - let invalid_json_cases = [ - r#"{ "keepalive": { "time": "zz" } }"#, - r#"{ "keepalive": { "timeout": "-5s" } }"#, - r#"{ "keepalive": { "max_connection_age": "10x" } }"#, - ]; - for js in invalid_json_cases { - let res: Result = serde_json::from_str(js); - assert!(res.is_err(), "expected error for json: {}", js); - } - } - - #[test] - fn test_server_config_keepalive_roundtrip_duration_serialization() { - let keepalive = KeepaliveServerParameters { - max_connection_idle: Duration::from_secs(10).into(), - max_connection_age: Duration::from_secs(20).into(), - max_connection_age_grace: Duration::from_secs(30).into(), - time: Duration::from_secs(3).into(), - timeout: Duration::from_secs(1).into(), - }; - - let cfg = ServerConfig::with_endpoint("127.0.0.1:50000").with_keepalive(keepalive.clone()); - let serialized = serde_json::to_string(&cfg).expect("serialize"); - let deserialized: ServerConfig = serde_json::from_str(&serialized).expect("deserialize"); - - assert_eq!( - deserialized.keepalive.max_connection_idle, - Duration::from_secs(10) - ); - assert_eq!( - deserialized.keepalive.max_connection_age, - Duration::from_secs(20) - ); - assert_eq!( - deserialized.keepalive.max_connection_age_grace, - Duration::from_secs(30) - ); - assert_eq!(deserialized.keepalive.time, Duration::from_secs(3)); - assert_eq!(deserialized.keepalive.timeout, Duration::from_secs(1)); + self.run_grpc_server(svc, drain_rx).await } } diff --git a/data-plane/core/config/src/lib.rs b/data-plane/core/config/src/lib.rs index 71a7cc2d9..9e096b735 100644 --- a/data-plane/core/config/src/lib.rs +++ b/data-plane/core/config/src/lib.rs @@ -3,14 +3,18 @@ pub mod auth; pub mod backoff; +pub mod client; pub mod component; pub mod grpc; pub mod provider; +pub mod server; +#[cfg(feature = "native")] pub mod testutils; pub mod tls; pub mod transport; +pub mod websocket; mod opaque; -pub const CLIENT_CONFIG_SCHEMA_JSON: &str = include_str!("./grpc/schema/client-config.schema.json"); -pub const SERVER_CONFIG_SCHEMA_JSON: &str = include_str!("./grpc/schema/server-config.schema.json"); +pub const CLIENT_CONFIG_SCHEMA_JSON: &str = include_str!("./schema/client-config.schema.json"); +pub const SERVER_CONFIG_SCHEMA_JSON: &str = include_str!("./schema/server-config.schema.json"); diff --git a/data-plane/core/config/src/grpc/schema/client-config.schema.json b/data-plane/core/config/src/schema/client-config.schema.json similarity index 100% rename from data-plane/core/config/src/grpc/schema/client-config.schema.json rename to data-plane/core/config/src/schema/client-config.schema.json diff --git a/data-plane/core/config/src/grpc/schema/generate_schema.rs b/data-plane/core/config/src/schema/generate_schema.rs similarity index 80% rename from data-plane/core/config/src/grpc/schema/generate_schema.rs rename to data-plane/core/config/src/schema/generate_schema.rs index 2b3e654bd..6fdb2f33e 100644 --- a/data-plane/core/config/src/grpc/schema/generate_schema.rs +++ b/data-plane/core/config/src/schema/generate_schema.rs @@ -1,6 +1,6 @@ use schemars::{JsonSchema, schema_for}; -use slim_config::grpc::client::ClientConfig; -use slim_config::grpc::server::ServerConfig; +use slim_config::client::ClientConfig; +use slim_config::server::ServerConfig; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -10,7 +10,7 @@ fn write_schema(file_name: &str) { let schema_json = serde_json::to_string_pretty(&schema).unwrap(); let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.push(format!("src/grpc/schema/{}", file_name)); + path.push(format!("src/schema/{}", file_name)); let mut file = File::create(&path).unwrap(); file.write_all(schema_json.as_bytes()).unwrap(); diff --git a/data-plane/core/config/src/grpc/schema/server-config.schema.json b/data-plane/core/config/src/schema/server-config.schema.json similarity index 100% rename from data-plane/core/config/src/grpc/schema/server-config.schema.json rename to data-plane/core/config/src/schema/server-config.schema.json diff --git a/data-plane/core/config/src/server.rs b/data-plane/core/config/src/server.rs new file mode 100644 index 000000000..f4c72ba05 --- /dev/null +++ b/data-plane/core/config/src/server.rs @@ -0,0 +1,921 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(feature = "native")] +use display_error_chain::ErrorChainExt; +use duration_string::DurationString; +#[cfg(feature = "native")] +use futures::FutureExt; +#[cfg(feature = "native")] +use futures::Stream; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "native")] +use std::convert::Infallible; +#[cfg(feature = "native")] +use std::future::Future; +#[cfg(all(feature = "native", target_family = "unix"))] +use std::path::PathBuf; +#[cfg(feature = "native")] +use std::pin::Pin; +#[cfg(feature = "native")] +use std::sync::Arc; +use std::time::Duration; +#[cfg(feature = "native")] +use std::{net::SocketAddr, str::FromStr}; +#[cfg(feature = "native")] +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(all(feature = "native", target_family = "unix"))] +use tokio::net::UnixListener; +#[cfg(all(feature = "native", target_family = "unix"))] +use tokio_stream::wrappers::UnixListenerStream; +#[cfg(feature = "native")] +use tokio_util::sync::CancellationToken; +#[cfg(feature = "native")] +use tonic::transport::server::TcpIncoming; +#[cfg(feature = "native")] +use tower_http::BoxError; +#[cfg(feature = "native")] +use tracing::debug; + +#[cfg(feature = "native")] +use crate::auth::ServerAuthenticator; +use crate::auth::basic::Config as BasicAuthenticationConfig; +#[cfg(feature = "native")] +use crate::auth::jwt::Config as JwtAuthenticationConfig; +#[cfg(all(feature = "native", not(target_family = "windows")))] +use crate::auth::spire::SpireConfig as SpireAuthConfig; +use crate::component::configuration::Configuration; +use crate::grpc::errors::ConfigError; +use crate::transport::TransportProtocol; +#[cfg(feature = "native")] +use slim_auth::metadata::MetadataMap; +#[cfg(not(feature = "native"))] +type MetadataMap = std::collections::HashMap; + +#[cfg(feature = "native")] +use crate::tls::common::RustlsConfigLoader; +use crate::tls::server::TlsServerConfig as TLSSetting; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] +pub struct KeepaliveServerParameters { + /// max_connection_idle sets the time after which an idle connection is closed. + #[serde(default = "default_max_connection_idle")] + #[schemars(with = "String")] + pub max_connection_idle: DurationString, + + /// max_connection_age sets the maximum amount of time a connection may exist before it will be closed. + #[serde(default = "default_max_connection_age")] + #[schemars(with = "String")] + pub max_connection_age: DurationString, + + /// max_connection_age_grace is an additional time given after MaxConnectionAge before closing the connection. + #[serde(default = "default_max_connection_age_grace")] + #[schemars(with = "String")] + pub max_connection_age_grace: DurationString, + + /// Time sets the frequency of the keepalive ping. + #[serde(default = "default_time")] + #[schemars(with = "String")] + pub time: DurationString, + + /// Timeout sets the amount of time the server waits for a keepalive ping ack. + #[serde(default = "default_timeout")] + #[schemars(with = "String")] + pub timeout: DurationString, +} + +/// Enum holding one configuration for the client. +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, JsonSchema)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum AuthenticationConfig { + /// Basic authentication configuration. + Basic(BasicAuthenticationConfig), + /// JWT authentication configuration. + #[cfg(feature = "native")] + Jwt(JwtAuthenticationConfig), + /// SPIRE/SPIFFE authentication configuration. + #[cfg(all(feature = "native", not(target_family = "windows")))] + Spire(SpireAuthConfig), + /// None + #[default] + None, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] +pub struct ServerConfig { + /// Endpoint is the address to listen on. + pub endpoint: String, + + /// Transport protocol to use for dataplane communication. + #[serde(default)] + pub transport: TransportProtocol, + + /// Configures the protocol to use TLS. + #[serde(default, rename = "tls")] + pub tls_setting: TLSSetting, + + /// Use HTTP 2 only. + #[serde(default = "default_http2_only")] + pub http2_only: bool, + + /// Maximum size (in MiB) of messages accepted by the server. + pub max_frame_size: Option, + + /// MaxConcurrentStreams sets the limit on the number of concurrent streams to each ServerTransport. + pub max_concurrent_streams: Option, + + /// Max header list size + pub max_header_list_size: Option, + + /// ReadBufferSize for gRPC server. + // TODO(msardara): not implemented yet + pub read_buffer_size: Option, + + /// WriteBufferSize for gRPC server. + // TODO(msardara): not implemented yet + pub write_buffer_size: Option, + + /// Keepalive anchor for all the settings related to keepalive. + #[serde(default)] + pub keepalive: KeepaliveServerParameters, + + /// Auth for this receiver. + #[serde(default)] + pub auth: AuthenticationConfig, + + /// Arbitrary user-provided metadata. + pub metadata: Option, +} + +/// Default values for KeepaliveServerParameters +impl Default for KeepaliveServerParameters { + fn default() -> Self { + Self { + max_connection_idle: default_max_connection_idle(), + max_connection_age: default_max_connection_age(), + max_connection_age_grace: default_max_connection_age_grace(), + time: default_time(), + timeout: default_timeout(), + } + } +} + +fn default_max_connection_idle() -> DurationString { + Duration::from_secs(3600).into() +} + +fn default_max_connection_age() -> DurationString { + Duration::from_secs(2 * 3600).into() +} + +fn default_max_connection_age_grace() -> DurationString { + Duration::from_secs(5 * 60).into() +} + +fn default_time() -> DurationString { + Duration::from_secs(2 * 60).into() +} + +fn default_timeout() -> DurationString { + Duration::from_secs(20).into() +} + +/// Default values for ServerConfig +impl Default for ServerConfig { + fn default() -> Self { + Self { + endpoint: String::new(), + transport: TransportProtocol::default(), + tls_setting: TLSSetting::default(), + http2_only: default_http2_only(), + max_frame_size: Some(4), + max_concurrent_streams: Some(100), + max_header_list_size: None, + read_buffer_size: Some(1024 * 1024), + write_buffer_size: Some(1024 * 1024), + keepalive: KeepaliveServerParameters::default(), + auth: AuthenticationConfig::default(), + metadata: None, + } + } +} + +fn default_http2_only() -> bool { + true +} + +/// Display implementation for ServerConfig +/// This is used to print the ServerConfig in a human-readable format. +impl std::fmt::Display for ServerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ServerConfig {{ endpoint: {}, transport: {:?}, tls_setting: {}, http2_only: {}, max_frame_size: {:?}, max_concurrent_streams: {:?}, max_header_list_size: {:?}, read_buffer_size: {:?}, write_buffer_size: {:?}, keepalive: {:?}, auth: {:?}, metadata: {:?} }}", + self.endpoint, + self.transport, + self.tls_setting, + self.http2_only, + self.max_frame_size, + self.max_concurrent_streams, + self.max_header_list_size, + self.read_buffer_size, + self.write_buffer_size, + self.keepalive, + self.auth, + self.metadata + ) + } +} + +#[cfg(test)] +mod metadata_tests { + use super::*; + + #[test] + fn server_config_with_metadata_roundtrip_yaml() { + let mut md = MetadataMap::default(); + md.insert("role", "ingress"); + md.insert("replicas", 3u64); + let mut nested = MetadataMap::default(); + nested.insert("inner", "v"); + md.insert("nested", nested); + + let cfg = ServerConfig { + endpoint: "127.0.0.1:50051".to_string(), + metadata: Some(md.clone()), + ..Default::default() + }; + + let s = serde_yaml::to_string(&cfg).expect("serialize"); + let deser: ServerConfig = serde_yaml::from_str(&s).expect("deserialize"); + assert_eq!(deser.metadata, Some(md)); + } +} + +impl Configuration for ServerConfig { + type Error = ConfigError; + + fn validate(&self) -> Result<(), Self::Error> { + // Validate the client configuration + self.tls_setting.validate()?; + + Ok(()) + } +} + +/// ServerFuture is a type alias for a boxed future that returns a Result<(), tonic::transport::Error>. +#[cfg(feature = "native")] +type ServerFuture = Pin> + Send>>; + +/// Convert ServerConfig to IncomingServerConfig +/// This function takes a ServerConfig and a service and returns a ServerFuture. +/// The ServerFuture is a boxed future that returns a Result<(), tonic::transport::Error>. +/// The ServerFuture is created by creating a new TcpIncoming and then creating a new Server. +impl ServerConfig { + pub fn with_endpoint(endpoint: &str) -> Self { + Self { + endpoint: endpoint.to_string(), + ..Default::default() + } + } + + pub fn with_tls_settings(self, tls_setting: TLSSetting) -> Self { + Self { + tls_setting, + ..self + } + } + + pub fn with_transport(self, transport: TransportProtocol) -> Self { + Self { transport, ..self } + } + + pub fn with_http2_only(self, http2_only: bool) -> Self { + Self { http2_only, ..self } + } + + pub fn with_max_frame_size(self, max_frame_size: Option) -> Self { + Self { + max_frame_size, + ..self + } + } + + pub fn with_max_concurrent_streams(self, max_concurrent_streams: Option) -> Self { + Self { + max_concurrent_streams, + ..self + } + } + + pub fn with_max_header_list_size(self, max_header_list_size: Option) -> Self { + Self { + max_header_list_size, + ..self + } + } + + pub fn with_read_buffer_size(self, read_buffer_size: Option) -> Self { + Self { + read_buffer_size, + ..self + } + } + + pub fn with_write_buffer_size(self, write_buffer_size: Option) -> Self { + Self { + write_buffer_size, + ..self + } + } + + pub fn with_keepalive(self, keepalive: KeepaliveServerParameters) -> Self { + Self { keepalive, ..self } + } + + pub fn with_auth(self, auth: AuthenticationConfig) -> Self { + Self { auth, ..self } + } + + #[cfg(all(feature = "native", target_family = "unix"))] + fn parse_unix_socket_path(endpoint: &str) -> Result { + let path = endpoint.strip_prefix("unix://").unwrap_or(endpoint); + + let without_query = path.split_once('?').map(|(p, _)| p).unwrap_or(path); + let path_part = without_query + .split_once('#') + .map(|(p, _)| p) + .unwrap_or(without_query); + + if path_part.is_empty() { + return Err(ConfigError::UnixSocketMissingPath); + } + + Ok(PathBuf::from(path_part)) + } + + #[cfg(feature = "native")] + fn create_server_builder(&self) -> tonic::transport::Server { + let builder: tonic::transport::Server = + tonic::transport::Server::builder().accept_http1(false); + + let builder = match self.max_concurrent_streams { + Some(max_concurrent_streams) => { + builder.concurrency_limit_per_connection(max_concurrent_streams as usize) + } + None => builder, + }; + + let builder = match self.max_frame_size { + Some(max_frame_size) => builder.max_frame_size(max_frame_size * 1024 * 1024), + None => builder, + }; + + let builder = match self.max_header_list_size { + Some(max_header_list_size) => builder.http2_max_header_list_size(max_header_list_size), + None => builder, + }; + + let builder = builder.http2_keepalive_interval(Some(self.keepalive.time.into())); + let builder = builder.http2_keepalive_timeout(Some(self.keepalive.timeout.into())); + + builder.max_connection_age(self.keepalive.max_connection_age.into()) + } + + #[cfg(feature = "native")] + async fn serve_with_incoming( + &self, + svc: &[S], + incoming: I, + ) -> Result + where + S: tower_service::Service< + http::Request, + Response = http::Response, + Error = Infallible, + > + + tonic::server::NamedService + + Clone + + Send + + 'static + + Sync, + S::Future: Send + 'static, + I: Stream> + Send + 'static, + IO: AsyncRead + AsyncWrite + tonic::transport::server::Connected + Unpin + Send + 'static, + IE: Into + Send + 'static, + { + let mut builder = self.create_server_builder(); + + match &self.auth { + AuthenticationConfig::Basic(basic) => { + let auth_layer = basic.get_server_layer()?; + + let mut builder = builder.layer(auth_layer); + + let mut router = builder.add_service(svc[0].clone()); + for s in svc.iter().skip(1) { + router = builder.add_service(s.clone()); + } + + Ok(router.serve_with_incoming(incoming).boxed()) + } + AuthenticationConfig::Jwt(jwt) => { + // Build the authentication layer and perform its async initialization + let mut auth_layer = , + >>::get_server_layer(jwt)?; + + auth_layer.initialize().await?; + + let mut builder = builder.layer(auth_layer); + + let mut router = builder.add_service(svc[0].clone()); + for s in svc.iter().skip(1) { + router = builder.add_service(s.clone()); + } + + Ok(router.serve_with_incoming(incoming).boxed()) + } + #[cfg(all(feature = "native", not(target_family = "windows")))] + AuthenticationConfig::Spire(spire) => { + let mut auth_layer = , + >>::get_server_layer(spire)?; + + auth_layer.initialize().await?; + + let mut builder = builder.layer(auth_layer); + + let mut router = builder.add_service(svc[0].clone()); + for s in svc.iter().skip(1) { + router = builder.add_service(s.clone()); + } + + Ok(router.serve_with_incoming(incoming).boxed()) + } + AuthenticationConfig::None => { + let mut router = builder.add_service(svc[0].clone()); + for s in svc.iter().skip(1) { + router = builder.add_service(s.clone()); + } + + Ok(router.serve_with_incoming(incoming).boxed()) + } + } + } + + #[cfg(feature = "native")] + pub(crate) async fn to_grpc_server_future( + &self, + svc: &[S], + ) -> Result + where + S: tower_service::Service< + http::Request, + Response = http::Response, + Error = Infallible, + > + + tonic::server::NamedService + + Clone + + Send + + 'static + + Sync, + S::Future: Send + 'static, + { + if svc.is_empty() { + return Err(ConfigError::MissingServices); + } + + if self.transport == TransportProtocol::Websocket { + return Err(ConfigError::GrpcServerUnsupportedTransport); + } + + if self.endpoint.is_empty() { + return Err(ConfigError::MissingEndpoint); + } + + #[cfg(target_family = "unix")] + if self.endpoint.starts_with("unix://") { + if !self.tls_setting.insecure { + // For local Unix domain sockets we currently require insecure=true + return Err(ConfigError::UnixSocketTlsUnsupported); + } + + let socket_path = Self::parse_unix_socket_path(self.endpoint.as_str())?; + + // Best-effort cleanup of any stale socket file + let _ = std::fs::remove_file(&socket_path); + + let listener = UnixListener::bind(&socket_path)?; + let incoming = UnixListenerStream::new(listener); + + return self.serve_with_incoming(svc, incoming).await; + } + + #[cfg(not(target_family = "unix"))] + if self.endpoint.starts_with("unix://") { + return Err(ConfigError::UnixSocketUnsupported); + } + + let addr = SocketAddr::from_str(self.endpoint.as_str())?; + + // Async TLS configuration load (may involve SPIFFE operations) + let tls_config = self.tls_setting.load_rustls_config().await?; + let incoming = TcpIncoming::bind(addr)?; + + match tls_config { + Some(tls_config) => { + let incoming = tonic_tls::rustls::TlsIncoming::new(incoming, Arc::new(tls_config)); + self.serve_with_incoming(svc, incoming).await + } + None => self.serve_with_incoming(svc, incoming).await, + } + } + + #[cfg(feature = "native")] + pub(crate) async fn run_grpc_server( + &self, + svc: &[S], + drain_rx: drain::Watch, + ) -> Result + where + S: tower_service::Service< + http::Request, + Response = http::Response, + Error = Infallible, + > + + tonic::server::NamedService + + Clone + + Send + + 'static + + Sync, + S::Future: Send + 'static, + { + debug!(%self, "server configured: setting it up"); + let server_future = self.to_grpc_server_future(svc).await?; + + // create a new cancellation token + let token = CancellationToken::new(); + let token_clone = token.clone(); + + // spawn server acceptor in a new task + tokio::spawn(async move { + debug!("starting server main loop"); + let shutdown = drain_rx.signaled(); + + tokio::select! { + res = server_future => { + match res { + Ok(_) => { + debug!("server shutdown"); + } + Err(e) => { + tracing::error!(error = %e.chain(), "server error"); + } + } + } + _ = shutdown => { + debug!("shutting down server"); + } + _ = token.cancelled() => { + debug!("cancellation token triggered: shutting down server"); + } + } + }); + + Ok(token_clone) + } +} + +#[cfg(all(test, feature = "native"))] +mod tests { + use super::*; + use crate::testutils::{Empty, helloworld::greeter_server::GreeterServer}; + use crate::tls::common::TlsSource; + use serde_json; + + static TEST_DATA_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata/grpc"); + + #[test] + fn test_default_keepalive_server_parameters() { + let keepalive = KeepaliveServerParameters::default(); + assert_eq!(keepalive.max_connection_idle, default_max_connection_idle()); + assert_eq!(keepalive.max_connection_age, default_max_connection_age()); + assert_eq!( + keepalive.max_connection_age_grace, + default_max_connection_age_grace() + ); + assert_eq!(keepalive.time, default_time()); + assert_eq!(keepalive.timeout, default_timeout()); + } + + #[test] + fn test_default_server_config() { + let server_config = ServerConfig::default(); + assert_eq!(server_config.endpoint, String::new()); + assert_eq!(server_config.transport, TransportProtocol::Grpc); + assert_eq!(server_config.tls_setting, TLSSetting::default()); + assert_eq!(server_config.http2_only, default_http2_only()); + assert_eq!(server_config.max_frame_size, Some(4)); + assert_eq!(server_config.max_concurrent_streams, Some(100)); + assert_eq!(server_config.max_header_list_size, None); + assert_eq!(server_config.read_buffer_size, Some(1024 * 1024)); + assert_eq!(server_config.write_buffer_size, Some(1024 * 1024)); + assert_eq!( + server_config.keepalive, + KeepaliveServerParameters::default() + ); + assert_eq!(server_config.auth, AuthenticationConfig::None); + } + + #[test] + fn test_with_http2_only() { + let default_value = default_http2_only(); + + let cfg = ServerConfig::default().with_http2_only(!default_value); + assert_eq!(cfg.http2_only, !default_value); + + let cfg = cfg.with_http2_only(default_value); + assert_eq!(cfg.http2_only, default_value); + + let endpoint = "127.0.0.1:50051"; + let cfg = ServerConfig::with_endpoint(endpoint).with_http2_only(false); + assert_eq!(cfg.endpoint, endpoint); + assert!(!cfg.http2_only); + assert_eq!(cfg.max_frame_size, ServerConfig::default().max_frame_size); + } + + #[test] + fn test_with_max_frame_size() { + let cfg = ServerConfig::default().with_max_frame_size(Some(8)); + assert_eq!(cfg.max_frame_size, Some(8)); + + let cfg = cfg.with_max_frame_size(None); + assert_eq!(cfg.max_frame_size, None); + + let cfg = ServerConfig::default().with_max_frame_size(Some(0)); + assert_eq!(cfg.max_frame_size, Some(0)); + + let cfg = ServerConfig::default().with_max_frame_size(Some(u32::MAX)); + assert_eq!(cfg.max_frame_size, Some(u32::MAX)); + + let original = ServerConfig::default(); + let updated = original.clone().with_max_frame_size(Some(16)); + assert_eq!(updated.max_frame_size, Some(16)); + assert_eq!(updated.endpoint, original.endpoint); + assert_eq!(updated.http2_only, original.http2_only); + assert_eq!( + updated.max_concurrent_streams, + original.max_concurrent_streams + ); + } + + #[test] + fn test_with_max_concurrent_streams() { + let cfg = ServerConfig::default().with_max_concurrent_streams(Some(250)); + assert_eq!(cfg.max_concurrent_streams, Some(250)); + + let cfg = cfg.with_max_concurrent_streams(None); + assert_eq!(cfg.max_concurrent_streams, None); + + let cfg = ServerConfig::default().with_max_concurrent_streams(Some(1)); + assert_eq!(cfg.max_concurrent_streams, Some(1)); + + let cfg = ServerConfig::default().with_max_concurrent_streams(Some(u32::MAX)); + assert_eq!(cfg.max_concurrent_streams, Some(u32::MAX)); + + let original = ServerConfig::default(); + let updated = original.clone().with_max_concurrent_streams(Some(500)); + assert_eq!(updated.max_concurrent_streams, Some(500)); + assert_eq!(updated.max_frame_size, original.max_frame_size); + assert_eq!(updated.http2_only, original.http2_only); + } + + #[test] + fn test_with_max_header_list_size() { + let cfg = ServerConfig::default().with_max_header_list_size(Some(8192)); + assert_eq!(cfg.max_header_list_size, Some(8192)); + + let cfg = cfg.with_max_header_list_size(None); + assert_eq!(cfg.max_header_list_size, None); + + let cfg = ServerConfig::default().with_max_header_list_size(Some(0)); + assert_eq!(cfg.max_header_list_size, Some(0)); + + let cfg = ServerConfig::default().with_max_header_list_size(Some(u32::MAX)); + assert_eq!(cfg.max_header_list_size, Some(u32::MAX)); + + let original = ServerConfig::default(); + let updated = original.clone().with_max_header_list_size(Some(4096)); + assert_eq!(updated.max_header_list_size, Some(4096)); + assert_eq!(updated.max_frame_size, original.max_frame_size); + assert_eq!( + updated.max_concurrent_streams, + original.max_concurrent_streams + ); + } + + #[test] + fn test_with_read_buffer_size() { + let cfg = ServerConfig::default().with_read_buffer_size(Some(2 * 1024 * 1024)); + assert_eq!(cfg.read_buffer_size, Some(2 * 1024 * 1024)); + + let cfg = cfg.with_read_buffer_size(None); + assert_eq!(cfg.read_buffer_size, None); + + let cfg = ServerConfig::default().with_read_buffer_size(Some(0)); + assert_eq!(cfg.read_buffer_size, Some(0)); + + let cfg = ServerConfig::default().with_read_buffer_size(Some(usize::MAX)); + assert_eq!(cfg.read_buffer_size, Some(usize::MAX)); + + let original = ServerConfig::default(); + let updated = original.clone().with_read_buffer_size(Some(64 * 1024)); + assert_eq!(updated.read_buffer_size, Some(64 * 1024)); + assert_eq!(updated.write_buffer_size, original.write_buffer_size); + assert_eq!(updated.http2_only, original.http2_only); + } + + #[test] + fn test_with_write_buffer_size() { + let cfg = ServerConfig::default().with_write_buffer_size(Some(2 * 1024 * 1024)); + assert_eq!(cfg.write_buffer_size, Some(2 * 1024 * 1024)); + + let cfg = cfg.with_write_buffer_size(None); + assert_eq!(cfg.write_buffer_size, None); + + let cfg = ServerConfig::default().with_write_buffer_size(Some(0)); + assert_eq!(cfg.write_buffer_size, Some(0)); + + let cfg = ServerConfig::default().with_write_buffer_size(Some(usize::MAX)); + assert_eq!(cfg.write_buffer_size, Some(usize::MAX)); + + let original = ServerConfig::default(); + let updated = original.clone().with_write_buffer_size(Some(64 * 1024)); + assert_eq!(updated.write_buffer_size, Some(64 * 1024)); + assert_eq!(updated.read_buffer_size, original.read_buffer_size); + assert_eq!(updated.http2_only, original.http2_only); + } + + #[test] + fn test_server_config_builders_chain_independently() { + let cfg = ServerConfig::with_endpoint("127.0.0.1:50051") + .with_http2_only(false) + .with_max_frame_size(Some(8)) + .with_max_concurrent_streams(Some(200)) + .with_max_header_list_size(Some(16384)) + .with_read_buffer_size(Some(2 * 1024 * 1024)) + .with_write_buffer_size(Some(4 * 1024 * 1024)); + + assert_eq!(cfg.endpoint, "127.0.0.1:50051"); + assert!(!cfg.http2_only); + assert_eq!(cfg.max_frame_size, Some(8)); + assert_eq!(cfg.max_concurrent_streams, Some(200)); + assert_eq!(cfg.max_header_list_size, Some(16384)); + assert_eq!(cfg.read_buffer_size, Some(2 * 1024 * 1024)); + assert_eq!(cfg.write_buffer_size, Some(4 * 1024 * 1024)); + + // Defaults preserved for unchanged fields. + assert_eq!(cfg.transport, TransportProtocol::default()); + assert_eq!(cfg.tls_setting, TLSSetting::default()); + assert_eq!(cfg.keepalive, KeepaliveServerParameters::default()); + assert_eq!(cfg.auth, AuthenticationConfig::default()); + } + + #[tokio::test] + async fn test_to_incoming_server_config() { + let mut server_config = ServerConfig::default(); + let empty_service = Arc::new(Empty::new()); + + // no endpoint - should return an error + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) + .await; + // Make sure the error is a ConfigError::MissingEndpoint + assert!(ret.is_err_and(|e| { e.to_string().contains("missing grpc endpoint") })); + + // set the endpoint in the config. Now it shouhld fail because of the invalid endpoint + server_config.endpoint = "0.0.0.0:123456".to_string(); + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) + .await; + // Make sure we got an EndpointParse error + assert!(ret.is_err_and(|e| { matches!(e, ConfigError::EndpointParse(_)) })); + + // set a valid endpoint in the config. Now it should fail because of the missing cert/key files for tls + server_config.endpoint = "0.0.0.0:12345".to_string(); + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) + .await; + assert!(ret.is_err_and(|e| { matches!(e, ConfigError::TlsConfig(_)) })); + + // set the tls setting to insecure. Now it should return a server future + server_config.tls_setting.insecure = true; + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) + .await; + assert!(ret.is_ok()); + + // drop it, as we have a server listening on the port now + drop(ret.unwrap()); + + // Set insecure to false and configure certificate/key via TlsSource::File (updated API) + server_config.tls_setting.insecure = false; + server_config.tls_setting.config.source = TlsSource::File { + cert: format!("{}/server.crt", TEST_DATA_PATH), + key: format!("{}/server.key", TEST_DATA_PATH), + }; + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service.clone())]) + .await; + assert!(ret.is_ok()); + } + + #[tokio::test] + async fn test_to_server_future_rejects_websocket_transport() { + let empty_service = Arc::new(Empty::new()); + let server_config = ServerConfig::with_endpoint("0.0.0.0:12345") + .with_transport(TransportProtocol::Websocket); + let ret = server_config + .to_server_future(&[GreeterServer::from_arc(empty_service)]) + .await; + assert!(matches!( + ret, + Err(ConfigError::GrpcServerUnsupportedTransport) + )); + } + + #[test] + fn test_keepalive_server_parameters_valid_durations_deserialize() { + let json = r#"{ + "endpoint": "0.0.0.0:12345", + "keepalive": { + "max_connection_idle": "30m", + "max_connection_age": "1h30m", + "max_connection_age_grace": "15s", + "time": "5s", + "timeout": "2s" + } + }"#; + + let cfg: ServerConfig = serde_json::from_str(json).expect("deserialization should succeed"); + assert_eq!( + cfg.keepalive.max_connection_idle, + Duration::from_secs(30 * 60) + ); + assert_eq!( + cfg.keepalive.max_connection_age, + Duration::from_secs(90 * 60) + ); + assert_eq!( + cfg.keepalive.max_connection_age_grace, + Duration::from_secs(15) + ); + assert_eq!(cfg.keepalive.time, Duration::from_secs(5)); + assert_eq!(cfg.keepalive.timeout, Duration::from_secs(2)); + } + + #[test] + fn test_invalid_keepalive_duration_strings_fail_deserialize() { + let invalid_json_cases = [ + r#"{ "keepalive": { "time": "zz" } }"#, + r#"{ "keepalive": { "timeout": "-5s" } }"#, + r#"{ "keepalive": { "max_connection_age": "10x" } }"#, + ]; + for js in invalid_json_cases { + let res: Result = serde_json::from_str(js); + assert!(res.is_err(), "expected error for json: {}", js); + } + } + + #[test] + fn test_server_config_keepalive_roundtrip_duration_serialization() { + let keepalive = KeepaliveServerParameters { + max_connection_idle: Duration::from_secs(10).into(), + max_connection_age: Duration::from_secs(20).into(), + max_connection_age_grace: Duration::from_secs(30).into(), + time: Duration::from_secs(3).into(), + timeout: Duration::from_secs(1).into(), + }; + + let cfg = ServerConfig::with_endpoint("127.0.0.1:50000").with_keepalive(keepalive.clone()); + let serialized = serde_json::to_string(&cfg).expect("serialize"); + let deserialized: ServerConfig = serde_json::from_str(&serialized).expect("deserialize"); + + assert_eq!( + deserialized.keepalive.max_connection_idle, + Duration::from_secs(10) + ); + assert_eq!( + deserialized.keepalive.max_connection_age, + Duration::from_secs(20) + ); + assert_eq!( + deserialized.keepalive.max_connection_age_grace, + Duration::from_secs(30) + ); + assert_eq!(deserialized.keepalive.time, Duration::from_secs(3)); + assert_eq!(deserialized.keepalive.timeout, Duration::from_secs(1)); + } +} diff --git a/data-plane/core/config/src/testutils.rs b/data-plane/core/config/src/testutils.rs index 96cc347d1..1b4f1f671 100644 --- a/data-plane/core/config/src/testutils.rs +++ b/data-plane/core/config/src/testutils.rs @@ -1,21 +1,26 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +#[cfg(feature = "native")] use tonic::{Request, Response, Status}; +#[cfg(feature = "native")] #[rustfmt::skip] pub mod helloworld; pub mod tower_service; +#[cfg(feature = "native")] #[derive(Default)] pub struct Empty {} +#[cfg(feature = "native")] impl Empty { pub fn new() -> Self { Self {} } } +#[cfg(feature = "native")] #[tonic::async_trait] impl helloworld::greeter_server::Greeter for Empty { async fn say_hello( diff --git a/data-plane/core/config/src/tls.rs b/data-plane/core/config/src/tls.rs index 287c4ab65..70773442b 100644 --- a/data-plane/core/config/src/tls.rs +++ b/data-plane/core/config/src/tls.rs @@ -1,11 +1,225 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +#[cfg(feature = "native")] pub mod client; +#[cfg(feature = "native")] pub mod common; +#[cfg(feature = "native")] pub mod errors; +#[cfg(feature = "native")] pub mod provider; +#[cfg(feature = "native")] pub mod root_store_builder; +#[cfg(feature = "native")] pub mod server; +#[cfg(feature = "native")] pub use root_store_builder::RootStoreBuilder; + +#[cfg(not(feature = "native"))] +pub mod errors { + use thiserror::Error; + + #[derive(Error, Debug)] + pub enum ConfigError { + #[error("TLS is unavailable for this build configuration")] + Unsupported, + } +} + +#[cfg(not(feature = "native"))] +pub mod client { + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + + use crate::component::configuration::Configuration; + use crate::tls::errors::ConfigError; + + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone, JsonSchema)] + pub struct TlsClientConfig { + #[serde(default = "default_insecure")] + pub insecure: bool, + #[serde(default)] + pub insecure_skip_verify: bool, + } + + impl Default for TlsClientConfig { + fn default() -> Self { + Self { + insecure: default_insecure(), + insecure_skip_verify: false, + } + } + } + + fn default_insecure() -> bool { + false + } + + impl TlsClientConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn insecure() -> Self { + Self { + insecure: true, + ..Self::default() + } + } + + pub fn with_insecure_skip_verify(self, insecure_skip_verify: bool) -> Self { + Self { + insecure_skip_verify, + ..self + } + } + + pub fn with_insecure(self, insecure: bool) -> Self { + Self { insecure, ..self } + } + + pub fn with_ca_file(self, _ca_file: &str) -> Self { + self + } + + pub fn with_ca_pem(self, _ca_pem: &str) -> Self { + self + } + + pub fn with_include_system_ca_certs_pool(self, _include: bool) -> Self { + self + } + + pub fn with_cert_and_key_file(self, _cert_file: &str, _key_file: &str) -> Self { + self + } + + pub fn with_cert_and_key_pem(self, _cert_pem: &str, _key_pem: &str) -> Self { + self + } + + pub fn with_tls_version(self, _tls_version: &str) -> Self { + self + } + + pub fn with_reload_interval(self, _reload_interval: Option) -> Self { + self + } + } + + impl std::fmt::Display for TlsClientConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } + } + + impl Configuration for TlsClientConfig { + type Error = ConfigError; + + fn validate(&self) -> Result<(), Self::Error> { + Ok(()) + } + } +} + +#[cfg(not(feature = "native"))] +pub mod server { + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + + use crate::component::configuration::Configuration; + use crate::tls::errors::ConfigError; + + #[derive(Debug, Deserialize, Serialize, PartialEq, Clone, JsonSchema)] + pub struct TlsServerConfig { + #[serde(default = "default_insecure")] + pub insecure: bool, + #[serde(default = "default_reload_client_ca_file")] + pub reload_client_ca_file: bool, + } + + impl Default for TlsServerConfig { + fn default() -> Self { + Self { + insecure: default_insecure(), + reload_client_ca_file: default_reload_client_ca_file(), + } + } + } + + fn default_insecure() -> bool { + false + } + + fn default_reload_client_ca_file() -> bool { + false + } + + impl TlsServerConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn insecure() -> Self { + Self { + insecure: true, + ..Self::default() + } + } + + pub fn with_insecure(self, insecure: bool) -> Self { + Self { insecure, ..self } + } + + pub fn with_reload_client_ca_file(self, reload_client_ca_file: bool) -> Self { + Self { + reload_client_ca_file, + ..self + } + } + + pub fn with_ca_file(self, _ca_file: &str) -> Self { + self + } + + pub fn with_ca_pem(self, _ca_pem: &str) -> Self { + self + } + + pub fn with_include_system_ca_certs_pool(self, _include: bool) -> Self { + self + } + + pub fn with_cert_and_key_file(self, _cert_path: &str, _key_path: &str) -> Self { + self + } + + pub fn with_cert_and_key_pem(self, _cert_pem: &str, _key_pem: &str) -> Self { + self + } + + pub fn with_tls_version(self, _tls_version: &str) -> Self { + self + } + + pub fn with_reload_interval(self, _reload_interval: Option) -> Self { + self + } + } + + impl std::fmt::Display for TlsServerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } + } + + impl Configuration for TlsServerConfig { + type Error = ConfigError; + + fn validate(&self) -> Result<(), Self::Error> { + Ok(()) + } + } +} diff --git a/data-plane/core/config/src/tls/client.rs b/data-plane/core/config/src/tls/client.rs index 19212d6fd..5e5afbdb4 100644 --- a/data-plane/core/config/src/tls/client.rs +++ b/data-plane/core/config/src/tls/client.rs @@ -286,6 +286,9 @@ impl TlsClientConfig { return Ok(None); } + // Ensure rustls has a process-level crypto provider before any config builders run. + crate::tls::provider::initialize_crypto_provider(); + // Resolve TLS version let tls_version = match self.config.tls_version.as_str() { "tls1.2" => &TLS12, diff --git a/data-plane/core/config/src/tls/server.rs b/data-plane/core/config/src/tls/server.rs index e6acd3758..f16943843 100644 --- a/data-plane/core/config/src/tls/server.rs +++ b/data-plane/core/config/src/tls/server.rs @@ -229,6 +229,9 @@ impl TlsServerConfig { return Ok(None); } + // Ensure rustls has a process-level crypto provider before any config builders run. + crate::tls::provider::initialize_crypto_provider(); + // TLS version let tls_version = match self.config.tls_version.as_str() { "tls1.2" => &TLS12, diff --git a/data-plane/core/config/src/websocket.rs b/data-plane/core/config/src/websocket.rs new file mode 100644 index 000000000..3f49a587a --- /dev/null +++ b/data-plane/core/config/src/websocket.rs @@ -0,0 +1,22 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +// When both native and wasm features are enabled (e.g. via --all-features), +// native takes precedence via the not(feature = "native") guards below. + +#[cfg(feature = "native")] +#[path = "websocket/client.rs"] +pub mod client; +#[cfg(all(feature = "wasm", not(feature = "native")))] +#[path = "websocket/client_wasm.rs"] +pub mod client; + +#[cfg(feature = "native")] +#[path = "websocket/common.rs"] +pub mod common; +#[cfg(all(feature = "wasm", not(feature = "native")))] +#[path = "websocket/common_wasm.rs"] +pub mod common; + +#[cfg(feature = "native")] +pub mod server; diff --git a/data-plane/core/config/src/websocket/client.rs b/data-plane/core/config/src/websocket/client.rs new file mode 100644 index 000000000..d0dd293a7 --- /dev/null +++ b/data-plane/core/config/src/websocket/client.rs @@ -0,0 +1,312 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use std::future::Future; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; + +use bytes::Bytes; +use fastwebsockets::handshake; +use http::header::{AUTHORIZATION, CONNECTION, HOST, ORIGIN, UPGRADE}; +use http_body_util::Empty; +use hyper::Request; +use hyper::header::{HeaderName, HeaderValue}; +use tokio::net::TcpStream; +use tokio_rustls::TlsConnector; + +use crate::client::ClientConfig; +use crate::grpc::errors::ConfigError; +use crate::tls::common::RustlsConfigLoader; +use crate::transport::TransportProtocol; + +use super::common::{ + ClientHandshakeAuth, UpgradedWebSocket, WebSocketEndpoint, build_client_handshake_auth, +}; + +pub struct WebSocketClientChannel { + pub websocket: UpgradedWebSocket, + pub local_addr: Option, + pub remote_addr: Option, +} + +impl ClientConfig { + pub async fn to_websocket_channel(&self) -> Result { + if self.transport != TransportProtocol::Websocket { + return Err(ConfigError::WebSocketClientUnsupportedTransport); + } + + let endpoint = WebSocketEndpoint::parse(self.endpoint.as_str())?; + let auth = build_client_handshake_auth(self).await?; + + // TODO(hackeramitkumar): In query-param mode we should suppress Authorization headers + // for browser-compat behavior and rely only on the configured query parameter. + let query_param = self + .websocket_auth_query_param + .as_deref() + .zip(auth.bearer_token.as_deref()); + + let request_uri = endpoint.request_uri(query_param)?; + let request = build_handshake_request(self, &endpoint, request_uri, &auth)?; + + let stream = connect_tcp(self, &endpoint).await?; + let local_addr = stream.local_addr().ok(); + let remote_addr = stream.peer_addr().ok(); + + let websocket = if endpoint.secure { + let tls_config = self.tls_setting.load_rustls_config().await?; + let tls_config = tls_config.ok_or(ConfigError::WebSocketTlsConfiguration)?; + let connector = TlsConnector::from(Arc::new(tls_config)); + + let server_name = self + .server_name + .as_deref() + .unwrap_or(endpoint.host.as_str()); + let server_name = + tokio_rustls::rustls::pki_types::ServerName::try_from(server_name.to_string()) + .map_err(|_| ConfigError::WebSocketTlsConfiguration)?; + + let tls_stream = connector + .connect(server_name, stream) + .await + .map_err(ConfigError::WebSocketConnection)?; + + handshake::client(&SpawnExecutor, request, tls_stream) + .await + .map_err(ConfigError::WebSocketHandshake)? + .0 + } else { + handshake::client(&SpawnExecutor, request, stream) + .await + .map_err(ConfigError::WebSocketHandshake)? + .0 + }; + + Ok(WebSocketClientChannel { + websocket, + local_addr, + remote_addr, + }) + } +} + +fn build_handshake_request( + config: &ClientConfig, + endpoint: &WebSocketEndpoint, + uri: http::Uri, + auth: &ClientHandshakeAuth, +) -> Result>, ConfigError> { + let mut request = Request::builder() + .method("GET") + .uri(uri) + .header(HOST, endpoint.authority.as_str()) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header("Sec-WebSocket-Key", handshake::generate_key()) + .header("Sec-WebSocket-Version", "13") + .body(Empty::::new()) + .map_err(ConfigError::WebSocketRequest)?; + + let headers = request.headers_mut(); + + if let Some(origin) = config.origin.as_deref() { + headers.insert(ORIGIN, HeaderValue::from_str(origin)?); + } + + // TODO(hackeramitkumar): Skip this header when websocket_auth_query_param is active. + if let Some(auth_header) = auth.authorization_header.as_deref() { + headers.insert(AUTHORIZATION, HeaderValue::from_str(auth_header)?); + } + + for (name, value) in &config.headers { + headers.insert(HeaderName::from_str(name)?, HeaderValue::from_str(value)?); + } + + Ok(request) +} + +async fn connect_tcp( + config: &ClientConfig, + endpoint: &WebSocketEndpoint, +) -> Result { + let connect = TcpStream::connect(endpoint.socket_address()); + let timeout: std::time::Duration = config.connect_timeout.into(); + + if timeout.is_zero() { + return connect.await.map_err(ConfigError::WebSocketConnection); + } + + match tokio::time::timeout(timeout, connect).await { + Ok(result) => result.map_err(ConfigError::WebSocketConnection), + Err(_) => Err(ConfigError::WebSocketConnection(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "websocket connect timeout", + ))), + } +} + +struct SpawnExecutor; + +impl hyper::rt::Executor for SpawnExecutor +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::time::Duration; + + use super::*; + + #[test] + fn build_handshake_request_sets_required_and_optional_headers() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:8080/ws").expect("endpoint"); + let uri = endpoint.request_uri(None).expect("uri"); + + let mut headers = HashMap::new(); + headers.insert("x-test-header".to_string(), "x-value".to_string()); + let config = ClientConfig::with_endpoint("ws://localhost:8080/ws") + .with_transport(TransportProtocol::Websocket) + .with_origin("https://example.com") + .with_headers(headers); + + let auth = ClientHandshakeAuth { + authorization_header: Some("Bearer test-token".to_string()), + bearer_token: Some("test-token".to_string()), + }; + + let request = build_handshake_request(&config, &endpoint, uri, &auth).expect("request"); + assert_eq!(request.method(), http::Method::GET); + assert_eq!(request.uri().to_string(), "ws://localhost:8080/ws"); + + let headers = request.headers(); + assert_eq!( + headers.get(HOST).and_then(|v| v.to_str().ok()), + Some("localhost:8080") + ); + assert_eq!( + headers.get(UPGRADE).and_then(|v| v.to_str().ok()), + Some("websocket") + ); + assert_eq!( + headers.get(CONNECTION).and_then(|v| v.to_str().ok()), + Some("upgrade") + ); + assert_eq!( + headers.get(AUTHORIZATION).and_then(|v| v.to_str().ok()), + Some("Bearer test-token") + ); + assert_eq!( + headers.get(ORIGIN).and_then(|v| v.to_str().ok()), + Some("https://example.com") + ); + assert_eq!( + headers.get("x-test-header").and_then(|v| v.to_str().ok()), + Some("x-value") + ); + assert!(headers.contains_key("Sec-WebSocket-Key")); + assert_eq!( + headers + .get("Sec-WebSocket-Version") + .and_then(|v| v.to_str().ok()), + Some("13") + ); + } + + #[test] + fn build_handshake_request_rejects_invalid_origin_header() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:8080/ws").expect("endpoint"); + let uri = endpoint.request_uri(None).expect("uri"); + let config = ClientConfig::with_endpoint("ws://localhost:8080/ws") + .with_transport(TransportProtocol::Websocket) + .with_origin("https://example.com\ninvalid"); + + let err = build_handshake_request(&config, &endpoint, uri, &ClientHandshakeAuth::default()) + .expect_err("invalid origin must fail"); + assert!(matches!(err, ConfigError::HeaderValueParse(_))); + } + + #[test] + fn build_handshake_request_rejects_invalid_custom_header_name() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:8080/ws").expect("endpoint"); + let uri = endpoint.request_uri(None).expect("uri"); + let mut headers = HashMap::new(); + headers.insert("bad header".to_string(), "value".to_string()); + let config = ClientConfig::with_endpoint("ws://localhost:8080/ws") + .with_transport(TransportProtocol::Websocket) + .with_headers(headers); + + let err = build_handshake_request(&config, &endpoint, uri, &ClientHandshakeAuth::default()) + .expect_err("invalid header name must fail"); + assert!(matches!(err, ConfigError::HeaderNameParse(_))); + } + + #[test] + fn build_handshake_request_rejects_invalid_auth_header() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:8080/ws").expect("endpoint"); + let uri = endpoint.request_uri(None).expect("uri"); + let config = ClientConfig::with_endpoint("ws://localhost:8080/ws") + .with_transport(TransportProtocol::Websocket); + let auth = ClientHandshakeAuth { + authorization_header: Some("Bearer invalid\nvalue".to_string()), + bearer_token: None, + }; + + let err = build_handshake_request(&config, &endpoint, uri, &auth) + .expect_err("invalid auth header must fail"); + assert!(matches!(err, ConfigError::HeaderValueParse(_))); + } + + #[tokio::test] + async fn connect_tcp_with_zero_timeout_connects_successfully() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("listener bind"); + let port = listener.local_addr().expect("local addr").port(); + + let accept_task = tokio::spawn(async move { + let _ = listener.accept().await.expect("accept"); + }); + + let endpoint = + WebSocketEndpoint::parse(&format!("ws://127.0.0.1:{port}/ws")).expect("endpoint parse"); + let config = ClientConfig::with_endpoint(&format!("ws://127.0.0.1:{port}/ws")) + .with_transport(TransportProtocol::Websocket) + .with_connect_timeout(Duration::ZERO); + + let stream = connect_tcp(&config, &endpoint).await.expect("connect"); + assert!(stream.peer_addr().is_ok()); + drop(stream); + accept_task.await.expect("accept task"); + } + + #[tokio::test] + async fn to_websocket_channel_rejects_non_websocket_transport() { + let config = ClientConfig::with_endpoint("http://127.0.0.1:12345"); + let err = match config.to_websocket_channel().await { + Ok(_) => panic!("must reject grpc transport"), + Err(err) => err, + }; + assert!(matches!( + err, + ConfigError::WebSocketClientUnsupportedTransport + )); + } + + #[tokio::test] + async fn to_websocket_channel_rejects_invalid_endpoint_scheme() { + let config = ClientConfig::with_endpoint("http://127.0.0.1:12345") + .with_transport(TransportProtocol::Websocket); + let err = match config.to_websocket_channel().await { + Ok(_) => panic!("must reject non ws scheme"), + Err(err) => err, + }; + assert!(matches!(err, ConfigError::InvalidWebSocketEndpointScheme)); + } +} diff --git a/data-plane/core/config/src/websocket/client_wasm.rs b/data-plane/core/config/src/websocket/client_wasm.rs new file mode 100644 index 000000000..bcc0c1b87 --- /dev/null +++ b/data-plane/core/config/src/websocket/client_wasm.rs @@ -0,0 +1,66 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use std::net::SocketAddr; + +#[cfg(target_arch = "wasm32")] +use super::common::{WebSocketEndpoint, build_client_handshake_auth}; +use crate::client::ClientConfig; +use crate::grpc::errors::ConfigError; +#[cfg(target_arch = "wasm32")] +use crate::transport::TransportProtocol; + +#[cfg(target_arch = "wasm32")] +use gloo_net::websocket::futures::WebSocket; + +#[cfg(target_arch = "wasm32")] +pub struct WebSocketClientChannel { + pub websocket: WebSocket, + pub local_addr: Option, + pub remote_addr: Option, +} + +#[cfg(not(target_arch = "wasm32"))] +pub struct WebSocketClientChannel { + pub local_addr: Option, + pub remote_addr: Option, +} + +#[cfg(target_arch = "wasm32")] +impl ClientConfig { + pub async fn to_websocket_channel(&self) -> Result { + if self.transport != TransportProtocol::Websocket { + return Err(ConfigError::WebSocketClientUnsupportedTransport); + } + + let endpoint = WebSocketEndpoint::parse(self.endpoint.as_str())?; + let auth = build_client_handshake_auth(self).await?; + + if auth.authorization_header.is_some() && self.websocket_auth_query_param.is_none() { + return Err(ConfigError::WebSocketWasmAuthorizationHeaderUnsupported); + } + + let query_param = self + .websocket_auth_query_param + .as_deref() + .zip(auth.bearer_token.as_deref()); + let request_uri = endpoint.request_uri(query_param)?; + + let websocket = WebSocket::open(request_uri.to_string().as_str()) + .map_err(|err| ConfigError::WebSocketWasmConnection(err.to_string()))?; + + Ok(WebSocketClientChannel { + websocket, + local_addr: None, + remote_addr: None, + }) + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl ClientConfig { + pub async fn to_websocket_channel(&self) -> Result { + let _ = self; + Err(ConfigError::WebSocketWasmUnsupportedTarget) + } +} diff --git a/data-plane/core/config/src/websocket/common.rs b/data-plane/core/config/src/websocket/common.rs new file mode 100644 index 000000000..98f8aef9d --- /dev/null +++ b/data-plane/core/config/src/websocket/common.rs @@ -0,0 +1,566 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use std::str::FromStr; + +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use fastwebsockets::WebSocket; +use hyper::Request; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; +use slim_auth::traits::{TokenProvider, Verifier}; +use tracing::warn; + +use crate::client::{AuthenticationConfig as ClientAuthConfig, ClientConfig}; +use crate::grpc::errors::ConfigError; +use crate::server::AuthenticationConfig as ServerAuthConfig; + +pub type UpgradedWebSocket = WebSocket>; + +#[derive(Debug, Clone)] +pub struct WebSocketEndpoint { + pub uri: http::Uri, + pub secure: bool, + pub host: String, + pub authority: String, + pub port: u16, + pub path: String, +} + +#[derive(Debug, Clone, Default)] +pub struct ClientHandshakeAuth { + pub authorization_header: Option, + pub bearer_token: Option, +} + +impl WebSocketEndpoint { + pub fn parse(endpoint: &str) -> Result { + let uri = http::Uri::from_str(endpoint)?; + + let secure = match uri.scheme_str() { + Some("ws") => false, + Some("wss") => true, + _ => return Err(ConfigError::InvalidWebSocketEndpointScheme), + }; + + let authority = uri + .authority() + .ok_or(ConfigError::InvalidWebSocketEndpointScheme)? + .as_str() + .to_string(); + let host = uri + .host() + .ok_or(ConfigError::InvalidWebSocketEndpointScheme)? + .to_string(); + let port = uri.port_u16().unwrap_or(if secure { 443 } else { 80 }); + + let mut path = uri.path().to_string(); + if path.is_empty() { + path.push('/'); + } + + Ok(Self { + uri, + secure, + host, + authority, + port, + path, + }) + } + + pub fn socket_address(&self) -> String { + format!("{}:{}", self.host, self.port) + } + + pub fn request_uri(&self, query_param: Option<(&str, &str)>) -> Result { + let mut uri = format!( + "{}://{}{}", + if self.secure { "wss" } else { "ws" }, + self.authority, + self.path, + ); + + if let Some(existing_query) = self.uri.query() { + uri.push('?'); + uri.push_str(existing_query); + } + + if let Some((key, value)) = query_param + && !key.is_empty() + && !value.is_empty() + { + if self.uri.query().is_some() { + uri.push('&'); + } else { + uri.push('?'); + } + uri.push_str(key); + uri.push('='); + uri.push_str(value); + } + + http::Uri::from_str(&uri).map_err(ConfigError::from) + } +} + +pub async fn build_client_handshake_auth( + config: &ClientConfig, +) -> Result { + match &config.auth { + ClientAuthConfig::None => Ok(ClientHandshakeAuth::default()), + ClientAuthConfig::Basic(basic) => { + let encoded = BASE64_STANDARD.encode(format!( + "{}:{}", + basic.username(), + basic.password().as_str() + )); + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Basic {}", encoded)), + bearer_token: None, + }) + } + ClientAuthConfig::StaticJwt(static_jwt) => { + let mut provider = static_jwt.build_static_token_provider()?; + provider.initialize().await?; + let token = provider.get_token()?; + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Bearer {}", token)), + bearer_token: Some(token), + }) + } + ClientAuthConfig::Jwt(jwt) => { + let mut provider = jwt.get_provider()?; + provider.initialize().await?; + let token = provider.get_token()?; + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Bearer {}", token)), + bearer_token: Some(token), + }) + } + #[cfg(all(feature = "native", not(target_family = "windows")))] + ClientAuthConfig::Spire(spire) => { + let mut provider = spire.create_provider()?; + provider.initialize().await?; + let token = provider.get_token()?; + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Bearer {}", token)), + bearer_token: Some(token), + }) + } + } +} + +pub async fn authorize_server_handshake( + auth: &ServerAuthConfig, + request: &Request, +) -> bool { + // TODO(hackeramitkumar): Return structured rejection reasons so websocket server can + // report auth failures through existing connection-state reporting hooks. + match auth { + ServerAuthConfig::None => true, + ServerAuthConfig::Basic(basic) => { + let auth = match request + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + { + Some(value) => value, + None => return false, + }; + + let Some(encoded) = auth.strip_prefix("Basic ") else { + return false; + }; + + let decoded = match BASE64_STANDARD.decode(encoded.as_bytes()) { + Ok(bytes) => bytes, + Err(_) => return false, + }; + + let credentials = match std::str::from_utf8(&decoded) { + Ok(credentials) => credentials, + Err(_) => return false, + }; + + credentials == format!("{}:{}", basic.username(), basic.password().as_str()) + } + ServerAuthConfig::Jwt(jwt) => { + let token = request + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|header| header.strip_prefix("Bearer ")) + .map(str::to_string) + .or_else(|| extract_query_param(request.uri().query(), "token")); + + let Some(token) = token else { + return false; + }; + + let mut verifier = match jwt.get_verifier() { + Ok(verifier) => verifier, + Err(err) => { + warn!(error = %err, "failed to create websocket JWT verifier"); + return false; + } + }; + + if verifier.initialize().await.is_err() { + return false; + } + + verifier.verify(token).await.is_ok() + } + #[cfg(all(feature = "native", not(target_family = "windows")))] + ServerAuthConfig::Spire(spire) => { + let token = request + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|header| header.strip_prefix("Bearer ")) + .map(str::to_string) + .or_else(|| extract_query_param(request.uri().query(), "token")); + + let Some(token) = token else { + return false; + }; + + let mut verifier = match spire.create_provider() { + Ok(verifier) => verifier, + Err(err) => { + warn!(error = %err, "failed to create websocket SPIRE verifier"); + return false; + } + }; + + if verifier.initialize().await.is_err() { + return false; + } + + verifier.verify(token).await.is_ok() + } + } +} + +fn extract_query_param(query: Option<&str>, name: &str) -> Option { + let query = query?; + + for pair in query.split('&') { + let mut parts = pair.splitn(2, '='); + let key = parts.next().unwrap_or_default(); + let value = parts.next().unwrap_or_default(); + + if key == name && !value.is_empty() { + return Some(value.to_string()); + } + } + + None +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use http::header::AUTHORIZATION; + + use super::*; + use crate::auth::basic::Config as BasicConfig; + use crate::auth::jwt::{Claims, Config as JwtConfig, JwtKey}; + use crate::auth::static_jwt::Config as StaticJwtConfig; + use crate::client::AuthenticationConfig as ClientAuthenticationConfig; + use crate::server::AuthenticationConfig as ServerAuthenticationConfig; + use slim_auth::jwt::{Algorithm, Key, KeyData, KeyFormat}; + use slim_auth::traits::TokenProvider; + + fn unique_temp_file(prefix: &str) -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after unix epoch") + .as_nanos(); + std::env::temp_dir().join(format!("{prefix}_{nanos}.jwt")) + } + + fn hs256_key(secret: &str) -> Key { + Key { + algorithm: Algorithm::HS256, + format: KeyFormat::Pem, + key: KeyData::Data(secret.to_string()), + } + } + + fn default_claims() -> Claims { + Claims::default().with_subject("websocket-test-subject") + } + + #[test] + fn websocket_endpoint_parse_ws() { + let endpoint = WebSocketEndpoint::parse("ws://example.com:8080/socket") + .expect("ws endpoint should parse"); + assert!(!endpoint.secure); + assert_eq!(endpoint.host, "example.com"); + assert_eq!(endpoint.authority, "example.com:8080"); + assert_eq!(endpoint.port, 8080); + assert_eq!(endpoint.path, "/socket"); + assert_eq!(endpoint.socket_address(), "example.com:8080"); + } + + #[test] + fn websocket_endpoint_parse_wss_default_port() { + let endpoint = + WebSocketEndpoint::parse("wss://example.com/").expect("wss endpoint should parse"); + assert!(endpoint.secure); + assert_eq!(endpoint.port, 443); + assert_eq!(endpoint.path, "/"); + assert_eq!(endpoint.socket_address(), "example.com:443"); + } + + #[test] + fn websocket_endpoint_parse_rejects_non_ws_scheme() { + let err = WebSocketEndpoint::parse("http://example.com").expect_err("must fail"); + assert!(matches!(err, ConfigError::InvalidWebSocketEndpointScheme)); + } + + #[test] + fn websocket_endpoint_request_uri_appends_query_param() { + let endpoint = + WebSocketEndpoint::parse("ws://localhost:9000/ws?existing=1").expect("must parse"); + let uri = endpoint + .request_uri(Some(("token", "abc"))) + .expect("request uri should build"); + assert_eq!( + uri.to_string(), + "ws://localhost:9000/ws?existing=1&token=abc" + ); + } + + #[test] + fn websocket_endpoint_request_uri_ignores_empty_query_param() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:9000/ws").expect("must parse"); + let uri = endpoint + .request_uri(Some(("token", ""))) + .expect("request uri should build"); + assert_eq!(uri.to_string(), "ws://localhost:9000/ws"); + } + + #[test] + fn websocket_endpoint_request_uri_rejects_invalid_query_value() { + let endpoint = WebSocketEndpoint::parse("ws://localhost:9000/ws").expect("must parse"); + let err = endpoint + .request_uri(Some(("token", "a b"))) + .expect_err("must fail invalid URI"); + assert!(matches!(err, ConfigError::UriParse(_))); + } + + #[test] + fn extract_query_param_returns_expected_value() { + let token = extract_query_param(Some("foo=bar&token=abc123"), "token"); + assert_eq!(token, Some("abc123".to_string())); + } + + #[test] + fn extract_query_param_returns_none_for_missing_or_empty() { + assert_eq!(extract_query_param(Some("foo=bar"), "token"), None); + assert_eq!(extract_query_param(Some("token="), "token"), None); + assert_eq!(extract_query_param(None, "token"), None); + } + + #[tokio::test] + async fn build_client_handshake_auth_none() { + let cfg = ClientConfig::with_endpoint("ws://localhost:46357"); + let auth = build_client_handshake_auth(&cfg) + .await + .expect("none auth should succeed"); + assert_eq!(auth.authorization_header, None); + assert_eq!(auth.bearer_token, None); + } + + #[tokio::test] + async fn build_client_handshake_auth_basic() { + // codeql[rust/hard-coded-cryptographic-value] + let cfg = ClientConfig::with_endpoint("ws://localhost:46357").with_auth( + ClientAuthenticationConfig::Basic(BasicConfig::new("alice", "secret")), + ); + let auth = build_client_handshake_auth(&cfg) + .await + .expect("basic auth should succeed"); + assert_eq!(auth.bearer_token, None); + assert_eq!( + auth.authorization_header, + Some("Basic YWxpY2U6c2VjcmV0".into()) + ); + } + + #[tokio::test] + async fn build_client_handshake_auth_static_jwt() { + let path = unique_temp_file("ws_static_token"); + std::fs::write(&path, "STATIC_TOKEN").expect("must write token file"); + let cfg = ClientConfig::with_endpoint("ws://localhost:46357").with_auth( + ClientAuthenticationConfig::StaticJwt(StaticJwtConfig::with_file( + path.to_string_lossy().to_string(), + )), + ); + + let auth = build_client_handshake_auth(&cfg) + .await + .expect("static jwt auth should succeed"); + assert_eq!(auth.bearer_token.as_deref(), Some("STATIC_TOKEN")); + assert_eq!( + auth.authorization_header.as_deref(), + Some("Bearer STATIC_TOKEN") + ); + + let _ = std::fs::remove_file(path); + } + + #[tokio::test] + async fn build_client_handshake_auth_jwt() { + let cfg = ClientConfig::with_endpoint("ws://localhost:46357").with_auth( + ClientAuthenticationConfig::Jwt(JwtConfig::new( + default_claims(), + Duration::from_secs(60), + JwtKey::Encoding(hs256_key("shared-secret")), + )), + ); + + let auth = build_client_handshake_auth(&cfg) + .await + .expect("jwt auth should succeed"); + let token = auth + .bearer_token + .as_ref() + .expect("jwt token should be present"); + assert!(!token.is_empty()); + assert_eq!( + auth.authorization_header.as_deref(), + Some(format!("Bearer {token}").as_str()) + ); + } + + #[tokio::test] + async fn authorize_server_handshake_none_allows_request() { + let req = Request::builder().uri("/ws").body(()).expect("request"); + let allowed = authorize_server_handshake(&ServerAuthenticationConfig::None, &req).await; + assert!(allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_basic_accepts_valid_credentials() { + let req = Request::builder() + .uri("/ws") + .header(AUTHORIZATION, "Basic YWxpY2U6c2VjcmV0") + .body(()) + .expect("request"); + // codeql[rust/hard-coded-cryptographic-value] + let auth = ServerAuthenticationConfig::Basic(BasicConfig::new("alice", "secret")); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_basic_rejects_bad_header() { + let req = Request::builder() + .uri("/ws") + .header(AUTHORIZATION, "Bearer token") + .body(()) + .expect("request"); + // codeql[rust/hard-coded-cryptographic-value] + let auth = ServerAuthenticationConfig::Basic(BasicConfig::new("alice", "secret")); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(!allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_basic_rejects_bad_base64() { + let req = Request::builder() + .uri("/ws") + .header(AUTHORIZATION, "Basic !!!not-base64!!!") + .body(()) + .expect("request"); + // codeql[rust/hard-coded-cryptographic-value] + let auth = ServerAuthenticationConfig::Basic(BasicConfig::new("alice", "secret")); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(!allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_jwt_rejects_when_token_missing() { + let auth = ServerAuthenticationConfig::Jwt(JwtConfig::new( + default_claims(), + Duration::from_secs(60), + JwtKey::Decoding(hs256_key("shared-secret")), + )); + let req = Request::builder().uri("/ws").body(()).expect("request"); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(!allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_jwt_rejects_invalid_verifier_config() { + let auth = ServerAuthenticationConfig::Jwt(JwtConfig::new( + default_claims(), + Duration::from_secs(60), + JwtKey::Encoding(hs256_key("shared-secret")), + )); + let req = Request::builder() + .uri("/ws") + .header(AUTHORIZATION, "Bearer test-token") + .body(()) + .expect("request"); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(!allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_jwt_accepts_valid_bearer_header() { + let claims = default_claims(); + let signer_cfg = JwtConfig::new( + claims.clone(), + Duration::from_secs(60), + JwtKey::Encoding(hs256_key("shared-secret")), + ); + let mut signer = signer_cfg.get_provider().expect("signer"); + signer.initialize().await.expect("signer init"); + let token = signer.get_token().expect("token"); + + let auth = ServerAuthenticationConfig::Jwt(JwtConfig::new( + claims, + Duration::from_secs(60), + JwtKey::Decoding(hs256_key("shared-secret")), + )); + let req = Request::builder() + .uri("/ws") + .header(AUTHORIZATION, format!("Bearer {token}")) + .body(()) + .expect("request"); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(allowed); + } + + #[tokio::test] + async fn authorize_server_handshake_jwt_accepts_query_param_token() { + let claims = default_claims(); + let signer_cfg = JwtConfig::new( + claims.clone(), + Duration::from_secs(60), + JwtKey::Encoding(hs256_key("shared-secret")), + ); + let mut signer = signer_cfg.get_provider().expect("signer"); + signer.initialize().await.expect("signer init"); + let token = signer.get_token().expect("token"); + + let auth = ServerAuthenticationConfig::Jwt(JwtConfig::new( + claims, + Duration::from_secs(60), + JwtKey::Decoding(hs256_key("shared-secret")), + )); + let req = Request::builder() + .uri(format!("/ws?token={token}")) + .body(()) + .expect("request"); + let allowed = authorize_server_handshake(&auth, &req).await; + assert!(allowed); + } +} diff --git a/data-plane/core/config/src/websocket/common_wasm.rs b/data-plane/core/config/src/websocket/common_wasm.rs new file mode 100644 index 000000000..1a6575a6b --- /dev/null +++ b/data-plane/core/config/src/websocket/common_wasm.rs @@ -0,0 +1,127 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use std::str::FromStr; + +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; + +use crate::client::{AuthenticationConfig as ClientAuthConfig, ClientConfig}; +use crate::grpc::errors::ConfigError; + +#[derive(Debug, Clone)] +pub struct WebSocketEndpoint { + pub uri: http::Uri, + pub secure: bool, + pub authority: String, + pub path: String, +} + +#[derive(Debug, Clone, Default)] +pub struct ClientHandshakeAuth { + pub authorization_header: Option, + pub bearer_token: Option, +} + +impl WebSocketEndpoint { + pub fn parse(endpoint: &str) -> Result { + let uri = http::Uri::from_str(endpoint)?; + + let secure = match uri.scheme_str() { + Some("ws") => false, + Some("wss") => true, + _ => return Err(ConfigError::InvalidWebSocketEndpointScheme), + }; + + let authority = uri + .authority() + .ok_or(ConfigError::InvalidWebSocketEndpointScheme)? + .as_str() + .to_string(); + + let mut path = uri.path().to_string(); + if path.is_empty() { + path.push('/'); + } + + Ok(Self { + uri, + secure, + authority, + path, + }) + } + + pub fn request_uri(&self, query_param: Option<(&str, &str)>) -> Result { + let mut uri = format!( + "{}://{}{}", + if self.secure { "wss" } else { "ws" }, + self.authority, + self.path, + ); + + if let Some(existing_query) = self.uri.query() { + uri.push('?'); + uri.push_str(existing_query); + } + + if let Some((key, value)) = query_param + && !key.is_empty() + && !value.is_empty() + { + if self.uri.query().is_some() { + uri.push('&'); + } else { + uri.push('?'); + } + uri.push_str(key); + uri.push('='); + uri.push_str(value); + } + + http::Uri::from_str(&uri).map_err(ConfigError::from) + } +} + +pub async fn build_client_handshake_auth( + config: &ClientConfig, +) -> Result { + match &config.auth { + ClientAuthConfig::None => Ok(ClientHandshakeAuth::default()), + ClientAuthConfig::Basic(basic) => { + let encoded = BASE64_STANDARD.encode(format!( + "{}:{}", + basic.username(), + basic.password().as_str() + )); + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Basic {}", encoded)), + bearer_token: None, + }) + } + #[cfg(feature = "native")] + ClientAuthConfig::StaticJwt(static_jwt) => { + use slim_auth::traits::TokenProvider; + + let mut provider = static_jwt.build_static_token_provider()?; + provider.initialize().await?; + let token = provider.get_token()?; + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Bearer {}", token)), + bearer_token: Some(token), + }) + } + #[cfg(feature = "native")] + ClientAuthConfig::Jwt(jwt) => { + use slim_auth::traits::TokenProvider; + + let mut provider = jwt.get_provider()?; + provider.initialize().await?; + let token = provider.get_token()?; + Ok(ClientHandshakeAuth { + authorization_header: Some(format!("Bearer {}", token)), + bearer_token: Some(token), + }) + } + } +} diff --git a/data-plane/core/config/src/websocket/server.rs b/data-plane/core/config/src/websocket/server.rs new file mode 100644 index 000000000..67f571a59 --- /dev/null +++ b/data-plane/core/config/src/websocket/server.rs @@ -0,0 +1,414 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use std::convert::Infallible; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use bytes::Bytes; +use fastwebsockets::upgrade; +use http_body_util::Empty; +use hyper::Request; +use hyper::Response; +use hyper::StatusCode; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use tokio_util::sync::CancellationToken; +use tracing::{debug, warn}; + +use crate::grpc::errors::ConfigError; +use crate::server::ServerConfig; +use crate::tls::common::RustlsConfigLoader; +use crate::transport::TransportProtocol; + +use super::common::{UpgradedWebSocket, WebSocketEndpoint, authorize_server_handshake}; + +pub struct AcceptedWebSocketConnection { + pub websocket: UpgradedWebSocket, + pub remote_addr: Option, + pub local_addr: Option, +} + +pub type OnAcceptedWebSocket = Arc< + dyn Fn(AcceptedWebSocketConnection) -> Pin + Send>> + Send + Sync, +>; + +impl ServerConfig { + pub async fn run_websocket_server( + &self, + drain_rx: drain::Watch, + on_accepted: OnAcceptedWebSocket, + ) -> Result { + if self.transport != TransportProtocol::Websocket { + return Err(ConfigError::WebSocketServerUnsupportedTransport); + } + + let endpoint = WebSocketEndpoint::parse(self.endpoint.as_str())?; + let listener = TcpListener::bind(endpoint.socket_address()).await?; + + let tls_config = self.tls_setting.load_rustls_config().await?; + let tls_acceptor = match (endpoint.secure, tls_config) { + (true, Some(config)) => Some(TlsAcceptor::from(Arc::new(config))), + (true, None) => return Err(ConfigError::WebSocketTlsConfiguration), + (false, Some(_)) => return Err(ConfigError::WebSocketTlsConfiguration), + (false, None) => None, + }; + + let auth = self.auth.clone(); + let expected_path = endpoint.path.clone(); + + let cancellation_token = CancellationToken::new(); + let cancel_clone = cancellation_token.clone(); + + tokio::spawn(async move { + let mut drain_signal = std::pin::pin!(drain_rx.signaled()); + + loop { + tokio::select! { + _ = &mut drain_signal => { + debug!("websocket server shutting down on drain"); + break; + } + _ = cancel_clone.cancelled() => { + debug!("websocket server shutting down on cancellation token"); + break; + } + accepted = listener.accept() => { + let (stream, remote_addr) = match accepted { + Ok(val) => val, + Err(err) => { + warn!(error = %err, "websocket accept error"); + continue; + } + }; + + let local_addr = stream.local_addr().ok(); + let auth = auth.clone(); + let expected_path = expected_path.clone(); + let on_accepted = on_accepted.clone(); + + if let Some(acceptor) = tls_acceptor.clone() { + tokio::spawn(async move { + let stream = match acceptor.accept(stream).await { + Ok(stream) => stream, + Err(err) => { + warn!(error = %err, "websocket TLS accept error"); + return; + } + }; + + serve_connection( + stream, + auth, + expected_path, + on_accepted, + remote_addr, + local_addr, + ) + .await; + }); + } else { + tokio::spawn(async move { + serve_connection( + stream, + auth, + expected_path, + on_accepted, + remote_addr, + local_addr, + ) + .await; + }); + } + } + } + } + }); + + Ok(cancellation_token) + } +} + +async fn serve_connection( + stream: S, + auth: crate::server::AuthenticationConfig, + expected_path: String, + on_accepted: OnAcceptedWebSocket, + remote_addr: SocketAddr, + local_addr: Option, +) where + S: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let io = TokioIo::new(stream); + let service = service_fn(move |mut request: Request| { + let auth = auth.clone(); + let expected_path = expected_path.clone(); + let on_accepted = on_accepted.clone(); + + async move { + if request.uri().path() != expected_path { + return Ok::>, Infallible>(response_with_status( + StatusCode::NOT_FOUND, + )); + } + + if !upgrade::is_upgrade_request(&request) { + return Ok::>, Infallible>(response_with_status( + StatusCode::BAD_REQUEST, + )); + } + + if !authorize_server_handshake(&auth, &request).await { + return Ok::>, Infallible>(response_with_status( + StatusCode::UNAUTHORIZED, + )); + } + + match upgrade::upgrade(&mut request) { + Ok((response, future)) => { + tokio::spawn(async move { + match future.await { + Ok(websocket) => { + on_accepted(AcceptedWebSocketConnection { + websocket, + remote_addr: Some(remote_addr), + local_addr, + }) + .await; + } + Err(err) => { + warn!(error = %err, "websocket upgrade error"); + } + } + }); + + Ok::>, Infallible>(response) + } + Err(err) => { + warn!(error = %err, "websocket upgrade rejected"); + Ok::>, Infallible>(response_with_status( + StatusCode::BAD_REQUEST, + )) + } + } + } + }); + + let connection = http1::Builder::new() + .serve_connection(io, service) + .with_upgrades(); + + if let Err(err) = connection.await { + debug!(error = %err, "websocket HTTP connection closed with error"); + } +} + +fn response_with_status(status: StatusCode) -> Response> { + Response::builder() + .status(status) + .body(Empty::new()) + .expect("valid websocket HTTP response") +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + use crate::auth::basic::Config as BasicAuthConfig; + use crate::auth::jwt::{Claims, Config as JwtConfig, JwtKey}; + use crate::server::AuthenticationConfig; + use crate::tls::server::TlsServerConfig; + use slim_auth::jwt::{Algorithm, Key, KeyData, KeyFormat}; + + fn hs256_key(secret: &str) -> Key { + Key { + algorithm: Algorithm::HS256, + format: KeyFormat::Pem, + key: KeyData::Data(secret.to_string()), + } + } + + fn websocket_upgrade_request(path: &str, auth_header: Option<&str>) -> String { + let mut request = format!( + "GET {path} HTTP/1.1\r\n\ + Host: localhost\r\n\ + Connection: Upgrade\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n" + ); + if let Some(value) = auth_header { + request.push_str(&format!("Authorization: {value}\r\n")); + } + request.push_str("\r\n"); + request + } + + async fn run_serve_connection_case( + auth: AuthenticationConfig, + expected_path: &str, + raw_request: &str, + ) -> String { + let (mut client_io, server_io) = tokio::io::duplex(8 * 1024); + let on_accepted: OnAcceptedWebSocket = Arc::new(|_accepted| Box::pin(async move {})); + let remote_addr = SocketAddr::from(([127, 0, 0, 1], 51000)); + + let mut serve_task = tokio::spawn(serve_connection( + server_io, + auth, + expected_path.to_string(), + on_accepted, + remote_addr, + None, + )); + + client_io + .write_all(raw_request.as_bytes()) + .await + .expect("write request"); + + let mut buffer = [0_u8; 2048]; + let n = tokio::time::timeout(Duration::from_secs(2), client_io.read(&mut buffer)) + .await + .expect("timed out waiting for response") + .expect("read response"); + let response = String::from_utf8_lossy(&buffer[..n]).to_string(); + + drop(client_io); + if tokio::time::timeout(Duration::from_secs(2), &mut serve_task) + .await + .is_err() + { + serve_task.abort(); + } + + response + } + + #[test] + fn response_with_status_sets_http_status() { + let response = response_with_status(StatusCode::UNAUTHORIZED); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn run_websocket_server_rejects_non_websocket_transport() { + let config = ServerConfig::with_endpoint("127.0.0.1:0"); + let (_signal, watch) = drain::channel(); + let on_accepted: OnAcceptedWebSocket = Arc::new(|_accepted| Box::pin(async move {})); + + let err = config + .run_websocket_server(watch, on_accepted) + .await + .expect_err("must reject grpc transport"); + assert!(matches!( + err, + ConfigError::WebSocketServerUnsupportedTransport + )); + } + + #[tokio::test] + async fn run_websocket_server_rejects_wss_without_tls_config() { + let config = ServerConfig::with_endpoint("wss://127.0.0.1:0") + .with_transport(TransportProtocol::Websocket) + .with_tls_settings(TlsServerConfig::insecure()); + let (_signal, watch) = drain::channel(); + let on_accepted: OnAcceptedWebSocket = Arc::new(|_accepted| Box::pin(async move {})); + + let err = config + .run_websocket_server(watch, on_accepted) + .await + .expect_err("must reject missing tls config for wss"); + assert!(matches!(err, ConfigError::WebSocketTlsConfiguration)); + } + + #[tokio::test] + async fn run_websocket_server_rejects_ws_with_tls_configured() { + let testdata_dir = concat!(env!("CARGO_MANIFEST_DIR"), "/testdata/grpc"); + let tls = TlsServerConfig::new().with_cert_and_key_file( + &format!("{testdata_dir}/server.crt"), + &format!("{testdata_dir}/server.key"), + ); + let config = ServerConfig::with_endpoint("ws://127.0.0.1:0") + .with_transport(TransportProtocol::Websocket) + .with_tls_settings(tls); + let (_signal, watch) = drain::channel(); + let on_accepted: OnAcceptedWebSocket = Arc::new(|_accepted| Box::pin(async move {})); + + let err = config + .run_websocket_server(watch, on_accepted) + .await + .expect_err("must reject tls for ws"); + assert!(matches!(err, ConfigError::WebSocketTlsConfiguration)); + } + + #[tokio::test] + async fn serve_connection_returns_not_found_for_wrong_path() { + let response = run_serve_connection_case( + AuthenticationConfig::None, + "/ws", + "GET /other HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .await; + assert!(response.starts_with("HTTP/1.1 404")); + } + + #[tokio::test] + async fn serve_connection_returns_bad_request_for_non_upgrade_request() { + let response = run_serve_connection_case( + AuthenticationConfig::None, + "/ws", + "GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .await; + assert!(response.starts_with("HTTP/1.1 400")); + } + + #[tokio::test] + async fn serve_connection_returns_unauthorized_for_invalid_basic_credentials() { + // codeql[rust/hard-coded-cryptographic-value] + let response = run_serve_connection_case( + AuthenticationConfig::Basic(BasicAuthConfig::new("alice", "secret")), + "/ws", + &websocket_upgrade_request("/ws", Some("Basic YWxpY2U6d3Jvbmc=")), + ) + .await; + assert!(response.starts_with("HTTP/1.1 401")); + } + + #[tokio::test] + async fn serve_connection_returns_unauthorized_for_invalid_jwt_config() { + let response = run_serve_connection_case( + AuthenticationConfig::Jwt(JwtConfig::new( + Claims::default(), + Duration::from_secs(60), + JwtKey::Encoding(hs256_key("secret")), + )), + "/ws", + &websocket_upgrade_request("/ws", Some("Bearer token")), + ) + .await; + assert!(response.starts_with("HTTP/1.1 401")); + } + + #[tokio::test] + async fn serve_connection_returns_switching_protocols_for_valid_upgrade() { + let response = run_serve_connection_case( + AuthenticationConfig::None, + "/ws", + &websocket_upgrade_request("/ws", None), + ) + .await; + assert!(response.starts_with("HTTP/1.1 101")); + } +} diff --git a/data-plane/core/config/tests/e2e.rs b/data-plane/core/config/tests/e2e.rs index 2689c7591..63bca7ca2 100644 --- a/data-plane/core/config/tests/e2e.rs +++ b/data-plane/core/config/tests/e2e.rs @@ -4,7 +4,7 @@ use tonic::{Request, Response, Status, metadata::KeyAndValueRef}; use tracing::info; -use slim_config::grpc::client::ClientConfig; +use slim_config::client::ClientConfig; use slim_config::testutils::helloworld::greeter_server::Greeter; use slim_config::testutils::helloworld::{HelloReply, HelloRequest}; @@ -93,10 +93,10 @@ mod tests { // use slim_config_grpc::headers_middleware::SetRequestHeader; use slim_auth::jwt::{Key, KeyData}; use slim_auth::traits::Signer; - use slim_config::grpc::{client::ClientConfig, server::ServerConfig}; use slim_config::testutils::helloworld::HelloRequest; use slim_config::testutils::helloworld::greeter_client::GreeterClient; use slim_config::testutils::helloworld::greeter_server::GreeterServer; + use slim_config::{client::ClientConfig, server::ServerConfig}; use slim_testing::utils::setup_test_jwt_resolver; #[cfg(unix)] use { @@ -146,7 +146,7 @@ mod tests { }); // create a client using the channel - let channel = match client_config.to_channel().await { + let channel = match client_config.to_grpc_channel().await { Ok(ch) => ch, Err(e) => return Err(Box::new(e)), }; @@ -199,7 +199,7 @@ mod tests { yield_now().await; // Use the config-driven client to connect over the Unix socket - let channel = client_config.to_channel().await?; + let channel = client_config.to_grpc_channel().await?; let mut client = GreeterClient::new(channel); let request = tonic::Request::new(HelloRequest { @@ -314,7 +314,7 @@ mod tests { // create a new client with wrong credentials let channel = client_config .with_auth(auth_wrong_client_config) - .to_channel() + .to_grpc_channel() .await?; let mut client = GreeterClient::new(channel); diff --git a/data-plane/core/controller/Cargo.toml b/data-plane/core/controller/Cargo.toml index 0b2a6f586..e6f840933 100644 --- a/data-plane/core/controller/Cargo.toml +++ b/data-plane/core/controller/Cargo.toml @@ -10,18 +10,44 @@ include = ["src/**", "build.rs", "proto/**"] [lib] name = "slim_controller" +[features] +default = ["native"] +native = [ + "agntcy-slim-auth/native", + "agntcy-slim-config/native", + "agntcy-slim-datapath/native", + "agntcy-slim-session/native", + "agntcy-slim-signal/native", + "agntcy-slim-tracing/native", + "display-error-chain", + "drain", + "h2", + "tokio", + "tokio-stream", + "tokio-util", + "tonic", + "tonic-prost", +] +wasm = [ + "agntcy-slim-auth/wasm", + "agntcy-slim-config/wasm", + "agntcy-slim-datapath/wasm", + "agntcy-slim-session/wasm", + "agntcy-slim-tracing/wasm", +] + [dependencies] agntcy-slim-auth = { workspace = true } agntcy-slim-config = { workspace = true } agntcy-slim-datapath = { workspace = true } agntcy-slim-session = { workspace = true } -agntcy-slim-signal = { workspace = true } +agntcy-slim-signal = { workspace = true, optional = true } agntcy-slim-tracing = { workspace = true } agntcy-slim-version = { workspace = true } -display-error-chain = { workspace = true } -drain = { workspace = true } +display-error-chain = { workspace = true, optional = true } +drain = { workspace = true, optional = true } duration-string = { workspace = true } -h2 = { workspace = true } +h2 = { workspace = true, optional = true } parking_lot = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } @@ -29,11 +55,11 @@ rand = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } -tokio-stream = { workspace = true } -tokio-util = { workspace = true } -tonic = { workspace = true } -tonic-prost = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"], optional = true } +tokio-stream = { workspace = true, optional = true } +tokio-util = { workspace = true, optional = true } +tonic = { workspace = true, optional = true } +tonic-prost = { workspace = true, optional = true } tracing = { workspace = true } uuid = { workspace = true, features = ["v4"] } diff --git a/data-plane/core/controller/build.rs b/data-plane/core/controller/build.rs index b2c9ccdf5..ba95b3b08 100644 --- a/data-plane/core/controller/build.rs +++ b/data-plane/core/controller/build.rs @@ -33,6 +33,8 @@ fn main() { tonic_prost_build::configure() .out_dir("src/api/gen") + .server_mod_attribute("controller.proto.v1", "#[cfg(feature = \"native\")]") + .client_mod_attribute("controller.proto.v1", "#[cfg(feature = \"native\")]") .compile_protos( &[proto_file.to_str().unwrap()], &[std::path::Path::new(&manifest_dir) diff --git a/data-plane/core/controller/src/api/gen/controller.proto.v1.rs b/data-plane/core/controller/src/api/gen/controller.proto.v1.rs index 5bade8b9d..1cddf364c 100644 --- a/data-plane/core/controller/src/api/gen/controller.proto.v1.rs +++ b/data-plane/core/controller/src/api/gen/controller.proto.v1.rs @@ -337,6 +337,7 @@ impl ConnectionDirection { } } /// Generated client implementations. +#[cfg(feature = "native")] pub mod controller_service_client { #![allow( unused_variables, @@ -459,6 +460,7 @@ pub mod controller_service_client { } } /// Generated server implementations. +#[cfg(feature = "native")] pub mod controller_service_server { #![allow( unused_variables, diff --git a/data-plane/core/controller/src/config.rs b/data-plane/core/controller/src/config.rs index 6d95a99b3..792108647 100644 --- a/data-plane/core/controller/src/config.rs +++ b/data-plane/core/controller/src/config.rs @@ -7,10 +7,10 @@ use serde::Deserialize; use slim_auth::auth_provider::{AuthProvider, AuthVerifier}; use slim_config::auth::identity::{IdentityProviderConfig, IdentityVerifierConfig}; +use slim_config::client::ClientConfig; use slim_config::component::configuration::Configuration; use slim_config::component::id::ID; -use slim_config::grpc::client::ClientConfig; -use slim_config::grpc::server::ServerConfig; +use slim_config::server::ServerConfig; use slim_datapath::message_processing::MessageProcessor; use crate::errors::ControllerError; @@ -190,9 +190,9 @@ mod tests { use super::*; use slim_config::auth::jwt::Config as JwtConfig; use slim_config::auth::static_jwt::Config as StaticJwtConfig; + use slim_config::client::ClientConfig; use slim_config::component::id::{ID, Kind}; - use slim_config::grpc::client::ClientConfig; - use slim_config::grpc::server::ServerConfig; + use slim_config::server::ServerConfig; use slim_datapath::message_processing::MessageProcessor; use slim_testing::utils::TEST_VALID_SECRET; use std::sync::Arc; diff --git a/data-plane/core/controller/src/errors.rs b/data-plane/core/controller/src/errors.rs index 11dcbb35a..a2056f3b1 100644 --- a/data-plane/core/controller/src/errors.rs +++ b/data-plane/core/controller/src/errors.rs @@ -2,14 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 use slim_auth::errors::AuthError; +#[cfg(feature = "native")] use slim_config::grpc::errors::ConfigError; use slim_datapath::errors::DataPathError; use thiserror::Error; +#[cfg(feature = "native")] use tonic::Status; #[derive(Error, Debug)] pub enum ControllerError { // Configuration / setup + #[cfg(feature = "native")] #[error("configuration error")] ConfigError(#[from] ConfigError), @@ -20,6 +23,7 @@ pub enum ControllerError { AlreadyStopped, #[error("timeout waiting for shutdown to complete")] ShutdownTimeout, + #[cfg(feature = "native")] #[error("grpc error")] GrpcError(#[from] Status), diff --git a/data-plane/core/controller/src/lib.rs b/data-plane/core/controller/src/lib.rs index 47bbc5e56..9d992e909 100644 --- a/data-plane/core/controller/src/lib.rs +++ b/data-plane/core/controller/src/lib.rs @@ -2,6 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 pub mod api; +#[cfg(feature = "native")] pub mod config; pub mod errors; +#[cfg(feature = "native")] pub mod service; diff --git a/data-plane/core/controller/src/service.rs b/data-plane/core/controller/src/service.rs index 561fd9a7b..3042a6a44 100644 --- a/data-plane/core/controller/src/service.rs +++ b/data-plane/core/controller/src/service.rs @@ -10,7 +10,7 @@ use std::vec; use display_error_chain::ErrorChainExt; use slim_config::component::id::ID; -use slim_config::grpc::server::ServerConfig; +use slim_config::server::ServerConfig; use slim_session::SessionMessage; use slim_session::subscription_manager::SubscriptionManager; use tokio::sync::mpsc; @@ -35,7 +35,7 @@ use crate::errors::ControllerError; use prost_types::Struct; use slim_auth::auth_provider::{AuthProvider, AuthVerifier}; use slim_auth::traits::TokenProvider; -use slim_config::grpc::client::ClientConfig; +use slim_config::client::ClientConfig; use slim_datapath::api::{ CommandPayload, Content, MessageType::Link as LinkType, MessageType::Publish, MessageType::Subscribe, MessageType::SubscriptionAck as SubscriptionAckType, @@ -1885,7 +1885,7 @@ impl ControllerService { ) -> Result>, ControllerError> { info!(%config.endpoint, "connecting to control plane"); - let channel = config.to_channel().await?; + let channel = config.to_grpc_channel().await?; let mut client = ControllerServiceClient::new(channel.clone()); let (tx, rx) = mpsc::channel::>(128); diff --git a/data-plane/core/datapath/Cargo.toml b/data-plane/core/datapath/Cargo.toml index c73d397ab..2733eeff6 100644 --- a/data-plane/core/datapath/Cargo.toml +++ b/data-plane/core/datapath/Cargo.toml @@ -5,37 +5,88 @@ edition = { workspace = true } license = { workspace = true } description = "Core data plane functionality for SLIM" +[package.metadata.cargo-machete] +ignored = ["getrandom"] + [lib] name = "slim_datapath" [features] -default = [] +default = ["native"] +native = [ + "dep:agntcy-slim-config", + "agntcy-slim-config/native", + "dep:agntcy-slim-tracing", + "agntcy-slim-tracing/native", + "dep:display-error-chain", + "dep:drain", + "dep:fastwebsockets", + "dep:h2", + "dep:opentelemetry", + "dep:tokio", + "dep:tokio-stream", + "dep:tokio-util", + "dep:tonic", + "dep:tonic-prost", + "dep:tracing-opentelemetry", +] otel_tracing = ["dep:opentelemetry", "dep:tracing-opentelemetry"] +wasm = [ + "dep:agntcy-slim-config", + "agntcy-slim-config/wasm", + "dep:agntcy-slim-tracing", + "agntcy-slim-tracing/wasm", + "dep:display-error-chain", + "dep:getrandom", + "dep:tokio_with_wasm", + "dep:tokio-stream", + "dep:wasm-bindgen-futures", + "uuid/js", +] [dependencies] -agntcy-slim-config = { workspace = true } -agntcy-slim-tracing = { workspace = true } + +# Optional deps gated by the `native` / `wasm` / `otel_tracing` features above. +# `native` and `wasm` are mutually exclusive; cargo picks the right flavor of +# `agntcy-slim-config` / `agntcy-slim-tracing` via their own feature lists. +agntcy-slim-config = { workspace = true, optional = true, default-features = false } +agntcy-slim-tracing = { workspace = true, optional = true, default-features = false } + +# Always-available core deps (used by `messages`, `tables`, `api` regardless of +# the active runtime feature). agntcy-slim-version = { workspace = true } -display-error-chain = { workspace = true } -drain = { workspace = true } -h2 = { workspace = true } +display-error-chain = { workspace = true, optional = true } +drain = { workspace = true, optional = true } +fastwebsockets = { workspace = true, optional = true } +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } +h2 = { workspace = true, optional = true } +http = { workspace = true } opentelemetry = { workspace = true, optional = true } parking_lot = { workspace = true } prost = { workspace = true } rand = { workspace = true } semver = { workspace = true } -serde = { workspace = true } +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } -tokio-stream = { workspace = true } -tokio-util = { workspace = true } -tonic = { workspace = true } -tonic-prost = { workspace = true } +tokio = { workspace = true, optional = true } +tokio-stream = { workspace = true, optional = true } +tokio-util = { workspace = true, optional = true } +tokio_with_wasm = { version = "0.9", features = ["rt", "sync", "time", "macros"], optional = true } +tonic = { workspace = true, optional = true } +tonic-prost = { workspace = true, optional = true } tracing = { workspace = true } tracing-opentelemetry = { workspace = true, optional = true } twox-hash = { workspace = true } uuid = { workspace = true, features = ["v4"] } +wasm-bindgen-futures = { version = "0.4", optional = true } + +# Browser-only websocket transport for the wasm build. `gloo-net` only compiles +# for `wasm32-*` targets, so it is added unconditionally for that target and +# omitted everywhere else; it is gated on the `wasm` feature inside `lib.rs`. +[target.'cfg(target_arch = "wasm32")'.dependencies] +futures = { workspace = true } +gloo-net = { version = "0.6" } [build-dependencies] protoc-bin-vendored = { workspace = true } diff --git a/data-plane/core/datapath/build.rs b/data-plane/core/datapath/build.rs index 079cf1bb7..9a60fe58e 100644 --- a/data-plane/core/datapath/build.rs +++ b/data-plane/core/datapath/build.rs @@ -13,6 +13,10 @@ fn main() { tonic_prost_build::configure() .out_dir("src/api/gen") + // Gate the tonic gRPC service modules behind the "native" feature so the + // generated proto file compiles on wasm32 targets where tonic is unavailable. + .server_mod_attribute("dataplane.proto.v1", "#[cfg(feature = \"native\")]") + .client_mod_attribute("dataplane.proto.v1", "#[cfg(feature = \"native\")]") .compile_protos(&["proto/v1/data_plane.proto"], &["proto/v1"]) .unwrap(); } diff --git a/data-plane/core/datapath/src/api.rs b/data-plane/core/datapath/src/api.rs index 08302aaec..dff3c6711 100644 --- a/data-plane/core/datapath/src/api.rs +++ b/data-plane/core/datapath/src/api.rs @@ -4,6 +4,7 @@ //! gRPC bindings for data plane service. pub(crate) mod proto; +// Proto message types — always available (pure prost, no tonic needed) pub use proto::dataplane::v1::ApplicationPayload; pub use proto::dataplane::v1::CommandPayload; pub use proto::dataplane::v1::Content; @@ -32,8 +33,6 @@ pub use proto::dataplane::v1::SlimHeader; pub use proto::dataplane::v1::Subscribe as ProtoSubscribe; pub use proto::dataplane::v1::SubscriptionAck as ProtoSubscriptionAck; pub use proto::dataplane::v1::Unsubscribe as ProtoUnsubscribe; -pub use proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient; -pub use proto::dataplane::v1::data_plane_service_server::DataPlaneServiceServer; pub use proto::dataplane::v1::link::LinkType as ProtoLinkType; pub use proto::dataplane::v1::message::MessageType; pub use proto::dataplane::v1::message::MessageType::Link as ProtoLinkMessageType; @@ -41,3 +40,9 @@ pub use proto::dataplane::v1::message::MessageType::Publish as ProtoPublishType; pub use proto::dataplane::v1::message::MessageType::Subscribe as ProtoSubscribeType; pub use proto::dataplane::v1::message::MessageType::SubscriptionAck as ProtoSubscriptionAckType; pub use proto::dataplane::v1::message::MessageType::Unsubscribe as ProtoUnsubscribeType; + +// gRPC service types — native only (depend on tonic) +#[cfg(feature = "native")] +pub use proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient; +#[cfg(feature = "native")] +pub use proto::dataplane::v1::data_plane_service_server::DataPlaneServiceServer; diff --git a/data-plane/core/datapath/src/api/gen/dataplane.proto.v1.rs b/data-plane/core/datapath/src/api/gen/dataplane.proto.v1.rs index 81d8a949f..0e8b82918 100644 --- a/data-plane/core/datapath/src/api/gen/dataplane.proto.v1.rs +++ b/data-plane/core/datapath/src/api/gen/dataplane.proto.v1.rs @@ -466,6 +466,7 @@ impl SessionMessageType { } } /// Generated client implementations. +#[cfg(feature = "native")] pub mod data_plane_service_client { #![allow( unused_variables, @@ -585,6 +586,7 @@ pub mod data_plane_service_client { } } /// Generated server implementations. +#[cfg(feature = "native")] pub mod data_plane_service_server { #![allow( unused_variables, diff --git a/data-plane/core/datapath/src/api/proto.rs b/data-plane/core/datapath/src/api/proto.rs index 289b97670..7101d6ac5 100644 --- a/data-plane/core/datapath/src/api/proto.rs +++ b/data-plane/core/datapath/src/api/proto.rs @@ -3,6 +3,9 @@ pub mod dataplane { pub mod v1 { + // The generated proto file contains both prost message types (always available) + // and tonic gRPC service stubs (native-only). We include the full file and + // rely on the cfg-gated re-exports in api.rs to control visibility. include!("gen/dataplane.proto.v1.rs"); } } diff --git a/data-plane/core/datapath/src/connection.rs b/data-plane/core/datapath/src/connection.rs index bd4c78268..c81db1149 100644 --- a/data-plane/core/datapath/src/connection.rs +++ b/data-plane/core/datapath/src/connection.rs @@ -1,15 +1,15 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +use crate::Status; use crate::api::proto::dataplane::v1::Message; +use crate::runtime::CancellationToken; use parking_lot::RwLock; use semver::Version; -use slim_config::grpc::client::{ClientConfig, is_valid_uuid_v4}; +use slim_config::client::{ClientConfig, is_valid_uuid_v4}; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::mpsc; -use tokio_util::sync::CancellationToken; -use tonic::Status; /// Negotiation state shared between link negotiation fields. /// Kept under one lock so that the check-and-set is atomic. @@ -155,6 +155,7 @@ impl Connection { /// Set the link identifier at construction time so it is available the moment the /// connection enters the table, before the negotiation message is sent. + #[cfg(feature = "native")] pub(crate) fn with_link_id(self, link_id: String) -> Self { self.negotiation.write().link_id = Some(link_id); self diff --git a/data-plane/core/datapath/src/errors.rs b/data-plane/core/datapath/src/errors.rs index f74a51baf..d5ee3b574 100644 --- a/data-plane/core/datapath/src/errors.rs +++ b/data-plane/core/datapath/src/errors.rs @@ -4,7 +4,6 @@ use crate::api::ProtoSessionMessageType; use crate::api::proto::dataplane::v1::Message; use crate::messages::{Name, utils::MessageError}; -use slim_config::grpc::errors::ConfigError; use thiserror::Error; /// DataPath and subscription table errors merged into a single enum. @@ -15,6 +14,7 @@ pub enum DataPathError { ConnectionError, #[error("disconnection error")] DisconnectionError(u64), + #[cfg(feature = "native")] #[error("grpc error")] GrpcError(#[from] tonic::Status), @@ -55,8 +55,9 @@ pub enum DataPathError { }, // Configuration error + #[cfg(any(feature = "native", feature = "wasm"))] #[error("configuration error")] - ConfigurationError(#[from] ConfigError), + ConfigurationError(#[from] slim_config::grpc::errors::ConfigError), // Remote subscription ACK errors #[error("remote subscription ack timed out after {0} retries")] diff --git a/data-plane/core/datapath/src/lib.rs b/data-plane/core/datapath/src/lib.rs index e0227ca47..2ed47e23d 100644 --- a/data-plane/core/datapath/src/lib.rs +++ b/data-plane/core/datapath/src/lib.rs @@ -1,17 +1,82 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +// On wasm, alias tokio_with_wasm as tokio so all data plane code can keep +// using `tokio::*` paths uniformly across native and wasm builds (matching +// the pattern used by `slim-session`). +#[cfg(all(feature = "wasm", not(feature = "native")))] +extern crate tokio_with_wasm as tokio; + pub mod api; pub mod errors; -pub mod message_processing; pub mod messages; pub mod tables; +#[cfg(any(feature = "native", feature = "wasm"))] +pub mod message_processing; + +#[cfg(any(feature = "native", feature = "wasm"))] mod connection; +#[cfg(any(feature = "native", feature = "wasm"))] mod forwarder; #[cfg(feature = "otel_tracing")] mod otel_tracing; +#[cfg(any(feature = "native", feature = "wasm"))] mod recovery; +#[cfg(any(feature = "native", feature = "wasm"))] +pub mod runtime; +#[cfg(any(feature = "native", feature = "wasm"))] pub(crate) mod subscription_ack; +#[cfg(any(feature = "native", feature = "wasm"))] +mod websocket; +#[cfg(feature = "native")] pub use tonic::Status; + +/// Lightweight Status type for WASM builds where tonic is not available. +#[cfg(not(feature = "native"))] +#[derive(Debug, Clone)] +pub struct Status { + code: u32, + message: String, +} + +#[cfg(not(feature = "native"))] +impl Status { + pub fn new(code: u32, message: impl Into) -> Self { + Self { + code, + message: message.into(), + } + } + + pub fn internal(message: impl Into) -> Self { + Self::new(13, message) + } + + pub fn unavailable(message: impl Into) -> Self { + Self::new(14, message) + } + + pub fn invalid_argument(message: impl Into) -> Self { + Self::new(3, message) + } + + pub fn code(&self) -> u32 { + self.code + } + + pub fn message(&self) -> &str { + &self.message + } +} + +#[cfg(not(feature = "native"))] +impl std::fmt::Display for Status { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Status(code={}, message={})", self.code, self.message) + } +} + +#[cfg(not(feature = "native"))] +impl std::error::Error for Status {} diff --git a/data-plane/core/datapath/src/message_processing.rs b/data-plane/core/datapath/src/message_processing.rs index 9a75b7913..7a966e7a5 100644 --- a/data-plane/core/datapath/src/message_processing.rs +++ b/data-plane/core/datapath/src/message_processing.rs @@ -3,25 +3,35 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; -use std::{pin::Pin, sync::Arc}; +use std::sync::Arc; +#[cfg(feature = "native")] +use std::pin::Pin; + +#[cfg(feature = "native")] use crate::api::DataPlaneServiceServer; use display_error_chain::ErrorChainExt; use parking_lot::RwLock; +use slim_config::client::ClientConfig; +use slim_config::client::TransportChannel; use slim_config::component::configuration::Configuration; -use slim_config::grpc::client::ClientConfig; -use slim_config::grpc::server::ServerConfig; +#[cfg(feature = "native")] +use slim_config::server::ServerConfig; +use slim_config::transport::TransportProtocol; +#[cfg(feature = "native")] +use slim_config::websocket::server as websocket_server; use tokio::sync::mpsc::{self, Sender}; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::{Stream, StreamExt}; -use tokio_util::sync::CancellationToken; - -use tonic::{Request, Response, Status}; -use tracing::{Instrument, debug, error, info}; +use crate::Status; #[cfg(feature = "otel_tracing")] use crate::otel_tracing; +use crate::runtime::CancellationToken; +#[cfg(feature = "native")] +use tonic::{Request, Response}; +use tracing::{Instrument, debug, error, info}; use crate::api::ProtoMessage; use crate::api::ProtoPublishType as PublishType; @@ -34,7 +44,9 @@ use crate::api::{ }; use semver; +#[cfg(feature = "native")] use crate::api::proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient; +#[cfg(feature = "native")] use crate::api::proto::dataplane::v1::data_plane_service_server::DataPlaneService; use crate::connection::{Channel, Connection, Type as ConnectionType}; use crate::errors::{DataPathError, MessageContext}; @@ -45,6 +57,7 @@ use crate::recovery::RecoveryTable; use crate::tables::connection_table::ConnectionTable; use crate::tables::remote_subscription_table::SubscriptionInfo; use crate::tables::subscription_table::SubscriptionTableImpl; +use crate::websocket; fn local_version() -> &'static str { slim_version::version() @@ -56,10 +69,10 @@ struct MessageProcessorInternal { forwarder: Forwarder, /// Drain signal to gracefully close all pending tasks - drain_signal: parking_lot::RwLock>, + drain_signal: parking_lot::RwLock>, ///Drain watch to receive drain signal - drain_watch: parking_lot::RwLock>, + drain_watch: parking_lot::RwLock>, /// Tx channel towards control plane tx_control_plane: RwLock>>>, @@ -91,7 +104,7 @@ impl MessageProcessor { } pub fn new_with_options(service_id: String, recovery_ttl: Option) -> Self { - let (signal, watch) = drain::channel(); + let (signal, watch) = crate::runtime::drain_channel(); let recovery_table = match recovery_ttl { Some(ttl) => RecoveryTable::new(ttl), None => RecoveryTable::default(), @@ -116,11 +129,23 @@ impl MessageProcessor { /// Run a data plane gRPC server using this message processor's drain watch. /// Returns a cancellation token that can be used to stop the server task. + #[cfg(feature = "native")] pub async fn run_server( &self, config: &ServerConfig, ) -> Result { debug!(%config, "starting dataplane server"); + match config.transport { + TransportProtocol::Grpc => self.run_grpc_server(config).await, + TransportProtocol::Websocket => self.run_websocket_server(config).await, + } + } + + #[cfg(feature = "native")] + async fn run_grpc_server( + &self, + config: &ServerConfig, + ) -> Result { let watch = self.get_drain_watch()?; // Wrap self in an Arc since the server builder expects an Arc let svc = Arc::new(self.clone()); @@ -131,6 +156,67 @@ impl MessageProcessor { Ok(res) } + #[cfg(feature = "native")] + async fn run_websocket_server( + &self, + config: &ServerConfig, + ) -> Result { + let watch = self.get_drain_watch()?; + let processor = self.clone(); + + let on_accepted: websocket_server::OnAcceptedWebSocket = Arc::new(move |accepted| { + let processor = processor.clone(); + Box::pin(async move { + let cancellation_token = CancellationToken::new(); + let streams = websocket::spawn_transport_tasks( + accepted.websocket, + cancellation_token.clone(), + ); + + let connection = + Connection::new(ConnectionType::Remote, Channel::Client(streams.outbound)) + .with_remote_addr(accepted.remote_addr) + .with_local_addr(accepted.local_addr) + .with_cancellation_token(Some(cancellation_token.clone())); + + debug!( + remote = ?connection.remote_addr(), + local = ?connection.local_addr(), + "new websocket connection received from remote", + ); + info!(telemetry = true, counter.num_active_connections = 1); + + let conn_index = match processor + .forwarder() + .on_connection_established(connection, None) + { + Some(index) => index, + None => { + error!("failed to add websocket connection to table"); + cancellation_token.cancel(); + return; + } + }; + + if let Err(err) = processor.process_stream( + streams.inbound, + conn_index, + None, + cancellation_token, + false, + false, + ) { + error!(error = %err.chain(), "error starting websocket processing stream"); + } + }) + }); + + config + .run_websocket_server(watch, on_accepted) + .await + .map_err(Into::into) + } + pub async fn shutdown(&self) -> Result<(), DataPathError> { // Take the drain signal let signal = self @@ -167,7 +253,7 @@ impl MessageProcessor { self.internal.sub_ack_manager.remove(subscription_id); } - fn get_drain_watch(&self) -> Result { + fn get_drain_watch(&self) -> Result { self.internal .drain_watch .read() @@ -223,6 +309,30 @@ impl MessageProcessor { existing_conn_index: Option, ) -> Result<(JoinHandle<()>, u64), DataPathError> { client_config.validate()?; + + match client_config.transport { + #[cfg(feature = "native")] + TransportProtocol::Grpc => { + self.try_to_connect_grpc(client_config, local, remote, existing_conn_index) + .await + } + #[cfg(not(feature = "native"))] + TransportProtocol::Grpc => Err(DataPathError::ConnectionError), + TransportProtocol::Websocket => { + self.try_to_connect_websocket(client_config, local, remote, existing_conn_index) + .await + } + } + } + + #[cfg(feature = "native")] + async fn try_to_connect_grpc( + &self, + client_config: ClientConfig, + local: Option, + remote: Option, + existing_conn_index: Option, + ) -> Result<(JoinHandle<()>, u64), DataPathError> { let mut watch = std::pin::pin!(self.get_drain_watch()?.signaled()); let channel = tokio::select! { @@ -234,6 +344,11 @@ impl MessageProcessor { } }; + let channel = match channel { + TransportChannel::Grpc(channel) => channel, + TransportChannel::Websocket(_) => return Err(DataPathError::ConnectionError), + }; + let mut client = DataPlaneServiceClient::new(channel); let (tx, rx) = mpsc::channel(128); @@ -297,6 +412,72 @@ impl MessageProcessor { Ok((handle, conn_index)) } + async fn try_to_connect_websocket( + &self, + client_config: ClientConfig, + local: Option, + remote: Option, + existing_conn_index: Option, + ) -> Result<(JoinHandle<()>, u64), DataPathError> { + let mut watch = std::pin::pin!(self.get_drain_watch()?.signaled()); + + let channel = tokio::select! { + _ = &mut watch => { + return Err(DataPathError::ShuttingDownError); + } + res = client_config.to_channel() => { + res? + } + }; + + #[cfg(feature = "native")] + let connection = match channel { + TransportChannel::Websocket(channel) => *channel, + TransportChannel::Grpc(_) => return Err(DataPathError::ConnectionError), + }; + #[cfg(not(feature = "native"))] + let connection = match channel { + TransportChannel::Websocket(channel) => *channel, + }; + + let cancellation_token = CancellationToken::new(); + let streams = + websocket::spawn_transport_tasks(connection.websocket, cancellation_token.clone()); + let connection = Connection::new(ConnectionType::Remote, Channel::Client(streams.outbound)) + .with_local_addr(local.or(connection.local_addr)) + .with_remote_addr(remote.or(connection.remote_addr)) + .with_config_data(Some(client_config.clone())) + .with_cancellation_token(Some(cancellation_token.clone())); + + debug!( + remote = ?connection.remote_addr(), + local = ?connection.local_addr(), + "new websocket connection initiated locally", + ); + + let conn_index = self + .forwarder() + .on_connection_established(connection, existing_conn_index) + .ok_or(DataPathError::ConnectionTableAddError)?; + + debug!( + %conn_index, + is_local = false, + "new websocket connection index", + ); + + let handle = self.process_stream( + streams.inbound, + conn_index, + Some(client_config), + cancellation_token, + false, + false, + )?; + + Ok((handle, conn_index)) + } + pub async fn connect( &self, client_config: ClientConfig, @@ -307,6 +488,51 @@ impl MessageProcessor { .await } + /// Browser-only: register an already-opened `gloo_net` WebSocket as a + /// remote connection. Mirrors what `try_to_connect_websocket` does on + /// native, but leaves the handshake (auth, URL building) to the caller. + /// + /// This is the bridge that lets `slim-wasm` keep its existing + /// `SharedSecret`-based query-param token flow while still benefiting + /// from the data plane's forwarder, connection table, and per-connection + /// process_stream loop. After this call, publishing/subscribing through + /// the session layer goes through `MessageProcessor::send_msg`, which + /// fans out across all registered connections via the subscription + /// table — exactly like the native multi-connection case. + #[cfg(all(feature = "wasm", not(feature = "native")))] + pub fn register_websocket( + &self, + websocket: gloo_net::websocket::futures::WebSocket, + client_config: Option, + ) -> Result<(JoinHandle<()>, u64), DataPathError> { + let cancellation_token = CancellationToken::new(); + let streams = websocket::spawn_transport_tasks(websocket, cancellation_token.clone()); + + let connection = Connection::new(ConnectionType::Remote, Channel::Client(streams.outbound)) + .with_config_data(client_config.clone()) + .with_cancellation_token(Some(cancellation_token.clone())); + + debug!("new wasm websocket connection registered"); + + let conn_index = self + .forwarder() + .on_connection_established(connection, None) + .ok_or(DataPathError::ConnectionTableAddError)?; + + debug!(%conn_index, "new wasm websocket connection index"); + + let handle = self.process_stream( + streams.inbound, + conn_index, + client_config, + cancellation_token, + false, + false, + )?; + + Ok((handle, conn_index)) + } + pub fn disconnect(&self, conn: u64) -> Result { let connection = match self.forwarder().get_connection(conn) { Some(c) => c, @@ -407,7 +633,12 @@ impl MessageProcessor { let connection = self.forwarder().get_connection(out_conn); match connection { Some(conn) => { + // Telemetry context (when enabled via the `otel_tracing` feature) is + // already prepared by the caller (`send_msg` / `match_and_forward_msg`), + // so this hot-path stays free of OTEL work. `clear_slim_header()` is + // a no-op for Link / SubscriptionAck messages. msg.clear_slim_header(); + match conn.channel() { Channel::Server(s) => { s.send(Ok(msg)) @@ -970,7 +1201,10 @@ impl MessageProcessor { let error_message = payload.to_json_string(); // create Status error + #[cfg(feature = "native")] let status = Status::new(tonic::Code::Internal, error_message); + #[cfg(not(feature = "native"))] + let status = Status::internal(error_message); if tx.send(Err(status)).await.is_err() { debug!(error = %err.chain(), "unable to notify the error to the local app"); @@ -1123,11 +1357,18 @@ impl MessageProcessor { } } Err(e) => { - if let Some(io_err) = MessageProcessor::match_for_io_error(&e) { - if io_err.kind() == std::io::ErrorKind::BrokenPipe { - info!(%conn_index, "connection closed by peer"); + #[cfg(feature = "native")] + { + if let Some(io_err) = MessageProcessor::match_for_io_error(&e) { + if io_err.kind() == std::io::ErrorKind::BrokenPipe { + info!(%conn_index, "connection closed by peer"); + } + } else { + error!(error = %e.chain(), "error receiving messages"); } - } else { + } + #[cfg(not(feature = "native"))] + { error!(error = %e.chain(), "error receiving messages"); } break; @@ -1258,6 +1499,7 @@ impl MessageProcessor { Ok(handle) } + #[cfg(feature = "native")] fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> { let mut err: &(dyn std::error::Error + 'static) = err_status; @@ -1287,6 +1529,7 @@ impl MessageProcessor { } } +#[cfg(feature = "native")] #[tonic::async_trait] impl DataPlaneService for MessageProcessor { type OpenChannelStream = Pin> + Send + 'static>>; diff --git a/data-plane/core/datapath/src/recovery.rs b/data-plane/core/datapath/src/recovery.rs index de535dd3c..ff4bde2ef 100644 --- a/data-plane/core/datapath/src/recovery.rs +++ b/data-plane/core/datapath/src/recovery.rs @@ -91,8 +91,12 @@ impl RecoveryTable { /// /// `on_expire` is only called when the entry is still present at expiry time, i.e. recovery /// has not already consumed it via [`RecoveryTable::take`]. - pub(crate) fn spawn_ttl_task(&self, link_id: String, drain: drain::Watch, on_expire: F) - where + pub(crate) fn spawn_ttl_task( + &self, + link_id: String, + drain: crate::runtime::DrainWatch, + on_expire: F, + ) where F: FnOnce(RecoveryEntry) -> Fut + Send + 'static, Fut: Future + Send + 'static, { @@ -226,7 +230,7 @@ mod tests { t.store("link-1".into(), local, remote); let (tx, rx) = oneshot::channel::<()>(); - let (_signal, watch) = drain::channel(); + let (_signal, watch) = crate::runtime::drain_channel(); t.spawn_ttl_task("link-1".into(), watch, move |_entry| async move { let _ = tx.send(()); @@ -247,7 +251,7 @@ mod tests { let fired = Arc::new(AtomicBool::new(false)); let fired_clone = fired.clone(); - let (_signal, watch) = drain::channel(); + let (_signal, watch) = crate::runtime::drain_channel(); t.spawn_ttl_task("link-1".into(), watch, move |_entry| { let f = fired_clone.clone(); @@ -278,7 +282,7 @@ mod tests { let fired = Arc::new(AtomicBool::new(false)); let fired_clone = fired.clone(); - let (signal, watch) = drain::channel(); + let (signal, watch) = crate::runtime::drain_channel(); t.spawn_ttl_task("link-1".into(), watch, move |_entry| { let f = fired_clone.clone(); diff --git a/data-plane/core/datapath/src/runtime.rs b/data-plane/core/datapath/src/runtime.rs new file mode 100644 index 000000000..4e241cc55 --- /dev/null +++ b/data-plane/core/datapath/src/runtime.rs @@ -0,0 +1,204 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! Platform abstractions used by the data plane. +//! +//! On native builds, [`CancellationToken`] is re-exported from +//! `tokio_util::sync` and the `drain` crate is used directly for graceful +//! shutdown. On wasm builds, `tokio_util` is unavailable and the `drain` crate +//! pulls in non-portable platform code, so we provide: +//! +//! * A small `CancellationToken` shim mirroring the `tokio_util` API surface +//! used by the data plane. +//! * A drain compatibility layer (`Signal`, `Watch`, `channel`) backed by the +//! cancellation token. The wasm shim is API-compatible with `drain` for the +//! subset that the data plane uses (`Watch::signaled`, `Signal::drain`), +//! but does not wait for outstanding watchers — it simply cancels. + +#[cfg(feature = "native")] +pub use tokio_util::sync::CancellationToken; + +#[cfg(feature = "native")] +pub use drain::{Signal as DrainSignal, Watch as DrainWatch}; + +#[cfg(feature = "native")] +pub fn drain_channel() -> (DrainSignal, DrainWatch) { + drain::channel() +} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +mod cancellation { + use parking_lot::Mutex; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::task::Waker; + + /// Browser-friendly subset of the `tokio_util::sync::CancellationToken` + /// API used by the data plane. WASM is single-threaded so the wakers + /// stored here are only ever polled on the JS event loop. + #[derive(Clone)] + pub struct CancellationToken { + cancelled: Arc, + wakers: Arc>>, + children: Arc>>, + } + + impl std::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CancellationToken") + .field("cancelled", &self.is_cancelled()) + .finish() + } + } + + impl CancellationToken { + pub fn new() -> Self { + Self { + cancelled: Arc::new(AtomicBool::new(false)), + wakers: Arc::new(Mutex::new(Vec::new())), + children: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn cancel(&self) { + self.cancelled.store(true, Ordering::Release); + let wakers = std::mem::take(&mut *self.wakers.lock()); + for waker in wakers { + waker.wake(); + } + let children = std::mem::take(&mut *self.children.lock()); + for child in children { + child.cancel(); + } + } + + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Acquire) + } + + pub fn child_token(&self) -> Self { + let child = Self::new(); + let mut children = self.children.lock(); + if self.is_cancelled() { + drop(children); + child.cancel(); + return child; + } + children.push(child.clone()); + child + } + + fn register_waker(&self, waker: &Waker) { + if self.is_cancelled() { + waker.wake_by_ref(); + return; + } + let mut wakers = self.wakers.lock(); + if !wakers.iter().any(|w| w.will_wake(waker)) { + wakers.push(waker.clone()); + } + } + + pub async fn cancelled(&self) { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + struct CancelledFuture<'a> { + token: &'a CancellationToken, + } + + impl Future for CancelledFuture<'_> { + type Output = (); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.token.is_cancelled() { + Poll::Ready(()) + } else { + self.token.register_waker(cx.waker()); + if self.token.is_cancelled() { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + } + + CancelledFuture { token: self }.await + } + } + + impl Default for CancellationToken { + fn default() -> Self { + Self::new() + } + } +} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub use cancellation::CancellationToken; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +mod drain_shim { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use super::CancellationToken; + + /// Browser stand-in for `drain::Signal`. Cancelling this signal cancels + /// the underlying token; the `drain` future returns immediately because + /// there are no native worker threads to wait on. + #[derive(Debug)] + pub struct DrainSignal { + token: CancellationToken, + } + + impl DrainSignal { + pub async fn drain(self) { + self.token.cancel(); + } + } + + /// Browser stand-in for `drain::Watch`. `signaled()` resolves once the + /// associated `DrainSignal` has been drained. + #[derive(Debug, Clone)] + pub struct DrainWatch { + token: CancellationToken, + } + + impl DrainWatch { + pub fn signaled(self) -> DrainSignaled { + DrainSignaled { token: self.token } + } + } + + /// Future returned by `DrainWatch::signaled()`. + pub struct DrainSignaled { + token: CancellationToken, + } + + impl Future for DrainSignaled { + type Output = (); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // Reuse the cancellation future logic by polling a fresh + // `cancelled()` future on each call. Because WASM is single + // threaded, this allocation is cheap. + let mut fut = Box::pin(self.token.cancelled()); + fut.as_mut().poll(cx) + } + } + + pub fn drain_channel() -> (DrainSignal, DrainWatch) { + let token = CancellationToken::new(); + ( + DrainSignal { + token: token.clone(), + }, + DrainWatch { token }, + ) + } +} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub use drain_shim::{DrainSignal, DrainWatch, drain_channel}; diff --git a/data-plane/core/datapath/src/websocket.rs b/data-plane/core/datapath/src/websocket.rs new file mode 100644 index 000000000..daef8ade8 --- /dev/null +++ b/data-plane/core/datapath/src/websocket.rs @@ -0,0 +1,14 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(feature = "native")] +pub(crate) mod stream; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub(crate) mod stream_wasm; + +#[cfg(feature = "native")] +pub(crate) use stream::spawn_transport_tasks; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub(crate) use stream_wasm::spawn_transport_tasks; diff --git a/data-plane/core/datapath/src/websocket/stream.rs b/data-plane/core/datapath/src/websocket/stream.rs new file mode 100644 index 000000000..2eb4cd2d3 --- /dev/null +++ b/data-plane/core/datapath/src/websocket/stream.rs @@ -0,0 +1,126 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use fastwebsockets::{FragmentCollectorRead, Frame, OpCode, WebSocketError}; +use prost::Message as ProstMessage; +use slim_config::websocket::common::UpgradedWebSocket; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tonic::Status; +use tracing::{debug, warn}; + +use crate::api::proto::dataplane::v1::Message; + +pub(crate) struct WebSocketStreams { + pub(crate) inbound: ReceiverStream>, + pub(crate) outbound: mpsc::Sender, +} + +pub(crate) fn spawn_transport_tasks( + websocket: UpgradedWebSocket, + cancellation_token: CancellationToken, +) -> WebSocketStreams { + let (mut read_half, mut write_half) = websocket.split(tokio::io::split); + read_half.set_auto_close(false); + read_half.set_auto_pong(false); + + let mut reader = FragmentCollectorRead::new(read_half); + + let (tx_inbound, rx_inbound) = mpsc::channel::>(128); + let (tx_outbound, mut rx_outbound) = mpsc::channel::(128); + + let read_cancel = cancellation_token.clone(); + tokio::spawn(async move { + let mut noop_send = |_frame: Frame<'_>| async move { Result::<(), WebSocketError>::Ok(()) }; + + loop { + tokio::select! { + _ = read_cancel.cancelled() => { + break; + } + frame = reader.read_frame::<_, WebSocketError>(&mut noop_send) => { + let frame = match frame { + Ok(frame) => frame, + Err(WebSocketError::ConnectionClosed | WebSocketError::UnexpectedEOF) => { + debug!("websocket read loop closed by peer"); + break; + } + Err(err) => { + let _ = tx_inbound + .send(Err(Status::unavailable(format!( + "websocket read error: {err}", + )))) + .await; + break; + } + }; + + match frame.opcode { + OpCode::Binary => match Message::decode(&frame.payload[..]) { + Ok(msg) => { + if tx_inbound.send(Ok(msg)).await.is_err() { + break; + } + } + Err(err) => { + let _ = tx_inbound + .send(Err(Status::invalid_argument(format!( + "invalid protobuf payload in websocket frame: {err}", + )))) + .await; + } + }, + OpCode::Close => break, + OpCode::Text => { + warn!("ignoring text websocket frame, expected binary protobuf frame"); + } + OpCode::Ping | OpCode::Pong | OpCode::Continuation => { + // Control and continuation frames are handled by fastwebsockets; + // only complete binary payloads are forwarded to datapath processing. + } + } + } + } + } + + read_cancel.cancel(); + }); + + let write_cancel = cancellation_token.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + _ = write_cancel.cancelled() => { + let _ = write_half.write_frame(Frame::close(1000, &[])).await; + let _ = write_half.flush().await; + break; + } + maybe_msg = rx_outbound.recv() => { + let msg = match maybe_msg { + Some(msg) => msg, + None => break, + }; + + let payload = msg.encode_to_vec(); + if let Err(err) = write_half.write_frame(Frame::binary(payload.into())).await { + warn!(error = %err, "websocket write error"); + break; + } + + if let Err(err) = write_half.flush().await { + warn!(error = %err, "websocket flush error"); + break; + } + } + } + } + + write_cancel.cancel(); + }); + + WebSocketStreams { + inbound: ReceiverStream::new(rx_inbound), + outbound: tx_outbound, + } +} diff --git a/data-plane/core/datapath/src/websocket/stream_wasm.rs b/data-plane/core/datapath/src/websocket/stream_wasm.rs new file mode 100644 index 000000000..36c2bf163 --- /dev/null +++ b/data-plane/core/datapath/src/websocket/stream_wasm.rs @@ -0,0 +1,118 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! Browser-side websocket transport tasks. +//! +//! Mirrors `websocket/stream.rs` but uses [`gloo_net::websocket::futures::WebSocket`] +//! and `wasm_bindgen_futures::spawn_local` so the data plane's +//! `MessageProcessor::try_to_connect_websocket` can drive a websocket +//! connection from `wasm32-unknown-unknown` builds. + +use futures::sink::SinkExt; +use futures::stream::StreamExt; +use prost::Message as ProstMessage; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, warn}; + +use crate::Status; +use crate::api::proto::dataplane::v1::Message; +use crate::runtime::CancellationToken; + +pub(crate) struct WebSocketStreams { + pub(crate) inbound: ReceiverStream>, + pub(crate) outbound: mpsc::Sender, +} + +/// Spawn read/write loops over a browser `WebSocket`. The returned channels +/// look identical to the native flavor in `stream.rs`, so +/// [`crate::message_processing::MessageProcessor`] can plug them into its +/// [`crate::connection::Connection`] and `process_stream` machinery without +/// caring which transport produced them. +pub(crate) fn spawn_transport_tasks( + websocket: gloo_net::websocket::futures::WebSocket, + cancellation_token: CancellationToken, +) -> WebSocketStreams { + let (ws_sink, ws_stream) = websocket.split(); + + let (tx_inbound, rx_inbound) = mpsc::channel::>(128); + let (tx_outbound, rx_outbound) = mpsc::channel::(128); + + let read_cancel = cancellation_token.clone(); + wasm_bindgen_futures::spawn_local(async move { + let mut stream = ws_stream; + loop { + let next = tokio::select! { + _ = read_cancel.cancelled() => break, + next = stream.next() => next, + }; + + match next { + None => { + debug!("websocket read loop closed by peer"); + break; + } + Some(Ok(gloo_net::websocket::Message::Bytes(bytes))) => { + match Message::decode(&bytes[..]) { + Ok(msg) => { + if tx_inbound.send(Ok(msg)).await.is_err() { + break; + } + } + Err(err) => { + let _ = tx_inbound + .send(Err(Status::invalid_argument(format!( + "invalid protobuf payload in websocket frame: {err}", + )))) + .await; + } + } + } + Some(Ok(gloo_net::websocket::Message::Text(_))) => { + warn!("ignoring text websocket frame, expected binary protobuf frame"); + } + Some(Err(err)) => { + let _ = tx_inbound + .send(Err(Status::unavailable(format!( + "websocket read error: {err}", + )))) + .await; + break; + } + } + } + read_cancel.cancel(); + }); + + let write_cancel = cancellation_token.clone(); + wasm_bindgen_futures::spawn_local(async move { + let mut sink = ws_sink; + let mut rx = rx_outbound; + loop { + let next = tokio::select! { + _ = write_cancel.cancelled() => None, + next = rx.recv() => next, + }; + + let msg = match next { + Some(msg) => msg, + None => break, + }; + + let payload = msg.encode_to_vec(); + if let Err(err) = sink + .send(gloo_net::websocket::Message::Bytes(payload)) + .await + { + warn!(error = %err, "websocket write error"); + break; + } + } + write_cancel.cancel(); + }); + + WebSocketStreams { + inbound: ReceiverStream::new(rx_inbound), + outbound: tx_outbound, + } +} diff --git a/data-plane/core/datapath/tests/data_path_test.rs b/data-plane/core/datapath/tests/data_path_test.rs index 208861460..807f6be1e 100644 --- a/data-plane/core/datapath/tests/data_path_test.rs +++ b/data-plane/core/datapath/tests/data_path_test.rs @@ -9,11 +9,64 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use slim_config::grpc::{client::ClientConfig, server::ServerConfig}; + use slim_config::tls::client::TlsClientConfig; + use slim_config::tls::server::TlsServerConfig; + use slim_config::transport::TransportProtocol; + use slim_config::{client::ClientConfig, server::ServerConfig}; use slim_datapath::api::{DataPlaneServiceServer, ProtoMessage as Message}; use slim_datapath::errors::DataPathError; use slim_datapath::message_processing::MessageProcessor; + async fn run_transport_roundtrip( + server_conf: ServerConfig, + client_conf: ClientConfig, + connect_addr: Option, + transport_label: &str, + ) -> u64 { + let processor = MessageProcessor::new(); + let server_token = processor + .run_server(&server_conf) + .await + .unwrap_or_else(|e| panic!("failed to start {transport_label} dataplane server: {e}")); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let (_handle, conn_index) = processor + .connect(client_conf, None, connect_addr) + .await + .unwrap_or_else(|e| panic!("failed to connect {transport_label} client: {e}")); + + for _ in 0..3 { + processor + .send_msg(make_message("org", "namespace", "type"), conn_index) + .await + .unwrap_or_else(|e| { + panic!("failed to send message over {transport_label} client transport: {e}") + }); + } + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + for _ in 0..3 { + processor + .send_msg(make_message("org", "namespace", "type"), 0) + .await + .unwrap_or_else(|e| { + panic!("failed to send message over {transport_label} server transport: {e}") + }); + } + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let _ = processor.disconnect(conn_index); + server_token.cancel(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + processor + .shutdown() + .await + .unwrap_or_else(|e| panic!("failed to shutdown {transport_label} processor: {e}")); + + conn_index + } + #[tokio::test] #[traced_test] async fn test_connection() { @@ -206,6 +259,81 @@ mod tests { ); } + #[tokio::test] + #[traced_test] + async fn test_transport_roundtrip_grpc() { + let server_conf = ServerConfig::with_endpoint("127.0.0.1:51060") + .with_tls_settings(TlsServerConfig::insecure()); + + let client_conf = ClientConfig::with_endpoint("http://127.0.0.1:51060") + .with_tls_setting(TlsClientConfig::insecure()); + + let conn_index = run_transport_roundtrip( + server_conf, + client_conf, + Some(SocketAddr::from(([127, 0, 0, 1], 51060))), + "grpc", + ) + .await; + + assert!(logs_contain( + "received message from connection conn_index=0" + )); + let expected = format!("received message from connection conn_index={conn_index}"); + assert!(logs_contain(expected.as_str())); + } + + #[tokio::test] + #[traced_test] + async fn test_websocket_connection_ws() { + let server_conf = ServerConfig::with_endpoint("ws://127.0.0.1:51061") + .with_transport(TransportProtocol::Websocket) + .with_tls_settings(TlsServerConfig::insecure()); + + let client_conf = ClientConfig::with_endpoint("ws://127.0.0.1:51061") + .with_transport(TransportProtocol::Websocket) + .with_tls_setting(TlsClientConfig::insecure()); + let conn_index = + run_transport_roundtrip(server_conf, client_conf, None, "websocket ws").await; + + assert!(logs_contain( + "received message from connection conn_index=0" + )); + let expected = format!("received message from connection conn_index={conn_index}"); + assert!(logs_contain(expected.as_str())); + } + + #[tokio::test] + #[traced_test] + async fn test_websocket_connection_wss() { + let grpc_tls_testdata = format!("{}/../config/testdata/grpc", env!("CARGO_MANIFEST_DIR")); + + let server_tls = TlsServerConfig::new().with_cert_and_key_file( + &format!("{}/server.crt", grpc_tls_testdata), + &format!("{}/server.key", grpc_tls_testdata), + ); + + let server_conf = ServerConfig::with_endpoint("wss://127.0.0.1:51062") + .with_transport(TransportProtocol::Websocket) + .with_tls_settings(server_tls); + + let client_tls = + TlsClientConfig::new().with_ca_file(&format!("{}/ca.crt", grpc_tls_testdata)); + + let client_conf = ClientConfig::with_endpoint("wss://127.0.0.1:51062") + .with_transport(TransportProtocol::Websocket) + .with_server_name("example1") + .with_tls_setting(client_tls); + let conn_index = + run_transport_roundtrip(server_conf, client_conf, None, "websocket wss").await; + + assert!(logs_contain( + "received message from connection conn_index=0" + )); + let expected = format!("received message from connection conn_index={conn_index}"); + assert!(logs_contain(expected.as_str())); + } + fn make_message(org: &str, ns: &str, name: &str) -> Message { let source = Name::from_strings([org, ns, name]).with_id(0); let name = Name::from_strings([org, ns, name]).with_id(1); diff --git a/data-plane/core/mls/Cargo.toml b/data-plane/core/mls/Cargo.toml index c23250bb4..5f6dea095 100644 --- a/data-plane/core/mls/Cargo.toml +++ b/data-plane/core/mls/Cargo.toml @@ -5,21 +5,55 @@ license = { workspace = true } version = "0.1.15" description = "Messaging Layer Security for SLIM data plane." +[package.metadata.cargo-machete] +# `async-trait` is invoked indirectly by the `maybe_async::must_be_async` +# macro expansion in `identity_provider.rs` when building with the +# `mls_build_async` cfg (the wasm path), so cargo-machete cannot see it. +ignored = ["getrandom", "async-trait"] + [lib] name = "slim_mls" +[features] +default = ["native"] +native = [ + "agntcy-slim-auth/native", + "agntcy-slim-datapath/native", + "dep:mls-rs-crypto-awslc", +] +wasm = [ + "agntcy-slim-auth/wasm", + "agntcy-slim-datapath/wasm", + "dep:mls-rs-crypto-webcrypto", + "dep:getrandom", +] +# Opt-in to the MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ciphersuite +# (Curve25519 + Ed25519). Only meaningful with the `native` backend because +# browser WebCrypto (used by the `wasm` backend) does not support +# Curve25519. By default the crate uses P256_AES128 so that native and +# WASM peers can interoperate in the same MLS group. +curve25519 = [] + [dependencies] -agntcy-slim-auth = { workspace = true } -agntcy-slim-datapath = { workspace = true } +agntcy-slim-auth = { workspace = true, default-features = false } +agntcy-slim-datapath = { workspace = true, default-features = false } agntcy-slim-version = { workspace = true } +async-trait = { workspace = true } base64 = { workspace = true } +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } hex = "0.4" +maybe-async = "0.2.10" mls-rs = { workspace = true } mls-rs-core = { workspace = true } -mls-rs-crypto-awslc = { workspace = true } +mls-rs-crypto-awslc = { workspace = true, optional = true } +mls-rs-crypto-webcrypto = { version = "0.14", optional = true } serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } [dev-dependencies] tempfile = "3.3" + +[lints.rust.unexpected_cfgs] +level = "warn" +check-cfg = ["cfg(mls_build_async)"] diff --git a/data-plane/core/mls/src/crypto.rs b/data-plane/core/mls/src/crypto.rs new file mode 100644 index 000000000..44b752070 --- /dev/null +++ b/data-plane/core/mls/src/crypto.rs @@ -0,0 +1,17 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! Crypto provider selection based on compilation target and features. +//! +//! - `native` feature (default): Uses AWS-LC crypto provider for native targets +//! - `wasm` feature: Uses WebCrypto provider for browser/WASM targets + +#[cfg(feature = "native")] +pub use mls_rs_crypto_awslc::AwsLcCryptoProvider as CryptoProviderImpl; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub use mls_rs_crypto_webcrypto::WebCryptoProvider as CryptoProviderImpl; + +pub fn default_crypto_provider() -> CryptoProviderImpl { + CryptoProviderImpl::default() +} diff --git a/data-plane/core/mls/src/identity_provider.rs b/data-plane/core/mls/src/identity_provider.rs index 7b2d9175f..7f82a4e98 100644 --- a/data-plane/core/mls/src/identity_provider.rs +++ b/data-plane/core/mls/src/identity_provider.rs @@ -74,13 +74,22 @@ where } } +// The `IdentityProvider` trait declared in `mls-rs-core` is wrapped by +// `maybe_async` so it is sync on native (no `mls_build_async`) and async on +// wasm (where `mls_build_async` is set in `data-plane/.cargo/config.toml` +// because `mls-rs-crypto-webcrypto` is async-only). We mirror the same +// `cfg_attr` pair on this impl so it picks up the matching shape on both +// sides; the methods themselves are written `async fn` and `must_be_sync` +// strips that on native. +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +#[cfg_attr(mls_build_async, maybe_async::must_be_async)] impl IdentityProvider for SlimIdentityProvider where V: Verifier + Send + Sync + Clone + 'static, { type Error = MlsError; - fn validate_member( + async fn validate_member( &self, signing_identity: &SigningIdentity, _timestamp: Option, @@ -102,7 +111,7 @@ where Ok(()) } - fn validate_external_sender( + async fn validate_external_sender( &self, _signing_identity: &SigningIdentity, _timestamp: Option, @@ -112,7 +121,7 @@ where Err(MlsError::ExternalCommitNotSupported) } - fn identity( + async fn identity( &self, signing_identity: &SigningIdentity, _extensions: &ExtensionList, @@ -122,7 +131,7 @@ where Ok(identity_claims.subject.into_bytes()) } - fn valid_successor( + async fn valid_successor( &self, predecessor: &SigningIdentity, successor: &SigningIdentity, diff --git a/data-plane/core/mls/src/lib.rs b/data-plane/core/mls/src/lib.rs index 6d684b43c..7fc7b501b 100644 --- a/data-plane/core/mls/src/lib.rs +++ b/data-plane/core/mls/src/lib.rs @@ -1,6 +1,7 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +pub mod crypto; pub mod errors; pub mod identity_provider; pub mod mls; diff --git a/data-plane/core/mls/src/mls.rs b/data-plane/core/mls/src/mls.rs index 37f8834ec..b5b761c03 100644 --- a/data-plane/core/mls/src/mls.rs +++ b/data-plane/core/mls/src/mls.rs @@ -8,10 +8,14 @@ use mls_rs::{ group::ReceivedMessage, identity::{SigningIdentity, basic::BasicCredential}, }; -#[cfg(test)] +// `CipherSuiteProvider` / `CryptoProvider` are only referenced by +// `generate_key_pair`, which is itself feature/test gated. Importing them +// unconditionally would produce an `unuse d_imports` warning under the +// `curve25519` feature. +#[cfg(any(test, not(all(feature = "native", feature = "curve25519"))))] use mls_rs::{CipherSuiteProvider, CryptoProvider}; -use mls_rs_crypto_awslc::AwsLcCryptoProvider; +use crate::crypto::CryptoProviderImpl; use std::collections::HashSet; use tracing::debug; @@ -20,7 +24,19 @@ use slim_auth::traits::{TokenProvider, Verifier}; use crate::errors::MlsError; use crate::identity_provider::SlimIdentityProvider; +// Default cipher suite is P256_AES128 (NIST P-256 + ECDSA-P256 + AES-128-GCM) +// so that native and WASM peers can interoperate in the same MLS group: +// browser WebCrypto (used by the `wasm` backend) does not support +// Curve25519, but it does support P-256. +// +// Operators that do not need browser interoperability can opt in to the +// stronger CURVE25519_AES128 ciphersuite by enabling the crate's +// `curve25519` feature. This only takes effect on the `native` backend +// because the `wasm` backend cannot service it at runtime. +#[cfg(all(feature = "native", feature = "curve25519"))] const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128; +#[cfg(not(all(feature = "native", feature = "curve25519")))] +const CIPHERSUITE: CipherSuite = CipherSuite::P256_AES128; pub type CommitMsg = Vec; pub type WelcomeMsg = Vec; @@ -55,7 +71,7 @@ where mls_rs::client_builder::WithIdentityProvider< SlimIdentityProvider, mls_rs::client_builder::WithCryptoProvider< - AwsLcCryptoProvider, + CryptoProviderImpl, mls_rs::client_builder::BaseConfig, >, >, @@ -66,7 +82,7 @@ where mls_rs::client_builder::WithIdentityProvider< SlimIdentityProvider, mls_rs::client_builder::WithCryptoProvider< - AwsLcCryptoProvider, + CryptoProviderImpl, mls_rs::client_builder::BaseConfig, >, >, @@ -116,6 +132,12 @@ where /// Creates a signing identity from the keys stored in the identity provider. /// The provider must have had its MLS keys generated (done automatically at construction). + /// + /// This path is only used when the crate is compiled with the + /// `curve25519` feature on the `native` backend: in that case the + /// auth provider hands us real Ed25519 key material and we can wrap + /// it directly into an MLS signing identity. + #[cfg(all(feature = "native", feature = "curve25519"))] fn create_signing_identity( &mut self, is_rotation: bool, @@ -144,19 +166,76 @@ where Ok((private_key, signing_identity)) } - #[cfg(test)] - fn generate_key_pair() -> Result<(SignatureSecretKey, SignaturePublicKey), MlsError> { - let crypto_provider = AwsLcCryptoProvider::default(); + /// Generate a fresh signature key pair for the active MLS ciphersuite. + /// + /// Always defers to the MLS crypto provider so the produced bytes are + /// guaranteed to be valid for the negotiated ciphersuite (P-256 by + /// default, or Curve25519 when the `curve25519` feature is enabled). + /// + /// Only compiled when the default P256 path is active or when running + /// the unit tests; the opt-in `curve25519` feature uses the auth + /// provider's Ed25519 keys directly and never calls this helper. + #[cfg(any(test, not(all(feature = "native", feature = "curve25519"))))] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn generate_key_pair() -> Result<(SignatureSecretKey, SignaturePublicKey), MlsError> { + let crypto_provider = crate::crypto::default_crypto_provider(); let cipher_suite_provider = crypto_provider .cipher_suite_provider(CIPHERSUITE) .ok_or(MlsError::CiphersuiteUnavailable)?; cipher_suite_provider .signature_key_generate() + .await .map_err(MlsError::crypto_provider) } - pub fn initialize(&mut self) -> Result<(), MlsError> { + /// Generate a fresh ciphersuite-correct key pair via the MLS crypto + /// provider, push it into the identity provider so the next token + /// embeds the matching public key, and assemble the corresponding + /// `SigningIdentity`. + /// + /// Used by the default P256 path on both `native` and `wasm`. The + /// auth provider's `rotate_signature_keys` only produces Ed25519 + /// material (or random bytes on WASM) and so cannot be used as the + /// source of truth when the negotiated ciphersuite is P-256. + #[cfg(not(all(feature = "native", feature = "curve25519")))] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn install_generated_signing_identity( + &mut self, + is_rotation: bool, + ) -> Result<(SignatureSecretKey, SigningIdentity), MlsError> { + let (priv_key, pub_key) = Self::generate_key_pair().await?; + + // Push the ciphersuite-correct keys into the identity provider so + // that get_token() embeds the matching public key in the + // credential. This is required: SlimIdentityProvider's + // validate_member() checks that the signing identity's public key + // equals the one bound to the token. + self.identity_provider + .set_signature_keys(priv_key.as_bytes().to_vec(), pub_key.as_bytes().to_vec())?; + + let token = self.identity_provider.get_token()?; + let basic_cred = BasicCredential::new(token.as_bytes().to_vec()); + let signing_identity = SigningIdentity::new(basic_cred.into_credential(), pub_key.clone()); + + if let Some(stored) = self.stored_identity.as_mut() { + stored.last_credential = Some(token); + stored.public_key_bytes = pub_key.as_bytes().to_vec(); + stored.private_key_bytes = priv_key.as_bytes().to_vec(); + + if is_rotation { + stored.credential_version = stored.credential_version.saturating_add(1); + } + } + + Ok((priv_key, signing_identity)) + } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn initialize(&mut self) -> Result<(), MlsError> { debug!("Initializing MLS"); // Generate fresh MLS signature keys before first use. This ensures that @@ -183,10 +262,21 @@ where self.stored_identity = Some(stored_identity); - // Create signing identity using keys provided by the identity provider + // Default (P256) path on either backend: the auth provider's keys + // are not valid for the negotiated ciphersuite, so we generate a + // fresh P-256 pair via the MLS crypto provider and install it. + #[cfg(not(all(feature = "native", feature = "curve25519")))] + let (private_key, signing_identity) = + self.install_generated_signing_identity(false).await?; + + // Opt-in Curve25519 path on `native`: the auth provider supplies + // real Ed25519 keys via rotate_signature_keys(), so use them + // directly. + #[cfg(all(feature = "native", feature = "curve25519"))] let (private_key, signing_identity) = self.create_signing_identity(false)?; - let crypto_provider = AwsLcCryptoProvider::default(); + let crypto_provider = crate::crypto::default_crypto_provider(); + let identity_provider = SlimIdentityProvider::new(self.identity_verifier.clone()); let client = Client::builder() @@ -200,15 +290,20 @@ where Ok(()) } - pub fn create_group(&mut self) -> Result, MlsError> { - debug!("Creating new MLS group"); + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn create_group(&mut self) -> Result, MlsError> { + tracing::info!("Creating new MLS group"); let client = self.client.as_ref().ok_or(MlsError::ClientNotInitialized)?; - let group = client.create_group(ExtensionList::default(), Default::default(), None)?; + tracing::info!("calling mls-rs client.create_group"); + let group = client + .create_group(ExtensionList::default(), Default::default(), None) + .await?; let group_id = group.group_id().to_vec(); self.group = Some(group); - debug!( + tracing::info!( id = ?hex::encode(&group_id), "MLS group created successfully", ); @@ -216,19 +311,27 @@ where Ok(group_id) } - pub fn generate_key_package(&self) -> Result { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn generate_key_package(&self) -> Result { debug!("Generating key package"); let client = self.client.as_ref().ok_or(MlsError::ClientNotInitialized)?; - let key_package = - client.generate_key_package_message(Default::default(), Default::default(), None)?; + let key_package = client + .generate_key_package_message(Default::default(), Default::default(), None) + .await?; let ret = key_package.to_bytes()?; Ok(ret) } - pub fn add_member(&mut self, key_package_bytes: &[u8]) -> Result { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn add_member( + &mut self, + key_package_bytes: &[u8], + ) -> Result { debug!("Adding member to the MLS group"); let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; let key_package = MlsMessage::from_bytes(key_package_bytes)?; @@ -240,12 +343,14 @@ where let old_roster = group.roster().members(); let mut ids = HashSet::new(); for m in old_roster { - let identifier = identity_provider.identity(&m.signing_identity, &m.extensions)?; + let identifier = identity_provider + .identity(&m.signing_identity, &m.extensions) + .await?; ids.insert(identifier); } let commit = group.commit_builder().add_member(key_package)?; - let commit = commit.build()?; + let commit = commit.build().await?; // create the commit message to broadcast in the group let commit_msg = commit.commit_message.to_bytes()?; @@ -255,15 +360,17 @@ where .welcome_messages .first() .ok_or(MlsError::NoWelcomeMessage) - .and_then(|w| w.to_bytes().map_err(MlsError::from))?; + .map(|w| w.to_bytes().map_err(MlsError::from))??; // apply the commit locally - group.apply_pending_commit()?; + group.apply_pending_commit().await?; let new_roster = group.roster().members(); let mut new_id = vec![]; for m in new_roster { - let identifier = identity_provider.identity(&m.signing_identity, &m.extensions)?; + let identifier = identity_provider + .identity(&m.signing_identity, &m.extensions) + .await?; if !ids.contains(&identifier) { new_id = identifier; break; @@ -278,38 +385,44 @@ where Ok(ret) } - pub fn remove_member(&mut self, identity: &[u8]) -> Result { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn remove_member(&mut self, identity: &[u8]) -> Result { debug!("Removing member from the MLS group"); let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; - let m = group.member_with_identity(identity)?; + let m = group.member_with_identity(identity).await?; let commit = group.commit_builder().remove_member(m.index)?; - let commit = commit.build()?; + let commit = commit.build().await?; let commit_msg = commit.commit_message.to_bytes()?; - group.apply_pending_commit()?; + group.apply_pending_commit().await?; Ok(commit_msg) } - pub fn process_commit(&mut self, commit_message: &[u8]) -> Result<(), MlsError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn process_commit(&mut self, commit_message: &[u8]) -> Result<(), MlsError> { let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; let commit = MlsMessage::from_bytes(commit_message)?; // process an incoming commit message - group.process_incoming_message(commit)?; + group.process_incoming_message(commit).await?; Ok(()) } - pub fn process_welcome(&mut self, welcome_message: &[u8]) -> Result, MlsError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn process_welcome(&mut self, welcome_message: &[u8]) -> Result, MlsError> { debug!("Processing welcome message and joining MLS group"); let client = self.client.as_ref().ok_or(MlsError::ClientNotInitialized)?; // process the welcome message and connect to the group let welcome = MlsMessage::from_bytes(welcome_message)?; - let (group, _) = client.join_group(None, &welcome, None)?; + let (group, _) = client.join_group(None, &welcome, None).await?; let group_id = group.group_id().to_vec(); self.group = Some(group); @@ -321,7 +434,9 @@ where Ok(group_id) } - pub fn process_proposal( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn process_proposal( &mut self, proposal_message: &[u8], create_commit: bool, @@ -329,7 +444,7 @@ where let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; let proposal = MlsMessage::from_bytes(proposal_message)?; - group.process_incoming_message(proposal)?; + group.process_incoming_message(proposal).await?; if !create_commit { debug!("process proposal but do not create commit. return empty commit"); @@ -337,49 +452,57 @@ where } // create commit message from proposal - let commit = group.commit_builder().build()?; + let commit = group.commit_builder().build().await?; // apply the commit locally - group.apply_pending_commit()?; + group.apply_pending_commit().await?; // return the commit message let commit_msg = commit.commit_message.to_bytes()?; Ok(commit_msg) } - pub fn process_local_pending_proposal(&mut self) -> Result { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn process_local_pending_proposal(&mut self) -> Result { let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; // create commit message from proposal - let commit = group.commit_builder().build()?; + let commit = group.commit_builder().build().await?; // apply the commit locally - group.apply_pending_commit()?; + group.apply_pending_commit().await?; // return the commit message let commit_msg = commit.commit_message.to_bytes()?; Ok(commit_msg) } - pub fn encrypt_message(&mut self, message: &[u8]) -> Result, MlsError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn encrypt_message(&mut self, message: &[u8]) -> Result, MlsError> { debug!("Encrypting MLS message"); let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; - let encrypted_msg = group.encrypt_application_message(message, Default::default())?; + let encrypted_msg = group + .encrypt_application_message(message, Default::default()) + .await?; let msg = encrypted_msg.to_bytes()?; Ok(msg) } - pub fn decrypt_message(&mut self, encrypted_message: &[u8]) -> Result, MlsError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn decrypt_message(&mut self, encrypted_message: &[u8]) -> Result, MlsError> { debug!("Decrypting MLS message"); let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; let message = MlsMessage::from_bytes(encrypted_message)?; - match group.process_incoming_message(message)? { + match group.process_incoming_message(message).await? { ReceivedMessage::ApplicationMessage(app_msg) => Ok(app_msg.data().to_vec()), _ => Err(MlsError::verification_failed( "Message was not an application message", @@ -387,9 +510,11 @@ where } } - pub fn write_to_storage(&mut self) -> Result<(), MlsError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn write_to_storage(&mut self) -> Result<(), MlsError> { let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; - group.write_to_storage()?; + group.write_to_storage().await?; Ok(()) } @@ -401,21 +526,31 @@ where self.group.as_ref().map(|g| g.current_epoch()) } - pub fn create_rotation_proposal(&mut self) -> Result { - // Ask the identity provider to generate new keys internally - self.identity_provider.rotate_signature_keys()?; - - // Create signing identity with token containing the new public key - let (new_private_key, new_signing_identity) = self.create_signing_identity(true)?; + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn create_rotation_proposal(&mut self) -> Result { + // Default (P256) path: generate the new key pair via the MLS + // crypto provider so it is valid for the negotiated ciphersuite, + // and install it into the identity provider so the rotated token + // embeds the matching public key. + #[cfg(not(all(feature = "native", feature = "curve25519")))] + let (new_private_key, new_signing_identity) = + self.install_generated_signing_identity(true).await?; + + // Opt-in Curve25519 path on `native`: ask the auth provider to + // rotate its Ed25519 keys internally and read them back. + #[cfg(all(feature = "native", feature = "curve25519"))] + let (new_private_key, new_signing_identity) = { + self.identity_provider.rotate_signature_keys()?; + self.create_signing_identity(true)? + }; // Now get mutable reference to group after creating signing identity let group = self.group.as_mut().ok_or(MlsError::GroupNotExists)?; - let update_proposal = group.propose_update_with_identity( - new_private_key.clone(), - new_signing_identity, - vec![], - )?; + let update_proposal = group + .propose_update_with_identity(new_private_key.clone(), new_signing_identity, vec![]) + .await?; debug!( "Created credential rotation proposal, stored new keys and incremented credential version" @@ -450,6 +585,62 @@ mod tests { const SHARED_SECRET: &str = "kjandjansdiasb8udaijdniasdaindasndasndasndasndasndasndasndas"; + /// The default ciphersuite must be P256_AES128 so that native and WASM + /// peers can join the same MLS group. Operators that explicitly opt in + /// to the `curve25519` feature get the legacy CURVE25519_AES128 suite, + /// which is incompatible with browser WebCrypto. + #[test] + fn test_default_ciphersuite_is_p256() { + #[cfg(not(all(feature = "native", feature = "curve25519")))] + assert_eq!( + CIPHERSUITE, + CipherSuite::P256_AES128, + "Default ciphersuite must be P256_AES128 for browser interop", + ); + + #[cfg(all(feature = "native", feature = "curve25519"))] + assert_eq!( + CIPHERSUITE, + CipherSuite::CURVE25519_AES128, + "`curve25519` feature must select CURVE25519_AES128", + ); + } + + /// `generate_key_pair` must yield bytes that decode back into the + /// active ciphersuite's public-key format. In particular this catches + /// the case where the cipher constant and the key-generation provider + /// drift apart. + #[test] + fn test_generate_key_pair_matches_active_ciphersuite() { + let (priv_key, pub_key) = + Mls::::generate_key_pair().expect("key gen"); + + assert!( + !priv_key.as_bytes().is_empty(), + "private key must not be empty" + ); + assert!( + !pub_key.as_bytes().is_empty(), + "public key must not be empty" + ); + + // Sanity bound on key sizes: + // - P-256 SEC1 uncompressed pubkey is 65 bytes, secret is 32 bytes + // - X25519/Ed25519 pubkeys/secrets are 32 bytes + // We don't pin to an exact size to stay resilient to representation + // changes in `mls-rs`, but the keys must fit within reasonable bounds. + assert!( + pub_key.as_bytes().len() <= 128, + "public key length {} unexpectedly large", + pub_key.as_bytes().len() + ); + assert!( + priv_key.as_bytes().len() <= 128, + "private key length {} unexpectedly large", + priv_key.as_bytes().len() + ); + } + #[test] fn test_mls_creation() -> Result<(), Box> { let mut mls = Mls::new( @@ -918,7 +1109,7 @@ mod tests { // Build MLS client with mismatched private key let verifier = SharedSecret::new("alice", SHARED_SECRET).unwrap(); - let crypto_provider = AwsLcCryptoProvider::default(); + let crypto_provider = crate::crypto::default_crypto_provider(); let identity_provider = SlimIdentityProvider::new(verifier.clone()); let client = Client::builder() .identity_provider(identity_provider) diff --git a/data-plane/core/service/Cargo.toml b/data-plane/core/service/Cargo.toml index 0fffecaab..366817ae8 100644 --- a/data-plane/core/service/Cargo.toml +++ b/data-plane/core/service/Cargo.toml @@ -10,7 +10,26 @@ name = "slim_service" crate-type = ["lib"] [features] -default = ["session"] +default = ["native", "session"] +native = [ + "agntcy-slim-auth/native", + "agntcy-slim-config/native", + "agntcy-slim-controller/native", + "agntcy-slim-datapath/native", + "agntcy-slim-mls/native", + "agntcy-slim-session?/native", + "display-error-chain", + "tokio", + "tokio-util", +] +wasm = [ + "agntcy-slim-auth/wasm", + "agntcy-slim-config/wasm", + "agntcy-slim-controller/wasm", + "agntcy-slim-datapath/wasm", + "agntcy-slim-mls/wasm", + "agntcy-slim-session?/wasm", +] session = ["agntcy-slim-session"] [dependencies] @@ -22,13 +41,13 @@ agntcy-slim-mls = { workspace = true } agntcy-slim-session = { workspace = true, optional = true } agntcy-slim-version = { workspace = true } async-trait = { workspace = true } -display-error-chain = { workspace = true } +display-error-chain = { workspace = true, optional = true } futures = { workspace = true } parking_lot = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } -tokio-util = { workspace = true } +tokio = { workspace = true, optional = true } +tokio-util = { workspace = true, optional = true } tracing = { workspace = true } [dev-dependencies] diff --git a/data-plane/core/service/src/errors.rs b/data-plane/core/service/src/errors.rs index 53f73d45c..c64fd3eea 100644 --- a/data-plane/core/service/src/errors.rs +++ b/data-plane/core/service/src/errors.rs @@ -16,6 +16,7 @@ pub enum ServiceError { // Configuration / setup #[error("no server or client configured")] NoServerOrClientConfigured, + #[cfg(feature = "native")] #[error("grpc configuration error")] GrpcConfigError(#[from] slim_config::grpc::errors::ConfigError), #[error("invalid configuration: {0}")] diff --git a/data-plane/core/service/src/lib.rs b/data-plane/core/service/src/lib.rs index 425c0634f..c582768ed 100644 --- a/data-plane/core/service/src/lib.rs +++ b/data-plane/core/service/src/lib.rs @@ -17,15 +17,14 @@ //! use slim_service::Service; //! use slim_config::component::ComponentBuilder; //! use slim_auth::shared_secret::SharedSecret; -//! use slim_auth::testutils::TEST_VALID_SECRET; //! use slim_datapath::messages::Name; //! //! // Create service instance (handles message processing) //! let service = Service::builder().build("svc-0".to_string()).expect("Failed to create service"); //! //! // Create authentication components -//! let provider = SharedSecret::new("myapp", TEST_VALID_SECRET)?; -//! let verifier = SharedSecret::new("myapp", TEST_VALID_SECRET)?; +//! let provider = SharedSecret::new("myapp", "shared-secret-value-0123456789abcdef").unwrap(); +//! let verifier = SharedSecret::new("myapp", "shared-secret-value-0123456789abcdef").unwrap(); //! //! // Create an app for messaging //! let app_name = Name::from_strings(["org", "ns", "app"]); @@ -34,10 +33,11 @@ //! ``` pub mod errors; +#[cfg(feature = "native")] #[macro_use] pub mod service; -#[cfg(feature = "session")] +#[cfg(all(feature = "native", feature = "session"))] pub mod app; // Third-party crates @@ -47,4 +47,5 @@ pub use slim_datapath::messages::utils::SlimHeaderFlags; pub use errors::ServiceError; #[cfg(feature = "session")] pub use errors::SubscriptionAckError; +#[cfg(feature = "native")] pub use service::{KIND, Service, ServiceBuilder, ServiceConfiguration}; diff --git a/data-plane/core/service/src/service.rs b/data-plane/core/service/src/service.rs index dd2c2d04d..dd8fa53ca 100644 --- a/data-plane/core/service/src/service.rs +++ b/data-plane/core/service/src/service.rs @@ -12,11 +12,11 @@ use serde::Deserialize; use tokio_util::sync::CancellationToken; use tracing::{debug, info}; +use slim_config::client::ClientConfig; use slim_config::component::configuration::Configuration; use slim_config::component::id::{ID, Kind}; use slim_config::component::{Component, ComponentBuilder}; -use slim_config::grpc::client::ClientConfig; -use slim_config::grpc::server::ServerConfig; +use slim_config::server::ServerConfig; use slim_controller::config::Config as ControllerConfig; use slim_controller::config::Config as DataplaneConfig; use slim_controller::service::ControlPlane; @@ -563,7 +563,7 @@ mod tests { use super::*; use slim_auth::shared_secret::SharedSecret; - use slim_config::grpc::server::ServerConfig; + use slim_config::server::ServerConfig; use slim_config::tls::server::TlsServerConfig; use slim_datapath::api::MessageType; use slim_datapath::messages::Name; @@ -724,7 +724,7 @@ mod tests { // build client configuration and connect let mut client_conf = - slim_config::grpc::client::ClientConfig::with_endpoint("http://0.0.0.0:12346"); + slim_config::client::ClientConfig::with_endpoint("http://0.0.0.0:12346"); client_conf.tls_setting.insecure = true; let conn_id = service .connect(&client_conf) diff --git a/data-plane/core/session/Cargo.toml b/data-plane/core/session/Cargo.toml index c75d00581..d8a4bd5ec 100644 --- a/data-plane/core/session/Cargo.toml +++ b/data-plane/core/session/Cargo.toml @@ -5,27 +5,54 @@ license = { workspace = true } version = "0.1.12" description = "SLIM session internal implementation." +[package.metadata.cargo-machete] +ignored = ["getrandom"] + [lib] name = "slim_session" +[features] +default = ["native"] +native = [ + "agntcy-slim-auth/native", + "agntcy-slim-datapath/native", + "agntcy-slim-mls/native", + "dep:tokio", + "dep:tokio-util", +] +wasm = [ + "agntcy-slim-auth/wasm", + "agntcy-slim-datapath/wasm", + "agntcy-slim-mls/wasm", + "dep:getrandom", + "dep:tokio_with_wasm", +] + [dependencies] -agntcy-slim-auth = { workspace = true } -agntcy-slim-datapath = { workspace = true } -agntcy-slim-mls = { workspace = true } +agntcy-slim-auth = { workspace = true, default-features = false } +agntcy-slim-datapath = { workspace = true, default-features = false } +agntcy-slim-mls = { workspace = true, default-features = false } agntcy-slim-version = { workspace = true } async-trait = { workspace = true } display-error-chain = { workspace = true } futures = { workspace = true } futures-timer = { workspace = true } +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } +maybe-async = "0.2.10" parking_lot = { workspace = true } rand = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } -tokio-util = { workspace = true } -tonic = { workspace = true } +tokio = { workspace = true, optional = true } +tokio-util = { workspace = true, optional = true } +tokio_with_wasm = { version = "0.9", features = ["rt", "sync", "time", "macros"], optional = true } tracing = { workspace = true } +web-time = "1" [dev-dependencies] agntcy-slim-testing = { workspace = true } tracing-test = { workspace = true } + +[lints.rust.unexpected_cfgs] +level = "warn" +check-cfg = ["cfg(mls_build_async)"] diff --git a/data-plane/core/session/src/common.rs b/data-plane/core/session/src/common.rs index b2e3aa645..d5548c9bb 100644 --- a/data-plane/core/session/src/common.rs +++ b/data-plane/core/session/src/common.rs @@ -3,13 +3,12 @@ use std::time::Duration; -// Third-party crates -use tonic::Status; - use slim_datapath::{ + Status, api::{ProtoMessage as Message, ProtoSessionMessageType, ProtoSessionType}, messages::{Name, utils::MessageError}, }; +use tokio::sync::{mpsc, oneshot}; // Local crate use crate::SessionError; @@ -21,12 +20,11 @@ pub const SESSION_RANGE: std::ops::Range = 0..(u32::MAX - 1000); pub const SESSION_UNSPECIFIED: u32 = u32::MAX; /// Channel used in the path service -> app -pub(crate) type AppChannelSender = - tokio::sync::mpsc::UnboundedSender>; +pub(crate) type AppChannelSender = mpsc::UnboundedSender>; /// Channel used in the path app -> service -pub type AppChannelReceiver = tokio::sync::mpsc::UnboundedReceiver>; +pub type AppChannelReceiver = mpsc::UnboundedReceiver>; /// Channel used in the path service -> slim -pub type SlimChannelSender = tokio::sync::mpsc::Sender>; +pub type SlimChannelSender = mpsc::Sender>; /// The state of a session #[derive(Clone, PartialEq, Debug)] @@ -85,7 +83,7 @@ pub enum SessionMessage { message: Message, direction: MessageDirection, /// Optional channel to signal when message processing is complete - ack_tx: Option>>, + ack_tx: Option>>, }, /// Error occurred during message processing MessageError { error: SessionError }, @@ -114,7 +112,5 @@ pub enum SessionMessage { /// to notify that the session can be removed safely DeleteSession { session_id: u32 }, /// Query the participants list from the handler - GetParticipantsList { - tx: tokio::sync::oneshot::Sender>, - }, + GetParticipantsList { tx: oneshot::Sender> }, } diff --git a/data-plane/core/session/src/context.rs b/data-plane/core/session/src/context.rs index 422b5fd38..18a0b4c23 100644 --- a/data-plane/core/session/src/context.rs +++ b/data-plane/core/session/src/context.rs @@ -96,7 +96,8 @@ mod tests { #[derive(Clone, Default)] struct DummyProvider; - #[async_trait] + #[cfg_attr(feature = "native", async_trait)] + #[cfg_attr(feature = "wasm", async_trait(?Send))] impl TokenProvider for DummyProvider { async fn initialize(&mut self) -> Result<(), AuthError> { Ok(()) @@ -111,7 +112,8 @@ mod tests { } #[derive(Clone, Default)] struct DummyVerifier; - #[async_trait] + #[cfg_attr(feature = "native", async_trait)] + #[cfg_attr(feature = "wasm", async_trait(?Send))] impl Verifier for DummyVerifier { async fn initialize(&mut self) -> Result<(), AuthError> { Ok(()) diff --git a/data-plane/core/session/src/errors.rs b/data-plane/core/session/src/errors.rs index 622fd6a4f..0b658eb65 100644 --- a/data-plane/core/session/src/errors.rs +++ b/data-plane/core/session/src/errors.rs @@ -1,6 +1,7 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +use slim_datapath::Status; use slim_datapath::errors::{ErrorPayload, MessageContext}; use slim_datapath::messages::Name; // Third-party crates @@ -11,7 +12,6 @@ use slim_auth::errors::AuthError; use slim_datapath::api::{ProtoMessage, ProtoSessionMessageType, ProtoSessionType}; use slim_datapath::messages::utils::MessageError; use slim_mls::errors::MlsError; -use tonic::Status; use crate::SessionMessage; use crate::subscription_manager::SubscriptionAckError; diff --git a/data-plane/core/session/src/interceptor.rs b/data-plane/core/session/src/interceptor.rs index 21832549f..ae58a69da 100644 --- a/data-plane/core/session/src/interceptor.rs +++ b/data-plane/core/session/src/interceptor.rs @@ -11,7 +11,8 @@ use slim_datapath::api::ProtoMessage as Message; // Local crate use crate::errors::SessionError; -#[async_trait::async_trait] +#[cfg_attr(feature = "native", async_trait::async_trait)] +#[cfg_attr(feature = "wasm", async_trait::async_trait(?Send))] pub trait SessionInterceptor { // interceptor to be executed when a message is received from the app async fn on_msg_from_app(&self, msg: &mut Message) -> Result<(), SessionError>; @@ -19,7 +20,8 @@ pub trait SessionInterceptor { async fn on_msg_from_slim(&self, msg: &mut Message) -> Result<(), SessionError>; } -#[async_trait::async_trait] +#[cfg_attr(feature = "native", async_trait::async_trait)] +#[cfg_attr(feature = "wasm", async_trait::async_trait(?Send))] pub trait SessionInterceptorProvider { /// add an interceptor to the session fn add_interceptor(&self, interceptor: Arc); @@ -75,7 +77,8 @@ where /// from the app, and verify the identity when a message is received from slim. /// If the identity is not found in the message metadata, it will return an error. /// If the identity verification fails, it will return an error as well. -#[async_trait::async_trait] +#[cfg_attr(feature = "native", async_trait::async_trait)] +#[cfg_attr(feature = "wasm", async_trait::async_trait(?Send))] impl SessionInterceptor for IdentityInterceptor where P: TokenProvider + Send + Sync + Clone + 'static, diff --git a/data-plane/core/session/src/lib.rs b/data-plane/core/session/src/lib.rs index 71d0270ce..2ba563125 100644 --- a/data-plane/core/session/src/lib.rs +++ b/data-plane/core/session/src/lib.rs @@ -1,6 +1,12 @@ // Copyright AGNTCY Contributors (https://github.com/agntcy) // SPDX-License-Identifier: Apache-2.0 +// On wasm, alias tokio_with_wasm as tokio so all code can use tokio:: paths uniformly. +#[cfg(all(feature = "wasm", not(feature = "native")))] +extern crate tokio_with_wasm as tokio; + +pub mod runtime; + mod common; pub mod completion_handle; pub mod context; diff --git a/data-plane/core/session/src/mls_state.rs b/data-plane/core/session/src/mls_state.rs index 91a3ad5a3..e3d5ba1bc 100644 --- a/data-plane/core/session/src/mls_state.rs +++ b/data-plane/core/session/src/mls_state.rs @@ -46,8 +46,10 @@ where P: TokenProvider + Send + Sync + Clone + 'static, V: Verifier + Send + Sync + Clone + 'static, { - pub(crate) fn new(mut mls: Mls) -> Result { - mls.initialize()?; + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn new(mut mls: Mls) -> Result { + mls.initialize().await?; Ok(MlsState { mls, @@ -57,12 +59,19 @@ where }) } - pub(crate) fn generate_key_package(&mut self) -> Result { - let ret = self.mls.generate_key_package()?; + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn generate_key_package(&mut self) -> Result { + let ret = self.mls.generate_key_package().await?; Ok(ret) } - pub(crate) fn process_welcome_message(&mut self, msg: &Message) -> Result<(), SessionError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn process_welcome_message( + &mut self, + msg: &Message, + ) -> Result<(), SessionError> { if self.last_mls_msg_id != 0 { debug!("Welcome message already received, drop"); // we already got a welcome message, ignore this one @@ -77,12 +86,14 @@ where self.last_mls_msg_id = mls_payload.commit_id; let welcome = &mls_payload.mls_content; - self.group = self.mls.process_welcome(welcome)?; + self.group = self.mls.process_welcome(welcome).await?; Ok(()) } - pub(crate) fn process_control_message( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn process_control_message( &mut self, msg: Message, local_name: &Name, @@ -105,18 +116,18 @@ where // base on the message type, process it match msg.get_session_header().session_message_type() { ProtoSessionMessageType::GroupProposal => { - self.process_proposal_message(msg, local_name)?; + self.process_proposal_message(msg, local_name).await?; } ProtoSessionMessageType::GroupAdd => { let payload = msg.extract_group_add()?; let mls_payload = payload.mls.as_ref().ok_or(MlsError::NoGroupAddPayload)?; - self.process_commit_message(mls_payload)?; + self.process_commit_message(mls_payload).await?; } ProtoSessionMessageType::GroupRemove => { let payload = msg.extract_group_remove()?; let mls_payload = payload.mls.as_ref().ok_or(MlsError::NoGroupRemovePayload)?; - self.process_commit_message(mls_payload)?; + self.process_commit_message(mls_payload).await?; } _type => { error!(?_type, "unknown control message type, drop it"); @@ -130,16 +141,23 @@ where Ok(true) } - fn process_commit_message(&mut self, mls_payload: &MlsPayload) -> Result<(), SessionError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn process_commit_message( + &mut self, + mls_payload: &MlsPayload, + ) -> Result<(), SessionError> { trace!(id = %mls_payload.commit_id, "processing stored commit",); // process the commit message - self.mls.process_commit(&mls_payload.mls_content)?; + self.mls.process_commit(&mls_payload.mls_content).await?; Ok(()) } - fn process_proposal_message( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn process_proposal_message( &mut self, proposal: Message, local_name: &Name, @@ -159,7 +177,9 @@ where return Ok(()); } - self.mls.process_proposal(&payload.mls_proposal, false)?; + self.mls + .process_proposal(&payload.mls_proposal, false) + .await?; Ok(()) } @@ -253,7 +273,9 @@ where /// # Returns /// * `Ok(())` if processing succeeds /// * `Err(SessionError)` if processing fails or message format is invalid - pub fn process_message( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub async fn process_message( &mut self, msg: &mut Message, direction: MessageDirection, @@ -261,11 +283,11 @@ where match direction { MessageDirection::South => { // Encrypting message going to SLIM - self.encrypt_message(msg) + self.encrypt_message(msg).await } MessageDirection::North => { // Decrypting message coming from SLIM - self.decrypt_message(msg) + self.decrypt_message(msg).await } } } @@ -278,7 +300,9 @@ where /// # Returns /// * `Ok(())` if encryption succeeds /// * `Err(SessionError)` if encryption fails or message format is invalid - fn encrypt_message(&mut self, msg: &mut Message) -> Result<(), SessionError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn encrypt_message(&mut self, msg: &mut Message) -> Result<(), SessionError> { if !Self::should_process_message(msg) { return Ok(()); } @@ -286,7 +310,7 @@ where let payload = msg.get_payload().unwrap().as_application_payload()?; debug!("Encrypting message for group member"); - let encrypted_payload = self.mls.encrypt_message(&payload.blob)?; + let encrypted_payload = self.mls.encrypt_message(&payload.blob).await?; msg.set_payload( ApplicationPayload::new(&payload.payload_type, encrypted_payload.to_vec()).as_content(), @@ -303,7 +327,9 @@ where /// # Returns /// * `Ok(())` if decryption succeeds /// * `Err(SessionError)` if decryption fails or message format is invalid - fn decrypt_message(&mut self, msg: &mut Message) -> Result<(), SessionError> { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + async fn decrypt_message(&mut self, msg: &mut Message) -> Result<(), SessionError> { if !Self::should_process_message(msg) { return Ok(()); } @@ -311,7 +337,7 @@ where let payload = msg.get_payload().unwrap().as_application_payload()?; debug!("Decrypting message for group member"); - let decrypted_payload = self.mls.decrypt_message(&payload.blob)?; + let decrypted_payload = self.mls.decrypt_message(&payload.blob).await?; msg.set_payload( ApplicationPayload::new(&payload.payload_type, decrypted_payload.to_vec()).as_content(), @@ -350,19 +376,25 @@ where } } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] pub(crate) async fn init_moderator(&mut self) -> Result<(), SessionError> { - self.common.mls.create_group()?; + tracing::info!("MLS moderator init: calling create_group"); + self.common.mls.create_group().await?; + tracing::info!("MLS moderator init: create_group succeeded"); Ok(()) } - pub(crate) fn add_participant( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn add_participant( &mut self, msg: &Message, ) -> Result<(CommitMsg, WelcomeMsg), SessionError> { let payload = msg.extract_join_reply()?; // Propagate MlsError directly (will become SessionError::MlsOp via #[from]) - let ret = self.common.mls.add_member(payload.key_package())?; + let ret = self.common.mls.add_member(payload.key_package()).await?; // add participant to the list self.participants @@ -371,7 +403,12 @@ where Ok((ret.commit_message, ret.welcome_message)) } - pub(crate) fn remove_participant(&mut self, msg: &Message) -> Result { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn remove_participant( + &mut self, + msg: &Message, + ) -> Result { debug!("Remove participant from the MLS group"); let name = msg.get_dst(); let id = match self.participants.get(&name) { @@ -382,7 +419,7 @@ where } }; - let ret = self.common.mls.remove_member(id)?; + let ret = self.common.mls.remove_member(id).await?; // remove the participant from the list self.participants.remove(&name); @@ -391,18 +428,24 @@ where } #[allow(dead_code)] - pub(crate) fn process_proposal_message( + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn process_proposal_message( &mut self, proposal: &ProposalMsg, ) -> Result { - let commit = self.common.mls.process_proposal(proposal, true)?; + let commit = self.common.mls.process_proposal(proposal, true).await?; Ok(commit) } #[allow(dead_code)] - pub(crate) fn process_local_pending_proposal(&mut self) -> Result { - let commit = self.common.mls.process_local_pending_proposal()?; + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + #[cfg_attr(mls_build_async, maybe_async::must_be_async)] + pub(crate) async fn process_local_pending_proposal( + &mut self, + ) -> Result { + let commit = self.common.mls.process_local_pending_proposal().await?; Ok(commit) } diff --git a/data-plane/core/session/src/runtime.rs b/data-plane/core/session/src/runtime.rs new file mode 100644 index 000000000..4937c7bb5 --- /dev/null +++ b/data-plane/core/session/src/runtime.rs @@ -0,0 +1,124 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! Minimal platform abstractions that tokio / tokio_with_wasm do not cover. +//! +//! Everything else (channels, spawn, sleep, select!, JoinHandle) is available +//! directly via `tokio::*` on both native and wasm thanks to the +//! `extern crate tokio_with_wasm as tokio` alias in lib.rs. + +// ── CancellationToken ── +// tokio_util::sync::CancellationToken is not part of tokio_with_wasm, +// so we provide a unified re-export / implementation here. + +#[cfg(feature = "native")] +pub use tokio_util::sync::CancellationToken; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +mod cancellation { + use parking_lot::Mutex; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::task::Waker; + + #[derive(Clone)] + pub struct CancellationToken { + cancelled: Arc, + wakers: Arc>>, + children: Arc>>, + } + + impl std::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CancellationToken") + .field("cancelled", &self.is_cancelled()) + .finish() + } + } + + impl CancellationToken { + pub fn new() -> Self { + Self { + cancelled: Arc::new(AtomicBool::new(false)), + wakers: Arc::new(Mutex::new(Vec::new())), + children: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn cancel(&self) { + self.cancelled.store(true, Ordering::Release); + let wakers = std::mem::take(&mut *self.wakers.lock()); + for waker in wakers { + waker.wake(); + } + let children = std::mem::take(&mut *self.children.lock()); + for child in children { + child.cancel(); + } + } + + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Acquire) + } + + pub fn child_token(&self) -> Self { + let child = Self::new(); + let mut children = self.children.lock(); + if self.is_cancelled() { + drop(children); + child.cancel(); + return child; + } + children.push(child.clone()); + child + } + + fn register_waker(&self, waker: &Waker) { + if self.is_cancelled() { + waker.wake_by_ref(); + return; + } + let mut wakers = self.wakers.lock(); + if !wakers.iter().any(|w| w.will_wake(waker)) { + wakers.push(waker.clone()); + } + } + + pub async fn cancelled(&self) { + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + struct CancelledFuture<'a> { + token: &'a CancellationToken, + } + + impl Future for CancelledFuture<'_> { + type Output = (); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.token.is_cancelled() { + Poll::Ready(()) + } else { + self.token.register_waker(cx.waker()); + if self.token.is_cancelled() { + Poll::Ready(()) + } else { + Poll::Pending + } + } + } + } + + CancelledFuture { token: self }.await + } + } + + impl Default for CancellationToken { + fn default() -> Self { + Self::new() + } + } +} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub use cancellation::CancellationToken; diff --git a/data-plane/core/session/src/session.rs b/data-plane/core/session/src/session.rs index 7605ce9c2..a820edd5d 100644 --- a/data-plane/core/session/src/session.rs +++ b/data-plane/core/session/src/session.rs @@ -7,7 +7,7 @@ use slim_datapath::{ messages::Name, }; -use tokio::sync::mpsc::{self}; +use tokio::sync::mpsc; use tracing::debug; use crate::{ @@ -206,7 +206,8 @@ impl Session { /// Implementation of MessageHandler trait for Session /// This allows Session to be used as a layer in the generic layer system -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl MessageHandler for Session { async fn init(&mut self) -> Result<(), SessionError> { // Session is the innermost layer, no initialization needed diff --git a/data-plane/core/session/src/session_builder.rs b/data-plane/core/session/src/session_builder.rs index 152bd147f..0275cc109 100644 --- a/data-plane/core/session/src/session_builder.rs +++ b/data-plane/core/session/src/session_builder.rs @@ -5,6 +5,7 @@ use std::marker::PhantomData; use slim_auth::traits::{TokenProvider, Verifier}; use slim_datapath::messages::Name; +use tokio::sync::mpsc::{self, Receiver, Sender}; use crate::{ Direction, @@ -123,7 +124,7 @@ where identity_provider: Option

, identity_verifier: Option, tx: Option, - tx_to_session_layer: Option>>, + tx_to_session_layer: Option>>, graceful_shutdown_timeout: Option, direction: Direction, subscription_manager: Option, @@ -195,7 +196,7 @@ where pub fn with_tx_to_session_layer( mut self, - tx_to_session_layer: tokio::sync::mpsc::Sender>, + tx_to_session_layer: Sender>, ) -> Self { self.tx_to_session_layer = Some(tx_to_session_layer); self @@ -363,8 +364,8 @@ where ) -> Result< ( W, - tokio::sync::mpsc::Sender, - tokio::sync::mpsc::Receiver, + Sender, + Receiver, SessionSettings, ), SessionError, @@ -372,7 +373,7 @@ where where W: MessageHandler, { - let (tx_session, rx_session) = tokio::sync::mpsc::channel(256); + let (tx_session, rx_session) = mpsc::channel(256); // Create the base Session layer let inner = crate::session::Session::new( diff --git a/data-plane/core/session/src/session_controller.rs b/data-plane/core/session/src/session_controller.rs index 81a0579c5..9e9d536dd 100644 --- a/data-plane/core/session/src/session_controller.rs +++ b/data-plane/core/session/src/session_controller.rs @@ -2,13 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 // Standard library imports -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, pin::pin, time::Duration}; use display_error_chain::ErrorChainExt; use parking_lot::Mutex; -use tokio::sync::{self, oneshot}; // Third-party crates -use tokio_util::sync::CancellationToken; +use crate::runtime::CancellationToken; +use tokio::sync::{mpsc, oneshot}; use tracing::{Instrument, debug}; use slim_auth::traits::{TokenProvider, Verifier}; @@ -46,7 +46,7 @@ pub struct SessionController { pub(crate) config: SessionConfig, /// channel to send messages to the processing loop - tx_controller: sync::mpsc::Sender, + tx_controller: mpsc::Sender, /// use in drop implementation to gracefully close the processing loop pub(crate) cancellation_token: CancellationToken, @@ -73,8 +73,8 @@ impl SessionController { destination: Name, config: SessionConfig, settings: SessionSettings, - tx: sync::mpsc::Sender, - rx: sync::mpsc::Receiver, + tx: mpsc::Sender, + rx: mpsc::Receiver, inner: I, ) -> Self where @@ -87,7 +87,7 @@ impl SessionController { let cancellation_token = CancellationToken::new(); // setup tracing context - let span = tracing::debug_span!( + let span = tracing::info_span!( parent: None, "session_controller_processing_loop", session_id = id, @@ -114,7 +114,7 @@ impl SessionController { /// Internal processing loop that handles messages with mutable access fn enter_draining_state( - shutdown_deadline: &mut std::pin::Pin<&mut tokio::time::Sleep>, + shutdown_at: &mut Option, settings: &SessionSettings, ) where P: slim_auth::traits::TokenProvider + Send + Sync + Clone + 'static, @@ -124,14 +124,12 @@ impl SessionController { let shutdown_timeout = settings .graceful_shutdown_timeout .unwrap_or(Duration::from_secs(60)); - shutdown_deadline - .as_mut() - .reset(tokio::time::Instant::now() + shutdown_timeout); + *shutdown_at = Some(web_time::Instant::now() + shutdown_timeout); } async fn processing_loop( mut inner: impl MessageHandler + 'static, - mut rx: sync::mpsc::Receiver, + mut rx: mpsc::Receiver, cancellation_token: CancellationToken, settings: SessionSettings, ) where @@ -139,22 +137,41 @@ impl SessionController { V: slim_auth::traits::Verifier + Send + Sync + Clone + 'static, M: crate::subscription_manager::SubscriptionOps, { - // Start with an infinite timeout (will be updated on graceful shutdown) - let mut shutdown_deadline = std::pin::pin!(tokio::time::sleep(Duration::MAX)); + // Set when draining starts; each loop iteration builds a fresh sleep to the deadline + // so this works on native tokio (`Sleep::reset`) and on wasm (`tokio_with_wasm` sleep). + let mut shutdown_at: Option = None; // Init the inner components if let Err(e) = inner.init().await { tracing::error!(error = %e.chain(), "error during initialization of session"); } + // `cancelled()` stays permanently ready after the token is cancelled; without this flag, + // handlers that keep reporting `Active` (e.g. default `processing_state`) would re-enter + // the cancellation branch every loop iteration. + let mut controller_cancel_handled = false; + loop { + let shutdown_deadline_snapshot = shutdown_at; + let mut shutdown_deadline_fut = pin!(async move { + match shutdown_deadline_snapshot { + Some(deadline) => { + let now = web_time::Instant::now(); + if let Some(d) = deadline.checked_duration_since(now) { + tokio::time::sleep(d).await; + } + } + None => std::future::pending::<()>().await, + } + }); + tokio::select! { - _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active => { - // Update the timeout to the configured grace period + _ = cancellation_token.cancelled(), if inner.processing_state() == ProcessingState::Active && !controller_cancel_handled => { + controller_cancel_handled = true; + let shutdown_timeout = settings.graceful_shutdown_timeout - .unwrap_or(Duration::from_secs(60)); // Default 60 seconds if not configured + .unwrap_or(Duration::from_secs(60)); - // Finish any ongoing processing before starting drain debug!("consuming pending messages before entering draining state"); while let Ok(msg) = rx.try_recv() { if let Err(e) = inner.on_message(msg).await { @@ -163,19 +180,18 @@ impl SessionController { } } - // Send drain to message to the inner to notify the beginning of the drain if let Err(e) = inner.on_message(SessionMessage::StartDrain { grace_period: shutdown_timeout }).await { - tracing::error!(error = %e.chain(), "error during start drain"); + tracing::error!(error = %e.chain(), "error during start drain"); break; } - Self::enter_draining_state(&mut shutdown_deadline, &settings); + Self::enter_draining_state(&mut shutdown_at, &settings); debug!("cancellation requested, entering draining state"); } - _ = &mut shutdown_deadline => { + _ = shutdown_deadline_fut.as_mut(), if shutdown_at.is_some() => { debug!("graceful shutdown timeout reached, forcing exit"); break; } @@ -215,7 +231,7 @@ impl SessionController { // start (or reset) the graceful shutdown deadline just like on cancellation. if !draining && inner.processing_state() == ProcessingState::Draining { debug!("internal component requested draining, entering draining state"); - Self::enter_draining_state(&mut shutdown_deadline, &settings); + Self::enter_draining_state(&mut shutdown_at, &settings); } } } @@ -605,9 +621,7 @@ where } async fn await_subscription_ack( - rx: tokio::sync::oneshot::Receiver< - Result<(), crate::subscription_manager::SubscriptionAckError>, - >, + rx: oneshot::Receiver>, ) -> Result<(), SessionError> { crate::subscription_manager::SubscriptionManager::await_ack(rx) .await @@ -772,6 +786,7 @@ where #[cfg(test)] mod tests { use super::*; + use tokio::sync::mpsc; // Test: internal draining transition triggered by a leave request. // This test sends a LeaveRequest into a multicast participant session and then @@ -868,8 +883,8 @@ mod tests { self, ) -> ( SessionController, - tokio::sync::mpsc::Receiver>, - tokio::sync::mpsc::UnboundedReceiver>, + mpsc::Receiver>, + mpsc::UnboundedReceiver>, ) { let config = SessionConfig { session_type: self.session_type, @@ -880,9 +895,9 @@ mod tests { metadata: self.metadata, }; - let (tx_slim, rx_slim) = tokio::sync::mpsc::channel(10); - let (tx_app, rx_app) = tokio::sync::mpsc::unbounded_channel(); - let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10); + let (tx_slim, rx_slim) = mpsc::channel(10); + let (tx_app, rx_app) = mpsc::unbounded_channel(); + let (tx_session_layer, _rx_session_layer) = mpsc::channel(10); let tx = SessionTransmitter::new(tx_slim, tx_app); @@ -1262,10 +1277,9 @@ mod tests { let participant_name = Name::from_strings(["org", "ns", "participant"]); let participant_name_id = Name::from_strings(["org", "ns", "participant"]).with_id(1); // create a SessionModerator - let (tx_slim_moderator, mut rx_slim_moderator) = tokio::sync::mpsc::channel(10); - let (tx_app_moderator, _rx_app_moderator) = tokio::sync::mpsc::unbounded_channel(); - let (tx_session_layer_moderator, _rx_session_layer_moderator) = - tokio::sync::mpsc::channel(10); + let (tx_slim_moderator, mut rx_slim_moderator) = mpsc::channel(10); + let (tx_app_moderator, _rx_app_moderator) = mpsc::unbounded_channel(); + let (tx_session_layer_moderator, _rx_session_layer_moderator) = mpsc::channel(10); let tx_moderator = SessionTransmitter::new(tx_slim_moderator.clone(), tx_app_moderator.clone()); @@ -1296,10 +1310,9 @@ mod tests { .unwrap(); // create a SessionParticipant - let (tx_slim_participant, mut rx_slim_participant) = tokio::sync::mpsc::channel(10); - let (tx_app_participant, mut rx_app_participant) = tokio::sync::mpsc::unbounded_channel(); - let (tx_session_layer_participant, _rx_session_layer_participant) = - tokio::sync::mpsc::channel(10); + let (tx_slim_participant, mut rx_slim_participant) = mpsc::channel(10); + let (tx_app_participant, mut rx_app_participant) = mpsc::unbounded_channel(); + let (tx_session_layer_participant, _rx_session_layer_participant) = mpsc::channel(10); let tx_participant = SessionTransmitter::new(tx_slim_participant.clone(), tx_app_participant.clone()); @@ -1713,7 +1726,6 @@ mod tests { #[tokio::test] async fn test_internal_draining_via_processing_state_switch() { use super::*; - use tokio::sync::mpsc; use tracing::debug; // Custom handler that flips processing_state to Draining after first normal message @@ -1938,10 +1950,10 @@ mod tests { fn create_test_settings( graceful_shutdown_timeout: Option, ) -> SessionSettings { - let (tx_slim, _rx_slim) = tokio::sync::mpsc::channel(10); - let (tx_app, _rx_app) = tokio::sync::mpsc::unbounded_channel(); - let (tx_session, _rx_session) = tokio::sync::mpsc::channel(10); - let (tx_session_layer, _rx_session_layer) = tokio::sync::mpsc::channel(10); + let (tx_slim, _rx_slim) = mpsc::channel(10); + let (tx_app, _rx_app) = mpsc::unbounded_channel(); + let (tx_session, _rx_session) = mpsc::channel(10); + let (tx_session_layer, _rx_session_layer) = mpsc::channel(10); let subscription_manager = crate::subscription_manager::SubscriptionManager::new(tx_slim.clone()); @@ -1999,7 +2011,7 @@ mod tests { /// Helper to spawn a processing loop and return the task handle fn spawn_processing_loop( handler: DrainableHandler, - rx: tokio::sync::mpsc::Receiver, + rx: mpsc::Receiver, cancellation_token: CancellationToken, settings: SessionSettings, ) -> tokio::task::JoinHandle<()> { @@ -2014,7 +2026,7 @@ mod tests { let messages_received = handler.messages_received.clone(); let shutdown_called = handler.shutdown_called.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); @@ -2065,7 +2077,7 @@ mod tests { let messages_received = handler.messages_received.clone(); let shutdown_called = handler.shutdown_called.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); @@ -2103,7 +2115,7 @@ mod tests { let messages_received = handler.messages_received.clone(); let shutdown_called = handler.shutdown_called.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); @@ -2148,7 +2160,7 @@ mod tests { // Test that the timeout fires when draining takes too long with needs_drain=true let handler = DrainableHandler::new().with_needs_drain(true); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); @@ -2206,7 +2218,7 @@ mod tests { let messages_received = handler.messages_received.clone(); let shutdown_called = handler.shutdown_called.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); @@ -2237,7 +2249,7 @@ mod tests { let handler = DrainableHandler::new(); let messages_received = handler.messages_received.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(10); + let (tx, rx) = mpsc::channel(10); let cancellation_token = CancellationToken::new(); let token_clone = cancellation_token.clone(); diff --git a/data-plane/core/session/src/session_layer.rs b/data-plane/core/session/src/session_layer.rs index d8b48f480..f80c11471 100644 --- a/data-plane/core/session/src/session_layer.rs +++ b/data-plane/core/session/src/session_layer.rs @@ -5,8 +5,8 @@ use std::collections::HashMap; use std::sync::Arc; -use display_error_chain::ErrorChainExt; // Third-party crates +use display_error_chain::ErrorChainExt; use parking_lot::RwLock as SyncRwLock; use rand::Rng; @@ -359,27 +359,23 @@ where tokio::spawn(async move { loop { - tokio::select! { - next = rx_session.recv() => { - match next { - Some(Ok(SessionMessage::DeleteSession { session_id })) => { - debug!(%session_id, "received closing signal, cancel session from the pool"); - if pool_clone.write().remove(&session_id).is_none() { - warn!(%session_id, "requested to delete unknown session"); - } - } - Some(Ok(m)) => { - error!(?m, "received unexpected message"); - } - Some(Err(e)) => { - warn!(error = %e.chain(), "error from session"); - } - None => { - // All senders dropped; exit loop. - break; - } + match rx_session.recv().await { + Some(Ok(SessionMessage::DeleteSession { session_id })) => { + debug!(%session_id, "received closing signal, cancel session from the pool"); + if pool_clone.write().remove(&session_id).is_none() { + warn!(%session_id, "requested to delete unknown session"); } } + Some(Ok(m)) => { + error!(?m, "received unexpected message"); + } + Some(Err(e)) => { + warn!(error = %e.chain(), "error from session"); + } + None => { + // All senders dropped; exit loop. + break; + } } } }.instrument(sessions_span)); @@ -485,16 +481,21 @@ where /// corresponding session #[tracing::instrument(skip_all, fields(service_id = %self.service_id))] pub async fn handle_message_from_slim(&self, mut message: Message) -> Result<(), SessionError> { - tracing::trace!( + tracing::info!( msg_type = %message.get_session_message_type().as_str_name(), session_id = %message.get_id(), - "received message from SLIM", + "handle_message_from_slim: received", ); // Pass message to interceptors in the transmitter self.transmitter .on_msg_from_slim_interceptors(&mut message) .await?; + tracing::info!( + msg_type = %message.get_session_message_type().as_str_name(), + session_id = %message.get_id(), + "handle_message_from_slim: interceptors passed", + ); let (id, session_type, session_message_type) = { // get the session type and the session id from the message @@ -522,6 +523,11 @@ where // check if we have a session for the given session ID let session_controller = self.pool.read().get(&id).cloned(); if let Some(controller) = session_controller { + tracing::info!( + %id, + msg_type = %session_message_type.as_str_name(), + "handle_message_from_slim: forwarding to session controller", + ); // pass the message to the session controller.on_message_from_slim(message).await?; diff --git a/data-plane/core/session/src/session_moderator.rs b/data-plane/core/session/src/session_moderator.rs index 0298261bd..075496903 100644 --- a/data-plane/core/session/src/session_moderator.rs +++ b/data-plane/core/session/src/session_moderator.rs @@ -7,7 +7,6 @@ use std::{ }; use async_trait::async_trait; -use display_error_chain::ErrorChainExt; use slim_auth::traits::{TokenProvider, Verifier}; use slim_datapath::{ api::{ @@ -21,8 +20,9 @@ use slim_datapath::{ }; use tokio::sync::oneshot; +use display_error_chain::ErrorChainExt; use slim_mls::mls::Mls; -use tracing::debug; +use tracing::{debug, info}; use crate::{ common::{MessageDirection, SessionMessage}, @@ -99,7 +99,8 @@ where /// Implementation of MessageHandler trait for SessionModerator /// This allows the moderator to be used as a layer in the generic layer system -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl MessageHandler for SessionModerator where P: TokenProvider + Send + Sync + Clone + 'static, @@ -110,11 +111,18 @@ where async fn init(&mut self) -> Result<(), SessionError> { // Initialize MLS self.mls_state = if self.common.settings.config.mls_enabled { - let mls_state = MlsState::new(Mls::new( + let mls = Mls::new( self.common.settings.identity_provider.clone(), self.common.settings.identity_verifier.clone(), - )) - .expect("failed to create MLS state"); + ); + + #[cfg(feature = "native")] + let mls_state = MlsState::new(mls).expect("failed to create MLS state"); + + #[cfg(all(feature = "wasm", not(feature = "native")))] + let mls_state = MlsState::new(mls) + .await + .expect("failed to create MLS state"); Some(MlsModeratorState::new(mls_state)) } else { @@ -153,7 +161,13 @@ where // Apply MLS encryption/decryption if enabled if let Some(mls_state) = &mut self.mls_state { + #[cfg(not(mls_build_async))] mls_state.common.process_message(&mut message, direction)?; + #[cfg(mls_build_async)] + mls_state + .common + .process_message(&mut message, direction) + .await?; } self.inner @@ -439,9 +453,15 @@ where // Compute MLS payload if needed let mls_payload = match self.mls_state.as_mut() { Some(state) => { + #[cfg(not(mls_build_async))] let mls_content = state .remove_participant(msg) .map_err(|e| self.handle_task_error(e))?; + #[cfg(mls_build_async)] + let mls_content = state + .remove_participant(msg) + .await + .map_err(|e| self.handle_task_error(e))?; let commit_id = self.mls_state.as_mut().unwrap().get_next_mls_mgs_id(); Some(MlsPayload { commit_id, @@ -656,13 +676,15 @@ where } async fn on_discovery_reply(&mut self, msg: Message) -> Result<(), SessionError> { - debug!( + info!( source = %msg.get_source(), id = msg.get_id(), - "discovery reply", + session_id = self.common.settings.id, + "discovery reply received", ); // update sender status to stop timers self.common.sender.on_message(&msg).await?; + info!(session_id = self.common.settings.id, "sender timer stopped"); // evolve the current task state // the discovery phase is completed @@ -670,14 +692,26 @@ where .as_mut() .unwrap() .discovery_complete(msg.get_id())?; + info!( + session_id = self.common.settings.id, + "discovery phase complete", + ); // join the channel if needed + info!( + session_id = self.common.settings.id, + mls = self.mls_state.is_some(), + "calling join", + ); self.join(msg.get_source(), msg.get_incoming_conn()).await?; + info!(session_id = self.common.settings.id, "join completed"); // set a route to the remote participant + info!(session_id = self.common.settings.id, route_to = %msg.get_source(), "adding route"); self.common .add_route(msg.get_source(), msg.get_incoming_conn()) .await?; + info!(session_id = self.common.settings.id, "route added"); // if this is a multicast session we need to add a route for the channel // on the connection from where we received the message. This has to be done @@ -761,7 +795,10 @@ where // get mls data if MLS is enabled let (commit, welcome) = if let Some(mls_state) = &mut self.mls_state { + #[cfg(not(mls_build_async))] let (commit_payload, welcome_payload) = mls_state.add_participant(&msg)?; + #[cfg(mls_build_async)] + let (commit_payload, welcome_payload) = mls_state.add_participant(&msg).await?; // get the id of the commit, the welcome message has a random id let commit_id = self.mls_state.as_mut().unwrap().get_next_mls_mgs_id(); @@ -1263,6 +1300,7 @@ where async fn join(&mut self, remote: Name, conn: u64) -> Result<(), SessionError> { if self.subscribed { + info!("join: already subscribed, skipping"); return Ok(()); } @@ -1272,16 +1310,24 @@ where // if this is a point to point connection set the remote name so that we // can add also the right id to the message destination name if self.common.settings.config.session_type == ProtoSessionType::PointToPoint { + info!(remote = %remote, "join: setting P2P destination"); self.common.settings.destination = remote; } else { // if this is a multicast session we need to subscribe for the channel name let destination = self.common.settings.destination.clone(); + info!(destination = %destination, "join: subscribing to multicast channel"); self.common.add_subscription(destination, conn).await?; + info!("join: multicast subscription complete"); } // create mls group if needed if let Some(mls) = self.mls_state.as_mut() { + info!("join: creating MLS group (init_moderator)"); + #[cfg(not(mls_build_async))] + mls.init_moderator()?; + #[cfg(mls_build_async)] mls.init_moderator().await?; + info!("join: MLS group created successfully"); } // add ourself to the participants diff --git a/data-plane/core/session/src/session_participant.rs b/data-plane/core/session/src/session_participant.rs index 477a737e4..360c0608f 100644 --- a/data-plane/core/session/src/session_participant.rs +++ b/data-plane/core/session/src/session_participant.rs @@ -78,7 +78,8 @@ where /// Implementation of MessageHandler trait for SessionParticipant /// This allows the participant to be used as a layer in the generic layer system -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl MessageHandler for SessionParticipant where P: TokenProvider + Send + Sync + Clone + 'static, @@ -89,11 +90,18 @@ where async fn init(&mut self) -> Result<(), SessionError> { // Initialize MLS self.mls_state = if self.common.settings.config.mls_enabled { - let mls_state = MlsState::new(Mls::new( + let mls = Mls::new( self.common.settings.identity_provider.clone(), self.common.settings.identity_verifier.clone(), - )) - .expect("failed to create MLS state"); + ); + + #[cfg(feature = "native")] + let mls_state = MlsState::new(mls).expect("failed to create MLS state"); + + #[cfg(all(feature = "wasm", not(feature = "native")))] + let mls_state = MlsState::new(mls) + .await + .expect("failed to create MLS state"); Some(mls_state) } else { @@ -120,7 +128,10 @@ where } else { // Apply MLS encryption/decryption if enabled if let Some(mls_state) = &mut self.mls_state { + #[cfg(not(mls_build_async))] mls_state.process_message(&mut message, direction)?; + #[cfg(mls_build_async)] + mls_state.process_message(&mut message, direction).await?; } self.inner @@ -352,7 +363,10 @@ where let payload = if let Some(mls_state) = &mut self.mls_state { debug!("mls enabled, create the package key"); + #[cfg(not(mls_build_async))] let key = mls_state.generate_key_package()?; + #[cfg(mls_build_async)] + let key = mls_state.generate_key_package().await?; Some(key) } else { None @@ -380,7 +394,10 @@ where ); if let Some(mls_state) = &mut self.mls_state { + #[cfg(not(mls_build_async))] mls_state.process_welcome_message(&msg)?; + #[cfg(mls_build_async)] + mls_state.process_welcome_message(&msg).await?; } self.join(&msg).await?; @@ -432,8 +449,13 @@ where if let Some(mls_state) = &mut self.mls_state { debug!("process mls control update"); + #[cfg(not(mls_build_async))] let ret = mls_state.process_control_message(msg.clone(), &self.common.settings.source)?; + #[cfg(mls_build_async)] + let ret = mls_state + .process_control_message(msg.clone(), &self.common.settings.source) + .await?; if !ret { debug!( diff --git a/data-plane/core/session/src/session_settings.rs b/data-plane/core/session/src/session_settings.rs index 19a41ea1a..aa161655e 100644 --- a/data-plane/core/session/src/session_settings.rs +++ b/data-plane/core/session/src/session_settings.rs @@ -3,6 +3,7 @@ use slim_auth::traits::{TokenProvider, Verifier}; use slim_datapath::messages::Name; +use tokio::sync::mpsc::Sender; use crate::{ SessionError, @@ -46,10 +47,10 @@ where pub(crate) tx: SessionTransmitter, /// Tx channel for sending messages to session queue - pub(crate) tx_session: tokio::sync::mpsc::Sender, + pub(crate) tx_session: Sender, /// Channel to send messages to the session layer - pub(crate) tx_to_session_layer: tokio::sync::mpsc::Sender>, + pub(crate) tx_to_session_layer: Sender>, /// Identity token provider pub(crate) identity_provider: P, diff --git a/data-plane/core/session/src/subscription_manager.rs b/data-plane/core/session/src/subscription_manager.rs index f00ab3d7e..274d1ff50 100644 --- a/data-plane/core/session/src/subscription_manager.rs +++ b/data-plane/core/session/src/subscription_manager.rs @@ -9,6 +9,7 @@ use std::time::Duration; use async_trait::async_trait; use futures::future::Either; +#[cfg(feature = "native")] use futures_timer::Delay; use parking_lot::Mutex; use thiserror::Error; @@ -438,14 +439,27 @@ impl SubscriptionManager { /// Await a previously registered ACK receiver, with a deadline of [`ACK_TIMEOUT`]. /// - /// Uses [`futures_timer::Delay`] rather than `tokio::time::timeout` so that - /// this function works correctly outside a Tokio runtime with the time driver - /// enabled (e.g. when called from UniFFI async bindings). + /// On native this uses [`futures_timer::Delay`] (rather than + /// `tokio::time::timeout`) so the function works correctly outside a + /// Tokio runtime with the time driver enabled — for example when called + /// from UniFFI async bindings. + /// + /// On wasm we cannot use `futures_timer::Delay` because its `Delay::new` + /// calls `std::time::Instant::now()`, which panics on + /// `wasm32-unknown-unknown` ("time not implemented on this platform"). + /// `tokio_with_wasm`'s `tokio::time::sleep` is wired to `setTimeout` + /// under the hood, so it does not require a Tokio runtime and works + /// correctly in the browser. pub async fn await_ack( ack_rx: oneshot::Receiver>, ) -> Result<(), SubscriptionAckError> { futures::pin_mut!(ack_rx); + + #[cfg(feature = "native")] let delay = Delay::new(ACK_TIMEOUT); + #[cfg(all(feature = "wasm", not(feature = "native")))] + let delay = tokio::time::sleep(ACK_TIMEOUT); + futures::pin_mut!(delay); match futures::future::select(ack_rx, delay).await { diff --git a/data-plane/core/session/src/timer.rs b/data-plane/core/session/src/timer.rs index 3af7bc967..327da2f1b 100644 --- a/data-plane/core/session/src/timer.rs +++ b/data-plane/core/session/src/timer.rs @@ -3,14 +3,16 @@ // Standard library imports use std::sync::Arc; +use std::time::Duration; // Third-party crates -use tokio::time::{self, Duration}; -use tokio_util::sync::CancellationToken; -use tonic::async_trait; +use crate::runtime::CancellationToken; +use async_trait::async_trait; +use tokio::time; use tracing::trace; -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] pub trait TimerObserver { async fn on_timeout(&self, timer_id: u32, timeouts: u32); async fn on_failure(&self, timer_id: u32, timeouts: u32); @@ -128,7 +130,7 @@ impl Timer { tokio::pin!(timer); tokio::select! { - _ = timer.as_mut() => { + _ = &mut timer => { timeouts += 1; match max_retries { Some(max) => { @@ -142,11 +144,11 @@ impl Timer { None => observer.on_timeout(timer_id, timeouts).await } retry += 1; - }, + } _ = cancellation_token.cancelled() => { observer.on_stop(timer_id).await; break; - }, + } } } }); @@ -176,6 +178,7 @@ impl Drop for Timer { // tests #[cfg(test)] mod tests { + use tokio::time; use tracing::debug; use tracing_test::traced_test; @@ -185,7 +188,8 @@ mod tests { id: u32, } - #[async_trait] + #[cfg_attr(feature = "native", async_trait)] + #[cfg_attr(feature = "wasm", async_trait(?Send))] impl TimerObserver for Observer { async fn on_timeout(&self, timer_id: u32, timeouts: u32) { debug!( diff --git a/data-plane/core/session/src/timer_factory.rs b/data-plane/core/session/src/timer_factory.rs index 2478da488..3fc279170 100644 --- a/data-plane/core/session/src/timer_factory.rs +++ b/data-plane/core/session/src/timer_factory.rs @@ -3,9 +3,9 @@ use std::{sync::Arc, time::Duration}; +use async_trait::async_trait; use slim_datapath::{api::ProtoSessionMessageType, messages::Name}; use tokio::sync::mpsc::Sender; -use tonic::async_trait; use tracing::debug; use crate::{ @@ -19,7 +19,8 @@ struct ReliableTimerObserver { name: Option, } -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] impl TimerObserver for ReliableTimerObserver { async fn on_timeout(&self, message_id: u32, timeouts: u32) { if let Err(e) = self diff --git a/data-plane/core/session/src/traits.rs b/data-plane/core/session/src/traits.rs index d93a4cacf..f25c2aba6 100644 --- a/data-plane/core/session/src/traits.rs +++ b/data-plane/core/session/src/traits.rs @@ -10,6 +10,27 @@ use slim_datapath::api::ProtoMessage as Message; use super::SessionInterceptorProvider; use crate::{common::SessionMessage, errors::SessionError}; +// Conditional Send/Sync bounds for cross-platform support +#[cfg(feature = "native")] +pub trait MaybeSend: Send {} +#[cfg(feature = "native")] +impl MaybeSend for T {} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub trait MaybeSend {} +#[cfg(all(feature = "wasm", not(feature = "native")))] +impl MaybeSend for T {} + +#[cfg(feature = "native")] +pub trait MaybeSync: Sync {} +#[cfg(feature = "native")] +impl MaybeSync for T {} + +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub trait MaybeSync {} +#[cfg(all(feature = "wasm", not(feature = "native")))] +impl MaybeSync for T {} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProcessingState { Active, @@ -17,7 +38,8 @@ pub enum ProcessingState { } /// Session transmitter trait -#[async_trait] +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] pub trait Transmitter: SessionInterceptorProvider { async fn send_to_slim(&self, message: Result) -> Result<(), SessionError>; @@ -29,8 +51,9 @@ pub trait Transmitter: SessionInterceptorProvider { /// /// Each layer implements this trait and can hold an inner layer. /// The layer decides whether to pass messages to its inner layer or handle them itself (or both). -#[async_trait] -pub trait MessageHandler: Send + Sync { +#[cfg_attr(feature = "native", async_trait)] +#[cfg_attr(feature = "wasm", async_trait(?Send))] +pub trait MessageHandler: MaybeSend + MaybeSync { /// Init the layer. async fn init(&mut self) -> Result<(), SessionError>; diff --git a/data-plane/core/session/src/transmitter.rs b/data-plane/core/session/src/transmitter.rs index 867819fc2..276ed2677 100644 --- a/data-plane/core/session/src/transmitter.rs +++ b/data-plane/core/session/src/transmitter.rs @@ -52,7 +52,8 @@ impl SessionInterceptorProvider for SessionTransmitter { } } -#[async_trait::async_trait] +#[cfg_attr(feature = "native", async_trait::async_trait)] +#[cfg_attr(feature = "wasm", async_trait::async_trait(?Send))] impl Transmitter for SessionTransmitter { async fn send_to_app( &self, @@ -120,7 +121,8 @@ impl SessionInterceptorProvider for AppTransmitter { } } -#[async_trait::async_trait] +#[cfg_attr(feature = "native", async_trait::async_trait)] +#[cfg_attr(feature = "wasm", async_trait::async_trait(?Send))] impl Transmitter for AppTransmitter { async fn send_to_app( &self, @@ -179,7 +181,8 @@ mod tests { pub slim_calls: Arc>, } - #[async_trait] + #[cfg_attr(feature = "native", async_trait)] + #[cfg_attr(feature = "wasm", async_trait(?Send))] impl SessionInterceptor for RecordingInterceptor { async fn on_msg_from_app(&self, msg: &mut Message) -> Result<(), SessionError> { *self.app_calls.write() += 1; diff --git a/data-plane/core/signal/Cargo.toml b/data-plane/core/signal/Cargo.toml index d30fc7443..531d8aed0 100644 --- a/data-plane/core/signal/Cargo.toml +++ b/data-plane/core/signal/Cargo.toml @@ -8,7 +8,12 @@ description = "Small library to handle OS signals." [lib] name = "slim_signal" +[features] +default = ["native"] +native = ["tokio"] +wasm = [] + [dependencies] agntcy-slim-version = { workspace = true } -tokio = { version = "1", features = ["macros", "signal"] } +tokio = { version = "1", features = ["macros", "signal"], optional = true } tracing = { workspace = true } diff --git a/data-plane/core/signal/src/lib.rs b/data-plane/core/signal/src/lib.rs index 82e385105..fe5d1328e 100644 --- a/data-plane/core/signal/src/lib.rs +++ b/data-plane/core/signal/src/lib.rs @@ -5,7 +5,7 @@ pub async fn shutdown() { imp::shutdown().await } -#[cfg(unix)] +#[cfg(all(feature = "native", unix))] mod imp { use tokio::signal::unix::{SignalKind, signal}; use tracing::info; @@ -33,7 +33,7 @@ mod imp { } } -#[cfg(not(unix))] +#[cfg(all(feature = "native", not(unix)))] mod imp { use tracing::info; @@ -48,3 +48,12 @@ mod imp { ); } } + +#[cfg(all(feature = "wasm", not(feature = "native")))] +mod imp { + /// In WASM there are no OS signals. This future will pend forever. + /// Callers should use their own cancellation mechanism (e.g. CancellationToken). + pub(super) async fn shutdown() { + std::future::pending::<()>().await + } +} diff --git a/data-plane/core/slim-wasm/Cargo.toml b/data-plane/core/slim-wasm/Cargo.toml new file mode 100644 index 000000000..f1de3c99c --- /dev/null +++ b/data-plane/core/slim-wasm/Cargo.toml @@ -0,0 +1,33 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "agntcy-slim-wasm" +edition = { workspace = true } +license = { workspace = true } +version = "0.1.0" +description = "WASM/browser entry point for SLIM data plane." + +[lib] +name = "slim_wasm" +crate-type = ["cdylib", "rlib"] + +[dependencies] +agntcy-slim-auth = { workspace = true, default-features = false, features = ["wasm"] } +agntcy-slim-config = { workspace = true, default-features = false, features = ["wasm"] } +agntcy-slim-datapath = { workspace = true, default-features = false, features = ["wasm"] } +agntcy-slim-mls = { workspace = true, default-features = false, features = ["wasm"] } +agntcy-slim-session = { workspace = true, default-features = false, features = ["wasm"] } +agntcy-slim-tracing = { workspace = true, default-features = false, features = ["wasm"] } +console_error_panic_hook = "0.1" + +js-sys = "0.3" +parking_lot = { workspace = true } +tokio_with_wasm = { version = "0.9", features = ["rt", "sync", "time", "macros"] } +tracing = { workspace = true } +url = { workspace = true } +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +gloo-net = "0.6" diff --git a/data-plane/core/slim-wasm/src/lib.rs b/data-plane/core/slim-wasm/src/lib.rs new file mode 100644 index 000000000..d79cb78aa --- /dev/null +++ b/data-plane/core/slim-wasm/src/lib.rs @@ -0,0 +1,942 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! WASM/Browser entry point for the SLIM data plane. +//! +//! Provides a JavaScript-callable API via `wasm-bindgen` for connecting to a +//! SLIM data-plane instance from a web browser over WebSocket. +//! +//! ## Architecture +//! +//! Unlike the native data plane (which runs an in-process forwarder serving +//! many local apps and many remote peers), the browser is *just an app*: it +//! drives a single embedded `MessageProcessor` whose only purpose is to +//! multiplex one local connection (the JS-facing app) onto one or more +//! outgoing WebSocket connections to remote SLIM nodes. The session layer +//! talks to the embedded data plane through the same `(tx_slim, rx_slim)` +//! channel pair used on native, so subscription routing, link negotiation, +//! and remote subscription acks all just work. +//! +//! ## Quick Start (JavaScript) +//! +//! ```js +//! import init, { initTracing, SlimClient } from "slim_wasm"; +//! +//! await init(); +//! initTracing(); +//! +//! const client = await SlimClient.connect( +//! "ws://localhost:46357", "my-shared-secret-at-least-32-bytes!!", "org", "ns", "app" +//! ); +//! +//! // Subscribe to receive messages addressed to this name +//! await client.subscribe("org", "ns", "app"); +//! +//! // Start listening for events (messages, new sessions) +//! client.listen( +//! (msg) => console.log("message:", msg), +//! (session) => console.log("new session:", session), +//! ); +//! +//! // Create a session and publish +//! const sessionId = await client.createSession("org", "ns", "remote-app", "point-to-point"); +//! await client.publish(sessionId, new TextEncoder().encode("hello"), "text/plain"); +//! +//! // Connect to additional SLIM nodes — the browser data plane will route +//! // messages by destination Name across all active connections. +//! const conn2 = await client.addConnection("ws://other-slim:46357"); +//! await client.subscribe("org", "ns", "another-name", conn2); +//! ``` + +use wasm_bindgen::prelude::*; + +/// Initialize SLIM tracing (console-based logging for browser dev tools). +/// Call this once before using other SLIM APIs. +#[wasm_bindgen(js_name = "initTracing")] +pub fn init_tracing() { + console_error_panic_hook::set_once(); + + use slim_tracing::TracingConfiguration; + let config = TracingConfiguration::default(); + let _ = config.setup_tracing_subscriber(); +} + +// ── WASM-only implementation ── + +#[cfg(target_arch = "wasm32")] +mod wasm_impl { + use std::collections::HashMap; + use std::sync::Arc; + + use wasm_bindgen::prelude::*; + + use slim_auth::shared_secret::SharedSecret; + use slim_auth::traits::TokenProvider; + use slim_datapath::api::{ + MessageType, ProtoMessage, ProtoSessionMessageType, ProtoSessionType, + }; + use slim_datapath::message_processing::MessageProcessor; + use slim_datapath::messages::Name; + use slim_datapath::messages::utils::SlimHeaderFlags; + use slim_datapath::runtime::CancellationToken; + use slim_session::interceptor::{IdentityInterceptor, SessionInterceptorProvider}; + use slim_session::notification::Notification; + use slim_session::session_controller::SessionController; + use slim_session::subscription_manager::{SubscriptionManager, SubscriptionOps}; + use slim_session::transmitter::AppTransmitter; + use slim_session::{Direction, SessionConfig, SessionError, SessionLayer}; + use tokio_with_wasm::sync::mpsc; + + /// Wrapper for `js_sys::Function` that satisfies `Send + Sync` bounds. + /// + /// SAFETY: This is only used on `wasm32-unknown-unknown` which is single-threaded. + /// The `Send`/`Sync` markers exist solely to satisfy trait bounds on `Arc>` + /// wrappers required by the session layer infrastructure. + #[derive(Clone)] + struct JsCallback(js_sys::Function); + // SAFETY: WASM is single-threaded; no data races are possible. + unsafe impl Send for JsCallback {} + unsafe impl Sync for JsCallback {} + + /// A SLIM client that connects to one or more data-plane instances over + /// WebSocket. Provides subscribe/unsubscribe, session creation, message + /// publishing, and event listening for the browser. + #[wasm_bindgen] + pub struct SlimClient { + app_name: Name, + message_processor: Arc, + session_layer: Arc>, + sessions: Arc>>>, + #[allow(clippy::type_complexity)] + notification_rx: + Arc>>>>, + subscription_manager: SubscriptionManager, + on_message_cb: Arc>>, + cancel_token: CancellationToken, + /// Connection IDs of each remote SLIM node we've opened a websocket to. + /// Used as the default forward target for `subscribe` when the caller + /// does not specify a connection explicitly. + remote_conn_ids: Arc>>, + /// JWT-style token derived from the shared secret. Reused for each + /// new outgoing websocket connection (sent as a `?token=` query + /// parameter, matching the legacy slim-wasm contract). + auth_token: String, + } + + #[wasm_bindgen] + impl SlimClient { + /// Create a SLIM client and connect it to a first SLIM data-plane + /// instance via WebSocket. + /// + /// # Arguments + /// * `endpoint` - WebSocket URL (e.g. `ws://localhost:46357`) + /// * `shared_secret` - Shared secret for HMAC authentication (>= 32 bytes) + /// * `org` - Organization component of the local app name + /// * `ns` - Namespace component + /// * `app` - Application component + #[wasm_bindgen] + pub async fn connect( + endpoint: &str, + shared_secret: &str, + org: &str, + ns: &str, + app: &str, + ) -> Result { + let client = SlimClient::new_internal(shared_secret, org, ns, app)?; + client.add_connection(endpoint).await?; + Ok(client) + } + + /// Internal constructor: builds the embedded MessageProcessor, the + /// local connection that backs the App, and the session layer. No + /// websocket has been opened yet at this point. + fn new_internal( + shared_secret: &str, + org: &str, + ns: &str, + app: &str, + ) -> Result { + // Auth + let auth = SharedSecret::new(app, shared_secret) + .map_err(|e| JsError::new(&format!("auth error: {e}")))?; + + // Generate app ID from identity token ID (mirrors native App::new_with_direction) + let app_id = { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let token_id = auth + .get_id() + .map_err(|e| JsError::new(&format!("get_id error: {e}")))?; + let mut hasher = DefaultHasher::new(); + token_id.hash(&mut hasher); + hasher.finish() + }; + let app_name = Name::from_strings([org, ns, app]).with_id(app_id); + let auth_verifier = SharedSecret::new(app, shared_secret) + .map_err(|e| JsError::new(&format!("auth error: {e}")))?; + let auth_token = auth + .get_token() + .map_err(|e| JsError::new(&format!("token error: {e}")))?; + + // Create the embedded data plane. No recovery TTL: a browser tab + // doesn't outlive its connections in any meaningful way. + let message_processor = Arc::new(MessageProcessor::new_with_options( + format!("slim-wasm/{app_name}"), + Some(std::time::Duration::ZERO), + )); + + // Register the App's local connection on the data plane. From the + // forwarder's perspective, this is the same kind of "local app + // connection" that slim-service::App uses on native. + let (conn_id, tx_slim, rx_slim) = message_processor + .register_local_connection(false) + .map_err(|e| JsError::new(&format!("register_local_connection error: {e}")))?; + + // Channel used by SessionLayer to deliver session notifications + // (NewSession, NewMessage) up to the JS callback. + let (tx_app, notification_rx) = mpsc::channel(64); + + // Transmitter wires session interceptors (e.g. IdentityInterceptor) + // into outbound messages. + let transmitter = AppTransmitter { + slim_tx: tx_slim.clone(), + app_tx: tx_app.clone(), + interceptors: Arc::new(parking_lot::RwLock::new(vec![])), + }; + let identity_interceptor = Arc::new(IdentityInterceptor::new( + auth.clone(), + auth_verifier.clone(), + )); + transmitter.add_interceptor(identity_interceptor); + + let session_layer = Arc::new(SessionLayer::new( + app_name.clone(), + auth.clone(), + auth_verifier, + conn_id, + tx_slim, + tx_app, + transmitter, + Direction::Bidirectional, + String::new(), + )); + + let subscription_manager = session_layer.subscription_manager(); + + let cancel_token = CancellationToken::new(); + + // Start the local-app message loop: read from rx_slim, dispatch + // SubscriptionAcks to the subscription manager, and forward + // Publish messages into the SessionLayer. This mirrors + // `slim_service::App::process_messages` for the native case, + // including its initial *local-only* self-subscription so the + // embedded data plane has a route from any inbound publishes + // (addressed to this app's name) into the App layer. + spawn_app_message_loop( + app_name.clone(), + rx_slim, + session_layer.clone(), + subscription_manager.clone(), + cancel_token.child_token(), + ); + + Ok(SlimClient { + app_name, + message_processor, + session_layer, + sessions: Arc::new(parking_lot::Mutex::new(HashMap::new())), + notification_rx: Arc::new(parking_lot::Mutex::new(Some(notification_rx))), + subscription_manager, + on_message_cb: Arc::new(parking_lot::Mutex::new(None)), + cancel_token, + remote_conn_ids: Arc::new(parking_lot::Mutex::new(Vec::new())), + auth_token, + }) + } + + /// Open a WebSocket to an additional SLIM data-plane instance and + /// register it with the embedded data plane. Returns the new + /// connection ID, which can be passed to [`subscribe`] or + /// [`unsubscribe`] to scope a subscription to a specific node. + /// + /// Calling `addConnection` multiple times lets a single browser tab + /// fan out across several SLIM nodes; the embedded data plane + /// forwards messages by destination Name using the subscription + /// table — no JS-side routing logic required. + #[wasm_bindgen(js_name = "addConnection")] + pub async fn add_connection(&self, endpoint: &str) -> Result { + // Build URL with shared-secret token query param. We do this + // ourselves rather than via slim_config's auth flow because + // SharedSecret/JWT auth is not yet exposed through the wasm + // build of slim-config. + let mut url = url::Url::parse(endpoint) + .map_err(|e| JsError::new(&format!("invalid URL: {e}")))?; + url.query_pairs_mut().append_pair("token", &self.auth_token); + + let ws = gloo_net::websocket::futures::WebSocket::open(url.as_str()) + .map_err(|e| JsError::new(&format!("websocket error: {e}")))?; + + let (_handle, conn_id) = self + .message_processor + .register_websocket(ws, None) + .map_err(|e| JsError::new(&format!("register_websocket error: {e}")))?; + + self.remote_conn_ids.lock().push(conn_id); + + tracing::info!( + %endpoint, %conn_id, + "registered new remote SLIM websocket connection", + ); + + Ok(conn_id as u32) + } + + /// Subscribe to receive messages for the given name (`org`/`ns`/`name`). + /// + /// If `conn_id` is provided, the subscription is forwarded to that + /// specific remote SLIM connection. If `conn_id` is `None`, the + /// most recently added remote connection is used (matching the + /// single-connection legacy behavior). + #[wasm_bindgen] + pub async fn subscribe( + &self, + org: &str, + ns: &str, + name: &str, + conn_id: Option, + ) -> Result<(), JsError> { + let sub_name = Name::from_strings([org, ns, name]).with_id(self.session_layer.app_id()); + + let target_conn = self.resolve_conn(conn_id); + + let (subscription_id, ack_rx) = self + .subscription_manager + .subscribe(&self.app_name, &sub_name, target_conn) + .await + .map_err(|e| JsError::new(&format!("subscribe error: {e}")))?; + + SubscriptionManager::await_ack(ack_rx) + .await + .map_err(|e| JsError::new(&format!("subscription rejected: {e}")))?; + + self.session_layer.add_app_name(sub_name, subscription_id); + tracing::info!("subscribed to {org}/{ns}/{name} (conn={target_conn:?})"); + Ok(()) + } + + /// Unsubscribe from a name. + #[wasm_bindgen] + pub async fn unsubscribe( + &self, + org: &str, + ns: &str, + name: &str, + conn_id: Option, + ) -> Result<(), JsError> { + let unsub_name = + Name::from_strings([org, ns, name]).with_id(self.session_layer.app_id()); + + let subscription_id = self.session_layer.remove_app_name(&unsub_name).unwrap_or(0); + let target_conn = self.resolve_conn(conn_id); + + let ack_rx = self + .subscription_manager + .unsubscribe(&self.app_name, &unsub_name, subscription_id, target_conn) + .await + .map_err(|e| JsError::new(&format!("unsubscribe error: {e}")))?; + + SubscriptionManager::await_ack(ack_rx) + .await + .map_err(|e| JsError::new(&format!("unsubscription rejected: {e}")))?; + + tracing::info!("unsubscribed from {org}/{ns}/{name}"); + Ok(()) + } + + /// Pick a default forwarding target when the caller hasn't specified + /// one: the most recently added remote connection. + fn resolve_conn(&self, conn_id: Option) -> Option { + match conn_id { + Some(id) => Some(id as u64), + None => self.remote_conn_ids.lock().last().copied(), + } + } + + /// Install a forwarding route in the embedded data plane: tells the + /// browser-side data plane that messages destined for `org/ns/name` + /// should be forwarded onto a specific remote SLIM connection. + /// + /// Without a route, the local data plane has no entry for arbitrary + /// remote names so a publish (e.g. an `inviteParticipant` discovery + /// request, or a `publish` to a peer the browser hasn't subscribed + /// to) returns `NoMatch` and never reaches the wire. + /// + /// Mirrors `slim_service::App::set_route` on native: emits a + /// `Subscribe` with `recv_from=conn_id`, which registers the + /// destination at that connection in the local subscription table. + /// If `conn_id` is omitted, the most recently added remote + /// connection is used. + #[wasm_bindgen(js_name = "setRoute")] + pub async fn set_route( + &self, + org: &str, + ns: &str, + name: &str, + conn_id: Option, + ) -> Result<(), JsError> { + let target = self.resolve_conn(conn_id).ok_or_else(|| { + JsError::new("no remote connection available; call addConnection first") + })?; + let dest = Name::from_strings([org, ns, name]); + let mut msg = ProtoMessage::builder() + .source(self.app_name.clone()) + .destination(dest) + .flags(SlimHeaderFlags::default().with_recv_from(target)) + .build_subscribe() + .map_err(|e| JsError::new(&format!("set_route build error: {e}")))?; + + let identity = self + .session_layer + .get_identity_token() + .map_err(|e| JsError::new(&format!("identity error: {e}")))?; + msg.get_slim_header_mut().set_identity(identity); + + self.session_layer + .tx_slim() + .send(Ok(msg)) + .await + .map_err(|e| JsError::new(&format!("set_route send error: {e}")))?; + + tracing::info!("set route to {org}/{ns}/{name} via conn={target}"); + Ok(()) + } + + /// Remove a previously-installed forwarding route. Mirrors + /// `slim_service::App::remove_route` on native (sends an + /// `Unsubscribe` with `recv_from=conn_id`). + #[wasm_bindgen(js_name = "removeRoute")] + pub async fn remove_route( + &self, + org: &str, + ns: &str, + name: &str, + conn_id: Option, + ) -> Result<(), JsError> { + let target = self + .resolve_conn(conn_id) + .ok_or_else(|| JsError::new("no remote connection available"))?; + let dest = Name::from_strings([org, ns, name]); + let mut msg = ProtoMessage::builder() + .source(self.app_name.clone()) + .destination(dest) + .flags(SlimHeaderFlags::default().with_recv_from(target)) + .build_unsubscribe() + .map_err(|e| JsError::new(&format!("remove_route build error: {e}")))?; + + let identity = self + .session_layer + .get_identity_token() + .map_err(|e| JsError::new(&format!("identity error: {e}")))?; + msg.get_slim_header_mut().set_identity(identity); + + self.session_layer + .tx_slim() + .send(Ok(msg)) + .await + .map_err(|e| JsError::new(&format!("remove_route send error: {e}")))?; + + tracing::info!("removed route to {org}/{ns}/{name} via conn={target}"); + Ok(()) + } + + /// Create a session to a remote application. + /// + /// `session_type` must be `"point-to-point"` (or `"p2p"`) or `"multicast"`. + /// Set `mls_enabled` to `true` for end-to-end encrypted group sessions (MLS). + /// Returns the session ID on success. + #[wasm_bindgen(js_name = "createSession")] + pub async fn create_session( + &self, + dest_org: &str, + dest_ns: &str, + dest_app: &str, + session_type: &str, + mls_enabled: Option, + ) -> Result { + let destination = Name::from_strings([dest_org, dest_ns, dest_app]); + + let proto_session_type = match session_type { + "point-to-point" | "p2p" => ProtoSessionType::PointToPoint, + "multicast" => ProtoSessionType::Multicast, + _ => { + return Err(JsError::new(&format!( + "unknown session type: {session_type}. Use 'point-to-point' or 'multicast'" + ))); + } + }; + + let config = SessionConfig { + session_type: proto_session_type, + max_retries: Some(10), + interval: Some(std::time::Duration::from_secs(2)), + mls_enabled: mls_enabled.unwrap_or(false), + initiator: true, + metadata: HashMap::new(), + }; + + let (session_ctx, init_ack) = self + .session_layer + .create_session(config, self.app_name.clone(), destination, None) + .await + .map_err(|e| JsError::new(&format!("create session error: {e}")))?; + + let session_id = session_ctx.session_id(); + let session_arc = session_ctx + .session_arc() + .ok_or_else(|| JsError::new("session already dropped"))?; + + self.sessions.lock().insert(session_id, session_arc.clone()); + + // Spawn a receiver to forward session messages to the JS callback + let on_msg = self.on_message_cb.clone(); + let (_, mut rx) = session_ctx.into_parts(); + let sid = session_id; + wasm_bindgen_futures::spawn_local(async move { + tracing::info!(session_id = sid, "session receiver started"); + while let Some(msg_result) = rx.recv().await { + match msg_result { + Ok(msg) => { + tracing::info!( + session_id = sid, + "session receiver got message, invoking JS callback", + ); + let cb = on_msg.lock().clone(); + if let Some(callback) = cb { + let obj = build_message_js_object(&msg); + if let Err(e) = callback.0.call1(&JsValue::NULL, &obj) { + tracing::error!( + session_id = sid, + error = ?e, + "JS on_message callback threw", + ); + } + } else { + tracing::warn!( + session_id = sid, + "no on_message callback registered", + ); + } + } + Err(e) => { + tracing::warn!( + session_id = sid, + error = ?e, + "session receiver got error", + ); + } + } + } + tracing::info!(session_id = sid, "session receiver ended"); + }); + + // Wait for session establishment (discovery handshake) with timeout + tracing::info!(session_id, "waiting for session init_ack (30s timeout)"); + let init_result: Result, _> = tokio_with_wasm::select! { + result = init_ack => Ok(result), + _ = tokio_with_wasm::time::sleep(std::time::Duration::from_secs(30)) => { + Err("timeout") + } + }; + match init_result { + Ok(inner) => { + inner.map_err(|e| JsError::new(&format!("session init error: {e}")))?; + } + Err(_) => { + return Err(JsError::new(&format!( + "session {session_id} init timed out after 30s" + ))); + } + } + + tracing::info!("session {session_id} created"); + Ok(session_id) + } + + /// Publish a message through an established session. + /// + /// # Arguments + /// * `session_id` - Session ID returned by `createSession()` + /// * `payload` - Message payload bytes + /// * `payload_type` - Optional MIME type (defaults to `"msg"`) + #[wasm_bindgen] + pub async fn publish( + &self, + session_id: u32, + payload: &[u8], + payload_type: Option, + ) -> Result<(), JsError> { + let session = self + .sessions + .lock() + .get(&session_id) + .cloned() + .ok_or_else(|| JsError::new(&format!("session {session_id} not found")))?; + + let dest = session.dst().clone(); + let ack = session + .publish(&dest, payload.to_vec(), payload_type, None) + .await + .map_err(|e| JsError::new(&format!("publish error: {e}")))?; + + ack.await + .map_err(|e| JsError::new(&format!("publish ack error: {e}")))?; + + Ok(()) + } + + /// Start listening for incoming messages and session notifications. + /// + /// `on_message` is called with a JS object: + /// ```js + /// { source: string, payload: Uint8Array, payloadType: string, sessionId: number } + /// ``` + /// + /// `on_session` is called with a JS object: + /// ```js + /// { sessionId: number, source: string, destination: string } + /// ``` + /// + /// Can only be called once. Call before `createSession()` to receive + /// messages on sessions created by the remote side. + #[wasm_bindgen] + pub fn listen( + &self, + on_message: js_sys::Function, + on_session: js_sys::Function, + ) -> Result<(), JsError> { + // Store callback for use by session receivers + *self.on_message_cb.lock() = Some(JsCallback(on_message.clone())); + + let mut rx = self + .notification_rx + .lock() + .take() + .ok_or_else(|| JsError::new("listen() has already been called"))?; + + let sessions = self.sessions.clone(); + let on_msg_store = self.on_message_cb.clone(); + let cancel = self.cancel_token.child_token(); + let on_message = JsCallback(on_message); + let on_session = JsCallback(on_session); + + wasm_bindgen_futures::spawn_local(async move { + loop { + let next = tokio_with_wasm::select! { + m = rx.recv() => m, + _ = cancel.cancelled() => None, + }; + match next { + None => break, + Some(Ok(notification)) => match notification { + Notification::NewMessage(msg) => { + let obj = build_message_js_object(&msg); + on_message.0.call1(&JsValue::NULL, &obj).ok(); + } + Notification::NewSession(ctx) => { + let session_id = ctx.session_id(); + let session_arc = ctx.session_arc(); + + // Build JS notification + let obj = js_sys::Object::new(); + js_sys::Reflect::set( + &obj, + &"sessionId".into(), + &JsValue::from(session_id), + ) + .ok(); + + if let Some(ref sa) = session_arc { + let src = format!("{}", sa.source()); + let dst = format!("{}", sa.dst()); + js_sys::Reflect::set(&obj, &"source".into(), &src.into()).ok(); + js_sys::Reflect::set(&obj, &"destination".into(), &dst.into()) + .ok(); + sessions.lock().insert(session_id, sa.clone()); + } + + // Spawn receiver for incoming session messages + let on_msg_clone = on_msg_store.clone(); + let (_, mut srx) = ctx.into_parts(); + wasm_bindgen_futures::spawn_local(async move { + while let Some(msg_result) = srx.recv().await { + if let Ok(msg) = msg_result { + let cb = on_msg_clone.lock().clone(); + if let Some(callback) = cb { + let obj = build_message_js_object(&msg); + callback.0.call1(&JsValue::NULL, &obj).ok(); + } + } + } + }); + + on_session.0.call1(&JsValue::NULL, &obj.into()).ok(); + } + }, + Some(Err(e)) => { + tracing::warn!("notification error: {e}"); + } + } + } + tracing::info!("notification loop ended"); + }); + + Ok(()) + } + + /// Invite a participant to a multicast (group) session. + /// + /// Only the session initiator (moderator) can invite participants. + /// The participant must be subscribed to the given name. + #[wasm_bindgen(js_name = "inviteParticipant")] + pub async fn invite_participant( + &self, + session_id: u32, + org: &str, + ns: &str, + name: &str, + ) -> Result<(), JsError> { + let session = self + .sessions + .lock() + .get(&session_id) + .cloned() + .ok_or_else(|| JsError::new(&format!("session {session_id} not found")))?; + + let dest = Name::from_strings([org, ns, name]); + let ack = session + .invite_participant(&dest) + .await + .map_err(|e| JsError::new(&format!("invite error: {e}")))?; + + ack.await + .map_err(|e| JsError::new(&format!("invite ack error: {e}")))?; + + tracing::info!("invited {org}/{ns}/{name} to session {session_id}"); + Ok(()) + } + + /// Remove a participant from a multicast (group) session. + /// + /// Only the session initiator (moderator) can remove participants. + #[wasm_bindgen(js_name = "removeParticipant")] + pub async fn remove_participant( + &self, + session_id: u32, + org: &str, + ns: &str, + name: &str, + ) -> Result<(), JsError> { + let session = self + .sessions + .lock() + .get(&session_id) + .cloned() + .ok_or_else(|| JsError::new(&format!("session {session_id} not found")))?; + + let dest = Name::from_strings([org, ns, name]); + let ack = session + .remove_participant(&dest) + .await + .map_err(|e| JsError::new(&format!("remove error: {e}")))?; + + ack.await + .map_err(|e| JsError::new(&format!("remove ack error: {e}")))?; + + tracing::info!("removed {org}/{ns}/{name} from session {session_id}"); + Ok(()) + } + + /// Get the list of participants in a session. + #[wasm_bindgen(js_name = "participantsList")] + pub async fn participants_list(&self, session_id: u32) -> Result, JsError> { + let session = self + .sessions + .lock() + .get(&session_id) + .cloned() + .ok_or_else(|| JsError::new(&format!("session {session_id} not found")))?; + + let names = session + .participants_list() + .await + .map_err(|e| JsError::new(&format!("participants list error: {e}")))?; + + Ok(names.into_iter().map(|n| format!("{n}")).collect()) + } + + /// Delete a session by ID. + #[wasm_bindgen(js_name = "deleteSession")] + pub fn delete_session(&self, session_id: u32) -> Result<(), JsError> { + self.sessions.lock().remove(&session_id); + let _ = self.session_layer.remove_session(session_id); + tracing::info!("session {session_id} deleted"); + Ok(()) + } + + /// Disconnect from all SLIM data planes by cancelling background tasks. + /// The data plane closes each registered websocket cleanly via its + /// per-connection cancellation token. + #[wasm_bindgen] + pub fn disconnect(&self) { + self.cancel_token.cancel(); + for conn_id in self.remote_conn_ids.lock().drain(..) { + let _ = self.message_processor.disconnect(conn_id); + } + tracing::info!("disconnected from SLIM"); + } + + /// Get the local app name. + #[wasm_bindgen(getter, js_name = "appName")] + pub fn app_name(&self) -> String { + format!("{}", self.app_name) + } + + /// Get IDs of all active sessions. + #[wasm_bindgen(js_name = "sessionIds")] + pub fn session_ids(&self) -> Vec { + self.sessions.lock().keys().copied().collect() + } + + /// Get the connection IDs of all currently registered remote + /// SLIM nodes (in the order they were added). + #[wasm_bindgen(js_name = "connectionIds")] + pub fn connection_ids(&self) -> Vec { + self.remote_conn_ids + .lock() + .iter() + .map(|&id| id as u32) + .collect() + } + } + + /// Pump messages from the data plane → local app and dispatch them to + /// the SessionLayer / SubscriptionManager. Equivalent to + /// `slim_service::App::process_messages` for the native build. + /// + /// Mirrors the native loop's startup behaviour: before entering the + /// dispatch loop we kick off a *local-only* self-subscription + /// (`subscribe(app_name, app_name, None)` — no `forward_to`) so the + /// embedded data plane registers a route from this name to the local + /// app connection. That route is what lets inbound publishes from the + /// gateway eventually land in `rx_slim` and surface to the App layer. + /// It does *not* register the name on the gateway — the user is still + /// expected to call `SlimClient::subscribe(...)` for that, exactly as + /// `sdk-mock` calls `app.subscribe(app.app_name(), Some(conn_id))` + /// after `Service::connect`. + fn spawn_app_message_loop( + app_name: Name, + mut rx: mpsc::Receiver>, + session_layer: Arc>, + subscription_manager: SubscriptionManager, + cancel: CancellationToken, + ) { + wasm_bindgen_futures::spawn_local(async move { + // Initiate the local self-subscription so the ACK is tracked + // and resolved through the normal loop machinery below. This + // mirrors `slim_service::App::process_messages`'s startup. + let (_init_sub_id, init_ack_rx) = subscription_manager + .subscribe(&app_name, &app_name, None) + .await + .expect("error sending initial self-subscription"); + let mut init_ack_future = std::pin::pin!(init_ack_rx); + let mut init_ack_done = false; + + loop { + tokio_with_wasm::select! { + next = rx.recv() => { + match next { + None => { + tracing::info!("app message loop ended (slim channel closed)"); + break; + } + Some(Ok(msg)) => { + match msg.message_type.as_ref() { + Some(MessageType::Publish(_)) => { + if let Err(e) = + session_layer.handle_message_from_slim(msg).await + && !matches!(e, SessionError::SubscriptionNotFound(_)) + { + tracing::warn!(error = ?e, "handle_message_from_slim error"); + } + } + Some(MessageType::SubscriptionAck(ack)) => { + subscription_manager.resolve_ack(ack); + } + // Subscribe / Unsubscribe / Link / None: not + // expected on the app-facing channel; the data + // plane's `process_stream` already handled them. + _ => {} + } + } + Some(Err(e)) => { + tracing::warn!(error = %e, "received error from SLIM"); + } + } + } + result = &mut init_ack_future, if !init_ack_done => { + init_ack_done = true; + match result { + Ok(Ok(())) => tracing::debug!(%app_name, "initial self-subscription confirmed"), + Ok(Err(e)) => tracing::error!(%app_name, error = %e, "initial self-subscription failed"), + Err(_) => tracing::error!(%app_name, "initial self-subscription ack channel closed"), + } + } + _ = cancel.cancelled() => { + break; + } + } + } + }); + } + + /// Convert a ProtoMessage into a JS object for the on_message callback. + fn build_message_js_object(msg: &ProtoMessage) -> JsValue { + let obj = js_sys::Object::new(); + + let source = format!("{}", msg.get_source()); + js_sys::Reflect::set(&obj, &"source".into(), &source.into()).ok(); + + let session_id = msg.get_session_header().session_id; + js_sys::Reflect::set(&obj, &"sessionId".into(), &JsValue::from(session_id)).ok(); + + if msg.is_publish() { + if let Some(content) = msg.get_payload() { + if let Ok(app_payload) = content.as_application_payload() { + let arr = js_sys::Uint8Array::from(&app_payload.blob[..]); + js_sys::Reflect::set(&obj, &"payload".into(), &arr).ok(); + js_sys::Reflect::set( + &obj, + &"payloadType".into(), + &app_payload.payload_type.clone().into(), + ) + .ok(); + } + } + } + + // Session message type (Msg, JoinRequest, etc.) — useful for filtering on JS side. + let smt: ProtoSessionMessageType = msg + .get_session_header() + .session_message_type + .try_into() + .unwrap_or_default(); + js_sys::Reflect::set( + &obj, + &"sessionMessageType".into(), + &JsValue::from_str(smt.as_str_name()), + ) + .ok(); + + obj.into() + } +} diff --git a/data-plane/core/slim/Cargo.toml b/data-plane/core/slim/Cargo.toml index 13d1dad29..9f5d84754 100644 --- a/data-plane/core/slim/Cargo.toml +++ b/data-plane/core/slim/Cargo.toml @@ -17,10 +17,10 @@ default = ["multicore"] multicore = ["tokio/rt-multi-thread", "num_cpus"] [dependencies] -agntcy-slim-config = { workspace = true } -agntcy-slim-service = { workspace = true } -agntcy-slim-signal = { workspace = true } -agntcy-slim-tracing = { workspace = true } +agntcy-slim-config = { workspace = true, features = ["native"] } +agntcy-slim-service = { workspace = true, features = ["native"] } +agntcy-slim-signal = { workspace = true, features = ["native"] } +agntcy-slim-tracing = { workspace = true, features = ["native"] } agntcy-slim-version = { workspace = true } anyhow = { workspace = true } clap = { workspace = true } diff --git a/data-plane/core/tracing/Cargo.toml b/data-plane/core/tracing/Cargo.toml index 01d51a639..237622ebc 100644 --- a/data-plane/core/tracing/Cargo.toml +++ b/data-plane/core/tracing/Cargo.toml @@ -5,22 +5,48 @@ license = { workspace = true } version = "0.3.9" description = "Observability for SLIM data plane: logs, traces and metrics infrastructure." +[package.metadata.cargo-machete] +ignored = ["getrandom"] + [lib] name = "slim_tracing" +[features] +default = ["native"] +native = [ + "dep:agntcy-slim-config", + "dep:opentelemetry", + "dep:opentelemetry-otlp", + "dep:opentelemetry-semantic-conventions", + "dep:opentelemetry-stdout", + "dep:opentelemetry_sdk", + "dep:tracing-opentelemetry", + "dep:tracing-subscriber", +] +wasm = ["dep:tracing-subscriber", "dep:getrandom", "dep:web-sys", "uuid/js"] [dependencies] -agntcy-slim-config = { workspace = true } + +# Native-only dependencies (OpenTelemetry stack) +agntcy-slim-config = { workspace = true, optional = true, features = ["native"] } agntcy-slim-version = { workspace = true } + +# WASM support for uuid randomness +getrandom = { version = "0.3", features = ["wasm_js"], optional = true } once_cell = { workspace = true } -opentelemetry = { workspace = true } -opentelemetry-otlp = { workspace = true } -opentelemetry-semantic-conventions = { workspace = true } -opentelemetry-stdout = { workspace = true } -opentelemetry_sdk = { workspace = true } -serde = { workspace = true } +opentelemetry = { workspace = true, optional = true } +opentelemetry-otlp = { workspace = true, optional = true } +opentelemetry-semantic-conventions = { workspace = true, optional = true } +opentelemetry-stdout = { workspace = true, optional = true } +opentelemetry_sdk = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } thiserror = { workspace = true } tracing = { workspace = true } -tracing-opentelemetry = { workspace = true } -tracing-subscriber = { workspace = true } +tracing-opentelemetry = { workspace = true, optional = true } + +# Shared optional (used by both native and wasm) +tracing-subscriber = { workspace = true, optional = true, features = ["env-filter"] } uuid = { workspace = true } + +# WASM browser console output +web-sys = { version = "0.3", features = ["console"], optional = true } diff --git a/data-plane/core/tracing/src/lib.rs b/data-plane/core/tracing/src/lib.rs index 5af09bb67..49cc56a0b 100644 --- a/data-plane/core/tracing/src/lib.rs +++ b/data-plane/core/tracing/src/lib.rs @@ -3,652 +3,12 @@ pub mod utils; -use opentelemetry::{KeyValue, global, trace::TracerProvider as _}; -use opentelemetry_otlp::{ExporterBuildError, WithExportConfig}; -use opentelemetry_sdk::{ - Resource, - metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider}, - trace::{RandomIdGenerator, Sampler, SdkTracerProvider}, -}; -use opentelemetry_semantic_conventions::attribute::{ - DEPLOYMENT_ENVIRONMENT_NAME, SERVICE_NAME, SERVICE_VERSION, -}; -use serde::Deserialize; -use thiserror::Error; -use tracing::Level; -use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer}; -use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; - -use slim_config::{ - grpc::{client::ClientConfig, errors::ConfigError as GrpcConfigError}, - tls::client::TlsClientConfig, -}; - -const OTEL_EXPORTER_OTLP_ENDPOINT: &str = "http://localhost:4317"; - -#[derive(Error, Debug)] -pub enum ConfigError { - // gRPC / remote configuration - #[error("error loading GRPC config")] - GRPCError(#[from] GrpcConfigError), - - #[error("error building exporter")] - OpenTelemetryInitError(#[from] ExporterBuildError), - - // Filter parsing / directives - #[error("error parsing filter directives")] - FilterParseError(#[from] tracing_subscriber::filter::ParseError), - - // Tracing subscriber initialization - #[error("error setting up tracing subscriber")] - TracingSetupError(#[from] tracing_subscriber::util::TryInitError), -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct TracingConfiguration { - #[serde(default = "default_log_level")] - log_level: String, - - #[serde(default = "default_display_thread_names")] - display_thread_names: bool, - - #[serde(default = "default_display_thread_ids")] - display_thread_ids: bool, - - #[serde(default = "default_filter")] - filters: Vec, - - #[serde(default)] - opentelemetry: OpenTelemetryConfig, -} - -// default implementation for TracingConfiguration -impl Default for TracingConfiguration { - fn default() -> Self { - TracingConfiguration { - log_level: default_log_level(), - display_thread_names: default_display_thread_names(), - display_thread_ids: default_display_thread_ids(), - filters: default_filter(), - opentelemetry: OpenTelemetryConfig::default(), - } - } -} - -#[derive(Clone, Debug, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct OpenTelemetryConfig { - #[serde(default)] - enabled: bool, - - #[serde(default)] - grpc: ClientConfig, - - #[serde(default = "default_service_name")] - service_name: String, - - #[serde(default = "default_service_version")] - service_version: String, - - #[serde(default = "default_environment")] - environment: String, - - #[serde(default = "default_metrics_interval")] - metrics_interval_secs: u64, -} - -impl OpenTelemetryConfig { - /// Sets whether OpenTelemetry tracing and metrics are enabled. - /// - /// # Arguments - /// - /// * `enabled` - A boolean indicating whether OpenTelemetry should be enabled - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_enabled(mut self, enabled: bool) -> Self { - self.enabled = enabled; - self - } - - /// Sets the gRPC configuration for OpenTelemetry export. - /// - /// # Arguments - /// - /// * `grpc_config` - The gRPC client configuration to use for OpenTelemetry export - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_grpc_config(mut self, grpc_config: ClientConfig) -> Self { - self.grpc = grpc_config; - self - } - - /// Sets the service name for OpenTelemetry traces and metrics. - /// - /// # Arguments - /// - /// * `service_name` - The name of the service to be used in OpenTelemetry data - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_service_name(mut self, service_name: String) -> Self { - self.service_name = service_name; - self - } - - /// Sets the service version for OpenTelemetry traces and metrics. - /// - /// # Arguments - /// - /// * `service_version` - The version of the service to be used in OpenTelemetry data - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_service_version(mut self, service_version: String) -> Self { - self.service_version = service_version; - self - } - - /// Sets the deployment environment for OpenTelemetry traces and metrics. - /// - /// # Arguments - /// - /// * `environment` - The deployment environment (e.g., "development", "production") - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_environment(mut self, environment: String) -> Self { - self.environment = environment; - self - } - - /// Sets the interval in seconds between metric exports. - /// - /// # Arguments - /// - /// * `metrics_interval_secs` - The interval in seconds between metric exports - /// - /// # Returns - /// - /// Returns `self` for method chaining - pub fn with_metrics_interval_secs(mut self, metrics_interval_secs: u64) -> Self { - self.metrics_interval_secs = metrics_interval_secs; - self - } - - /// Returns whether OpenTelemetry tracing and metrics are enabled. - /// - /// # Returns - /// - /// Returns a boolean indicating whether OpenTelemetry is enabled - pub fn enabled(&self) -> bool { - self.enabled - } - - /// Returns the gRPC configuration for OpenTelemetry export. - /// - /// # Returns - /// - /// Returns a reference to the gRPC client configuration - pub fn grpc_config(&self) -> &ClientConfig { - &self.grpc - } - - /// Returns the service name used in OpenTelemetry data. - /// - /// # Returns - /// - /// Returns a reference to the service name string - pub fn service_name(&self) -> &str { - &self.service_name - } - - /// Returns the service version used in OpenTelemetry data. - /// - /// # Returns - /// - /// Returns a reference to the service version string - pub fn service_version(&self) -> &str { - &self.service_version - } - - /// Returns the deployment environment used in OpenTelemetry data. - /// - /// # Returns - /// - /// Returns a reference to the environment string - pub fn environment(&self) -> &str { - &self.environment - } - - /// Returns the interval in seconds between metric exports. - /// - /// # Returns - /// - /// Returns the metrics interval in seconds - pub fn metrics_interval_secs(&self) -> u64 { - self.metrics_interval_secs - } -} - -// default implementation for OpenTelemetryConfig -impl Default for OpenTelemetryConfig { - fn default() -> Self { - OpenTelemetryConfig { - enabled: false, - grpc: ClientConfig::with_endpoint(OTEL_EXPORTER_OTLP_ENDPOINT) - .with_tls_setting(TlsClientConfig::new().with_insecure(true)), - service_name: default_service_name(), - service_version: default_service_version(), - environment: default_environment(), - metrics_interval_secs: default_metrics_interval(), - } - } -} - -fn default_log_level() -> String { - "info".to_string() -} - -fn default_display_thread_names() -> bool { - true -} - -fn default_display_thread_ids() -> bool { - false -} - -fn default_filter() -> Vec { - // Only module names here. Their effective level will be the configured `log_level`. - vec![ - "slim_datapath".to_string(), - "slim_service".to_string(), - "slim_controller".to_string(), - "slim_auth".to_string(), - "slim_config".to_string(), - "slim_mls".to_string(), - "slim_session".to_string(), - "slim_signal".to_string(), - "slim_tracing".to_string(), - "_slim_bindings".to_string(), - "slim_testing".to_string(), - "slim".to_string(), - "slim_examples".to_string(), - "sdk_mock".to_string(), - "client".to_string(), - ] -} - -fn default_service_name() -> String { - "slim-data-plane".to_string() -} - -fn default_service_version() -> String { - "v0.1.0".to_string() -} - -fn default_environment() -> String { - "development".to_string() -} - -fn default_metrics_interval() -> u64 { - 30 // default to 30 seconds -} - -// function to convert string tracing level to tracing::Level -fn resolve_level(level: &str) -> tracing::Level { - let level = level.to_lowercase(); - match level.as_str() { - "trace" => Level::TRACE, - "debug" => Level::DEBUG, - "info" => Level::INFO, - "warn" => Level::WARN, - "error" => Level::ERROR, - _ => Level::INFO, // default level - } -} - -pub struct OtelGuard { - tracer_provider: Option, - meter_provider: Option, -} - -impl Drop for OtelGuard { - fn drop(&mut self) { - if let Some(tracer) = self.tracer_provider.take() - && let Err(err) = tracer.shutdown() - { - eprintln!("Error shutting down tracer provider: {err:?}"); - } - - if let Some(meter) = self.meter_provider.take() - && let Err(err) = meter.shutdown() - { - eprintln!("Error shutting down meter provider: {err:?}"); - } - } -} - -impl TracingConfiguration { - pub fn with_log_level(self, log_level: String) -> Self { - TracingConfiguration { log_level, ..self } - } - - pub fn with_display_thread_names(self, display_thread_names: bool) -> Self { - TracingConfiguration { - display_thread_names, - ..self - } - } - - pub fn with_display_thread_ids(self, display_thread_ids: bool) -> Self { - TracingConfiguration { - display_thread_ids, - ..self - } - } - - pub fn with_filter(self, filter: Vec) -> Self { - TracingConfiguration { - filters: filter, - ..self - } - } - - pub fn with_opentelemetry_config(mut self, config: OpenTelemetryConfig) -> Self { - self.opentelemetry = config; - self - } - - pub fn enable_opentelemetry(mut self) -> Self { - self.opentelemetry.enabled = true; - self - } - - pub fn with_metrics_interval(mut self, interval_secs: u64) -> Self { - self.opentelemetry.metrics_interval_secs = interval_secs; - self - } - - pub fn log_level(&self) -> &str { - &self.log_level - } - - pub fn display_thread_names(&self) -> bool { - self.display_thread_names - } - - pub fn display_thread_ids(&self) -> bool { - self.display_thread_ids - } - - pub fn filter(&self) -> &Vec { - &self.filters - } - - /// Set up a subscriber - pub fn setup_tracing_subscriber(&self) -> Result { - let fmt_layer = fmt::layer() - .with_thread_ids(self.display_thread_ids) - .with_thread_names(self.display_thread_names) - .with_line_number(true) - .with_filter(tracing_subscriber::filter::filter_fn( - |metadata: &tracing::Metadata| { - !metadata - .fields() - .iter() - .any(|field| field.name() == "telemetry") - }, - )); - - // Build the EnvFilter with correct precedence: - // 1. Environment variable (RUST_LOG) overrides everything (both modules & levels) - // 2. User-provided filter (if differs from the default) overrides default (modules & levels) - // 3. Default filter modules use the configured `log_level` - // - // Additionally, environment variable has highest priority. - let level_filter = if let Ok(env_value) = std::env::var("RUST_LOG") { - // Highest priority: environment. - // If env_value has no global directive (a bare level), and consists only of module=level - // directives, then append a global "off" so that unspecified modules are silenced. - // Examples: - // slim=debug -> slim=debug,off - // slim=debug,slim_auth=trace -> slim=debug,slim_auth=trace,off - // debug -> debug (keep global) - // info,slim=debug -> info,slim=debug (keep global) - let needs_global_off = { - let tokens: Vec<&str> = env_value - .split(',') - .map(|t| t.trim()) - .filter(|t| !t.is_empty()) - .collect(); - let bare_level_present = tokens - .iter() - .any(|t| matches!(*t, "trace" | "debug" | "info" | "warn" | "error" | "off")); - !bare_level_present - }; - let augmented = if needs_global_off { - format!("{env_value},off") - } else { - env_value - }; - EnvFilter::new(augmented) - } else { - let is_default = self.filters == default_filter(); - - // Always set a fallback directive using the configured log_level. - let builder = - EnvFilter::builder().with_default_directive(resolve_level(self.log_level()).into()); - - let filter_string = if is_default { - // Apply the configured log_level to each default module. - self.filters - .iter() - .map(|m| { - // In case a module accidentally already contains a level (e.g. "foo=debug"), - // keep only the part before '=' to enforce overriding with `log_level`. - let module = m.split('=').next().unwrap_or(m); - format!("{module}={}", self.log_level()) - }) - .collect::>() - .join(",") - } else { - // Custom filter provided: treat entries as authoritative. - // They may include levels (module=level) or just modules. - // For entries without explicit level, append the configured log_level. - self.filters - .iter() - .map(|d| { - if d.contains('=') { - d.clone() - } else { - format!("{d}={}", self.log_level()) - } - }) - .collect::>() - .join(",") - }; - - builder.parse_lossy(filter_string) - }; - - if self.opentelemetry.enabled { - // TODO(msardara): derive a tonic channel directly when opentelemetry-otlp - // upgrades to tonic version 0.13.0 - let endpoint = self.opentelemetry.grpc.endpoint.clone(); - - // resource - let resource = Resource::builder() - .with_attributes([ - KeyValue::new(SERVICE_NAME, self.opentelemetry.service_name.clone()), - KeyValue::new(SERVICE_VERSION, self.opentelemetry.service_version.clone()), - KeyValue::new( - DEPLOYMENT_ENVIRONMENT_NAME, - self.opentelemetry.environment.clone(), - ), - ]) - .build(); - - // init tracer provider - let exporter = opentelemetry_otlp::SpanExporter::builder() - .with_tonic() - .with_endpoint(&endpoint) - .build()?; - - let tracer_provider = SdkTracerProvider::builder() - // TODO(zkacsand): customize sampling strategy - .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased( - 1.0, - )))) - .with_id_generator(RandomIdGenerator::default()) - .with_resource(resource.clone()) - .with_batch_exporter(exporter) - .build(); - - let exporter = opentelemetry_otlp::MetricExporter::builder() - .with_tonic() - .with_endpoint(&endpoint) - .with_temporality(opentelemetry_sdk::metrics::Temporality::default()) - .build()?; - - let reader = PeriodicReader::builder(exporter) - .with_interval(std::time::Duration::from_secs( - self.opentelemetry.metrics_interval_secs, - )) - .build(); - - let stdout_reader = - PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build(); - - let meter_provider = MeterProviderBuilder::default() - .with_resource(resource.clone()) - .with_reader(reader) - .with_reader(stdout_reader) - .build(); - - // set global meter provider - global::set_meter_provider(meter_provider.clone()); - - // Sst up the trace context propagator - let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new(); - global::set_text_map_propagator(propagator); - - let tracer = tracer_provider.tracer("tracing-otel-subscriber"); - - // Construct the subscriber with OpenTelemetry - tracing_subscriber::registry() - .with(level_filter) - .with(fmt_layer) - .with(MetricsLayer::new(meter_provider.clone())) - .with(OpenTelemetryLayer::new(tracer)) - .try_init()?; - - Ok(OtelGuard { - tracer_provider: Some(tracer_provider), - meter_provider: Some(meter_provider), - }) - } else { - // Basic subscriber without OpenTelemetry - tracing_subscriber::registry() - .with(level_filter) - .with(fmt_layer) - .try_init()?; - - Ok(OtelGuard { - tracer_provider: None, - meter_provider: None, - }) - } - } -} - -// tests -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_tracing_configuration() { - let config = TracingConfiguration::default(); - assert_eq!(config.log_level, default_log_level()); - assert_eq!(config.display_thread_names, default_display_thread_names()); - assert_eq!(config.display_thread_ids, default_display_thread_ids()); - assert_eq!(config.filters, default_filter()); - } - - #[test] - fn test_resolve_level() { - assert_eq!(resolve_level("trace"), Level::TRACE); - assert_eq!(resolve_level("debug"), Level::DEBUG); - assert_eq!(resolve_level("info"), Level::INFO); - assert_eq!(resolve_level("warn"), Level::WARN); - assert_eq!(resolve_level("error"), Level::ERROR); - assert_eq!(resolve_level("invalid"), Level::INFO); - } - - #[test] - fn test_tracing_configuration_builder_methods() { - let config = TracingConfiguration::default() - .with_log_level("debug".to_string()) - .with_display_thread_names(false) - .with_display_thread_ids(true) - .with_filter(vec!["debug".to_string()]); - - assert_eq!(config.log_level(), "debug"); - assert!(!config.display_thread_names()); - assert!(config.display_thread_ids()); - assert_eq!(config.filter(), &vec!["debug".to_string()]); - } - - #[test] - fn test_opentelemetry_config_default() { - let config = OpenTelemetryConfig::default(); - assert!(!config.enabled()); - assert_eq!(config.service_name(), default_service_name()); - assert_eq!(config.grpc_config().endpoint, OTEL_EXPORTER_OTLP_ENDPOINT); - assert_eq!(config.service_version(), default_service_version()); - assert_eq!(config.environment(), default_environment()); - assert_eq!(config.metrics_interval_secs(), default_metrics_interval()); - } - - #[test] - fn test_tracing_configuration_with_opentelemetry() { - let otel_config = OpenTelemetryConfig::default() - .with_enabled(true) - .with_service_name("test-service".to_string()) - .with_service_version("1.0.0".to_string()); - - let config = TracingConfiguration::default().with_opentelemetry_config(otel_config); - - assert!(config.opentelemetry.enabled()); - assert_eq!(config.opentelemetry.service_name(), "test-service"); - assert_eq!(config.opentelemetry.service_version(), "1.0.0"); - } - - #[test] - fn test_enable_opentelemetry() { - let config = TracingConfiguration::default().enable_opentelemetry(); - assert!(config.opentelemetry.enabled()); - } - - #[test] - fn test_with_metrics_interval() { - let config = TracingConfiguration::default().with_metrics_interval(60); - assert_eq!(config.opentelemetry.metrics_interval_secs(), 60); - } - - #[test] - fn test_otel_guard_drop() { - // This test verifies that OtelGuard can be created and dropped without panicking - let config = TracingConfiguration::default(); - let guard = config.setup_tracing_subscriber().unwrap(); - drop(guard); // Should not panic - } -} +#[cfg(feature = "native")] +mod native; +#[cfg(feature = "native")] +pub use native::*; + +#[cfg(all(feature = "wasm", not(feature = "native")))] +mod wasm; +#[cfg(all(feature = "wasm", not(feature = "native")))] +pub use wasm::*; diff --git a/data-plane/core/tracing/src/native.rs b/data-plane/core/tracing/src/native.rs new file mode 100644 index 000000000..fcf90fc6a --- /dev/null +++ b/data-plane/core/tracing/src/native.rs @@ -0,0 +1,496 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +use opentelemetry::{KeyValue, global, trace::TracerProvider as _}; +use opentelemetry_otlp::{ExporterBuildError, WithExportConfig}; +use opentelemetry_sdk::{ + Resource, + metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider}, + trace::{RandomIdGenerator, Sampler, SdkTracerProvider}, +}; +use opentelemetry_semantic_conventions::attribute::{ + DEPLOYMENT_ENVIRONMENT_NAME, SERVICE_NAME, SERVICE_VERSION, +}; +use serde::Deserialize; +use thiserror::Error; +use tracing::Level; +use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer}; +use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; + +use slim_config::{ + grpc::{client::ClientConfig, errors::ConfigError as GrpcConfigError}, + tls::client::TlsClientConfig, +}; + +const OTEL_EXPORTER_OTLP_ENDPOINT: &str = "http://localhost:4317"; + +#[derive(Error, Debug)] +pub enum ConfigError { + #[error("error loading GRPC config")] + GRPCError(#[from] GrpcConfigError), + + #[error("error building exporter")] + OpenTelemetryInitError(#[from] ExporterBuildError), + + #[error("error parsing filter directives")] + FilterParseError(#[from] tracing_subscriber::filter::ParseError), + + #[error("error setting up tracing subscriber")] + TracingSetupError(#[from] tracing_subscriber::util::TryInitError), +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct TracingConfiguration { + #[serde(default = "default_log_level")] + log_level: String, + + #[serde(default = "default_display_thread_names")] + display_thread_names: bool, + + #[serde(default = "default_display_thread_ids")] + display_thread_ids: bool, + + #[serde(default = "default_filter")] + filters: Vec, + + #[serde(default)] + opentelemetry: OpenTelemetryConfig, +} + +impl Default for TracingConfiguration { + fn default() -> Self { + TracingConfiguration { + log_level: default_log_level(), + display_thread_names: default_display_thread_names(), + display_thread_ids: default_display_thread_ids(), + filters: default_filter(), + opentelemetry: OpenTelemetryConfig::default(), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct OpenTelemetryConfig { + #[serde(default)] + enabled: bool, + + #[serde(default)] + grpc: ClientConfig, + + #[serde(default = "default_service_name")] + service_name: String, + + #[serde(default = "default_service_version")] + service_version: String, + + #[serde(default = "default_environment")] + environment: String, + + #[serde(default = "default_metrics_interval")] + metrics_interval_secs: u64, +} + +impl OpenTelemetryConfig { + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + pub fn with_grpc_config(mut self, grpc_config: ClientConfig) -> Self { + self.grpc = grpc_config; + self + } + pub fn with_service_name(mut self, service_name: String) -> Self { + self.service_name = service_name; + self + } + pub fn with_service_version(mut self, service_version: String) -> Self { + self.service_version = service_version; + self + } + pub fn with_environment(mut self, environment: String) -> Self { + self.environment = environment; + self + } + pub fn with_metrics_interval_secs(mut self, metrics_interval_secs: u64) -> Self { + self.metrics_interval_secs = metrics_interval_secs; + self + } + pub fn enabled(&self) -> bool { + self.enabled + } + pub fn grpc_config(&self) -> &ClientConfig { + &self.grpc + } + pub fn service_name(&self) -> &str { + &self.service_name + } + pub fn service_version(&self) -> &str { + &self.service_version + } + pub fn environment(&self) -> &str { + &self.environment + } + pub fn metrics_interval_secs(&self) -> u64 { + self.metrics_interval_secs + } +} + +impl Default for OpenTelemetryConfig { + fn default() -> Self { + OpenTelemetryConfig { + enabled: false, + grpc: ClientConfig::with_endpoint(OTEL_EXPORTER_OTLP_ENDPOINT) + .with_tls_setting(TlsClientConfig::new().with_insecure(true)), + service_name: default_service_name(), + service_version: default_service_version(), + environment: default_environment(), + metrics_interval_secs: default_metrics_interval(), + } + } +} + +fn default_log_level() -> String { + "info".to_string() +} +fn default_display_thread_names() -> bool { + true +} +fn default_display_thread_ids() -> bool { + false +} +fn default_filter() -> Vec { + vec![ + "slim_datapath".to_string(), + "slim_service".to_string(), + "slim_controller".to_string(), + "slim_auth".to_string(), + "slim_config".to_string(), + "slim_mls".to_string(), + "slim_session".to_string(), + "slim_signal".to_string(), + "slim_tracing".to_string(), + "_slim_bindings".to_string(), + "slim_testing".to_string(), + "slim".to_string(), + "slim_examples".to_string(), + "sdk_mock".to_string(), + "client".to_string(), + ] +} +fn default_service_name() -> String { + "slim-data-plane".to_string() +} +fn default_service_version() -> String { + "v0.1.0".to_string() +} +fn default_environment() -> String { + "development".to_string() +} +fn default_metrics_interval() -> u64 { + 30 +} + +fn resolve_level(level: &str) -> tracing::Level { + match level.to_lowercase().as_str() { + "trace" => Level::TRACE, + "debug" => Level::DEBUG, + "info" => Level::INFO, + "warn" => Level::WARN, + "error" => Level::ERROR, + _ => Level::INFO, + } +} + +pub struct OtelGuard { + tracer_provider: Option, + meter_provider: Option, +} + +impl Drop for OtelGuard { + fn drop(&mut self) { + if let Some(tracer) = self.tracer_provider.take() + && let Err(err) = tracer.shutdown() + { + eprintln!("Error shutting down tracer provider: {err:?}"); + } + if let Some(meter) = self.meter_provider.take() + && let Err(err) = meter.shutdown() + { + eprintln!("Error shutting down meter provider: {err:?}"); + } + } +} + +impl TracingConfiguration { + pub fn with_log_level(self, log_level: String) -> Self { + TracingConfiguration { log_level, ..self } + } + + pub fn with_display_thread_names(self, display_thread_names: bool) -> Self { + TracingConfiguration { + display_thread_names, + ..self + } + } + + pub fn with_display_thread_ids(self, display_thread_ids: bool) -> Self { + TracingConfiguration { + display_thread_ids, + ..self + } + } + + pub fn with_filter(self, filter: Vec) -> Self { + TracingConfiguration { + filters: filter, + ..self + } + } + + pub fn with_opentelemetry_config(mut self, config: OpenTelemetryConfig) -> Self { + self.opentelemetry = config; + self + } + + pub fn enable_opentelemetry(mut self) -> Self { + self.opentelemetry.enabled = true; + self + } + + pub fn with_metrics_interval(mut self, interval_secs: u64) -> Self { + self.opentelemetry.metrics_interval_secs = interval_secs; + self + } + + pub fn log_level(&self) -> &str { + &self.log_level + } + pub fn display_thread_names(&self) -> bool { + self.display_thread_names + } + pub fn display_thread_ids(&self) -> bool { + self.display_thread_ids + } + pub fn filter(&self) -> &Vec { + &self.filters + } + + pub fn setup_tracing_subscriber(&self) -> Result { + let fmt_layer = fmt::layer() + .with_thread_ids(self.display_thread_ids) + .with_thread_names(self.display_thread_names) + .with_line_number(true) + .with_filter(tracing_subscriber::filter::filter_fn( + |metadata: &tracing::Metadata| { + !metadata + .fields() + .iter() + .any(|field| field.name() == "telemetry") + }, + )); + + let level_filter = if let Ok(env_value) = std::env::var("RUST_LOG") { + let needs_global_off = { + let tokens: Vec<&str> = env_value + .split(',') + .map(|t| t.trim()) + .filter(|t| !t.is_empty()) + .collect(); + !tokens + .iter() + .any(|t| matches!(*t, "trace" | "debug" | "info" | "warn" | "error" | "off")) + }; + let augmented = if needs_global_off { + format!("{env_value},off") + } else { + env_value + }; + EnvFilter::new(augmented) + } else { + let is_default = self.filters == default_filter(); + let builder = + EnvFilter::builder().with_default_directive(resolve_level(self.log_level()).into()); + let filter_string = if is_default { + self.filters + .iter() + .map(|m| { + let module = m.split('=').next().unwrap_or(m); + format!("{module}={}", self.log_level()) + }) + .collect::>() + .join(",") + } else { + self.filters + .iter() + .map(|d| { + if d.contains('=') { + d.clone() + } else { + format!("{d}={}", self.log_level()) + } + }) + .collect::>() + .join(",") + }; + builder.parse_lossy(filter_string) + }; + + if self.opentelemetry.enabled { + let endpoint = self.opentelemetry.grpc.endpoint.clone(); + let resource = Resource::builder() + .with_attributes([ + KeyValue::new(SERVICE_NAME, self.opentelemetry.service_name.clone()), + KeyValue::new(SERVICE_VERSION, self.opentelemetry.service_version.clone()), + KeyValue::new( + DEPLOYMENT_ENVIRONMENT_NAME, + self.opentelemetry.environment.clone(), + ), + ]) + .build(); + + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_tonic() + .with_endpoint(&endpoint) + .build()?; + + let tracer_provider = SdkTracerProvider::builder() + .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased( + 1.0, + )))) + .with_id_generator(RandomIdGenerator::default()) + .with_resource(resource.clone()) + .with_batch_exporter(exporter) + .build(); + + let exporter = opentelemetry_otlp::MetricExporter::builder() + .with_tonic() + .with_endpoint(&endpoint) + .with_temporality(opentelemetry_sdk::metrics::Temporality::default()) + .build()?; + + let reader = PeriodicReader::builder(exporter) + .with_interval(std::time::Duration::from_secs( + self.opentelemetry.metrics_interval_secs, + )) + .build(); + + let stdout_reader = + PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build(); + + let meter_provider = MeterProviderBuilder::default() + .with_resource(resource.clone()) + .with_reader(reader) + .with_reader(stdout_reader) + .build(); + + global::set_meter_provider(meter_provider.clone()); + let propagator = opentelemetry_sdk::propagation::TraceContextPropagator::new(); + global::set_text_map_propagator(propagator); + let tracer = tracer_provider.tracer("tracing-otel-subscriber"); + + tracing_subscriber::registry() + .with(level_filter) + .with(fmt_layer) + .with(MetricsLayer::new(meter_provider.clone())) + .with(OpenTelemetryLayer::new(tracer)) + .try_init()?; + + Ok(OtelGuard { + tracer_provider: Some(tracer_provider), + meter_provider: Some(meter_provider), + }) + } else { + tracing_subscriber::registry() + .with(level_filter) + .with(fmt_layer) + .try_init()?; + + Ok(OtelGuard { + tracer_provider: None, + meter_provider: None, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_tracing_configuration() { + let config = TracingConfiguration::default(); + assert_eq!(config.log_level, default_log_level()); + assert_eq!(config.display_thread_names, default_display_thread_names()); + assert_eq!(config.display_thread_ids, default_display_thread_ids()); + assert_eq!(config.filters, default_filter()); + } + + #[test] + fn test_resolve_level() { + assert_eq!(resolve_level("trace"), Level::TRACE); + assert_eq!(resolve_level("debug"), Level::DEBUG); + assert_eq!(resolve_level("info"), Level::INFO); + assert_eq!(resolve_level("warn"), Level::WARN); + assert_eq!(resolve_level("error"), Level::ERROR); + assert_eq!(resolve_level("invalid"), Level::INFO); + } + + #[test] + fn test_tracing_configuration_builder_methods() { + let config = TracingConfiguration::default() + .with_log_level("debug".to_string()) + .with_display_thread_names(false) + .with_display_thread_ids(true) + .with_filter(vec!["debug".to_string()]); + + assert_eq!(config.log_level(), "debug"); + assert!(!config.display_thread_names()); + assert!(config.display_thread_ids()); + assert_eq!(config.filter(), &vec!["debug".to_string()]); + } + + #[test] + fn test_opentelemetry_config_default() { + let config = OpenTelemetryConfig::default(); + assert!(!config.enabled()); + assert_eq!(config.service_name(), default_service_name()); + assert_eq!(config.grpc_config().endpoint, OTEL_EXPORTER_OTLP_ENDPOINT); + assert_eq!(config.service_version(), default_service_version()); + assert_eq!(config.environment(), default_environment()); + assert_eq!(config.metrics_interval_secs(), default_metrics_interval()); + } + + #[test] + fn test_tracing_configuration_with_opentelemetry() { + let otel_config = OpenTelemetryConfig::default() + .with_enabled(true) + .with_service_name("test-service".to_string()) + .with_service_version("1.0.0".to_string()); + let config = TracingConfiguration::default().with_opentelemetry_config(otel_config); + assert!(config.opentelemetry.enabled()); + assert_eq!(config.opentelemetry.service_name(), "test-service"); + assert_eq!(config.opentelemetry.service_version(), "1.0.0"); + } + + #[test] + fn test_enable_opentelemetry() { + let config = TracingConfiguration::default().enable_opentelemetry(); + assert!(config.opentelemetry.enabled()); + } + + #[test] + fn test_with_metrics_interval() { + let config = TracingConfiguration::default().with_metrics_interval(60); + assert_eq!(config.opentelemetry.metrics_interval_secs(), 60); + } + + #[test] + fn test_otel_guard_drop() { + let config = TracingConfiguration::default(); + let guard = config.setup_tracing_subscriber().unwrap(); + drop(guard); + } +} diff --git a/data-plane/core/tracing/src/utils.rs b/data-plane/core/tracing/src/utils.rs index 5a5bbdf4e..b1efe9fa0 100644 --- a/data-plane/core/tracing/src/utils.rs +++ b/data-plane/core/tracing/src/utils.rs @@ -4,5 +4,9 @@ use once_cell::sync::Lazy; use uuid::Uuid; +#[cfg(not(target_arch = "wasm32"))] pub static INSTANCE_ID: Lazy = Lazy::new(|| std::env::var("SLIM_INSTANCE_ID").unwrap_or_else(|_| Uuid::new_v4().to_string())); + +#[cfg(target_arch = "wasm32")] +pub static INSTANCE_ID: Lazy = Lazy::new(|| Uuid::new_v4().to_string()); diff --git a/data-plane/core/tracing/src/wasm.rs b/data-plane/core/tracing/src/wasm.rs new file mode 100644 index 000000000..ce33f5223 --- /dev/null +++ b/data-plane/core/tracing/src/wasm.rs @@ -0,0 +1,153 @@ +// Copyright AGNTCY Contributors (https://github.com/agntcy) +// SPDX-License-Identifier: Apache-2.0 + +//! WASM-compatible tracing configuration. +//! Routes log output to the browser's `console.log` via `web-sys`. + +use serde::Deserialize; +use std::io::{self, Write}; +use thiserror::Error; +use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; + +#[derive(Error, Debug)] +pub enum ConfigError { + #[error("error parsing filter directives")] + FilterParseError(#[from] tracing_subscriber::filter::ParseError), + + #[error("error setting up tracing subscriber")] + TracingSetupError(#[from] tracing_subscriber::util::TryInitError), +} + +/// A writer that buffers a single log line and flushes it to `console.log`. +struct ConsoleWriter(Vec); + +impl Write for ConsoleWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + let msg = String::from_utf8_lossy(&self.0); + let msg = msg.trim_end(); // remove trailing newline + if !msg.is_empty() { + web_sys::console::log_1(&msg.into()); + } + self.0.clear(); + Ok(()) + } +} + +impl Drop for ConsoleWriter { + fn drop(&mut self) { + let _ = self.flush(); + } +} + +/// MakeWriter that produces ConsoleWriter instances. +struct ConsoleMakeWriter; + +impl<'a> fmt::MakeWriter<'a> for ConsoleMakeWriter { + type Writer = ConsoleWriter; + + fn make_writer(&'a self) -> Self::Writer { + ConsoleWriter(Vec::with_capacity(256)) + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct TracingConfiguration { + #[serde(default = "default_log_level")] + log_level: String, + + #[serde(default = "default_filter")] + filters: Vec, +} + +impl Default for TracingConfiguration { + fn default() -> Self { + TracingConfiguration { + log_level: default_log_level(), + filters: default_filter(), + } + } +} + +fn default_log_level() -> String { + "info".to_string() +} + +fn default_filter() -> Vec { + vec![ + "slim_datapath".to_string(), + "slim_service".to_string(), + "slim_controller".to_string(), + "slim_auth".to_string(), + "slim_config".to_string(), + "slim_mls".to_string(), + "slim_session".to_string(), + "slim_signal".to_string(), + "slim_tracing".to_string(), + "_slim_bindings".to_string(), + "slim".to_string(), + ] +} + +/// Guard type for compatibility with the native API. +/// In WASM there is nothing to shut down. +pub struct OtelGuard; + +impl TracingConfiguration { + pub fn with_log_level(self, log_level: String) -> Self { + TracingConfiguration { log_level, ..self } + } + + pub fn with_filter(self, filter: Vec) -> Self { + TracingConfiguration { + filters: filter, + ..self + } + } + + pub fn log_level(&self) -> &str { + &self.log_level + } + pub fn filter(&self) -> &Vec { + &self.filters + } + + pub fn setup_tracing_subscriber(&self) -> Result { + let fmt_layer = fmt::layer() + .with_writer(ConsoleMakeWriter) + .with_ansi(false) + .without_time() + .with_line_number(true) + .with_filter(tracing_subscriber::filter::filter_fn( + |metadata: &tracing::Metadata| { + !metadata + .fields() + .iter() + .any(|field| field.name() == "telemetry") + }, + )); + + let mut env_filter = EnvFilter::try_new(self.log_level.trim())?; + for f in &self.filters { + let directive_string = if f.contains('=') { + f.clone() + } else { + // Bare module names match native: treat as `module=`. + format!("{f}={}", self.log_level) + }; + let directive = directive_string.parse()?; + env_filter = env_filter.add_directive(directive); + } + + tracing_subscriber::registry() + .with(env_filter) + .with(fmt_layer) + .try_init()?; + + Ok(OtelGuard) + } +} diff --git a/data-plane/examples/Cargo.toml b/data-plane/examples/Cargo.toml index 1bdf4c06b..d6d750432 100644 --- a/data-plane/examples/Cargo.toml +++ b/data-plane/examples/Cargo.toml @@ -14,14 +14,14 @@ path = "src/client/main.rs" [dependencies] agntcy-slim = { workspace = true } -agntcy-slim-auth = { workspace = true } -agntcy-slim-config = { workspace = true } -agntcy-slim-controller = { workspace = true } -agntcy-slim-datapath = { workspace = true } -agntcy-slim-mls = { workspace = true } -agntcy-slim-service = { workspace = true, features = ["session"] } -agntcy-slim-session = { workspace = true } -agntcy-slim-signal = { workspace = true } +agntcy-slim-auth = { workspace = true, features = ["native"] } +agntcy-slim-config = { workspace = true, features = ["native"] } +agntcy-slim-controller = { workspace = true, features = ["native"] } +agntcy-slim-datapath = { workspace = true, features = ["native"] } +agntcy-slim-mls = { workspace = true, features = ["native"] } +agntcy-slim-service = { workspace = true, features = ["native", "session"] } +agntcy-slim-session = { workspace = true, features = ["native"] } +agntcy-slim-signal = { workspace = true, features = ["native"] } agntcy-slim-testing = { workspace = true } anyhow = { workspace = true } clap = { workspace = true } diff --git a/data-plane/examples/README.md b/data-plane/examples/README.md index 0bac701d7..b75cc6ceb 100644 --- a/data-plane/examples/README.md +++ b/data-plane/examples/README.md @@ -1,6 +1,6 @@ -# Run this example +# Run Examples -# Docker Compose Quick Start +## Docker Compose Quick Start To run all services (slim server, mock app server, and mock app client) using Docker Compose: @@ -22,13 +22,15 @@ To stop and remove all containers: docker compose down ``` -# Manual (Local) Run +## Manual (Local) Run + +### gRPC transport You can also run each service locally using Taskfile: 1. In the first terminal, run the slim service: ```sh - task run:server + task run:slim ``` 2. In the second terminal, run the mock-app server: ```sh @@ -38,3 +40,33 @@ You can also run each service locally using Taskfile: ```sh task run:mock-app:client ``` + +### WebSocket transport (`ws://`) + +1. In the first terminal, run the slim service: + ```sh + task run:slim:websocket + ``` +2. In the second terminal, run the mock-app server: + ```sh + task run:mock-app:server-websocket + ``` +3. In the third terminal, run the mock-app client: + ```sh + task run:mock-app:client-websocket + ``` + +### Secure WebSocket transport (`wss://`) + +1. In the first terminal, run the slim service: + ```sh + task run:slim:websocket:wss + ``` +2. In the second terminal, run the mock-app server: + ```sh + task run:mock-app:server-websocket-wss + ``` +3. In the third terminal, run the mock-app client: + ```sh + task run:mock-app:client-websocket-wss + ``` diff --git a/data-plane/examples/Taskfile.yaml b/data-plane/examples/Taskfile.yaml index 216f87847..e789b4075 100644 --- a/data-plane/examples/Taskfile.yaml +++ b/data-plane/examples/Taskfile.yaml @@ -21,6 +21,18 @@ tasks: cmds: - cargo run --bin slim -- --config ./config/base/server-config.yaml + run:slim:websocket: + desc: "Run SLIM gateway with WebSocket transport (ws://)" + dir: ".." + cmds: + - cargo run --bin slim -- --config ./config/websocket/server-config.yaml + + run:slim:websocket:wss: + desc: "Run SLIM gateway with secure WebSocket transport (wss://)" + dir: ".." + cmds: + - cargo run --bin slim -- --config ./config/websocket/server-config-wss.yaml + run:slim:unix: desc: "Run SLIM gateway with Unix domain socket gRPC" dir: ".." @@ -33,12 +45,24 @@ tasks: cmds: - cargo run --bin sdk-mock -- --config ../../../config/base/client-config.yaml --local-name server --remote-name client + run:mock-app:server-websocket: + desc: "Run the mock app server over WebSocket transport (ws://)" + dir: src/sdk-mock + cmds: + - cargo run --bin sdk-mock -- --config ../../../config/websocket/client-config.yaml --local-name server --remote-name client + run:mock-app:server-unix: desc: "Run the mock app server" dir: src/sdk-mock cmds: - cargo run --bin sdk-mock -- --config ../../../config/unix/client-config.yaml --local-name server --remote-name client + run:mock-app:server-websocket-wss: + desc: "Run the mock app server over secure WebSocket transport (wss://)" + dir: ".." + cmds: + - cargo run --manifest-path ./examples/Cargo.toml --bin sdk-mock -- --config ./config/websocket/client-config-wss.yaml --local-name server --remote-name client + run:mock-app:server-mls: desc: "Run the mock app server" dir: src/sdk-mock @@ -51,12 +75,24 @@ tasks: cmds: - cargo run --bin sdk-mock -- --config ../../../config/base/client-config.yaml --local-name client --remote-name server --message "hey there!" + run:mock-app:client-websocket: + desc: "Run the mock app client over WebSocket transport (ws://)" + dir: src/sdk-mock + cmds: + - cargo run --bin sdk-mock -- --config ../../../config/websocket/client-config.yaml --local-name client --remote-name server --message "hey there!" + run:mock-app:client-unix: desc: "Run the mock app client using Unix domain sockets" dir: src/sdk-mock cmds: - cargo run --bin sdk-mock -- --config ../../../config/unix/client-config.yaml --local-name client --remote-name server --message "hey there!" + run:mock-app:client-websocket-wss: + desc: "Run the mock app client over secure WebSocket transport (wss://)" + dir: ".." + cmds: + - cargo run --manifest-path ./examples/Cargo.toml --bin sdk-mock -- --config ./config/websocket/client-config-wss.yaml --local-name client --remote-name server --message "hey there!" + run:mock-app:client-mls: desc: "Run the mock app client" dir: src/sdk-mock @@ -86,3 +122,32 @@ tasks: dir: "." cmds: - docker compose down + + run:browser:serve: + desc: "Serve the browser demo on http://localhost:8080" + dir: ".." + cmds: + - echo "Serving browser demo at http://localhost:8080/examples/browser/" + - python3 -m http.server 8080 + + run:browser: + desc: "Run browser demo (builds WASM, starts gateway, opens browser)" + dir: ".." + cmds: + - echo "Building WASM package..." + - wasm-pack build core/slim-wasm --target web --out-dir ../../pkg + - echo "" + - echo "═══════════════════════════════════════════════════════" + - echo " SLIM Browser Demo — Setup Instructions" + - echo "═══════════════════════════════════════════════════════" + - echo "" + - echo "1. Start the SLIM gateway (in a separate terminal):" + - echo " cd data-plane && cargo run --bin slim -- --config ./config/websocket/server-config.yaml" + - echo "" + - echo "2. (Optional) Start a native peer to chat with (in another terminal):" + - echo " cd data-plane/examples && task run:mock-app:server-websocket" + - echo "" + - echo "3. Open http://localhost:8080/examples/browser/ in your browser" + - echo "" + - echo "Starting HTTP server..." + - python3 -m http.server 8080 diff --git a/data-plane/examples/browser/index.html b/data-plane/examples/browser/index.html new file mode 100644 index 000000000..59accd765 --- /dev/null +++ b/data-plane/examples/browser/index.html @@ -0,0 +1,454 @@ + + + + + + + + SLIM Data-Plane — Browser Demo + + + +

SLIM Data-Plane

+

WebAssembly Browser Demo — Full Features

+ + +
+

Connection

+
+ + Disconnected +
+ + + + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+ + + + + + + + +
+

Event Log

+
+ +
+
+
+ + + + diff --git a/data-plane/examples/src/sdk-mock/args.rs b/data-plane/examples/src/sdk-mock/args.rs index 890563014..06a3832b6 100644 --- a/data-plane/examples/src/sdk-mock/args.rs +++ b/data-plane/examples/src/sdk-mock/args.rs @@ -16,10 +16,10 @@ pub struct Args { #[arg(short, long, value_name = "LOCAL_NAME")] local_name: String, - /// Set the topic to subscribe to. - #[clap(long, env, required = true)] + /// Set the remote app name. If omitted, sessions are accepted from any peer. + #[clap(long, env)] #[arg(short, long, value_name = "REMOTE_NAME")] - remote_name: String, + remote_name: Option, /// Set the message to publish. If not set, the program will subscribe to the topic. #[clap(long, env)] @@ -41,8 +41,8 @@ impl Args { &self.local_name } - pub fn remote_name(&self) -> &str { - &self.remote_name + pub fn remote_name(&self) -> Option<&str> { + self.remote_name.as_deref() } pub fn message(&self) -> Option<&str> { diff --git a/data-plane/examples/src/sdk-mock/main.rs b/data-plane/examples/src/sdk-mock/main.rs index edfc78645..902a1b2e7 100644 --- a/data-plane/examples/src/sdk-mock/main.rs +++ b/data-plane/examples/src/sdk-mock/main.rs @@ -121,7 +121,7 @@ async fn main() { .expect("invalid tracing configuration") .setup_tracing_subscriber(); - info!(%config_file, %local_name, %remote_name, "starting client"); + info!(%config_file, %local_name, remote_name = ?remote_name, "starting client"); // get service let id = slim_config::component::id::ID::new_with_str("slim/0").unwrap(); @@ -150,10 +150,15 @@ async fn main() { app.subscribe(app.app_name(), Some(conn_id)).await.unwrap(); - // Set a route for the remote app - let remote_app_name = Name::from_strings(["org", "default", remote_name]); - info!(remote_app = %remote_app_name, "allowing messages to remote app"); - app.set_route(&remote_app_name, conn_id).await.unwrap(); + // Set a route for the remote app (only if --remote-name was provided) + let remote_app_name = if let Some(rn) = remote_name { + let name = Name::from_strings(["org", "default", rn]); + info!(remote_app = %name, "allowing messages to remote app"); + app.set_route(&name, conn_id).await.unwrap(); + Some(name) + } else { + None + }; // wait for the connection to be established tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -171,8 +176,12 @@ async fn main() { initiator: true, metadata: HashMap::new(), }; + let dest = remote_app_name + .clone() + .expect("--remote-name is required when sending a message"); + let session_ctx = app - .create_session(config, remote_app_name.clone(), None) + .create_session(config, dest.clone(), None) .await .expect("error creating p2p session"); @@ -183,16 +192,16 @@ async fn main() { let session = session_ctx.session_arc().unwrap(); // Spawn receiver to handle incoming messages - spawn_session_receiver(session_ctx, local_name.to_string(), remote_app_name.clone()); + spawn_session_receiver(session_ctx, local_name.to_string(), dest.clone()); // Await session initialization init_ack.await.expect("error initializing p2p session"); - info!(destination = %remote_app_name, "Sending message"); + info!(destination = %dest, "Sending message"); // publish message using session context session - .publish(&remote_app_name, msg.into(), None, None) + .publish(&dest, msg.into(), None, None) .await .unwrap(); @@ -219,11 +228,15 @@ async fn main() { // Get session before spawning receiver let session_arc = session.session_arc().unwrap(); + // Derive the peer name from the session destination (the remote peer) + let peer = session_arc.dst().clone(); + info!(peer = %peer, "accepted incoming session"); + // Use the extracted spawn_session_receiver function spawn_session_receiver( session, local_name.to_string(), - remote_app_name.clone(), + peer, ); // Save the session diff --git a/data-plane/slimctl/Cargo.toml b/data-plane/slimctl/Cargo.toml index b1b9c43b2..2b62fc438 100644 --- a/data-plane/slimctl/Cargo.toml +++ b/data-plane/slimctl/Cargo.toml @@ -23,11 +23,11 @@ path = "src/main.rs" [dependencies] # Local workspace agntcy-slim = { workspace = true } -agntcy-slim-config = { workspace = true } -agntcy-slim-datapath = { workspace = true } -agntcy-slim-service = { workspace = true } -agntcy-slim-signal = { workspace = true } -agntcy-slim-tracing = { workspace = true } +agntcy-slim-config = { workspace = true, features = ["native"] } +agntcy-slim-datapath = { workspace = true, features = ["native"] } +agntcy-slim-service = { workspace = true, features = ["native"] } +agntcy-slim-signal = { workspace = true, features = ["native"] } +agntcy-slim-tracing = { workspace = true, features = ["native"] } agntcy-slim-version = { workspace = true } # Error handling diff --git a/data-plane/slimctl/src/client.rs b/data-plane/slimctl/src/client.rs index 55defc7cb..876564d09 100644 --- a/data-plane/slimctl/src/client.rs +++ b/data-plane/slimctl/src/client.rs @@ -5,6 +5,7 @@ use anyhow::{Context, Result, bail}; use tonic::codegen::{Body, Bytes, StdError}; use slim_config::auth::basic::Config as BasicAuthConfig; +use slim_config::client::TransportChannel; use slim_config::grpc::client::{AuthenticationConfig, BackoffConfig, ClientConfig}; use slim_config::tls::client::TlsClientConfig; @@ -92,6 +93,12 @@ pub async fn get_control_plane_client( .to_channel() .await .context("failed to connect to server")?; + let channel = match channel { + TransportChannel::Grpc(channel) => channel, + TransportChannel::Websocket(_) => { + bail!("slimctl only supports gRPC control-plane transport") + } + }; Ok(ControlPlaneServiceClient::new(channel)) } @@ -114,6 +121,12 @@ pub async fn get_controller_client( .to_channel() .await .context("failed to connect to server")?; + let channel = match channel { + TransportChannel::Grpc(channel) => channel, + TransportChannel::Websocket(_) => { + bail!("slimctl only supports gRPC control-plane transport") + } + }; Ok(ControllerServiceClient::new(channel)) } diff --git a/data-plane/testing/Cargo.toml b/data-plane/testing/Cargo.toml index 20e2f0ef0..a2ff48593 100644 --- a/data-plane/testing/Cargo.toml +++ b/data-plane/testing/Cargo.toml @@ -26,12 +26,12 @@ path = "src/bin/stress_publish.rs" [dependencies] agntcy-slim = { workspace = true } -agntcy-slim-auth = { workspace = true } -agntcy-slim-config = { workspace = true } -agntcy-slim-datapath = { workspace = true } -agntcy-slim-service = { workspace = true, features = ["session"] } -agntcy-slim-session = { workspace = true } -agntcy-slim-tracing = { workspace = true } +agntcy-slim-auth = { workspace = true, features = ["native"] } +agntcy-slim-config = { workspace = true, features = ["native"] } +agntcy-slim-datapath = { workspace = true, features = ["native"] } +agntcy-slim-service = { workspace = true, features = ["native", "session"] } +agntcy-slim-session = { workspace = true, features = ["native"] } +agntcy-slim-tracing = { workspace = true, features = ["native"] } aws-lc-rs = { workspace = true } base64 = { workspace = true } bollard = { version = "0.17" }