Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ followed by a single string error description.
Commands are:

- `echo <text>`: integration replies with `ok <text>`
- `auth <email>`: integration replies with `ok <auth URL>`
- `auth <email> [<state>]`: integration replies with `ok <auth URL>`
- `verify <id_token>`: integration replies with `ok <email>`
- `clear-cache`: integration replies with `ok`

Here's an example flow, where we illustrate tabs with `||`, commands from the
test suite with `>>`, and responses from the integration with `<<`:
Expand Down
100 changes: 92 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"flag"
"fmt"
Expand Down Expand Up @@ -357,9 +358,9 @@ func main() {
email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
now := time.Now().Unix()
proc.writeLine("verify", sgn.sign(sgn.priv, &header{
proc.writeLine("verify", sgn.sign(&header{
KID: "bad key",
Alg: alg,
Alg: "RS256",
}, &payload{
Iss: srv.origin,
Aud: clientID,
Expand All @@ -373,12 +374,20 @@ func main() {
})

test("bad signature", func() {
fakeKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatal("rsa.GenerateKey error:", err)
}

oldKey := sgn.rsaKey
sgn.rsaKey = fakeKey

email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
now := time.Now().Unix()
proc.writeLine("verify", sgn.sign(sgn.fake, &header{
KID: kid,
Alg: alg,
proc.writeLine("verify", sgn.sign(&header{
KID: rsaKID,
Alg: "RS256",
}, &payload{
Iss: srv.origin,
Aud: clientID,
Expand All @@ -389,15 +398,17 @@ func main() {
}))
proc.expect("err", "rejects token")
}

sgn.rsaKey = oldKey
})

test("token cannot change alg from jwk", func() {
email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
now := time.Now().Unix()
proc.writeLine("verify", sgn.sign(sgn.priv, &header{
KID: kid,
Alg: "RS384",
proc.writeLine("verify", sgn.sign(&header{
KID: rsaKID,
Alg: "RS512",
}, &payload{
Iss: srv.origin,
Aud: clientID,
Expand All @@ -410,11 +421,84 @@ func main() {
}
})

test("token cannot change alg from request start", func() {
email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
oldAlg := sgn.alg
sgn.alg = "EdDSA"

now := time.Now().Unix()
proc.writeLine("verify", sgn.simple(&payload{
Iss: srv.origin,
Aud: clientID,
Exp: now + 5,
Iat: now,
Email: email,
Nonce: nonce,
}))
proc.expect("err", "rejects token")

sgn.alg = oldAlg
}
})

test("caching", func() {
assertEq(srv.numConfigRequests, 1, "discovery requested just once")
assertEq(srv.numKeysRequests, 1, "keys requested just once")
})

test("ed25519", func() {
oldAlg := sgn.alg
sgn.alg = "EdDSA"
proc.writeLine("clear-cache")
proc.expect("ok", "clears caches")

email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
now := time.Now().Unix()
proc.writeLine("verify", sgn.simple(&payload{
Iss: srv.origin,
Aud: clientID + "?id_token_signed_response_alg=EdDSA",
Exp: now + 5,
Iat: now,
Email: email,
Nonce: nonce,
}))
proc.expect("ok", "accepts token")
}

sgn.alg = oldAlg
proc.writeLine("clear-cache")
proc.expect("ok", "clears caches")
})

test("handles changing provider config", func() {
email := "john@example.com"
if nonce := quickStart(email); nonce != "" {
now := time.Now().Unix()
token := sgn.simple(&payload{
Iss: srv.origin,
Aud: clientID,
Exp: now + 5,
Iat: now,
Email: email,
Nonce: nonce,
})

oldAlg := sgn.alg
sgn.alg = "EdDSA"
proc.writeLine("clear-cache")
proc.expect("ok", "clears caches")

proc.writeLine("verify", token)
proc.expect("ok", "accepts token")

sgn.alg = oldAlg
proc.writeLine("clear-cache")
proc.expect("ok", "clears caches")
}
})

proc.stop()
if !allOk {
os.Exit(1)
Expand Down
10 changes: 6 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ const authEndpoint = "http://imaginary-server.test/fake-auth-route"
var srv *server

type discoveryDoc struct {
JWKsURI string `json:"jwks_uri"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
JWKsURI string `json:"jwks_uri"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
}

type server struct {
Expand All @@ -31,8 +32,9 @@ func initServer(keys jwk.Set) {
http.HandleFunc("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
srv.numConfigRequests++
body, err := json.Marshal(&discoveryDoc{
JWKsURI: fmt.Sprintf("%s/test-keys", origin),
AuthorizationEndpoint: authEndpoint,
JWKsURI: fmt.Sprintf("%s/test-keys", origin),
AuthorizationEndpoint: authEndpoint,
IDTokenSigningAlgValuesSupported: []string{sgn.alg},
})
if err != nil {
log.Fatal("json.Marshal error:", err)
Expand Down
82 changes: 53 additions & 29 deletions signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
Expand All @@ -13,16 +14,18 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
)

const kid = "test key"
const alg = "RS256"
const rsaKID = "test key RSA"
const ed25519KID = "test key Ed25519"

var sgn *signer

type signer struct {
priv *rsa.PrivateKey
fake *rsa.PrivateKey
key jwk.Key
keySet jwk.Set
alg string
rsaKey *rsa.PrivateKey
ed25519Key ed25519.PrivateKey
rsaJwk jwk.Key
ed25519Jwk jwk.Key
keySet jwk.Set
}

type header struct {
Expand All @@ -31,42 +34,61 @@ type header struct {
}

func initSigner() {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatal("rsa.GenerateKey error:", err)
}

fake, err := rsa.GenerateKey(rand.Reader, 2048)
ed25519Pub, ed25519Priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
log.Fatal("rsa.GenerateKey error:", err)
log.Fatal("ed25519.GenerateKey error:", err)
}

key, err := jwk.FromRaw(priv.PublicKey)
rsaJwk, err := jwk.FromRaw(rsaKey.PublicKey)
if err != nil {
log.Fatal("jwk.New error:", err)
}
if err := rsaJwk.Set(jwk.KeyIDKey, rsaKID); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}
if err := rsaJwk.Set(jwk.AlgorithmKey, "RS256"); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}
if err := rsaJwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}

if err := key.Set(jwk.KeyIDKey, kid); err != nil {
ed25519Jwk, err := jwk.FromRaw(ed25519Pub)
if err != nil {
log.Fatal("jwk.New error:", err)
}
if err := ed25519Jwk.Set(jwk.KeyIDKey, ed25519KID); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}
if err := ed25519Jwk.Set(jwk.AlgorithmKey, "EdDSA"); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}
if err := key.Set(jwk.AlgorithmKey, alg); err != nil {
if err := ed25519Jwk.Set(jwk.KeyUsageKey, "sig"); err != nil {
log.Fatal("jwk.Key.Set error:", err)
}

keySet := jwk.NewSet()
keySet.AddKey(key)
keySet.AddKey(rsaJwk)
keySet.AddKey(ed25519Jwk)

log.Print("generated server RSA key")

sgn = &signer{
priv: priv,
fake: fake,
key: key,
keySet: keySet,
alg: "RS256",
rsaKey: rsaKey,
ed25519Key: ed25519Priv,
rsaJwk: rsaJwk,
ed25519Jwk: ed25519Jwk,
keySet: keySet,
}
}

func (sgn *signer) sign(key *rsa.PrivateKey, hdr *header, pl interface{}) string {
func (sgn *signer) sign(hdr *header, pl interface{}) string {
hdrJSON, err := json.Marshal(hdr)
if err != nil {
log.Fatal("json.Marshal error:", err)
Expand All @@ -83,17 +105,14 @@ func (sgn *signer) sign(key *rsa.PrivateKey, hdr *header, pl interface{}) string

var sign []byte
switch hdr.Alg {
case "none":
// leave sign and err set to nil
case "RS256":
hash := sha256.Sum256([]byte(signed))
sign, err = rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hash[:])
case "RS384":
hash := sha512.Sum384([]byte(signed))
sign, err = rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA384, hash[:])
sign, err = rsa.SignPKCS1v15(rand.Reader, sgn.rsaKey, crypto.SHA256, hash[:])
case "RS512":
hash := sha512.Sum512([]byte(signed))
sign, err = rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA512, hash[:])
sign, err = rsa.SignPKCS1v15(rand.Reader, sgn.rsaKey, crypto.SHA512, hash[:])
case "EdDSA":
sign = ed25519.Sign(sgn.ed25519Key, []byte(signed))
default:
log.Fatalf("alg '%s' not supported by signer", hdr.Alg)
}
Expand All @@ -106,9 +125,14 @@ func (sgn *signer) sign(key *rsa.PrivateKey, hdr *header, pl interface{}) string
}

func (sgn *signer) simple(pl interface{}) string {
hdr := &header{
KID: kid,
Alg: alg,
hdr := &header{Alg: sgn.alg}
switch hdr.Alg {
case "RS256":
hdr.KID = rsaKID
case "EdDSA":
hdr.KID = ed25519KID
default:
log.Fatalf("alg '%s' not supported by signer", hdr.Alg)
}
return sgn.sign(sgn.priv, hdr, pl)
return sgn.sign(hdr, pl)
}
5 changes: 4 additions & 1 deletion subprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ func (proc *subprocess) expect(res, descr string) string {
log.Print(cmd)
return ""
}
return cmd[1]
if len(cmd) > 1 {
return cmd[1]
}
return ""
}

func (proc *subprocess) stop() {
Expand Down