diff --git a/docs/plugins/library/archiver.md b/docs/plugins/library/archiver.md index 0db914e7..51642689 100644 --- a/docs/plugins/library/archiver.md +++ b/docs/plugins/library/archiver.md @@ -1,6 +1,6 @@ # Archiver Library -`vfox` provides a decompression tool that supports `tar.gz`, `tgz`, `tar.xz`, `zip`, and `7z`. In Lua scripts, you can +`vfox` provides a decompression tool that supports `tar.gz`, `tgz`, `tar.xz`, `tar.zst`, `tzst`, `zip`, and `7z`. In Lua scripts, you can use `require("vfox.archiver")` to access it. **Usage** @@ -8,4 +8,4 @@ use `require("vfox.archiver")` to access it. ```lua local archiver = require("vfox.archiver") local err = archiver.decompress("testdata/test.zip", "testdata/test") -``` \ No newline at end of file +``` diff --git a/docs/zh-hans/plugins/library/archiver.md b/docs/zh-hans/plugins/library/archiver.md index f14e60ce..e1f8a3a2 100644 --- a/docs/zh-hans/plugins/library/archiver.md +++ b/docs/zh-hans/plugins/library/archiver.md @@ -1,10 +1,11 @@ # Archiver 标准库 -`vfox` 提供了解压工具, 支持`tar.gz`、`tgz`、`tar.xz`、`zip`、`7z`。在Lua脚本中,你可以使用`require("vfox.archiver")`来访问它。 +`vfox` 提供了解压工具, 支持`tar.gz`、`tgz`、`tar.xz`、`tar.zst`、`tzst`、`zip`、`7z`。在Lua脚本中,你可以使用`require("vfox.archiver")`来访问它。 例如: **Usage** -```shell + +```lua local archiver = require("vfox.archiver") local err = archiver.decompress("testdata/test.zip", "testdata/test") -``` \ No newline at end of file +``` diff --git a/go.mod b/go.mod index 31ded9d8..90a4d583 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/Microsoft/go-winio v0.6.2 github.com/PuerkitoBio/goquery v1.9.3 github.com/bodgit/sevenzip v1.5.1 + github.com/klauspost/compress v1.17.7 github.com/lithammer/fuzzysearch v1.1.8 github.com/pterm/pterm v0.12.79 github.com/schollz/progressbar/v3 v3.14.2 @@ -36,7 +37,6 @@ require ( github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect - github.com/klauspost/compress v1.17.7 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect diff --git a/internal/shared/util/decompressor.go b/internal/shared/util/decompressor.go index 6a057ccc..6953f0ec 100644 --- a/internal/shared/util/decompressor.go +++ b/internal/shared/util/decompressor.go @@ -29,6 +29,7 @@ import ( "strings" "github.com/bodgit/sevenzip" + "github.com/klauspost/compress/zstd" "github.com/ulikunitz/xz" ) @@ -262,6 +263,159 @@ loop: return nil } +type ZstdTarDecompressor struct { + src string +} + +func (z *ZstdTarDecompressor) Decompress(dest string) error { + rootFolderInTar := findRootFolderInZstdTar(z.src) + file, err := os.Open(z.src) + if err != nil { + return err + } + defer file.Close() + + zr, err := zstd.NewReader(file) + if err != nil { + return err + } + defer zr.Close() + + tr := tar.NewReader(zr) + var symlinks []symlink + +loop: + for { + header, err := tr.Next() + switch { + case err == io.EOF: + break loop + case err != nil: + return err + case header == nil: + continue + } + + target, err := safeZstdTarTarget(dest, header.Name, rootFolderInTar) + if err != nil { + return err + } + + switch header.Typeflag { + case tar.TypeDir: + if _, err := os.Stat(target); err != nil { + if err := os.MkdirAll(target, 0755); err != nil { + return err + } + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return err + } + f, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(f, tr); err != nil { + _ = f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + case tar.TypeSymlink: + symlinks = append(symlinks, symlink{header.Linkname, target}) + } + } + + for _, s := range symlinks { + dir := filepath.Dir(s.newname) + if _, err := os.Stat(dir); os.IsNotExist(err) { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + } + if err = os.Symlink(s.oldname, s.newname); err != nil { + return err + } + } + return nil +} + +func findRootFolderInZstdTar(tarFilePath string) string { + file, err := os.Open(tarFilePath) + if err != nil { + return "" + } + defer file.Close() + + zr, err := zstd.NewReader(file) + if err != nil { + return "" + } + defer zr.Close() + + tr := tar.NewReader(zr) + var firstElement string + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil || header == nil { + return "" + } + + normalizedPath := strings.Trim(strings.ReplaceAll(header.Name, "\\", "/"), "/") + if normalizedPath == "" || strings.HasPrefix(normalizedPath, ".DS_Store") || strings.HasPrefix(normalizedPath, "__MACOSX") { + continue + } + + currentFirstElement := strings.Split(normalizedPath, "/")[0] + if firstElement != "" && firstElement != currentFirstElement { + return "" + } + if firstElement == "" { + firstElement = currentFirstElement + } + } + return firstElement +} + +func safeZstdTarTarget(dest string, name string, rootFolderInTar string) (string, error) { + normalizedPath := strings.ReplaceAll(name, "\\", "/") + if strings.HasPrefix(normalizedPath, "/") { + return "", fmt.Errorf("archive entry %q is outside destination", name) + } + normalizedPath = strings.Trim(normalizedPath, "/") + if normalizedPath == "" { + return "", fmt.Errorf("archive entry %q is empty", name) + } + + parts := strings.Split(normalizedPath, "/") + if len(parts) > 1 && rootFolderInTar != "" && parts[0] == rootFolderInTar { + parts = parts[1:] + } + fname := filepath.Clean(strings.Join(parts, "/")) + if fname == "." { + return "", fmt.Errorf("archive entry %q is empty", name) + } + if !filepath.IsLocal(fname) { + return "", fmt.Errorf("archive entry %q is outside destination", name) + } + + target := filepath.Join(dest, fname) + rel, err := filepath.Rel(dest, target) + if err != nil { + return "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("archive entry %q is outside destination", name) + } + return target, nil +} + type ZipDecompressor struct { src string } @@ -525,6 +679,11 @@ func NewDecompressor(src string) Decompressor { src: src, } } + if strings.HasSuffix(filename, ".tar.zst") || strings.HasSuffix(filename, ".tzst") { + return &ZstdTarDecompressor{ + src: src, + } + } if strings.HasSuffix(filename, ".zip") { return &ZipDecompressor{ src: src, diff --git a/internal/shared/util/decompressor_test.go b/internal/shared/util/decompressor_test.go index b2ab93ed..d6c04057 100644 --- a/internal/shared/util/decompressor_test.go +++ b/internal/shared/util/decompressor_test.go @@ -17,7 +17,13 @@ package util import ( + "archive/tar" + "os" + "path/filepath" + "strings" "testing" + + "github.com/klauspost/compress/zstd" ) func TestNewDecompressor(t *testing.T) { @@ -31,12 +37,119 @@ func TestNewDecompressor(t *testing.T) { t.Errorf("Expected ZipDecompressor, got %T", zipDecompressor) } + zstdTarDecompressor := NewDecompressor("test.tar.zst") + if _, ok := zstdTarDecompressor.(*ZstdTarDecompressor); !ok { + t.Errorf("Expected ZstdTarDecompressor, got %T", zstdTarDecompressor) + } + + tzstDecompressor := NewDecompressor("test.tzst") + if _, ok := tzstDecompressor.(*ZstdTarDecompressor); !ok { + t.Errorf("Expected ZstdTarDecompressor, got %T", tzstDecompressor) + } + unknownDecompressor := NewDecompressor("test.unknown") if unknownDecompressor != nil { t.Errorf("Expected nil, got %T", unknownDecompressor) } } +func TestZstdTarDecompressor(t *testing.T) { + tempDir := t.TempDir() + archivePath := filepath.Join(tempDir, "test.tar.zst") + dest := filepath.Join(tempDir, "dest") + body := "Hello, zstd!" + + writeZstdTar(t, archivePath, "test.txt", body) + + decompressor := NewDecompressor(archivePath) + if err := decompressor.Decompress(dest); err != nil { + t.Fatalf("Failed to decompress: %v", err) + } + + decompressedFile, err := os.ReadFile(filepath.Join(dest, "test.txt")) + if err != nil { + t.Fatal(err) + } + if strings.TrimSpace(string(decompressedFile)) != body { + t.Errorf("Expected %q, got %q", body, string(decompressedFile)) + } +} + +func TestZstdTarDecompressorStripsCommonRootFolder(t *testing.T) { + tempDir := t.TempDir() + archivePath := filepath.Join(tempDir, "test.tar.zst") + dest := filepath.Join(tempDir, "dest") + body := "Hello from root!" + + writeZstdTar(t, archivePath, "root/test.txt", body) + + decompressor := NewDecompressor(archivePath) + if err := decompressor.Decompress(dest); err != nil { + t.Fatalf("Failed to decompress: %v", err) + } + + decompressedFile, err := os.ReadFile(filepath.Join(dest, "test.txt")) + if err != nil { + t.Fatal(err) + } + if strings.TrimSpace(string(decompressedFile)) != body { + t.Errorf("Expected %q, got %q", body, string(decompressedFile)) + } +} + +func TestZstdTarDecompressorRejectsPathTraversal(t *testing.T) { + tempDir := t.TempDir() + archivePath := filepath.Join(tempDir, "test.tar.zst") + dest := filepath.Join(tempDir, "dest") + + writeZstdTar(t, archivePath, "root/../../evil.txt", "evil") + + decompressor := NewDecompressor(archivePath) + if err := decompressor.Decompress(dest); err == nil { + t.Fatal("Expected path traversal archive entry to fail") + } + if _, err := os.Stat(filepath.Join(tempDir, "evil.txt")); !os.IsNotExist(err) { + t.Fatalf("Expected no file outside destination, got err %v", err) + } +} + +func writeZstdTar(t *testing.T, archivePath string, name string, body string) { + t.Helper() + + file, err := os.Create(archivePath) + if err != nil { + t.Fatal(err) + } + + zw, err := zstd.NewWriter(file) + if err != nil { + t.Fatal(err) + } + tw := tar.NewWriter(zw) + + err = tw.WriteHeader(&tar.Header{ + Name: name, + Mode: 0600, + Size: int64(len(body)), + Typeflag: tar.TypeReg, + }) + if err != nil { + t.Fatal(err) + } + if _, err := tw.Write([]byte(body)); err != nil { + t.Fatal(err) + } + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := zw.Close(); err != nil { + t.Fatal(err) + } + if err := file.Close(); err != nil { + t.Fatal(err) + } +} + //func TestDecompress(t *testing.T) { // // Create a temporary directory for testing // tempDir, err := os.MkdirTemp("", "decompress_test")