Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 96 additions & 58 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os/exec"
"path"
"strings"
"sync"
"time"

"al.essio.dev/pkg/shellescape"
Expand Down Expand Up @@ -113,11 +114,11 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command {
sshCmd.Flags().
StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be forwarded to the given host and port, or Unix socket, on the remote side.")
"host are to be forwarded to the given host, service name, and port, or Unix socket, on the remote side.")
sshCmd.Flags().
StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{},
"Specifies that connections to the given TCP port or Unix socket on the local (client) "+
"host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.")
"Specifies that connections to the given TCP port or Unix socket on the remote side "+
"are to be reverse forwarded to the given local host, service name, and port, or Unix socket.")
Comment thread
davzucky marked this conversation as resolved.
sshCmd.Flags().
StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{},
"Specifies which local env variables shall be sent to the container.")
Expand Down Expand Up @@ -382,16 +383,17 @@ func (cmd *SSHCmd) jumpContainer(
}

func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) {
timeout := time.Duration(0)
if cmd.ForwardPortsTimeout != "" {
timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout)
if err != nil {
return timeout, fmt.Errorf("parse forward ports timeout: %w", err)
}
if cmd.ForwardPortsTimeout == "" {
return 0, nil
}

log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout)
timeout, err := time.ParseDuration(cmd.ForwardPortsTimeout)
if err != nil {
return 0, fmt.Errorf("parse forward ports timeout: %w", err)
}

log.Infof("Using port forwarding timeout of %s", cmd.ForwardPortsTimeout)

return timeout, nil
}

Expand All @@ -400,89 +402,125 @@ func (cmd *SSHCmd) reverseForwardPorts(
containerClient *ssh.Client,
log log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ReverseForwardPorts,
logTemplate: "Reverse forwarding remote %s/%s to local %s/%s",
forwardFn: devssh.ReversePortForward,
}, log)
}

errChan := make(chan error, len(cmd.ReverseForwardPorts))
for _, portMapping := range cmd.ReverseForwardPorts {
func (cmd *SSHCmd) forwardPorts(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
) error {
return cmd.runPortForwards(ctx, containerClient, portForwardConfig{
mappings: cmd.ForwardPorts,
logTemplate: "Forwarding local %s/%s to remote %s/%s",
forwardFn: devssh.PortForward,
}, log)
}

type portForwardFunc func(
context.Context,
*ssh.Client,
string,
string,
string,
string,
time.Duration,
log.Logger,
) error

type portForwardConfig struct {
mappings []string
logTemplate string
forwardFn portForwardFunc
}

type parsedPortForward struct {
spec string
mapping port.Mapping
}

func parsePortForwards(mappings []string) ([]parsedPortForward, error) {
parsedMappings := make([]parsedPortForward, 0, len(mappings))
for _, portMapping := range mappings {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
return nil, fmt.Errorf("parse port mapping: %w", err)
}

// start the forwarding
log.Infof(
"Reverse forwarding local %s/%s to remote %s/%s",
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.ReversePortForward(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
)
if !errors.Is(io.EOF, err) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
parsedMappings = append(parsedMappings, parsedPortForward{
spec: portMapping,
mapping: mapping,
})
}

return <-errChan
return parsedMappings, nil
}

func (cmd *SSHCmd) forwardPorts(
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}

errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping

// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()

err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}

go func() {
waitGroup.Wait()
close(errChan)
}()

for err := range errChan {
if err != nil {
return err
}
}

return <-errChan
return nil
}
Comment on lines +463 to 524
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Cancel sibling forwarders before returning the first runtime error.

If one forwarding goroutine fails, Line 519 returns immediately while the other forwarders keep using the caller’s ctx and may keep listeners open. Use a child context, cancel it on the first real error, then drain the channel until the WaitGroup closes it.

