Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 90 additions & 11 deletions cmd/image/image.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package image

import (
"crypto/sha512"
"errors"
"fmt"
"log/slog"

Expand All @@ -17,6 +19,11 @@ import (
"time"
)

const (
MD5 = ".md5"
SHA512sum = ".sha512sum"
)

type Image struct {
log *slog.Logger
}
Expand All @@ -25,24 +32,15 @@ func NewImage(log *slog.Logger) *Image {
return &Image{log: log}
}

// Pull a image from s3
func (i *Image) Pull(image, destination string) error {
i.log.Info("pull image", "image", image)
md5destination := destination + ".md5"
md5file := image + ".md5"
err := i.download(image, destination)
if err != nil {
return fmt.Errorf("unable to pull image %s %w", image, err)
}
err = i.download(md5file, md5destination)
defer os.Remove(md5destination)
err = i.verifyChecksumFile(image, destination)
if err != nil {
return fmt.Errorf("unable to pull md5 %s %w", md5file, err)
}
i.log.Info("check md5")
matches, err := i.checkMD5(destination, md5destination)
if err != nil || !matches {
return fmt.Errorf("md5sum mismatch")
return fmt.Errorf("unable to verify checksum file for image %s %w", image, err)
}

i.log.Info("pull image done", "image", image)
Expand Down Expand Up @@ -132,6 +130,33 @@ func (i *Image) checkMD5(file, md5file string) (bool, error) {
return true, nil
}

func (i *Image) checksha512(file, sha512file string) (bool, error) {
sha512fileContent, err := os.ReadFile(sha512file)
if err != nil {
return false, fmt.Errorf("unable to read sha512sum file %s %w", sha512file, err)
}
expectedsha512 := strings.Split(string(sha512fileContent), " ")[0]

f, err := os.Open(file)
if err != nil {
return false, fmt.Errorf("unable to read file: %s %w", file, err)
}
defer f.Close()

//nolint:gosec
Comment thread
simcod marked this conversation as resolved.
Outdated
h := sha512.New()
if _, err := io.Copy(h, f); err != nil {
return false, fmt.Errorf("unable to calculate sha512 of file: %s %w", file, err)
}
sourcesha512 := fmt.Sprintf("%x", h.Sum(nil))
i.log.Info("check sha512", "source sha512", sourcesha512, "expected sha512", expectedsha512)
if sourcesha512 != expectedsha512 {
return false, fmt.Errorf("source sha512:%s expected sha512:%s", sourcesha512, expectedsha512)
}
return true, nil

}

// downloadFile will download from a source url to a local file dest.
// It's efficient because it will write as it downloads
// and not load the whole file into memory.
Expand Down Expand Up @@ -171,3 +196,57 @@ func (i *Image) download(source, dest string) error {

return nil
}

// verifyChecksumFile checks for the existence of .md5 and .sha512sum files
func (i *Image) verifyChecksumFile(image, destination string) error {
var (
matches bool
err error
)

hash := getHash(image)
hashFile := image + hash

hashDestination := destination + hash
err = i.download(hashFile, hashDestination)
defer os.Remove(hashDestination)
if err != nil {
return fmt.Errorf("unable to pull %s %s %w", hash, hashFile, err)
}
i.log.Info("check", "hash", hash)
switch hash {
case SHA512sum:
matches, err = i.checksha512(hashFile, hashDestination)
case MD5:
matches, err = i.checkMD5(hashFile, hashDestination)
default:
return fmt.Errorf("no checksum file found for %s", image)
}
if err != nil || !matches {
return fmt.Errorf("hash mismatch")
}
return nil
}

func getHash(image string) string {
hashFiles := map[string]string{
SHA512sum: image + SHA512sum,
MD5: image + MD5,
}

for hashType, filePath := range hashFiles {
if fileExists(filePath) {
return hashType
}
}

return ""
}

// fileExists checks if a file exists at the given path
func fileExists(path string) bool {
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return false
}
return true
}
33 changes: 33 additions & 0 deletions cmd/image/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,36 @@ func TestCheckMD5(t *testing.T) {
}

}

func TestCheckSHA512(t *testing.T) {
testfile := "/tmp/testsha512"
testfileSHA512 := "/tmp/testsha512.sha512sum"
content := []byte("This is testcontent")
err := os.WriteFile(testfile, content, os.ModePerm) // nolint:gosec
if err != nil {
t.Error(err)
}
cmd := exec.Command("sha512sum", testfile)
sha512Content, err := cmd.Output()
if err != nil {
t.Error(err)
}
sha512, err := os.Create(testfileSHA512)
if err != nil {
t.Error(err)
}
_, err = sha512.Write(sha512Content)
if err != nil {
t.Error(err)
}
sha512.Close()
defer os.Remove(testfile)
defer os.Remove(testfileSHA512)
matches, err := NewImage(slog.Default()).checksha512(testfile, testfileSHA512)
if err != nil {
t.Error(err)
}
if !matches {
t.Error("expected sha512 matches, but didn't")
}
}
Loading