Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ test:
test-coverage:
@mkdir -p .coverage
@go test `go list ./... | grep -Ev "diodepb|examples|internal"` -race -cover -json -coverprofile=.coverage/cover.out.tmp ./... | tparse -format=markdown > .coverage/test-report.md
Comment thread
leoparente marked this conversation as resolved.
@cat .coverage/cover.out.tmp > .coverage/cover.out
@grep -v "diode/ingester.go" .coverage/cover.out.tmp > .coverage/cover.out
@go tool cover -func=.coverage/cover.out | grep total | awk '{print substr($$3, 1, length($$3)-1)}' > .coverage/coverage.txt

.PHONY: codegen
codegen:
@go run internal/cmd/codegen/main.go | gofmt > ./diode/ingester.go
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ go get github.com/netboxlabs/diode-sdk-go

### Environment variables

* `DIODE_API_KEY` - API key for the Diode service
* `DIODE_SDK_LOG_LEVEL` - Log level for the SDK (default: `INFO`)
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication

### Example

Expand Down
240 changes: 209 additions & 31 deletions diode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"regexp"
"runtime"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/netboxlabs/diode-sdk-go/diode/v1/diodepb"
Expand All @@ -31,15 +36,19 @@ const (
// SDKVersion is the version of the Diode SDK
SDKVersion = "0.2.0"

// DiodeAPIKeyEnvVarName is the environment variable name for the Diode API key
DiodeAPIKeyEnvVarName = "DIODE_API_KEY"
// DiodeClientIDEnvVarName is the environment variable name for the Diode Client ID
DiodeClientIDEnvVarName = "DIODE_CLIENT_ID"

// DiodeClientSecretEnvVarName is the environment variable name for the Diode Client Secret
DiodeClientSecretEnvVarName = "DIODE_CLIENT_SECRET"

// DiodeSDKLogLevelEnvVarName is the environment variable name for the Diode SDK log level
DiodeSDKLogLevelEnvVarName = "DIODE_SDK_LOG_LEVEL"

defaultStreamName = "latest"
// DiodeMaxAuthRetriesEnvVarName is the environment variable name for the maximum number of authentication retries
DiodeMaxAuthRetriesEnvVarName = "DIODE_MAX_AUTH_RETRIES"

authAPIKeyName = "diode-api-key"
defaultStreamName = "latest"
)

var allowedSchemesRe = regexp.MustCompile(`grpc|grpcs`)
Expand Down Expand Up @@ -76,17 +85,46 @@ func parseTarget(target string) (string, string, bool, error) {
return authority, path, tlsVerify, nil
}

