diff --git a/cmd/status.go b/cmd/status.go index 273f17dc..80fd5001 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -300,7 +300,7 @@ func getCurrentlyBootedPartition(a *core.ABRootManager) (string, string, error) if err != nil { return "", "", err } - defer bootPart.Unmount() + defer core.UnmountRecursive(tmpBootMount, 0) g, err := core.NewGrub(bootPart) if err != nil { diff --git a/core/chroot.go b/core/chroot.go index 59b0e7b1..1c61121a 100644 --- a/core/chroot.go +++ b/core/chroot.go @@ -93,32 +93,12 @@ func NewChroot(root string, rootUuid string, rootDevice string, mountUserEtc boo func (c *Chroot) Close() error { PrintVerboseInfo("Chroot.Close", "running...") - err := syscall.Unmount(filepath.Join(c.root, "/dev/pts"), 0) + err := UnmountRecursive(c.root, 0) if err != nil { PrintVerboseErr("Chroot.Close", 0, err) return err } - mountList := ReservedMounts - if c.etcMounted { - mountList = append(mountList, "/etc") - } - mountList = append(mountList, "") - - for _, mount := range mountList { - if mount == "/dev/pts" { - continue - } - - mountDir := filepath.Join(c.root, mount) - PrintVerboseInfo("Chroot.Close", "unmounting", mountDir) - err := syscall.Unmount(mountDir, 0) - if err != nil { - PrintVerboseErr("Chroot.Close", 1, err) - return err - } - } - PrintVerboseInfo("Chroot.Close", "successfully closed.") return nil } diff --git a/core/disk-manager.go b/core/disk-manager.go index 88325075..798dc2a1 100644 --- a/core/disk-manager.go +++ b/core/disk-manager.go @@ -14,13 +14,16 @@ package core */ import ( + "bufio" "encoding/json" - "errors" "fmt" "os" "os/exec" + "path/filepath" "strings" "syscall" + + "golang.org/x/sys/unix" ) // DiskManager exposes functions to interact with the system's disks @@ -192,33 +195,102 @@ func (p *Partition) Mount(destination string) error { return nil } -// Unmount unmounts a partition -func (p *Partition) Unmount() error { - PrintVerboseInfo("Partition.Unmount", "running...") +// Returns whether the partition is a device-mapper virtual partition +func (p *Partition) IsDevMapper() bool { + return p.Parent != nil +} + +// IsEncrypted returns whether the partition is encrypted +func (p *Partition) IsEncrypted() bool { + return strings.HasPrefix(p.FsType, "crypto_") +} - if p.MountPoint == "" { - PrintVerboseErr("Partition.Unmount", 0, errors.New("no mount point")) - return errors.New("no mount point") +func UnmountRecursive(mountPoint string, flags int) error { + mountPointOld := mountPoint + mountPoint, err := filepath.EvalSymlinks(mountPoint) + if err != nil { + return fmt.Errorf("could not find real path for %s: %w", mountPointOld, err) } - err := syscall.Unmount(p.MountPoint, 0) + systemMountpoints, err := readMountPoints() if err != nil { - PrintVerboseErr("Partition.Unmount", 1, err) + return fmt.Errorf("Could not load system mounts: %w", err) + } + + mountId := "" + + for id, systemMount := range systemMountpoints { + if systemMount.mountPoint == mountPoint { + mountId = id + } + } + + if mountId == "" { + PrintVerboseInfo("Partition.UnmountRecursive", "umounting "+mountPoint) + err := unix.Unmount(mountPoint, flags) + PrintVerboseErr("Partition.UnmountRecursive", 1, err) return err } - PrintVerboseInfo("Partition.Unmount", "successfully unmounted", p.MountPoint) - p.MountPoint = "" + err = unmountRecursive(mountId, systemMountpoints, flags, 0) + if err != nil { + newErr := fmt.Errorf("could not recursively unmount %s: %w", mountPoint, err) + PrintVerboseErr("Partition.UnmountRecursive", 1, newErr) + return newErr + } return nil } -// Returns whether the partition is a device-mapper virtual partition -func (p *Partition) IsDevMapper() bool { - return p.Parent != nil +func unmountRecursive(mountPointId string, systemMountpoints map[string]mount, flags int, depth int) error { + if depth >= 1000 { + return fmt.Errorf("too many layers when trying to recursively unmount") + } + + mount, ok := systemMountpoints[mountPointId] + if !ok { + return fmt.Errorf("could not find mountpoint with id %s", mountPointId) + } + + for childMountId, childMount := range systemMountpoints { + if childMount.parentId == mountPointId { + err := unmountRecursive(childMountId, systemMountpoints, flags, depth+1) + if err != nil { + return fmt.Errorf("could not unmount %s: %w", childMount.mountPoint, err) + } + } + } + + PrintVerboseInfo("Partition.UnmountRecursive", "umounting "+mount.mountPoint) + return unix.Unmount(mount.mountPoint, flags) } -// IsEncrypted returns whether the partition is encrypted -func (p *Partition) IsEncrypted() bool { - return strings.HasPrefix(p.FsType, "crypto_") +type mount struct { + id string + parentId string + mountPoint string +} + +func readMountPoints() (map[string]mount, error) { + f, err := os.Open("/proc/self/mountinfo") + if err != nil { + return nil, err + } + defer f.Close() + + mounts := make(map[string]mount) + scanner := bufio.NewScanner(f) + + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 5 { + continue + } + + id := fields[0] + + mounts[id] = mount{id: id, parentId: fields[1], mountPoint: fields[4]} + } + + return mounts, scanner.Err() } diff --git a/core/system.go b/core/system.go index a010c7b6..6f230f4d 100644 --- a/core/system.go +++ b/core/system.go @@ -272,10 +272,9 @@ func (s *ABSystem) RunOperation(operation ABSystemOperation, deleteBeforeCopy bo return err } - partFuture.Partition.Unmount() // just in case - partBoot.Unmount() - futureRoot := "/part-future" + + UnmountRecursive(futureRoot, 0) err = partFuture.Partition.Mount(futureRoot) if err != nil { PrintVerboseErr("ABSystem.RunOperation", 2.3, err) @@ -283,7 +282,7 @@ func (s *ABSystem) RunOperation(operation ABSystemOperation, deleteBeforeCopy bo } cq.Add(func(args ...interface{}) error { - return partFuture.Partition.Unmount() + return UnmountRecursive(futureRoot, 0) }, nil, 90, &goodies.NoErrorHandler{}, false) // Stage 3: Make a imageRecipe with user packages @@ -537,7 +536,7 @@ func (s *ABSystem) RunOperation(operation ABSystemOperation, deleteBeforeCopy bo } cq.Add(func(args ...interface{}) error { - return initPartition.Unmount() + return UnmountRecursive(initMountpoint, 0) }, nil, 80, &goodies.NoErrorHandler{}, false) futureInitDir := filepath.Join(initMountpoint, partFuture.Label) @@ -645,7 +644,7 @@ func (s *ABSystem) RunOperation(operation ABSystemOperation, deleteBeforeCopy bo } cq.Add(func(args ...interface{}) error { - return partBoot.Unmount() + return UnmountRecursive(tmpBootMount, 0) }, nil, 100, &goodies.NoErrorHandler{}, false) // Stage 9: Atomic swap the bootloader @@ -758,7 +757,7 @@ func (s *ABSystem) Rollback(checkOnly bool) (response ABRollbackResponse, err er } cq.Add(func(args ...interface{}) error { - return partBoot.Unmount() + return UnmountRecursive(tmpBootMount, 0) }, nil, 100, &goodies.NoErrorHandler{}, false) grub, err := NewGrub(partBoot)