Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/authschemes/auth_bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func MakeAuthBridge(db *database.Connection, sessionStore *session.Store, authSc

// CreateNewUser allows new users to be registered into the system, if they do not already exist.
// Note that slug must be unique
func (ah AShirtAuthBridge) CreateNewUser(profile UserProfile) (*dtos.CreateUserOutput, error) {
return services.CreateUser(ah.db, profile.ToCreateUserInput())
func (ah AShirtAuthBridge) CreateNewUser(ctx context.Context, profile UserProfile) (*dtos.CreateUserOutput, error) {
return services.CreateUser(ctx, ah.db, profile.ToCreateUserInput())
}

// SetAuthSchemeSession sets authscheme specific session data to the current user session. Session data should
Expand Down
7 changes: 4 additions & 3 deletions backend/authschemes/auth_bridge_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package authschemes_test

import (
"context"
"encoding/gob"
"net/http"
"net/http/httptest"
Expand All @@ -21,7 +22,7 @@ import (
func TestCreateNewUser(t *testing.T) {
db, _, bridge := initBridgeTest(t)

newUser, err := bridge.CreateNewUser(authschemes.UserProfile{
newUser, err := bridge.CreateNewUser(context.Background(), authschemes.UserProfile{
FirstName: "Alice",
LastName: "Defaultuser",
Email: "alice@example.com",
Expand All @@ -39,7 +40,7 @@ func TestCreateNewUser(t *testing.T) {
require.Equal(t, "slug", user.Slug)

// Creating a user with a slug that already exists appends a random number to the slug
newUser, err = bridge.CreateNewUser(authschemes.UserProfile{
newUser, err = bridge.CreateNewUser(context.Background(), authschemes.UserProfile{
FirstName: "Bob",
LastName: "Snooper",
Email: "bob@example.com",
Expand Down Expand Up @@ -260,7 +261,7 @@ func initBridgeTest(t *testing.T) (*database.Connection, *session.Store, authsch
}

func createDummyUser(t *testing.T, bridge authschemes.AShirtAuthBridge, extra string) int64 {
newUser, err := bridge.CreateNewUser(authschemes.UserProfile{
newUser, err := bridge.CreateNewUser(context.Background(), authschemes.UserProfile{
FirstName: "Dummy",
LastName: "User",
Email: "email+" + extra + "@example.com",
Expand Down
2 changes: 1 addition & 1 deletion backend/authschemes/localauth/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func registerNewUser(ctx context.Context, bridge authschemes.AShirtAuthBridge, i
return backend.WrapError("Unable to generate encrypted password", err)
}

userResult, err := bridge.CreateNewUser(authschemes.UserProfile{
userResult, err := bridge.CreateNewUser(ctx, authschemes.UserProfile{
FirstName: info.FirstName,
LastName: info.LastName,
Slug: strings.ToLower(info.FirstName + "." + info.LastName),
Expand Down
2 changes: 1 addition & 1 deletion backend/authschemes/oidcauth/oidc_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (o OIDCAuth) handleCallback(w http.ResponseWriter, r *http.Request, bridge
return o.authFailure(w, r, backend.WrapError("Registration is disabled", err), "/autherror/registrationdisabled")
}

userResult, err := bridge.CreateNewUser(*userProfile)
userResult, err := bridge.CreateNewUser(r.Context(), *userProfile)
if err != nil {
return o.authFailure(w, r, backend.WrapError("Create new "+authName+" user failed ["+userProfile.Slug+"]", err), "/autherror/incomplete")
}
Expand Down
2 changes: 1 addition & 1 deletion backend/authschemes/recoveryauth/recovery_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (p RecoveryAuthScheme) BindRoutes(r chi.Router, bridge authschemes.AShirtAu
if dr.Error != nil {
return nil, dr.Error
}
userResult, err := bridge.CreateNewUser(profile)
userResult, err := bridge.CreateNewUser(r.Context(), profile)
if err != nil {
return nil, backend.WrapError("Unable to create new user", err)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/authschemes/webauthn/webauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (a WebAuthn) BindRoutes(r chi.Router, bridge authschemes.AShirtAuthBridge)
Slug: strings.ToLower(data.UserData.FirstName + "." + data.UserData.LastName),
Email: data.UserData.Email,
}
userResult, err := bridge.CreateNewUser(userProfile)
userResult, err := bridge.CreateNewUser(r.Context(), userProfile)
if err != nil {
return nil, backend.WrapError("Unable to create user", err)
}
Expand Down
11 changes: 6 additions & 5 deletions backend/bin/api/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"errors"
"net/http"

Expand All @@ -17,17 +18,17 @@ func main() {
err := config.LoadAPIConfig()
logger := logging.SetupStdoutLogging()
if err != nil {
logging.Fatal(logger, "Unable to start due to configuration error", "error", err, "action", "exiting")
logging.Fatal(context.Background(), logger,"Unable to start due to configuration error", "error", err, "action", "exiting")
}

db, err := database.NewConnection(config.DBUri(), "/migrations")
if err != nil {
logging.Fatal(logger, "Unable to connect to database", "error", err, "action", "exiting")
logging.Fatal(context.Background(), logger,"Unable to connect to database", "error", err, "action", "exiting")
}

logger.Info("checking database schema")
if err := db.CheckSchema(); err != nil {
logging.Fatal(logger, "schema read error", "error", err)
logging.Fatal(context.Background(), logger,"schema read error", "error", err)
}

contentStore, err := confighelpers.ChooseContentStoreType(config.AllStoreConfig())
Expand All @@ -36,7 +37,7 @@ func main() {
contentStore, err = confighelpers.DefaultS3Store()
}
if err != nil {
logging.Fatal(logger, "store setup error", "error", err)
logging.Fatal(context.Background(), logger,"store setup error", "error", err)
}

s := chi.NewRouter()
Expand All @@ -49,5 +50,5 @@ func main() {

logger.Info("starting API server", "port", config.Port())
serveErr := http.ListenAndServe(":"+config.Port(), s)
logging.Fatal(logger, "server shutting down", "err", serveErr)
logging.Fatal(context.Background(), logger,"server shutting down", "err", serveErr)
}
11 changes: 6 additions & 5 deletions backend/bin/web/web.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -32,17 +33,17 @@ func main() {
err := config.LoadWebConfig()
logger := logging.SetupStdoutLogging()
if err != nil {
logging.Fatal(logger, "Unable to start due to configuration error", "error", err, "action", "exiting")
logging.Fatal(context.Background(), logger,"Unable to start due to configuration error", "error", err, "action", "exiting")
}

db, err := database.NewConnection(config.DBUri(), "/migrations")
if err != nil {
logging.Fatal(logger, "Unable to connect to database", "error", err, "action", "exiting")
logging.Fatal(context.Background(), logger,"Unable to connect to database", "error", err, "action", "exiting")
}

logger.Info("checking database schema")
if err := db.CheckSchema(); err != nil {
logging.Fatal(logger, "schema read error", "error", err)
logging.Fatal(context.Background(), logger,"schema read error", "error", err)
}

contentStore, err := confighelpers.ChooseContentStoreType(config.AllStoreConfig())
Expand All @@ -51,7 +52,7 @@ func main() {
contentStore, err = confighelpers.DefaultS3Store()
}
if err != nil {
logging.Fatal(logger, "store setup error", "error", err)
logging.Fatal(context.Background(), logger,"store setup error", "error", err)
}
logger.Info("Using Storage", "type", contentStore.Name())

Expand Down Expand Up @@ -106,7 +107,7 @@ func main() {

logger.Info("starting Web server", "port", config.Port())
serveErr := http.ListenAndServe(":"+config.Port(), r)
logging.Fatal(logger, "server shutting down", "err", serveErr)
logging.Fatal(context.Background(), logger,"server shutting down", "err", serveErr)
}

func handleAuthType(cfg config.AuthInstanceConfig) (authschemes.AuthScheme, error) {
Expand Down
10 changes: 5 additions & 5 deletions backend/contentstore/s3presigner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package contentstore

import (
"context"
"log"
"log/slog"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -12,21 +12,21 @@ import (

type Presigner struct {
PresignClient *s3.PresignClient
Logger *slog.Logger
}

func (presigner Presigner) GetObject(
bucketName string, objectKey string, minutes time.Duration) (*v4.PresignedHTTPRequest, error) {
ctx context.Context, bucketName string, objectKey string, minutes time.Duration) (*v4.PresignedHTTPRequest, error) {
contentType := "image/jpeg"
request, err := presigner.PresignClient.PresignGetObject(context.TODO(), &s3.GetObjectInput{
request, err := presigner.PresignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucketName),
Key: aws.String(objectKey),
ResponseContentType: aws.String(contentType),
}, func(opts *s3.PresignOptions) {
opts.Expires = minutes
})
if err != nil {
log.Printf("Couldn't get a presigned request to get %v:%v. Here's why: %v\n",
bucketName, objectKey, err)
presigner.Logger.ErrorContext(ctx, "Couldn't get a presigned request", "bucket", bucketName, "key", objectKey, "error", err)
}
return request, err
}
7 changes: 4 additions & 3 deletions backend/contentstore/s3store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package contentstore
import (
"context"
"io"
"log/slog"
"time"

"github.com/ashirt-ops/ashirt-server/backend"
Expand Down Expand Up @@ -82,11 +83,11 @@ type URLData struct {
ExpirationTime time.Time `json:"expirationTime"`
}

func (s *S3Store) SendURLData(key string) (*URLData, error) {
func (s *S3Store) SendURLData(ctx context.Context, key string) (*URLData, error) {
minutes := time.Minute * time.Duration(30)
presignClient := s3.NewPresignClient(s.s3Client)
presigner := Presigner{PresignClient: presignClient}
presignedGetRequest, err := presigner.GetObject(s.bucketName, key, minutes)
presigner := Presigner{PresignClient: presignClient, Logger: slog.Default()}
presignedGetRequest, err := presigner.GetObject(ctx, s.bucketName, key, minutes)
if err != nil {
return nil, backend.WrapError("Unable to get presigned URL", err)
}
Expand Down
3 changes: 2 additions & 1 deletion backend/database/helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package database

import (
"context"
"database/sql"
"fmt"
"strings"
Expand Down Expand Up @@ -142,7 +143,7 @@ func (c *Connection) execSquirrel(sQuery squirrel.Sqlizer) (sql.Result, error) {
}

func logQuery(query string, values []interface{}) {
logging.SystemLog("executing query", "query", query, "values", fmt.Sprintf("%v", values))
logging.SystemLog(context.Background(), "executing query", "query", query, "values", fmt.Sprintf("%v", values))
}

// IsEmptyResultSetError returns true if the passed error is a database error resulting
Expand Down
12 changes: 6 additions & 6 deletions backend/enhancementservices/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func SendServiceWorkerEvent(db *database.Connection, input SendServiceWorkerEven
payloads, _ = input.Builder(tx)
})
if err != nil {
input.Logger.Error("Unable to execute service workers", "error", err.Error())
input.Logger.ErrorContext(workerContext, "Unable to execute service workers", "error", err.Error())
return
}

Expand All @@ -72,9 +72,9 @@ func SendServiceWorkerEvent(db *database.Connection, input SendServiceWorkerEven
)

if err != nil {
logger.Error("Unable to run worker", "error", err)
logger.ErrorContext(workerContext, "Unable to run worker", "error", err)
} else {
logger.Info("Worker completed")
logger.InfoContext(workerContext, "Worker completed")
}
}()
}
Expand Down Expand Up @@ -103,7 +103,7 @@ func SendEvidenceCreatedEvent(db *database.Connection, reqLogger *slog.Logger, o
helpers.Map(workersToRun, getServiceWorkerName))
})
if err != nil {
reqLogger.Error("Unable to execute service workers", "error", err.Error())
reqLogger.ErrorContext(workerContext, "Unable to execute service workers", "error", err.Error())
return
}

Expand All @@ -124,9 +124,9 @@ func SendEvidenceCreatedEvent(db *database.Connection, reqLogger *slog.Logger, o
)

if err != nil {
logger.Error("Unable to run worker", "error", err)
logger.ErrorContext(workerContext, "Unable to run worker", "error", err)
} else {
logger.Info("Worker completed")
logger.InfoContext(workerContext, "Worker completed")
}
}()
}
Expand Down
12 changes: 6 additions & 6 deletions backend/logging/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ func AddRequestLogger(ctx context.Context, baseLogger *slog.Logger) (context.Con

// Fatal is an effective copy of go's log.Fatal, but using the logger provided, rather than
// using go's native logging. After writing the message, the code will exit with code 1
func Fatal(logger *slog.Logger, msg string, keyvals ...interface{}) {
logger.Error(msg, keyvals...)
func Fatal(ctx context.Context, logger *slog.Logger, msg string, keyvals ...interface{}) {
logger.ErrorContext(ctx, msg, keyvals...)
os.Exit(1)
}

func LogWithoutAuth(msg string, keyvals ...interface{}) {
func LogWithoutAuth(ctx context.Context, msg string, keyvals ...interface{}) {
if systemLogger != nil {
systemLogger.Info(msg, keyvals...)
systemLogger.InfoContext(ctx, msg, keyvals...)
}
}

// SystemLog provides a system-level logger, which is not tied to any request.
// this should be used in situations where either a context is not handy, but logging is important
// or for events that are not tied to a request (e.g. losing database connection)
func SystemLog(msg string, keyvals ...interface{}) {
func SystemLog(ctx context.Context, msg string, keyvals ...interface{}) {
if systemLogger != nil {
systemLogger.Info(msg, keyvals...)
systemLogger.InfoContext(ctx, msg, keyvals...)
}
}
8 changes: 4 additions & 4 deletions backend/server/middleware/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func AuthenticateAppAndInjectCtx(db *database.Connection) MiddlewareFunc {

userData, err := authenticateAPI(db, r, body)
if err != nil {
logging.LogWithoutAuth(
logging.LogWithoutAuth(r.Context(),
"Unable to build user policy",
"error", err.Error(),
)
Expand Down Expand Up @@ -136,7 +136,7 @@ func buildPolicyForUser(ctx context.Context, db *database.Connection, userID int

var groupRoles []models.UserGroupOperationPermission

err := db.WithTx(context.Background(), func(tx *database.Transactable) {
err := db.WithTx(ctx, func(tx *database.Transactable) {
tx.Select(&roles, sq.Select("operation_id", "role").
From("user_operation_permissions").
Where(sq.Eq{"user_id": userID}))
Expand All @@ -152,7 +152,7 @@ func buildPolicyForUser(ctx context.Context, db *database.Connection, userID int
})

if err != nil {
logging.ReqLogger(ctx).Error("Unable to build user policy", "error", err.Error())
logging.ReqLogger(ctx).ErrorContext(ctx, "Unable to build user policy", "error", err.Error())
return &policy.Deny{}
}
roleMap := make(map[int64]policy.OperationRole)
Expand Down Expand Up @@ -223,7 +223,7 @@ func cloneBody(r *http.Request) (io.Reader, func(), error) {
r.Body.Close()
err := os.Remove(bodyTmpFile.Name())
if err != nil {
logging.LogWithoutAuth(
logging.LogWithoutAuth(r.Context(),
"Unable to remove tmp file",
"error", err,
"tmpFile", bodyTmpFile.Name(),
Expand Down
4 changes: 2 additions & 2 deletions backend/server/middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ func LogRequests(baseLogger *slog.Logger) MiddlewareFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx, logger := logging.AddRequestLogger(r.Context(), baseLogger)
logger.Info("Incoming request", "method", r.Method, "url", r.URL, "from", r.RemoteAddr)
logger.InfoContext(ctx, "Incoming request", "method", r.Method, "url", r.URL, "from", r.RemoteAddr)
ww := &responseWriterWrapper{w, 0, 200}

next.ServeHTTP(ww, r.WithContext(ctx))
logger.Info("Request Completed", "status", ww.status, "sizeInBytes", ww.size, "duration", time.Since(start))
logger.InfoContext(ctx, "Request Completed", "status", ww.status, "sizeInBytes", ww.size, "duration", time.Since(start))
})
}
}
Loading
Loading