// getAPIKey returns the API key either from provided value or environment variable
func getAPIKey(apiKey string) (string, error) {
if apiKey == "" {
apiKey = os.Getenv(DiodeAPIKeyEnvVarName)
// getClientID returns the client ID either from provided value or environment variable
func getClientID(clientID string) (string, error) {
if clientID == "" {
clientID = os.Getenv(DiodeClientIDEnvVarName)
}

if clientID == "" {
return "", fmt.Errorf("client_id param or %s environment variable required", DiodeClientIDEnvVarName)
}

return clientID, nil
}

// getClientSecret returns the client secret either from provided value or environment variable
func getClientSecret(clientSecret string) (string, error) {
if clientSecret == "" {
clientSecret = os.Getenv(DiodeClientSecretEnvVarName)
}

if apiKey == "" {
return "", fmt.Errorf("api_key param or %s environment variable required", DiodeAPIKeyEnvVarName)
if clientSecret == "" {
return "", fmt.Errorf("client_secret param or %s environment variable required", DiodeClientSecretEnvVarName)
}

return apiKey, nil
return clientSecret, nil
}

// getAuthRetries returns the maximum number of authentication retries
func getAuthRetries(maxAuthRetries int) (int, error) {
maxAuthRetriesStr := os.Getenv(DiodeMaxAuthRetriesEnvVarName)
if maxAuthRetriesStr != "" {
retries, err := strconv.Atoi(maxAuthRetriesStr)
if err != nil {
return 0, fmt.Errorf("invalid value for %s: %w", DiodeMaxAuthRetriesEnvVarName, err)
}
maxAuthRetries = retries
}
if maxAuthRetries <= 0 {
return 0, fmt.Errorf("max_auth_retries param or %s environment variable must be greater than 0", DiodeMaxAuthRetriesEnvVarName)
}
return maxAuthRetries, nil
}

// Client is an interface that defines the methods available from Diode API
Expand Down Expand Up @@ -115,8 +153,14 @@ type GRPCClient struct {
// Producer's application version
appVersion string

// An API key for the Diode API
apiKey string
// The client ID for the API
clientID string

// The client secret for the API
clientSecret string

// The maximum number of authentication retries
maxAuthRetries int

// GRPC target
target string
Expand All @@ -140,13 +184,108 @@ type GRPCClient struct {
// ClientOption is a functional option for the GRPCClient
type ClientOption func(*GRPCClient)

// WithAPIKey sets the API key for the client
func WithAPIKey(apiKey string) ClientOption {
// WithClientID sets the client ID for the GRPCClient
func WithClientID(clientID string) ClientOption {
return func(c *GRPCClient) {
c.apiKey = apiKey
c.clientID = clientID
}
}

// WithClientSecret sets the client secret for the GRPCClient
func WithClientSecret(clientSecret string) ClientOption {
return func(c *GRPCClient) {
c.clientSecret = clientSecret
}
}

// authenticate fetches an OAuth2 token using client credentials and updates the metadata with the token.
func (g *GRPCClient) authenticate() error {
authClient := newDiodeAuthentication(g.target, g.path, g.tlsVerify, g.clientID, g.clientSecret)
accessToken, err := authClient.authenticate(g.logger)
if err != nil {
return fmt.Errorf("authentication failed: %w", err)
}

// Update metadata with the new authorization token
g.metadata.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
return nil
}

// DiodeAuthentication handles OAuth2 authentication for the Diode API.
type diodeAuthentication struct {
target string
path string
tlsVerify bool
clientID string
clientSecret string
}

// NewDiodeAuthentication creates a new instance of DiodeAuthentication.
func newDiodeAuthentication(target string, path string, tlsVerify bool, clientID, clientSecret string) *diodeAuthentication {
return &diodeAuthentication{
target: target,
path: path,
tlsVerify: tlsVerify,
clientID: clientID,
clientSecret: clientSecret,
}
}

// Authenticate requests an OAuth2 token using client credentials and returns it.
func (d *diodeAuthentication) authenticate(logger *slog.Logger) (string, error) {
scheme := "http"
if d.tlsVerify {
scheme = "https"
}
authURL := fmt.Sprintf("%s://%s/auth/token", scheme, d.target)
if d.path != "" {
authURL = fmt.Sprintf("%s://%s%s/auth/token", scheme, d.target, d.path)
}
data := url.Values{}
data.Set("grant_type", "client_credentials")
data.Set("client_id", d.clientID)
data.Set("client_secret", d.clientSecret)

req, err := http.NewRequest(http.MethodPost, authURL, strings.NewReader(data.Encode()))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

client := &http.Client{}
if !d.tlsVerify {
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}

resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
logger.Error("failed to close response body", "error", err)
}
}()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("authentication failed: %s", resp.Status)
}

var result struct {
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}

if result.AccessToken == "" {
return "", errors.New("access token not found in response")
}

return result.AccessToken, nil
}

// NewClient creates a new diode client based on gRPC
func NewClient(target string, appName string, appVersion string, opts ...ClientOption) (Client, error) {
logger := newLogger()
Expand Down Expand Up @@ -191,31 +330,48 @@ func NewClient(target string, appName string, appVersion string, opts ...ClientO
goVersion := runtime.Version()

c := &GRPCClient{
logger: logger,
conn: conn,
client: diodepb.NewIngesterServiceClient(conn),
appName: appName,
appVersion: appVersion,
target: target,
path: path,
tlsVerify: tlsVerify,
platform: platform,
goVersion: goVersion,
logger: logger,
conn: conn,
client: diodepb.NewIngesterServiceClient(conn),
appName: appName,
appVersion: appVersion,
target: target,
path: path,
tlsVerify: tlsVerify,
platform: platform,
goVersion: goVersion,
maxAuthRetries: 3,
Comment thread
leoparente marked this conversation as resolved.
}

var apiKey string
var clientID string
var clientSecret string

for _, o := range opts {
o(c)
}

apiKey, err = getAPIKey(c.apiKey)
c.metadata = metadata.Pairs("platform", platform, "go-version", goVersion)

c.maxAuthRetries, err = getAuthRetries(c.maxAuthRetries)
if err != nil {
return nil, err
}

c.apiKey = apiKey
c.metadata = metadata.Pairs(authAPIKeyName, c.apiKey, "platform", platform, "go-version", goVersion)
clientID, err = getClientID(c.clientID)
if err != nil {
return nil, err
}
clientSecret, err = getClientSecret(c.clientSecret)
if err != nil {
return nil, err
}

c.clientID = clientID
c.clientSecret = clientSecret

if err = c.authenticate(); err != nil {
return nil, err
}

return c, nil
}
Expand Down Expand Up @@ -246,7 +402,29 @@ func (g *GRPCClient) Ingest(ctx context.Context, entities []Entity) (*diodepb.In

ctx = metadata.NewOutgoingContext(ctx, g.metadata)

return g.client.Ingest(ctx, req)
var err error
var res *diodepb.IngestResponse

attempt := 0
for {
res, err = g.client.Ingest(ctx, req)
if err != nil {
if status.Code(err) == codes.Unauthenticated {
attempt++
if attempt >= g.maxAuthRetries {
return nil, fmt.Errorf("authentication failed after %d attempts: %w", attempt, err)
}
g.logger.Debug("Authentication failed, retrying...", "attempt", attempt)
if err := g.authenticate(); err != nil {
g.logger.Error("Failed to re-authenticate", "error", err)
}
continue
}
return nil, err
}
break
}
return res, nil
}

// convertEntitiesToProto converts entities to proto entities
Expand Down
Loading
Loading