diff --git a/7z.go b/7z.go index 2e50203..c81494a 100644 --- a/7z.go +++ b/7z.go @@ -8,6 +8,7 @@ import ( "io" "io/fs" "log" + "path/filepath" "strings" "github.com/bodgit/sevenzip" @@ -37,7 +38,7 @@ func (z SevenZip) Match(_ context.Context, filename string, stream io.Reader) (M var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), z.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == z.Extension() { mr.ByName = true } diff --git a/brotli.go b/brotli.go index 9fb5a17..a832f20 100644 --- a/brotli.go +++ b/brotli.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "unicode/utf8" @@ -26,7 +27,7 @@ func (br Brotli) Match(ctx context.Context, filename string, stream io.Reader) ( var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), br.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == br.Extension() { mr.ByName = true } diff --git a/bz2.go b/bz2.go index ff7bb3d..cc3a751 100644 --- a/bz2.go +++ b/bz2.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "github.com/dsnet/compress/bzip2" @@ -25,7 +26,7 @@ func (bz Bz2) Match(_ context.Context, filename string, stream io.Reader) (Match var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), bz.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == bz.Extension() { mr.ByName = true } diff --git a/formats_test.go b/formats_test.go index 8220245..927b689 100644 --- a/formats_test.go +++ b/formats_test.go @@ -114,6 +114,44 @@ func checkErr(t *testing.T, err error, msgFmt string, args ...any) { t.Fatalf(msgFmt+": %s", args...) } +func TestIdentifyFindFormatByFileName(t *testing.T) { + tests := []struct { + filename string + expected string + }{ + { + filename: "test.tar", + expected: ".tar", + }, + { + filename: "test.tar.bz2", + expected: ".tar.bz2", + }, + { + filename: "test.tar.br", + expected: ".tar.br", + }, + { + filename: "test.tar.bru", + expected: ".tar", + }, + { + filename: "test.7z", + expected: ".7z", + }, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + format, _, err := Identify(context.Background(), tt.filename, nil) + checkErr(t, err, "identifying") + if format.Extension() != tt.expected { + t.Errorf("unexpected extension: %v, expected: %v", format.Extension(), tt.expected) + } + }) + } +} + func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testing.T) { // Using the outcome of `n, err := io.ReadFull(stream, buf)` without minding n // may lead to a mis-characterization for cases with known header ending with 0x0 diff --git a/gz.go b/gz.go index adbf1ed..e0b9d99 100644 --- a/gz.go +++ b/gz.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "github.com/klauspost/compress/gzip" @@ -37,7 +38,7 @@ func (gz Gz) Match(_ context.Context, filename string, stream io.Reader) (MatchR var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), gz.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == gz.Extension() { mr.ByName = true } diff --git a/lz4.go b/lz4.go index 39fce3c..d5f2b86 100644 --- a/lz4.go +++ b/lz4.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "github.com/pierrec/lz4/v4" @@ -25,7 +26,7 @@ func (lz Lz4) Match(_ context.Context, filename string, stream io.Reader) (Match var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), lz.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == lz.Extension() { mr.ByName = true } diff --git a/minlz.go b/minlz.go index 72aede0..b00a30c 100644 --- a/minlz.go +++ b/minlz.go @@ -27,7 +27,7 @@ func (mz MinLZ) Match(_ context.Context, filename string, stream io.Reader) (Mat var mr MatchResult // match filename - if filepath.Ext(strings.ToLower(filename)) == ".mz" { + if filepath.Ext(strings.ToLower(filename)) == mz.Extension() { mr.ByName = true } diff --git a/rar.go b/rar.go index 388ecab..9669780 100644 --- a/rar.go +++ b/rar.go @@ -10,6 +10,7 @@ import ( "log" "os" "path" + "path/filepath" "strings" "time" @@ -60,7 +61,7 @@ func (r Rar) Match(_ context.Context, filename string, stream io.Reader) (MatchR var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), r.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == r.Extension() { mr.ByName = true } diff --git a/sz.go b/sz.go index bb23f21..bbad8f1 100644 --- a/sz.go +++ b/sz.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "github.com/klauspost/compress/s2" @@ -51,8 +52,8 @@ func (sz Sz) Match(_ context.Context, filename string, stream io.Reader) (MatchR var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), sz.Extension()) || - strings.Contains(strings.ToLower(filename), ".s2") { + if filepath.Ext(strings.ToLower(filename)) == sz.Extension() || + filepath.Ext(strings.ToLower(filename)) == ".s2" { mr.ByName = true } diff --git a/xz.go b/xz.go index e213bae..717d477 100644 --- a/xz.go +++ b/xz.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" fastxz "github.com/mikelolasagasti/xz" @@ -24,7 +25,7 @@ func (x Xz) Match(_ context.Context, filename string, stream io.Reader) (MatchRe var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), x.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == x.Extension() { mr.ByName = true } diff --git a/zip.go b/zip.go index be0b7bc..630bd6c 100644 --- a/zip.go +++ b/zip.go @@ -10,6 +10,7 @@ import ( "log" "os" "path" + "path/filepath" "strings" szip "github.com/STARRY-S/zip" @@ -85,7 +86,7 @@ func (z Zip) Match(_ context.Context, filename string, stream io.Reader) (MatchR var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), z.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == z.Extension() { mr.ByName = true } diff --git a/zlib.go b/zlib.go index 9ee64f4..3730e48 100644 --- a/zlib.go +++ b/zlib.go @@ -3,6 +3,7 @@ package archives import ( "context" "io" + "path/filepath" "strings" "github.com/klauspost/compress/zlib" @@ -24,7 +25,7 @@ func (zz Zlib) Match(_ context.Context, filename string, stream io.Reader) (Matc var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), zz.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == zz.Extension() { mr.ByName = true } diff --git a/zstd.go b/zstd.go index c36c6b9..e5f717a 100644 --- a/zstd.go +++ b/zstd.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "path/filepath" "strings" "github.com/klauspost/compress/zstd" @@ -26,7 +27,7 @@ func (zs Zstd) Match(_ context.Context, filename string, stream io.Reader) (Matc var mr MatchResult // match filename - if strings.Contains(strings.ToLower(filename), zs.Extension()) { + if filepath.Ext(strings.ToLower(filename)) == zs.Extension() { mr.ByName = true }