diff --git a/README.rst b/README.rst index 505ff09197..1871f23c36 100644 --- a/README.rst +++ b/README.rst @@ -558,6 +558,17 @@ To easily deploy Vault locally: (DO NOT DO THIS FOR PRODUCTION!!!) $ sops encrypt --verbose prod/raw.yaml > prod/encrypted.yaml +Restricting HC Vault servers that SOPS can talk to +************************************************** + +If you want to restrict which HC Vault servers SOPS is allowed to talk to, you can set the ``SOPS_HC_VAULT_ALLOWLIST`` environment variable. +When set to ``all`` (the default value), there is no restriction. +When set to ``none``, SOPS will not allow any access to HC Vault servers for decryption or encryption. + +When set to any other value, this value will be interpreted as a comma-separated list of strings. +If SOPS attempts to contact a vault URL that starts with one of these strings, SOPS will attempt to contact that URL. +If there is no matching prefix in ``SOPS_HC_VAULT_ALLOWLIST``, SOPS will not contact that URL. + Encrypting using HuaweiCloud KMS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/hcvault/keysource.go b/hcvault/keysource.go index 67706e71e3..c60e11cfcb 100644 --- a/hcvault/keysource.go +++ b/hcvault/keysource.go @@ -26,8 +26,75 @@ import ( const ( // KeyTypeIdentifier is the string used to identify a Vault MasterKey. KeyTypeIdentifier = "hc_vault" + // SopsHCVaultAllowlist can be set as an environment variable with a string list + // of age keys as value. + SopsHCVaultAllowlist = "SOPS_HC_VAULT_ALLOWLIST" + // Special value for allowlist that allows all hosts + AllowlistAllHosts = "all" + // Special value for allowlist that allows no hosts + AllowlistNoHosts = "none" + // Default value of allowlist. Should eventually be changed to "none". + AllowlistDefault = AllowlistAllHosts ) +type allowList struct { + All bool + URIs []string +} + +func (al *allowList) Allows(address string) bool { + if al.All { + return true + } + if !strings.HasSuffix(address, "/") { + address = address + "/" + } + for _, uri := range al.URIs { + if strings.HasPrefix(address, uri) { + return true + } + } + return false +} + +func parseAllowlistString(allowlistStr string) (allowList, error) { + switch allowlistStr { + case AllowlistAllHosts: + return allowList{ + All: true, + URIs: nil, + }, nil + case AllowlistNoHosts: + return allowList{ + All: false, + URIs: nil, + }, nil + } + uris := strings.Split(allowlistStr, ",") + for idx, uri := range uris { + uri = strings.Trim(uri, " ") + if uri == "" { + return allowList{}, fmt.Errorf("%s's entry %d is empty", SopsHCVaultAllowlist, idx+1) + } + if !strings.HasSuffix(uri, "/") { + uri = uri + "/" + } + uris[idx] = uri + } + return allowList{ + All: false, + URIs: uris, + }, nil +} + +func getAllowlist() (allowList, error) { + var allowlistStr = AllowlistDefault + if allowlist, ok := os.LookupEnv(SopsHCVaultAllowlist); ok && len(allowlist) > 0 { + allowlistStr = allowlist + } + return parseAllowlistString(allowlistStr) +} + func init() { log = logging.NewLogger("VAULT_TRANSIT") } @@ -331,6 +398,14 @@ func vaultClient(address, token string, hc *http.Client) (*api.Client, error) { cfg := api.DefaultConfig() cfg.Address = address + allowlist, err := getAllowlist() + if err != nil { + return nil, err + } + if !allowlist.Allows(address) { + return nil, fmt.Errorf("Allowlist does not allow %s", address) + } + if hc != nil { cfg.HttpClient = hc } diff --git a/hcvault/keysource_test.go b/hcvault/keysource_test.go index 02d5a13e2e..8059ed9ae8 100644 --- a/hcvault/keysource_test.go +++ b/hcvault/keysource_test.go @@ -27,6 +27,8 @@ var ( // testVaultAddress is the HTTP/S address of the Vault server, it is set // by TestMain after booting it. testVaultAddress string + // Whether to skip all Docker-based tests. + testSkipDocker = false ) // TestMain initializes a Vault server using Docker, writes the HTTP address to @@ -34,6 +36,11 @@ var ( // Vault Transit on the testEnginePath. It then runs all the tests, which can // make use of the various `test*` variables. func TestMain(m *testing.M) { + if testSkipDocker { + os.Exit(m.Run()) + return + } + // Uses a sensible default on Windows (TCP/HTTP) and Linux/MacOS (socket) pool, err := dockertest.NewPool("") if err != nil { @@ -179,6 +186,10 @@ func TestNewMasterKeyFromURI(t *testing.T) { } func TestMasterKey_Encrypt(t *testing.T) { + if testSkipDocker { + return + } + key := NewMasterKey(testVaultAddress, testEnginePath, "encrypt") (Token(testVaultToken)).ApplyToMasterKey(key) assert.NoError(t, createVaultKey(key)) @@ -207,6 +218,10 @@ func TestMasterKey_Encrypt(t *testing.T) { } func TestMasterKey_EncryptIfNeeded(t *testing.T) { + if testSkipDocker { + return + } + key := NewMasterKey(testVaultAddress, testEnginePath, "encrypt-if-needed") (Token(testVaultToken)).ApplyToMasterKey(key) assert.NoError(t, createVaultKey(key)) @@ -226,6 +241,10 @@ func TestMasterKey_EncryptedDataKey(t *testing.T) { } func TestMasterKey_Decrypt(t *testing.T) { + if testSkipDocker { + return + } + key := NewMasterKey(testVaultAddress, testEnginePath, "decrypt") (Token(testVaultToken)).ApplyToMasterKey(key) assert.NoError(t, createVaultKey(key)) @@ -254,6 +273,10 @@ func TestMasterKey_Decrypt(t *testing.T) { } func TestMasterKey_EncryptDecrypt_RoundTrip(t *testing.T) { + if testSkipDocker { + return + } + token := Token(testVaultToken) encryptKey := NewMasterKey(testVaultAddress, testEnginePath, "roundtrip") @@ -519,3 +542,144 @@ func createVaultKey(key *MasterKey) error { _, err = client.Logical().Read(p) return err } + +func TestAllowlistParse(t *testing.T) { + t.Run("success", func(t *testing.T) { + al, err := parseAllowlistString("all") + assert.NoError(t, err) + assert.Equal(t, allowList{ + All: true, + URIs: nil, + }, al) + + al, err = parseAllowlistString("none") + assert.NoError(t, err) + assert.Equal(t, allowList{ + All: false, + URIs: nil, + }, al) + + al, err = parseAllowlistString("non") + assert.NoError(t, err) + assert.Equal(t, allowList{ + All: false, + URIs: []string{ + "non/", + }, + }, al) + + al, err = parseAllowlistString("foo,bar/,baz") + assert.NoError(t, err) + assert.Equal(t, allowList{ + All: false, + URIs: []string{ + "foo/", + "bar/", + "baz/", + }, + }, al) + + al, err = parseAllowlistString(" foo/ , bar, baz ") + assert.NoError(t, err) + assert.Equal(t, allowList{ + All: false, + URIs: []string{ + "foo/", + "bar/", + "baz/", + }, + }, al) + }) + + t.Run("error", func(t *testing.T) { + al, err := parseAllowlistString("") + assert.Error(t, err) + assert.Equal(t, "SOPS_HC_VAULT_ALLOWLIST's entry 1 is empty", err.Error()) + assert.Equal(t, allowList{ + All: false, + URIs: nil, + }, al) + + al, err = parseAllowlistString(",") + assert.Error(t, err) + assert.Equal(t, "SOPS_HC_VAULT_ALLOWLIST's entry 1 is empty", err.Error()) + assert.Equal(t, allowList{ + All: false, + URIs: nil, + }, al) + + al, err = parseAllowlistString(",a") + assert.Error(t, err) + assert.Equal(t, "SOPS_HC_VAULT_ALLOWLIST's entry 1 is empty", err.Error()) + assert.Equal(t, allowList{ + All: false, + URIs: nil, + }, al) + + al, err = parseAllowlistString("a,") + assert.Error(t, err) + assert.Equal(t, "SOPS_HC_VAULT_ALLOWLIST's entry 2 is empty", err.Error()) + assert.Equal(t, allowList{ + All: false, + URIs: nil, + }, al) + }) +} + +func TestAllowlistAllow(t *testing.T) { + al, _ := parseAllowlistString("all") + assert.Equal(t, al.Allows(""), true) + assert.Equal(t, al.Allows("foo"), true) + assert.Equal(t, al.Allows("bar"), true) + assert.Equal(t, al.Allows("http://example.com"), true) + assert.Equal(t, al.Allows("http://example.com/"), true) + assert.Equal(t, al.Allows("https://example.com/foo"), true) + + al, _ = parseAllowlistString("none") + assert.Equal(t, al.Allows(""), false) + assert.Equal(t, al.Allows("foo"), false) + assert.Equal(t, al.Allows("bar"), false) + assert.Equal(t, al.Allows("http://example.com"), false) + assert.Equal(t, al.Allows("http://example.com/"), false) + assert.Equal(t, al.Allows("https://example.com/foo"), false) + + al, _ = parseAllowlistString("http://example.com") + assert.Equal(t, al.Allows("http://example.co"), false) + assert.Equal(t, al.Allows("http://example.com"), true) + assert.Equal(t, al.Allows("http://example.comm"), false) + assert.Equal(t, al.Allows("http://example.com:80"), false) + assert.Equal(t, al.Allows("http://example.com/"), true) + assert.Equal(t, al.Allows("http://example.com/foo"), true) + assert.Equal(t, al.Allows("http://fiz@example.com/"), false) + assert.Equal(t, al.Allows("http://example.com:123/"), false) + assert.Equal(t, al.Allows("https://example.com"), false) + assert.Equal(t, al.Allows("https://example.com/"), false) + assert.Equal(t, al.Allows(""), false) + + al, _ = parseAllowlistString("http://example.com, https://example.org/bar/,http://foo:80") + assert.Equal(t, al.Allows("http://example.com"), true) + assert.Equal(t, al.Allows("http://example.com/"), true) + assert.Equal(t, al.Allows("http://example.com/foo"), true) + assert.Equal(t, al.Allows("http://fiz@example.com/"), false) + assert.Equal(t, al.Allows("http://example.com:123/"), false) + assert.Equal(t, al.Allows("https://example.com"), false) + assert.Equal(t, al.Allows("https://example.com/"), false) + assert.Equal(t, al.Allows("http://example.org"), false) + assert.Equal(t, al.Allows("http://example.org/"), false) + assert.Equal(t, al.Allows("http://example.org/foo"), false) + assert.Equal(t, al.Allows("http://fiz@example.org/"), false) + assert.Equal(t, al.Allows("http://example.org:123/"), false) + assert.Equal(t, al.Allows("https://example.org"), false) + assert.Equal(t, al.Allows("https://example.org/"), false) + assert.Equal(t, al.Allows("https://example.org/bar"), true) + assert.Equal(t, al.Allows("https://example.org/barr"), false) + assert.Equal(t, al.Allows("https://example.org/bar/"), true) + assert.Equal(t, al.Allows("https://example.org/bar/baz"), true) + assert.Equal(t, al.Allows("http://foo"), false) + assert.Equal(t, al.Allows("http://foo/"), false) + assert.Equal(t, al.Allows("http://foo:80"), true) + assert.Equal(t, al.Allows("http://foo:80/"), true) + assert.Equal(t, al.Allows("http://foo:8080"), false) + assert.Equal(t, al.Allows("http://foo:8080/"), false) + assert.Equal(t, al.Allows(""), false) +}