Proposed cleanup on first error
 func (cmd *SSHCmd) runPortForwards(
 	ctx context.Context,
 	containerClient *ssh.Client,
 	config portForwardConfig,
 	logger log.Logger,
 ) error {
 	timeout, err := cmd.forwardTimeout(logger)
 	if err != nil {
 		return fmt.Errorf("parse forward ports timeout: %w", err)
 	}

 	parsedMappings, err := parsePortForwards(config.mappings)
 	if err != nil {
 		return err
 	}
+
+	forwardCtx, cancel := context.WithCancel(ctx)
+	defer cancel()

 	errChan := make(chan error, len(parsedMappings))
 	var waitGroup sync.WaitGroup
 	for _, parsedMapping := range parsedMappings {
 		portMapping, mapping := parsedMapping.spec, parsedMapping.mapping

@@
 			defer waitGroup.Done()

 			err := config.forwardFn(
-				ctx,
+				forwardCtx,
 				containerClient,
 				mapping.Host.Protocol,
 				mapping.Host.Address,
 				mapping.Container.Protocol,
@@
 		}(portMapping, mapping)
 	}

 	go func() {
 		waitGroup.Wait()
 		close(errChan)
 	}()

+	var firstErr error
 	for err := range errChan {
-		if err != nil {
-			return err
+		if err != nil && firstErr == nil {
+			firstErr = err
+			cancel()
 		}
 	}

-	return nil
+	return firstErr
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
log log.Logger,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(log)
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
errChan := make(chan error, len(cmd.ForwardPorts))
for _, portMapping := range cmd.ForwardPorts {
mapping, err := port.ParsePortSpec(portMapping)
if err != nil {
return fmt.Errorf("parse port mapping: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
log.Infof(
"Forwarding local %s/%s to remote %s/%s",
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
go func(portMapping string) {
err := devssh.PortForward(
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
ctx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
log,
logger,
)
if !errors.Is(io.EOF, err) {
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping)
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
for err := range errChan {
if err != nil {
return err
}
}
return <-errChan
return nil
}
func (cmd *SSHCmd) runPortForwards(
ctx context.Context,
containerClient *ssh.Client,
config portForwardConfig,
logger log.Logger,
) error {
timeout, err := cmd.forwardTimeout(logger)
if err != nil {
return fmt.Errorf("parse forward ports timeout: %w", err)
}
parsedMappings, err := parsePortForwards(config.mappings)
if err != nil {
return err
}
forwardCtx, cancel := context.WithCancel(ctx)
defer cancel()
errChan := make(chan error, len(parsedMappings))
var waitGroup sync.WaitGroup
for _, parsedMapping := range parsedMappings {
portMapping, mapping := parsedMapping.spec, parsedMapping.mapping
// start the forwarding
logger.Infof(
config.logTemplate,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
)
waitGroup.Add(1)
go func(portMapping string, mapping port.Mapping) {
defer waitGroup.Done()
err := config.forwardFn(
forwardCtx,
containerClient,
mapping.Host.Protocol,
mapping.Host.Address,
mapping.Container.Protocol,
mapping.Container.Address,
timeout,
logger,
)
if err != nil && !errors.Is(err, io.EOF) {
errChan <- fmt.Errorf("error forwarding %s: %w", portMapping, err)
}
}(portMapping, mapping)
}
go func() {
waitGroup.Wait()
close(errChan)
}()
var firstErr error
for err := range errChan {
if err != nil && firstErr == nil {
firstErr = err
cancel()
}
}
return firstErr
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/ssh.go` around lines 463 - 524, Wrap the incoming ctx in a cancellable
child (ctx, cancel := context.WithCancel(ctx)) at the top of runPortForwards and
pass that child ctx into config.forwardFn so all forwarder goroutines see
cancellation; in the per-goroutine error handling, when you detect a real error
(err != nil && !errors.Is(err, io.EOF)) send the formatted error to errChan and
call cancel() to cancel sibling forwarders; in the receiving loop, capture the
first error but continue draining errChan until it is closed by the waitGroup
goroutine, then return the first non-nil error (ensuring cancel is
deferred/called to release resources when leaving).


func (cmd *SSHCmd) startTunnel(
Expand Down
Loading