1+ import dataclasses
2+ import functools
3+ from collections .abc import Mapping
14from typing import Any , Dict , List
25
36from common .db_utils import write_to_db
@@ -42,14 +45,65 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
4245 return []
4346
4447
48+ def _ensure_list (obj ):
49+ """
50+ Ensure the object is returned as a list.
51+ """
52+ if isinstance (obj , list ):
53+ return obj
54+ if isinstance (obj , (str , bytes , Mapping )):
55+ return [obj ]
56+ if hasattr (obj , "__iter__" ) and not hasattr (obj , "__len__" ): # 如 generator
57+ return list (obj )
58+ return [obj ]
59+
60+
61+ def _to_dict (obj : Any ) -> Dict [str , Any ]:
62+ """
63+ Convert various object types to a dictionary for DB writing.
64+ """
65+ if isinstance (obj , Mapping ):
66+ return dict (obj )
67+ if dataclasses .is_dataclass (obj ):
68+ return dataclasses .asdict (obj )
69+ if hasattr (obj , "_asdict" ): # namedtuple
70+ return obj ._asdict ()
71+ if hasattr (obj , "__dict__" ):
72+ return vars (obj )
73+ raise TypeError (f"Cannot convert { type (obj )} to dict for DB writing" )
74+
75+
76+ def proj_process (table_name : str , ** kwargs ) -> List [Dict [str , Any ]]:
77+ if "_proj" not in kwargs :
78+ return []
79+
80+ name = kwargs .get ("_name" , table_name )
81+ raw_input = kwargs ["_proj" ]
82+ raw_results = _ensure_list (raw_input )
83+
84+ processed_results = []
85+ for result in raw_results :
86+ try :
87+ dict_result = _to_dict (result )
88+ write_to_db (name , dict_result )
89+ processed_results .append (dict_result )
90+ except Exception as e :
91+ raise ValueError (f"Failed to process item in _proj: { e } " ) from e
92+
93+ return processed_results
94+
95+
4596# ---------------- decorator ----------------
4697def export_vars (func ):
98+ @functools .wraps (func )
4799 def wrapper (* args , ** kwargs ):
48100 result = func (* args , ** kwargs )
49- # If the function returns a dict containing '_data' or 'data ', post-process it
101+ # If the function returns a dict containing '_data' or '_proj ', post-process it
50102 if isinstance (result , dict ):
51- if "_data" in result or "data" in result :
103+ if "_data" in result :
52104 return post_process (func .__name__ , ** result )
105+ if "_proj" in result :
106+ return proj_process (func .__name__ , ** result )
53107 # Otherwise return unchanged
54108 return result
55109
@@ -63,33 +117,6 @@ def capture():
63117 return {"name" : "demo" , "_data" : {"accuracy" : 0.1 , "loss" : 0.3 }}
64118
65119
66- @export_vars
67- def capture_list ():
68- """All lists via '_name' + '_data'"""
69- return {
70- "_name" : "demo" ,
71- "_data" : {
72- "accuracy" : [0.1 , 0.2 , 0.3 ],
73- "loss" : [0.1 , 0.2 , 0.3 ],
74- },
75- }
76-
77-
78- @export_vars
79- def capture_mix ():
80- """Mixed single + lists via '_name' + '_data'"""
81- return {
82- "_name" : "demo" ,
83- "_data" : {
84- "length" : 10086 , # single value
85- "accuracy" : [0.1 , 0.2 , 0.3 ], # list
86- "loss" : [0.1 , 0.2 , 0.3 ], # list
87- },
88- }
89-
90-
91120# quick test
92121if __name__ == "__main__" :
93122 print ("capture(): " , capture ())
94- print ("capture_list(): " , capture_list ())
95- print ("capture_mix(): " , capture_mix ())
0 commit comments