Skip to content

Commit 9f54e0f

Browse files
authored
feat: support trailing line comment for mage:import (#480)
1 parent 9e91a03 commit 9f54e0f

File tree

4 files changed

+63
-13
lines changed

4 files changed

+63
-13
lines changed

mage/import_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,26 @@ func TestMageImportsOneLine(t *testing.T) {
285285
t.Fatalf("expected: %q got: %q", expected, actual)
286286
}
287287
}
288+
func TestMageImportsTrailing(t *testing.T) {
289+
stdout := &bytes.Buffer{}
290+
stderr := &bytes.Buffer{}
291+
inv := Invocation{
292+
Dir: "./testdata/mageimport/trailing",
293+
Stdout: stdout,
294+
Stderr: stderr,
295+
Args: []string{"build"},
296+
}
297+
298+
code := Invoke(inv)
299+
if code != 0 {
300+
t.Fatalf("expected to exit with code 0, but got %v, stderr:\n%s", code, stderr)
301+
}
302+
actual := stdout.String()
303+
expected := "build\n"
304+
if actual != expected {
305+
t.Fatalf("expected: %q got: %q", expected, actual)
306+
}
307+
}
288308

289309
func TestMageImportsTaggedPackage(t *testing.T) {
290310
stdout := &bytes.Buffer{}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// +build mage
2+
3+
package main
4+
5+
import _ "github.com/magefile/mage/mage/testdata/mageimport/oneline/other" //mage:import
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package other
2+
3+
import "fmt"
4+
5+
func Build() {
6+
fmt.Println("build")
7+
}

parse/parse.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -456,19 +456,18 @@ func setImports(gocmd string, pi *PkgInfo) error {
456456
}
457457

458458
func getImportPath(imp *ast.ImportSpec) (path, alias string, ok bool) {
459-
if imp.Doc == nil || len(imp.Doc.List) == 9 {
460-
return "", "", false
461-
}
462-
// import is always the last comment
463-
s := imp.Doc.List[len(imp.Doc.List)-1].Text
464-
465-
// trim comment start and normalize for anyone who has spaces or not between
466-
// "//"" and the text
467-
vals := strings.Fields(strings.ToLower(s[2:]))
468-
if len(vals) == 0 {
469-
return "", "", false
470-
}
471-
if vals[0] != importTag {
459+
leadingVals := getImportPathFromCommentGroup(imp.Doc)
460+
trailingVals := getImportPathFromCommentGroup(imp.Comment)
461+
462+
var vals []string
463+
if len(leadingVals) > 0 {
464+
vals = leadingVals
465+
if len(trailingVals) > 0 {
466+
log.Println("warning:", importTag, "specified both before and after, picking first")
467+
}
468+
} else if len(trailingVals) > 0 {
469+
vals = trailingVals
470+
} else {
472471
return "", "", false
473472
}
474473
path, ok = lit2string(imp.Path)
@@ -489,6 +488,25 @@ func getImportPath(imp *ast.ImportSpec) (path, alias string, ok bool) {
489488
}
490489
}
491490

491+
func getImportPathFromCommentGroup(comments *ast.CommentGroup) []string {
492+
if comments == nil || len(comments.List) == 9 {
493+
return nil
494+
}
495+
// import is always the last comment
496+
s := comments.List[len(comments.List)-1].Text
497+
498+
// trim comment start and normalize for anyone who has spaces or not between
499+
// "//"" and the text
500+
vals := strings.Fields(strings.ToLower(s[2:]))
501+
if len(vals) == 0 {
502+
return nil
503+
}
504+
if vals[0] != importTag {
505+
return nil
506+
}
507+
return vals
508+
}
509+
492510
func isNamespace(t *doc.Type) bool {
493511
if len(t.Decl.Specs) != 1 {
494512
return false

0 commit comments

Comments
 (0)