Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prefix validation edge-case when extracting #715

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


SAMPLES_REPO ?= chainguard-dev/malcontent-samples
SAMPLES_COMMIT ?= 528a7e975638d2c5ce06da1af32c5918aa4d6c7e
SAMPLES_COMMIT ?= 2bd3bff19c0253821b3886db65a5059587cac893

# BEGIN: lint-install ../malcontent
# http://github.com/tinkerbell/lint-install
Expand Down
17 changes: 11 additions & 6 deletions pkg/action/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func isSupportedArchive(path string) bool {
return archiveMap[getExt(path)]
}

// isValidPath checks if the target file is within the given directory.
func isValidPath(target, dir string) bool {
Copy link
Member Author

@egibs egibs Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still works as expected:

$ go run cmd/mal/mal.go --all analyze out/chainguard-dev/malcontent-samples/linux/clean/gitlab-rails/iosswift.tar.gz 
πŸ”Ž Scanning "out/chainguard-dev/malcontent-samples/linux/clean/gitlab-rails/iosswift.tar.gz"
β”œβ”€ 🟑 out/chainguard-dev/malcontent-samples/linux/clean/gitlab-rails/iosswift.tar.gz ∴ /project.bundle [MEDIUM]
β”‚     ≑ credential [MEDIUM]
β”‚       🟑 sniffer/bpf β€” BPF (Berkeley Packet Filter): bpf
β”‚     ≑ networking [MEDIUM]
β”‚       🟑 tcp/ssh β€” Supports SSH (secure shell)
β”‚     ≑ process [LOW]
β”‚       πŸ”΅ chdir β€” changes working directory: cd B
β”‚
β”œβ”€ πŸ”΅ out/chainguard-dev/malcontent-samples/linux/clean/gitlab-rails/iosswift.tar.gz ∴ /tree/project.json [LOW]
β”‚     ≑ credential [LOW]
β”‚       πŸ”΅ password β€” references a 'password': require_password_to_approve
β”‚

return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir))
}

// getExt returns the extension of a file path
// and attempts to avoid including fragments of filenames with other dots before the extension.
func getExt(path string) string {
Expand Down Expand Up @@ -163,8 +168,8 @@ func extractTar(ctx context.Context, d string, f string) error {
}

target := filepath.Join(d, clean)
if !strings.HasPrefix(target, filepath.Clean(d)+string(os.PathSeparator)) {
return fmt.Errorf("invalid file path: %s", header.Name)
if !isValidPath(target, d) {
return fmt.Errorf("invalid file path: %s", target)
}

switch header.Typeflag {
Expand Down Expand Up @@ -206,7 +211,7 @@ func extractTar(ctx context.Context, d string, f string) error {
if err != nil {
return fmt.Errorf("failed to evaluate symlink: %w", err)
}
if !strings.HasPrefix(linkReal, filepath.Clean(d)+string(os.PathSeparator)) {
if !isValidPath(target, d) {
return fmt.Errorf("symlink points outside temporary directory: %s", linkReal)
}
if err := os.Symlink(linkReal, target); err != nil {
Expand Down Expand Up @@ -281,8 +286,8 @@ func extractZip(ctx context.Context, d string, f string) error {
}

name := filepath.Join(d, clean)
if !strings.HasPrefix(name, filepath.Clean(d)+string(os.PathSeparator)) {
logger.Warnf("skipping file path outside extraction directory: %s", file.Name)
if !isValidPath(name, d) {
logger.Warnf("skipping file path outside extraction directory: %s", name)
continue
}

Expand Down Expand Up @@ -494,7 +499,7 @@ func extractDeb(ctx context.Context, d, f string) error {
if err != nil {
return fmt.Errorf("failed to evaluate symlink: %w", err)
}
if !strings.HasPrefix(linkReal, filepath.Clean(d)+string(os.PathSeparator)) {
if !isValidPath(linkReal, d) {
return fmt.Errorf("symlink points outside temporary directory: %s", linkReal)
}
if err := os.Symlink(linkReal, target); err != nil {
Expand Down
79 changes: 79 additions & 0 deletions pkg/action/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"os"
"path/filepath"
"runtime"
"strings"
"testing"

"github.com/chainguard-dev/clog"
Expand Down Expand Up @@ -330,3 +331,81 @@
})
}
}

func TestIsValidPath(t *testing.T) {
tmpRoot, err := os.MkdirTemp("", "isValidPath-*")
if err != nil {
t.Fatalf("Failed to create temp base directory: %v", err)
}
defer os.RemoveAll(tmpRoot)

tempSubDir, err := os.MkdirTemp(tmpRoot, "isValidPathSub-*")
if err != nil {
t.Fatalf("Failed to create temp sub directory: %v", err)
}

tests := []struct {
name string
target string
baseDir string
expected bool
}{
{
name: "Valid direct child path",
target: filepath.Join(tmpRoot, "file.txt"),
baseDir: tmpRoot,
expected: true,
},
{
name: "Valid nested path",
target: filepath.Join(tempSubDir, "file.txt"),
baseDir: tmpRoot,
expected: true,
},
{
name: "Invalid parent directory traversal",
target: filepath.Join(tmpRoot, "../file.txt"),
baseDir: tmpRoot,
expected: false,
},
{
name: "Invalid absolute path outside base",
target: "/etc/passwd",
baseDir: tmpRoot,
expected: false,
},
{
name: "Invalid relative path outside base",
target: "../../etc/passwd",
baseDir: tmpRoot,
expected: false,
},
{
name: "Empty target path",
target: "",
baseDir: tmpRoot,
expected: false,
},
{
name: "Empty base directory",
target: filepath.Join(tmpRoot, "file.txt"),
baseDir: "",
expected: false,
},
{
name: "Path with irregular separators",
target: strings.Replace(filepath.Join(tmpRoot, "file.txt"), "/", "//", -1),

Check failure on line 397 in pkg/action/archive_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint

wrapperFunc: use strings.ReplaceAll method in `strings.Replace(filepath.Join(tmpRoot, "file.txt"), "/", "//", -1)` (gocritic)
baseDir: tmpRoot,
expected: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isValidPath(tt.target, tt.baseDir)
if result != tt.expected {
t.Errorf("isValidPath(%q, %q) = %v, want %v", tt.target, tt.baseDir, result, tt.expected)
}
})
}
}
Loading