diff --git a/Makefile b/Makefile index d0712b81..1b0eea89 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ SAMPLES_REPO ?= chainguard-dev/malcontent-samples -SAMPLES_COMMIT ?= 528a7e975638d2c5ce06da1af32c5918aa4d6c7e +SAMPLES_COMMIT ?= 2bd3bff19c0253821b3886db65a5059587cac893 # BEGIN: lint-install ../malcontent # http://github.com/tinkerbell/lint-install diff --git a/pkg/action/archive.go b/pkg/action/archive.go index 8c637a9f..c94ac9e4 100644 --- a/pkg/action/archive.go +++ b/pkg/action/archive.go @@ -48,6 +48,11 @@ func isSupportedArchive(path string) bool { return archiveMap[getExt(path)] } +// isValidPath checks if the target file is within the given directory. +func isValidPath(target, dir string) bool { + return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir)) +} + // getExt returns the extension of a file path // and attempts to avoid including fragments of filenames with other dots before the extension. func getExt(path string) string { @@ -163,8 +168,8 @@ func extractTar(ctx context.Context, d string, f string) error { } target := filepath.Join(d, clean) - if !strings.HasPrefix(target, filepath.Clean(d)+string(os.PathSeparator)) { - return fmt.Errorf("invalid file path: %s", header.Name) + if !isValidPath(target, d) { + return fmt.Errorf("invalid file path: %s", target) } switch header.Typeflag { @@ -206,7 +211,7 @@ func extractTar(ctx context.Context, d string, f string) error { if err != nil { return fmt.Errorf("failed to evaluate symlink: %w", err) } - if !strings.HasPrefix(linkReal, filepath.Clean(d)+string(os.PathSeparator)) { + if !isValidPath(target, d) { return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) } if err := os.Symlink(linkReal, target); err != nil { @@ -281,8 +286,8 @@ func extractZip(ctx context.Context, d string, f string) error { } name := filepath.Join(d, clean) - if !strings.HasPrefix(name, filepath.Clean(d)+string(os.PathSeparator)) { - logger.Warnf("skipping file path outside extraction directory: %s", file.Name) + if !isValidPath(name, d) { + logger.Warnf("skipping file path outside extraction directory: %s", name) continue } @@ -494,7 +499,7 @@ func extractDeb(ctx context.Context, d, f string) error { if err != nil { return fmt.Errorf("failed to evaluate symlink: %w", err) } - if !strings.HasPrefix(linkReal, filepath.Clean(d)+string(os.PathSeparator)) { + if !isValidPath(linkReal, d) { return fmt.Errorf("symlink points outside temporary directory: %s", linkReal) } if err := os.Symlink(linkReal, target); err != nil { diff --git a/pkg/action/archive_test.go b/pkg/action/archive_test.go index 930b2ae5..571f903e 100644 --- a/pkg/action/archive_test.go +++ b/pkg/action/archive_test.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "github.com/chainguard-dev/clog" @@ -330,3 +331,81 @@ func TestGetExt(t *testing.T) { }) } } + +func TestIsValidPath(t *testing.T) { + tmpRoot, err := os.MkdirTemp("", "isValidPath-*") + if err != nil { + t.Fatalf("Failed to create temp base directory: %v", err) + } + defer os.RemoveAll(tmpRoot) + + tempSubDir, err := os.MkdirTemp(tmpRoot, "isValidPathSub-*") + if err != nil { + t.Fatalf("Failed to create temp sub directory: %v", err) + } + + tests := []struct { + name string + target string + baseDir string + expected bool + }{ + { + name: "Valid direct child path", + target: filepath.Join(tmpRoot, "file.txt"), + baseDir: tmpRoot, + expected: true, + }, + { + name: "Valid nested path", + target: filepath.Join(tempSubDir, "file.txt"), + baseDir: tmpRoot, + expected: true, + }, + { + name: "Invalid parent directory traversal", + target: filepath.Join(tmpRoot, "../file.txt"), + baseDir: tmpRoot, + expected: false, + }, + { + name: "Invalid absolute path outside base", + target: "/etc/passwd", + baseDir: tmpRoot, + expected: false, + }, + { + name: "Invalid relative path outside base", + target: "../../etc/passwd", + baseDir: tmpRoot, + expected: false, + }, + { + name: "Empty target path", + target: "", + baseDir: tmpRoot, + expected: false, + }, + { + name: "Empty base directory", + target: filepath.Join(tmpRoot, "file.txt"), + baseDir: "", + expected: false, + }, + { + name: "Path with irregular separators", + target: strings.ReplaceAll(filepath.Join(tmpRoot, "file.txt"), "/", "//"), + baseDir: tmpRoot, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidPath(tt.target, tt.baseDir) + if result != tt.expected { + t.Errorf("isValidPath(%q, %q) = %v, want %v", tt.target, tt.baseDir, result, tt.expected) + } + }) + } +}