@@ -108,7 +108,7 @@ def _translate_module(self, root: TSNode, program: Program) -> None:
108108 """Translate module-level definitions"""
109109 for child in root .children :
110110 if child .type == "function_definition" :
111- proc = self ._translate_function (child )
111+ proc = self ._translate_function (child , program = program )
112112 if proc :
113113 program .add_procedure (proc )
114114
@@ -125,7 +125,7 @@ def _translate_module(self, root: TSNode, program: Program) -> None:
125125
126126 if definition :
127127 if definition .type == "function_definition" :
128- proc = self ._translate_function (definition )
128+ proc = self ._translate_function (definition , program = program )
129129 if proc :
130130 program .add_procedure (proc )
131131 elif definition .type == "class_definition" :
@@ -143,7 +143,7 @@ def _translate_class(self, node: TSNode, program: Program) -> None:
143143 if body_node :
144144 for child in body_node .children :
145145 if child .type == "function_definition" :
146- proc = self ._translate_function (child , is_method = True )
146+ proc = self ._translate_function (child , is_method = True , program = program )
147147 if proc :
148148 proc .class_name = class_name
149149 proc .name = f"{ class_name } .{ proc .name } "
@@ -152,7 +152,7 @@ def _translate_class(self, node: TSNode, program: Program) -> None:
152152 elif child .type == "decorated_definition" :
153153 for c in child .children :
154154 if c .type == "function_definition" :
155- proc = self ._translate_function (c , is_method = True )
155+ proc = self ._translate_function (c , is_method = True , program = program )
156156 if proc :
157157 proc .class_name = class_name
158158 proc .name = f"{ class_name } .{ proc .name } "
@@ -161,7 +161,7 @@ def _translate_class(self, node: TSNode, program: Program) -> None:
161161
162162 self ._current_class = None
163163
164- def _translate_function (self , node : TSNode , is_method : bool = False ) -> Optional [Procedure ]:
164+ def _translate_function (self , node : TSNode , is_method : bool = False , program : Program = None ) -> Optional [Procedure ]:
165165 """Translate function definition to SIL Procedure"""
166166 # Get function name
167167 name_node = node .child_by_field_name ("name" )
@@ -205,10 +205,10 @@ def _translate_function(self, node: TSNode, is_method: bool = False) -> Optional
205205 proc .entry_node = entry .id
206206 self ._current_node = entry
207207
208- # Translate body
208+ # Translate body (pass program for nested functions)
209209 body_node = node .child_by_field_name ("body" )
210210 if body_node :
211- self ._translate_block (body_node )
211+ self ._translate_block (body_node , program = program )
212212
213213 # Create exit node
214214 exit_node = proc .new_node (NodeKind .EXIT )
@@ -264,10 +264,26 @@ def _translate_parameters(self, node: TSNode) -> List[Tuple[PVar, Typ]]:
264264
265265 return params
266266
267- def _translate_block (self , node : TSNode ) -> None :
267+ def _translate_block (self , node : TSNode , program : Program = None ) -> None :
268268 """Translate a block of statements"""
269269 for child in node .children :
270- self ._translate_statement (child )
270+ # Handle nested function definitions
271+ if child .type == "function_definition" :
272+ if program :
273+ proc = self ._translate_function (child )
274+ if proc :
275+ program .add_procedure (proc )
276+ elif child .type == "decorated_definition" :
277+ # Handle decorated nested functions
278+ if program :
279+ for c in child .children :
280+ if c .type == "function_definition" :
281+ proc = self ._translate_function (c )
282+ if proc :
283+ program .add_procedure (proc )
284+ break
285+ else :
286+ self ._translate_statement (child )
271287
272288 def _translate_statement (self , node : TSNode ) -> None :
273289 """Translate a single statement"""
@@ -385,7 +401,17 @@ def _translate_call_assignment(
385401
386402 func_name = self ._get_call_name (call_node )
387403 args = self ._get_call_args (call_node )
388- args_exp = [(self ._translate_expression (a ), Typ .unknown_type ()) for a in args ]
404+
405+ # Handle nested calls - expand them first
406+ args_exp = []
407+ for i , arg in enumerate (args ):
408+ if arg .type == "call" :
409+ # Nested call - expand it first
410+ nested_instrs , nested_var = self ._expand_nested_call (arg , loc )
411+ instrs .extend (nested_instrs )
412+ args_exp .append ((ExpVar (PVar (nested_var )), Typ .unknown_type ()))
413+ else :
414+ args_exp .append ((self ._translate_expression (arg ), Typ .unknown_type ()))
389415
390416 # Create return identifier
391417 ret_id = self ._new_ident (target )
@@ -421,17 +447,27 @@ def _translate_call_assignment(
421447 if spec and spec .is_taint_sink ():
422448 kind = SinkKind (spec .is_sink ) if spec .is_sink in [s .value for s in SinkKind ] else SinkKind .SQL_QUERY
423449 for arg_idx in spec .sink_args :
424- if arg_idx < len (args ):
425- arg_exp = self ._translate_expression (args [arg_idx ])
450+ if arg_idx < len (args_exp ):
426451 instrs .append (TaintSink (
427452 loc = loc ,
428- exp = arg_exp ,
453+ exp = args_exp [ arg_idx ][ 0 ] ,
429454 kind = kind ,
430455 description = spec .description
431456 ))
432457
433458 return instrs
434459
460+ def _expand_nested_call (self , call_node : TSNode , loc : Location ) -> Tuple [List [Instr ], str ]:
461+ """Expand a nested call and return (instructions, result_var_name)"""
462+ # Generate a temp variable for the result
463+ temp_var = f"__nested_{ self ._ident_counter } "
464+ self ._ident_counter += 1
465+
466+ # Translate the nested call as an assignment
467+ nested_instrs = self ._translate_call_assignment (temp_var , call_node , loc )
468+
469+ return nested_instrs , temp_var
470+
435471 def _translate_call_expr (self , call_node : TSNode ) -> List [Instr ]:
436472 """Translate standalone call: func(args)"""
437473 instrs = []
@@ -503,9 +539,10 @@ def _extract_fstring_parts(self, node: TSNode) -> List[Exp]:
503539 parts = []
504540
505541 def walk (n : TSNode ):
506- if n .type == "string_content" or n .type == "string" :
542+ if n .type == "string_content" :
543+ # Literal string content between interpolations
507544 text = self ._get_text (n )
508- if text and not text . startswith ( "f" ) :
545+ if text :
509546 parts .append (ExpConst .string (text ))
510547
511548 elif n .type == "interpolation" :
@@ -515,11 +552,17 @@ def walk(n: TSNode):
515552 exp = self ._translate_expression (child )
516553 parts .append (exp )
517554
518- elif n .type == "formatted_string" or n .type == "f_string" :
555+ elif n .type in ("string_start" , "string_end" ):
556+ # Skip f-string delimiters
557+ pass
558+
559+ elif n .type == "string" or n .type == "formatted_string" or n .type == "f_string" :
560+ # Walk children of string node
519561 for child in n .children :
520562 walk (child )
521563
522564 else :
565+ # For other nodes, recurse into children
523566 for child in n .children :
524567 walk (child )
525568
@@ -848,8 +891,27 @@ def _translate_expression(self, node: TSNode) -> Exp:
848891 return ExpConst .integer (0 )
849892
850893 elif node .type == "string" or node .type == "concatenated_string" :
851- text = self ._get_string_content (node )
852- return ExpConst .string (text )
894+ # Check if it's an f-string by looking for interpolation children
895+ has_interpolation = False
896+ for child in node .children :
897+ if child .type == "interpolation" :
898+ has_interpolation = True
899+ break
900+ elif child .type == "string_start" :
901+ start_text = self ._get_text (child )
902+ if start_text .startswith ('f' ) or start_text .startswith ('F' ):
903+ has_interpolation = True
904+ break
905+
906+ if has_interpolation :
907+ # It's an f-string - extract parts
908+ parts = self ._extract_fstring_parts (node )
909+ if parts :
910+ return ExpStringConcat (parts )
911+ return ExpConst .string ("" )
912+ else :
913+ text = self ._get_string_content (node )
914+ return ExpConst .string (text )
853915
854916 elif node .type in ("true" , "True" ):
855917 return ExpConst .boolean (True )
0 commit comments