Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 56 additions & 79 deletions lua/neotest-jest/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,46 @@ function adapter.build_position(file_path, source, captured_nodes)
end

---@type TSNode
local test_name_node = captured_nodes[match_type .. ".name"]
local name = vim.treesitter.get_node_text(test_name_node, source)
local node = captured_nodes[match_type .. ".name"]
local name = vim.treesitter.get_node_text(node, source)
local definition = captured_nodes[match_type .. ".definition"]
local type = node:type()
local nonStringNode = false

if type == "string" then
-- If the node is a string then strip the quotes from the name by getting
-- it's first named child (string_fragment). This works for single- and
-- double-quotes and is necessary since we match anything in the queries
-- used in discover_positions
local content = node:named_child(0)

if content then
name = vim.treesitter.get_node_text(content, source)
end
elseif type == "template_string" then
-- If the node is a template string then concatenate its named children
-- which is essentially the inner part of the backticks thus stripping
-- backticks. This is necessary since we match anything in the queries used
-- in discover_positions
local new_name = {}

for _, named_child in ipairs(node:named_children()) do
table.insert(new_name, vim.treesitter.get_node_text(named_child, source))
end

name = table.concat(new_name, "")
else
nonStringNode = true
end

return {
type = match_type,
path = file_path,
name = name,
range = { definition:range() },
-- Record the position of the line where the string name occurs
test_name_range = match_type == "test" and { test_name_node:range() } or nil,
is_parameterized = captured_nodes["each_property"] and true or false,
test_name_range = match_type == "test" and { node:range() } or nil,
is_parameterized = (captured_nodes["each_property"] or nonStringNode) and true or false,
}
end

Expand All @@ -100,61 +128,43 @@ function adapter.discover_positions(path)
; Matches: `describe('context', () => {})`
((call_expression
function: (identifier) @func_name (#eq? @func_name "describe")
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (arrow_function))
arguments: (arguments ((_) @namespace.name) (arrow_function))
)) @namespace.definition

; Matches: `describe('context', function() {})`
((call_expression
function: (identifier) @func_name (#eq? @func_name "describe")
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (function_expression))
arguments: (arguments ((_) @namespace.name) (function_expression))
)) @namespace.definition

; Matches: `describe('context', wrapper())`
((call_expression
function: (identifier) @func_name (#eq? @func_name "describe")
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (call_expression))
arguments: (arguments ((_) @namespace.name) (call_expression))
)) @namespace.definition

; Matches: `describe.only('context', () => {})`
((call_expression
function: (member_expression
object: (identifier) @func_name (#eq? @func_name "describe")
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (arrow_function))
arguments: (arguments ((_) @namespace.name) (arrow_function))
)) @namespace.definition

; Matches: `describe.only('context', function() {})`
((call_expression
function: (member_expression
object: (identifier) @func_name (#eq? @func_name "describe")
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (function_expression))
arguments: (arguments ((_) @namespace.name) (function_expression))
)) @namespace.definition

; Matches: `describe.only('context', wrapper())`
((call_expression
function: (member_expression
object: (identifier) @func_name (#eq? @func_name "describe")
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (call_expression))
arguments: (arguments ((_) @namespace.name) (call_expression))
)) @namespace.definition

; Matches: `describe.each(['data'])('context', () => {})`
Expand All @@ -165,10 +175,7 @@ function adapter.discover_positions(path)
property: (property_identifier) @each_property (#eq? @each_property "each")
)
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (arrow_function))
arguments: (arguments ((_) @namespace.name) (arrow_function))
)) @namespace.definition

; Matches: `describe.each(['data'])('context', function() {})`
Expand All @@ -179,10 +186,7 @@ function adapter.discover_positions(path)
property: (property_identifier) @each_property (#eq? @each_property "each")
)
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (function_expression))
arguments: (arguments ((_) @namespace.name) (function_expression))
)) @namespace.definition

; Matches: `describe.each(['data'])('context', wrapper())`
Expand All @@ -192,10 +196,7 @@ function adapter.discover_positions(path)
object: (identifier) @func_name (#eq? @func_name "describe")
)
)
arguments: (arguments ([
(string (string_fragment) @namespace.name)
(template_string (_) @namespace.name)
]) (call_expression))
arguments: (arguments ((_) @namespace.name) (call_expression))
)) @namespace.definition

; #########
Expand All @@ -205,61 +206,43 @@ function adapter.discover_positions(path)
; Matches: `it('test', () => {}) / test('test', () => {})`
((call_expression
function: (identifier) @func_name (#any-of? @func_name "it" "test")
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (arrow_function))
arguments: (arguments ((_) @test.name) (arrow_function))
)) @test.definition

; Matches: `it('test', function() {}) / test('test', function() {})`
((call_expression
function: (identifier) @func_name (#any-of? @func_name "it" "test")
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (function_expression))
arguments: (arguments ((_) @test.name) (function_expression))
)) @test.definition

; Matches: `it('test', wrapper()) / test('test', wrapper())`
((call_expression
function: (identifier) @func_name (#any-of? @func_name "it" "test")
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (call_expression))
arguments: (arguments ((_) @test.name) (call_expression))
)) @test.definition

