diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go index 9f70eb5fdb..9b6b18c180 100644 --- a/pkg/lisafs/connection.go +++ b/pkg/lisafs/connection.go @@ -45,15 +45,22 @@ type Connection struct { // associated with it for its entire lifetime. server *Server + // impl is the implementation that owns this connection's mount behavior and + // advertised protocol configuration. + impl ConnectionImpl + // mountPath is the path to a file inside the server that is served to this // connection as its root FD. IOW, this connection is mounted at this path. // mountPath is trusted because it is configured by the server (trusted) as // per the user's sandbox configuration. mountPath is immutable. mountPath string - // maxMessageSize is the cached value of server.impl.MaxMessageSize(). + // maxMessageSize is the cached value of impl.MaxMessageSize(). maxMessageSize uint32 + // opts is the cached value of impl.ConnectionOpts(). + opts ConnectionOpts + // readonly indicates if this connection is readonly. All write operations // will fail with EROFS. readonly bool @@ -84,17 +91,28 @@ type Connection struct { // CreateConnection initializes a new connection which will be mounted at // mountPath. The connection must be started separately. -func (s *Server) CreateConnection(sock *unet.Socket, mountPath string, readonly bool) (*Connection, error) { +func (s *Server) CreateConnection(sock *unet.Socket, mountPath string, readonly bool, impl ConnectionImpl) (*Connection, error) { mountPath = path.Clean(mountPath) if !filepath.IsAbs(mountPath) { log.Warningf("mountPath %q is not absolute", mountPath) return nil, unix.EINVAL } + if impl == nil { + log.Warningf("ConnectionImpl must not be nil") + return nil, unix.EINVAL + } + maxMessageSize := impl.MaxMessageSize() + if maxMessageSize == 0 { + log.Warningf("ConnectionImpl.MaxMessageSize() must not return 0") + return nil, unix.EINVAL + } c := &Connection{ sockComm: newSockComm(sock), server: s, - maxMessageSize: s.impl.MaxMessageSize(), + impl: impl, + maxMessageSize: maxMessageSize, + opts: impl.ConnectionOpts(), mountPath: mountPath, readonly: readonly, channels: make([]*channel, 0, maxChannels()), @@ -111,7 +129,7 @@ func (s *Server) CreateConnection(sock *unet.Socket, mountPath string, readonly } // ServerImpl returns the associated server implementation. -func (c *Connection) ServerImpl() ServerImpl { +func (c *Connection) ServerImpl() any { return c.server.impl } diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go index 13b1c3d9ba..3c31377a9c 100644 --- a/pkg/lisafs/connection_test.go +++ b/pkg/lisafs/connection_test.go @@ -38,12 +38,12 @@ var handlers = [...]lisafs.RPCHandler{ versionMsgID: versionHandler, } -// testServer implements lisafs.ServerImpl. +// testServer implements lisafs.ConnectionImpl. type testServer struct { lisafs.Server } -var _ lisafs.ServerImpl = (*testServer)(nil) +var _ lisafs.ConnectionImpl = (*testServer)(nil) type testControlFD struct { lisafs.ControlFD @@ -64,12 +64,17 @@ func (s *testServer) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lisaf return dummyRoot.FD(), lisafs.Statx{Mode: uint16(linux.S_IFDIR)}, -1, nil } -// MaxMessageSize implements lisafs.MaxMessageSize. +// ConnectionOpts implements lisafs.ConnectionImpl.ConnectionOpts. +func (s *testServer) ConnectionOpts() lisafs.ConnectionOpts { + return lisafs.ConnectionOpts{} +} + +// MaxMessageSize implements lisafs.ConnectionImpl.MaxMessageSize. func (s *testServer) MaxMessageSize() uint32 { return lisafs.MaxMessageSize() } -// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. +// SupportedMessages implements lisafs.ConnectionImpl.SupportedMessages. func (s *testServer) SupportedMessages() []lisafs.MID { return []lisafs.MID{ lisafs.Mount, @@ -86,9 +91,9 @@ func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) { } ts := &testServer{} - ts.Init(ts, lisafs.ServerOpts{}) + ts.Init(ts) ts.SetHandlers(handlers[:]) - conn, err := ts.CreateConnection(serverSocket, "/" /* mountPath */, false /* readonly */) + conn, err := ts.CreateConnection(serverSocket, "/" /* mountPath */, false /* readonly */, ts) if err != nil { t.Fatalf("starting connection failed: %v", err) return diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go index 9ffab8f56f..7d69031675 100644 --- a/pkg/lisafs/handlers.go +++ b/pkg/lisafs/handlers.go @@ -144,7 +144,7 @@ func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, if mountNode.isDeleted() { return unix.ENOENT } - mountPointFD, mountPointStat, mountPointHostFD, err = c.ServerImpl().Mount(c, mountNode) + mountPointFD, mountPointStat, mountPointHostFD, err = c.impl.Mount(c, mountNode) return err }); err != nil { return 0, err @@ -158,8 +158,8 @@ func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, ControlFD: mountPointFD.id, Stat: mountPointStat, }, - SupportedMs: c.ServerImpl().SupportedMessages(), - MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()), + SupportedMs: c.impl.SupportedMessages(), + MaxMessageSize: primitive.Uint32(c.maxMessageSize), } respPayloadLen := uint32(resp.SizeBytes()) resp.MarshalBytes(comm.PayloadBuf(respPayloadLen)) @@ -168,7 +168,7 @@ func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, // ChannelHandler handles the Channel RPC. func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { - ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize()) + ch, desc, fdSock, err := c.createChannel(c.maxMessageSize) if err != nil { return 0, err } @@ -263,7 +263,7 @@ func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32 var resp SetStatResp if err := fd.safelyWrite(func() error { - if fd.node.isDeleted() && !c.server.opts.SetAttrOnDeleted { + if fd.node.isDeleted() && !c.opts.SetAttrOnDeleted { return unix.EINVAL } failureMask, failureErr := fd.impl.SetStat(req) @@ -413,7 +413,7 @@ func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint3 payloadBuf := comm.PayloadBuf(uint32(maxPayloadSize)) payloadPos := numStats.SizeBytes() - if c.server.opts.WalkStatSupported { + if c.opts.WalkStatSupported { if err = startDir.safelyRead(func() error { return startDir.impl.WalkStat(req.Path, func(s Statx) { s.MarshalUnsafe(payloadBuf[payloadPos:]) @@ -528,7 +528,7 @@ func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, hostOpenFD int ) if err := fd.safelyRead(func() error { - if fd.node.isDeleted() && !c.server.opts.OpenOnDeleted { + if fd.node.isDeleted() && !c.opts.OpenOnDeleted { return unix.EINVAL } if fd.IsSymlink() { @@ -981,7 +981,7 @@ func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint } return 0, fd.controlFD.safelyWrite(func() error { - if fd.controlFD.node.isDeleted() && !c.server.opts.AllocateOnDeleted { + if fd.controlFD.node.isDeleted() && !c.opts.AllocateOnDeleted { return unix.EINVAL } return fd.impl.Allocate(req.Mode, req.Offset, req.Length) diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go index 9e810d403f..df3f0a0fe8 100644 --- a/pkg/lisafs/message.go +++ b/pkg/lisafs/message.go @@ -190,7 +190,7 @@ const ( ) // MaxMessageSize is the recommended max message size that can be used by -// connections. Server implementations may choose to use other values. +// connections. Connections may choose to use other values. func MaxMessageSize() uint32 { // Return HugePageSize - PageSize so that when flipcall packet window is // created with MaxMessageSize() + flipcall header size + channel header diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go index 226d3906b8..a4d1ec165c 100644 --- a/pkg/lisafs/server.go +++ b/pkg/lisafs/server.go @@ -39,16 +39,13 @@ type Server struct { // root is immutable. Server holds a ref on root for its entire lifetime. root *Node - // impl is the server implementation which embeds this server. - impl ServerImpl - - // opts is the server specific options. This dictates how some of the - // messages are handled. - opts ServerOpts + // impl is an opaque value associated with the Server, reachable via + // Connection.ServerImpl() so FD handlers can access server-wide state. + impl any } -// ServerOpts defines some server implementation specific behavior. -type ServerOpts struct { +// ConnectionOpts defines connection-specific behavior. +type ConnectionOpts struct { // WalkStatSupported is set to true if it's safe to call // ControlFDImpl.WalkStat and let the file implementation perform the walk // without holding locks on any of the descendant's Nodes. @@ -68,9 +65,8 @@ type ServerOpts struct { } // Init must be called before first use of the server. -func (s *Server) Init(impl ServerImpl, opts ServerOpts) { +func (s *Server) Init(impl any) { s.impl = impl - s.opts = opts s.handlers = handlers[:] s.root = &Node{} // s owns the ref on s.root. @@ -110,21 +106,22 @@ func (s *Server) Destroy() { s.root.DecRef(nil) } -// ServerImpl contains the implementation details for a Server. -// Implementations of ServerImpl should contain their associated Server by -// value as their first field. -type ServerImpl interface { +// ConnectionImpl contains the implementation details for a Connection. +type ConnectionImpl interface { // Mount is called when a Mount RPC is made. It mounts the connection on // mountNode. Mount may optionally donate a host FD to the mount point. // // Mount has a read concurrency guarantee on mountNode. Mount(c *Connection, mountNode *Node) (*ControlFD, Statx, int, error) - // SupportedMessages returns a list of messages that the server + // SupportedMessages returns a list of messages that the connection // implementation supports. SupportedMessages() []MID // MaxMessageSize is the maximum payload length (in bytes) that can be sent - // to this server implementation. + // to this connection implementation. MaxMessageSize() uint32 + + // ConnectionOpts returns the options for this connection implementation. + ConnectionOpts() ConnectionOpts } diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go index 95db15ef16..18103690e5 100644 --- a/pkg/lisafs/testsuite/testsuite.go +++ b/pkg/lisafs/testsuite/testsuite.go @@ -37,7 +37,7 @@ import ( // away all the caller specific details. type Tester interface { // NewServer returns a new instance of the tester server. - NewServer(t *testing.T) *lisafs.Server + NewServer(t *testing.T) (*lisafs.Server, lisafs.ConnectionImpl) // LinkSupported returns true if the backing server supports LinkAt. LinkSupported() bool @@ -94,8 +94,8 @@ func RunTest(t *testing.T, tester Tester, testName string, testFn TestFunc, moun t.Fatalf("socketpair got err %v expected nil", err) } - server := tester.NewServer(t) - conn, err := server.CreateConnection(serverSocket, mountPath, false /* readonly */) + server, impl := tester.NewServer(t) + conn, err := server.CreateConnection(serverSocket, mountPath, false /* readonly */, impl) if err != nil { t.Fatalf("starting connection failed: %v", err) return diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index dab3006e0f..489b8ae60d 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -99,7 +99,7 @@ func startGofer(root string, conf *config.Config) (int, func(), error) { HostFifo: conf.HostFifo, DonateMountPointFD: conf.DirectFS, }) - c, err := server.CreateConnection(socket, root, true /* readonly */) + c, err := server.CreateConnection(socket, root, true /* readonly */, server.ConnectionImpl()) if err != nil { return 0, nil, err } diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index df3a7e9bef..f17126fa5c 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -384,7 +384,7 @@ func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string, ruid i } for _, cfg := range cfgs { - conn, err := server.CreateConnection(cfg.sock, cfg.mountPath, cfg.readonly) + conn, err := server.CreateConnection(cfg.sock, cfg.mountPath, cfg.readonly, server.ConnectionImpl()) if err != nil { util.Fatalf("starting connection on FD %d for gofer mount failed: %v", cfg.sock.FD(), err) } diff --git a/runsc/fsgofer/lisafs.go b/runsc/fsgofer/lisafs.go index eb8609bafe..1633d81303 100644 --- a/runsc/fsgofer/lisafs.go +++ b/runsc/fsgofer/lisafs.go @@ -51,12 +51,6 @@ const ( // Config sets configuration options for each attach point. type Config struct { - // ROMount is set to true if this is a readonly mount. - ROMount bool - - // PanicOnWrite panics on attempts to write to RO mounts. - PanicOnWrite bool - // HostUDS signals whether the gofer can connect to host unix domain sockets. HostUDS config.HostUDS @@ -93,28 +87,43 @@ func OpenProcSelfFD(path string) error { return nil } -// LisafsServer implements lisafs.ServerImpl for fsgofer. +// LisafsServer serves fsgofer LisaFS connections. type LisafsServer struct { lisafs.Server config Config } -var _ lisafs.ServerImpl = (*LisafsServer)(nil) - // NewLisafsServer initializes a new lisafs server for fsgofer. func NewLisafsServer(config Config) *LisafsServer { s := &LisafsServer{config: config} - s.Server.Init(s, lisafs.ServerOpts{ + s.Server.Init(s) + return s +} + +// ConnectionImpl returns the stock fsgofer implementation for a new connection. +func (s *LisafsServer) ConnectionImpl() lisafs.ConnectionImpl { + return &stockConnectionImpl{server: s} +} + +// stockConnectionImpl implements lisafs.ConnectionImpl for stock fsgofer. +type stockConnectionImpl struct { + server *LisafsServer +} + +var _ lisafs.ConnectionImpl = (*stockConnectionImpl)(nil) + +// ConnectionOpts implements lisafs.ConnectionImpl.ConnectionOpts. +func (i *stockConnectionImpl) ConnectionOpts() lisafs.ConnectionOpts { + return lisafs.ConnectionOpts{ WalkStatSupported: true, SetAttrOnDeleted: true, AllocateOnDeleted: true, OpenOnDeleted: true, - }) - return s + } } -// Mount implements lisafs.ServerImpl.Mount. -func (s *LisafsServer) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lisafs.ControlFD, lisafs.Statx, int, error) { +// Mount implements lisafs.ConnectionImpl.Mount. +func (i *stockConnectionImpl) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lisafs.ControlFD, lisafs.Statx, int, error) { mountPath := mountNode.FilePath() rootHostFD, err := tryOpen(func(flags int) (int, error) { return unix.Open(mountPath, flags, 0) @@ -138,7 +147,7 @@ func (s *LisafsServer) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lis } clientHostFD := -1 - if s.config.DonateMountPointFD { + if i.server.config.DonateMountPointFD { clientHostFD, err = unix.Dup(rootHostFD) if err != nil { return nil, lisafs.Statx{}, -1, err @@ -156,13 +165,13 @@ func (s *LisafsServer) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lis return rootFD.FD(), stat, clientHostFD, nil } -// MaxMessageSize implements lisafs.ServerImpl.MaxMessageSize. -func (s *LisafsServer) MaxMessageSize() uint32 { +// MaxMessageSize implements lisafs.ConnectionImpl.MaxMessageSize. +func (i *stockConnectionImpl) MaxMessageSize() uint32 { return lisafs.MaxMessageSize() } -// SupportedMessages implements lisafs.ServerImpl.SupportedMessages. -func (s *LisafsServer) SupportedMessages() []lisafs.MID { +// SupportedMessages implements lisafs.ConnectionImpl.SupportedMessages. +func (i *stockConnectionImpl) SupportedMessages() []lisafs.MID { // Note that Flush is not supported. return []lisafs.MID{ lisafs.Mount, diff --git a/runsc/fsgofer/lisafs_test.go b/runsc/fsgofer/lisafs_test.go index 7a091794d4..4bae812615 100644 --- a/runsc/fsgofer/lisafs_test.go +++ b/runsc/fsgofer/lisafs_test.go @@ -37,8 +37,9 @@ func init() { type tester struct{} // NewServer implements testsuite.Tester.NewServer. -func (tester) NewServer(t *testing.T) *lisafs.Server { - return &fsgofer.NewLisafsServer(fsgofer.Config{}).Server +func (tester) NewServer(t *testing.T) (*lisafs.Server, lisafs.ConnectionImpl) { + server := fsgofer.NewLisafsServer(fsgofer.Config{}) + return &server.Server, server.ConnectionImpl() } // LinkSupported implements testsuite.Tester.LinkSupported.