; Matches: `test.only('test', () => {}) / it.only('test', () => {})`
((call_expression
function: (member_expression
object: (identifier) @func_name (#any-of? @func_name "test" "it")
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (arrow_function))
arguments: (arguments ((_) @test.name) (arrow_function))
)) @test.definition

; Matches: `test.only('test', function() {}) / it.only('test', function() {})`
((call_expression
function: (member_expression
object: (identifier) @func_name (#any-of? @func_name "test" "it")
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (function_expression))
arguments: (arguments ((_) @test.name) (function_expression))
)) @test.definition

; Matches: `test.only('test', wrapper()) / it.only('test', wrapper())`
((call_expression
function: (member_expression
object: (identifier) @func_name (#any-of? @func_name "test" "it")
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (call_expression))
arguments: (arguments ((_) @test.name) (call_expression))
)) @test.definition

; Matches: `test.each(['data'])('test', () => {}) / it.each(['data'])('test', () => {})`
Expand All @@ -270,10 +253,7 @@ function adapter.discover_positions(path)
property: (property_identifier) @each_property (#eq? @each_property "each")
)
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (arrow_function))
arguments: (arguments ((_) @test.name) (arrow_function))
)) @test.definition

; Matches: `test.each(['data'])('test', function() {}) / it.each(['data'])('test', function() {})`
Expand All @@ -284,10 +264,7 @@ function adapter.discover_positions(path)
property: (property_identifier) @each_property (#eq? @each_property "each")
)
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (function_expression))
arguments: (arguments ((_) @test.name) (function_expression))
)) @test.definition

; Matches: `test.each(['data'])('test', wrapper()) / it.each(['data'])('test', wrapper())`
Expand All @@ -298,10 +275,7 @@ function adapter.discover_positions(path)
property: (property_identifier) @each_property (#eq? @each_property "each")
)
)
arguments: (arguments ([
(string (string_fragment) @test.name)
(template_string (_) @test.name)
]) (call_expression))
arguments: (arguments ((_) @test.name) (call_expression))
)) @test.definition
]]

Expand Down Expand Up @@ -447,9 +421,12 @@ function adapter.build_spec(args)
local testNamePattern = ".*"

if pos.type == types.PositionType.test or pos.type == types.PositionType.namespace then
-- pos.id in form "path/to/file::Describe text::test text"
local testName = pos.id:sub(pos.id:find("::") + 2)
testName, _ = testName:gsub("::", " ")
-- Check if the position is a parametric (runtime) test by seeing if there is a corresponding
-- source-level test
local sourceLevelTest = parameterized_tests.getParametricTestToSourceLevelTest(pos)
local testName = sourceLevelTest or pos.id

testName, _ = testName:sub(pos.id:find("::") + 2):gsub("::", " ")
testNamePattern = util.escapeTestPattern(testName)

-- If the position or any of its enclosing blocks are parameterized, replace any
Expand Down
34 changes: 20 additions & 14 deletions lua/neotest-jest/parameterized-tests.lua
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
local lib = require("neotest.lib")
local jest_util = require("neotest-jest.jest-util")
local types = require("neotest.types")
local logger = require("neotest.logging")

local M = {}

---@class neotest-jest.RuntimeTestInfo
---@field pos_id string
---@field name string
---@field namespace_pos_id string
---@field namespace_name string

---@type table<string, table<string, string>>
local parametricTestToSourceLevelTest = {}

local JEST_PARAMETER_TYPES = {
"%%p",
Expand Down Expand Up @@ -207,13 +207,12 @@ function M.enrichPositionsWithParameterizedTests(file_path, parsed_parameterized
-- Get all runtime test information for path
local jest_test_discovery_output = runJestTestDiscovery(file_path)

logger.warn(jest_test_discovery_output)

if jest_test_discovery_output == nil then
return
end

local tests_by_position = getTestsByPosition(jest_test_discovery_output)
parametricTestToSourceLevelTest[file_path] = {}

-- For each parameterized test, find all tests that were in the same position
-- as it and add new range-less (range = nil) children to the tree
Expand Down Expand Up @@ -246,15 +245,6 @@ function M.enrichPositionsWithParameterizedTests(file_path, parsed_parameterized
for _, test_result in ipairs(parameterized_test_results_for_position) do
tryCreateNamespaceNodes(tree, test_result.pos_id)

-- Only create a new node if the test position has any test parameters
-- ('$param' or '%j') in the name. Otherwise, we would use a position
-- id that matches the source-level test name which would overwrite
-- the real position id in the tree.
--
-- There is no way for neotest-jest or jest to distinguish between
-- tests that share the same name anyway so not creating new nodes is
-- acceptable for now
-- if hasTestParameters(tree, pos) then
if not tree:get_key(test_result.pos_id) then
createNewChildNode(
tree,
Expand All @@ -270,6 +260,12 @@ function M.enrichPositionsWithParameterizedTests(file_path, parsed_parameterized
source_pos_id = pos.id,
}
)

if not parametricTestToSourceLevelTest[file_path] then
parametricTestToSourceLevelTest[file_path] = {}
end

parametricTestToSourceLevelTest[file_path][pos.id] = test_result.pos_id
end
end
end
Expand Down Expand Up @@ -327,4 +323,14 @@ function M.replaceTestParametersWithRegex(test_name)
return result
end

---@param pos neotest.Position
---@return string?
function M.getParametricTestToSourceLevelTest(pos)
if parametricTestToSourceLevelTest[pos.path] then
return parametricTestToSourceLevelTest[pos.path][pos.id]
end

return nil
end

return M
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
],
"endTime": 1751198748669,
"message": "\u001b[1m\u001b[31m \u001b[1m? \u001b[22m\u001b[1mdescribe text � 3\u001b[39m\u001b[22m\n\n ReferenceError: assert is not defined",
"name": "./spec/tests/basic-skipped-failed.test.ts",
"name": "./spec/tests/basicSkippedFailed.test.ts",
"startTime": 1751198748377,
"status": "failed",
"summary": ""
Expand Down
File renamed without changes.
Loading