Skip to content

Commit 35d59b1

Browse files
committed
[bugfix]Adapt the method of writing objects to the database.
1 parent e83b944 commit 35d59b1

File tree

5 files changed

+138
-47
lines changed

5 files changed

+138
-47
lines changed

test/common/capture_utils.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import dataclasses
2+
import functools
3+
from collections.abc import Mapping
14
from typing import Any, Dict, List
25

36
from common.db_utils import write_to_db
@@ -42,14 +45,65 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
4245
return []
4346

4447

48+
def _ensure_list(obj):
49+
"""
50+
Ensure the object is returned as a list.
51+
"""
52+
if isinstance(obj, list):
53+
return obj
54+
if isinstance(obj, (str, bytes, Mapping)):
55+
return [obj]
56+
if hasattr(obj, "__iter__") and not hasattr(obj, "__len__"): # 如 generator
57+
return list(obj)
58+
return [obj]
59+
60+
61+
def _to_dict(obj: Any) -> Dict[str, Any]:
62+
"""
63+
Convert various object types to a dictionary for DB writing.
64+
"""
65+
if isinstance(obj, Mapping):
66+
return dict(obj)
67+
if dataclasses.is_dataclass(obj):
68+
return dataclasses.asdict(obj)
69+
if hasattr(obj, "_asdict"): # namedtuple
70+
return obj._asdict()
71+
if hasattr(obj, "__dict__"):
72+
return vars(obj)
73+
raise TypeError(f"Cannot convert {type(obj)} to dict for DB writing")
74+
75+
76+
def proj_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
77+
if "_proj" not in kwargs:
78+
return []
79+
80+
name = kwargs.get("_name", table_name)
81+
raw_input = kwargs["_proj"]
82+
raw_results = _ensure_list(raw_input)
83+
84+
processed_results = []
85+
for result in raw_results:
86+
try:
87+
dict_result = _to_dict(result)
88+
write_to_db(name, dict_result)
89+
processed_results.append(dict_result)
90+
except Exception as e:
91+
raise ValueError(f"Failed to process item in _proj: {e}") from e
92+
93+
return processed_results
94+
95+
4596
# ---------------- decorator ----------------
4697
def export_vars(func):
98+
@functools.wraps(func)
4799
def wrapper(*args, **kwargs):
48100
result = func(*args, **kwargs)
49-
# If the function returns a dict containing '_data' or 'data', post-process it
101+
# If the function returns a dict containing '_data' or '_proj', post-process it
50102
if isinstance(result, dict):
51-
if "_data" in result or "data" in result:
103+
if "_data" in result:
52104
return post_process(func.__name__, **result)
105+
if "_proj" in result:
106+
return proj_process(func.__name__, **result)
53107
# Otherwise return unchanged
54108
return result
55109

@@ -63,33 +117,6 @@ def capture():
63117
return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}}
64118

65119

66-
@export_vars
67-
def capture_list():
68-
"""All lists via '_name' + '_data'"""
69-
return {
70-
"_name": "demo",
71-
"_data": {
72-
"accuracy": [0.1, 0.2, 0.3],
73-
"loss": [0.1, 0.2, 0.3],
74-
},
75-
}
76-
77-
78-
@export_vars
79-
def capture_mix():
80-
"""Mixed single + lists via '_name' + '_data'"""
81-
return {
82-
"_name": "demo",
83-
"_data": {
84-
"length": 10086, # single value
85-
"accuracy": [0.1, 0.2, 0.3], # list
86-
"loss": [0.1, 0.2, 0.3], # list
87-
},
88-
}
89-
90-
91120
# quick test
92121
if __name__ == "__main__":
93122
print("capture(): ", capture())
94-
print("capture_list(): ", capture_list())
95-
print("capture_mix(): ", capture_mix())

test/common/db.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
CREATE TABLE test_results (
2+
id INT AUTO_INCREMENT PRIMARY KEY,
3+
test_case VARCHAR(255) NOT NULL,
4+
status VARCHAR(50) NOT NULL,
5+
error TEXT,
6+
test_build_id VARCHAR(100) NOT NULL,
7+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
8+
);

test/common/db_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _get_db() -> Optional[MySQLDatabase]:
3636
backup_str = db_config.get("backup", "results/")
3737
_backup_path = Path(backup_str).resolve()
3838
_backup_path.mkdir(parents=True, exist_ok=True)
39-
logger.info(f"Backup directory set to: {_backup_path}")
39+
# logger.info(f"Backup directory set to: {_backup_path}")
4040

4141
if not _db_enabled:
4242
return None
@@ -94,7 +94,7 @@ def _backup_to_file(table_name: str, data: Dict[str, Any]) -> None:
9494
with file_path.open("a", encoding="utf-8") as f:
9595
json.dump(data, f, ensure_ascii=False)
9696
f.write("\n")
97-
logger.info(f"Data backed up to {file_path}")
97+
# logger.info(f"Data backed up to {file_path}")
9898
except Exception as e:
9999
logger.error(f"Failed to write backup file {file_path}: {e}")
100100

@@ -140,7 +140,7 @@ def write_to_db(table_name: str, data: Dict[str, Any]) -> bool:
140140

141141
with db.atomic():
142142
DynamicEntity.insert(filtered_data).execute()
143-
logger.info(f"Successfully inserted data into table '{table_name}'.")
143+
# logger.info(f"Successfully inserted data into table '{table_name}'.")
144144
return True
145145

146146
except peewee.PeeweeException as e:

test/config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ reports:
33
use_timestamp: true
44
directory_prefix: "pytest"
55
html: # pytest-html
6-
enabled: true
6+
enabled: false
77
filename: "report.html"
88
title: "UCM Pytest Test Report"
99

1010
database:
1111
backup: "results/"
12-
enabled: true
12+
enabled: false
1313
host: "127.0.0.1"
1414
port: 3306
15-
name: "ucm_pytest"
15+
name: "ucm_test"
1616
user: "root"
1717
password: "123456"
1818
charset: "utf8mb4"

test/suites/E2E/test_demo_performance.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,74 @@ def test_divide_by_zero(self, calc):
4141

4242

4343
@pytest.mark.feature("capture") # pytest must be the top
44-
@export_vars
45-
def test_capture_mix():
46-
"""Mixed single + lists via '_name' + '_data'"""
47-
assert 1 == 1
48-
return {
49-
"_name": "demo",
50-
"_data": {
51-
"length": 10086, # single value
52-
"accuracy": [0.1, 0.2, 0.3], # list
53-
"loss": [0.1, 0.2, 0.3], # list
54-
},
55-
}
44+
class TestCapture:
45+
@export_vars
46+
def test_capture_mix(self):
47+
"""Mixed single + lists via '_name' + '_data'"""
48+
assert 1 == 1
49+
return {
50+
"_name": "capture_demo",
51+
"_data": {
52+
"length": 1, # single value
53+
"accuracy": [0.1, 0.2, 0.3], # list
54+
"loss": [0.1, 0.2, 0.3], # list
55+
},
56+
}
57+
58+
@export_vars
59+
def test_capture_dict(self):
60+
"""Mixed single + lists via '_name' + '_proj'"""
61+
return {
62+
"_name": "capture_demo",
63+
"_proj": {"length": 2, "accuracy": 0.1, "loss": 0.1}
64+
65+
}
66+
67+
@export_vars
68+
def test_capture_list_dict(self):
69+
"""Mixed single + lists via '_name' + '_proj'"""
70+
return {
71+
"_name": "capture_demo",
72+
"_proj": [
73+
{"length": 3, "accuracy": 0.1, "loss": 0.1},
74+
{"length": 3, "accuracy": 0.2, "loss": 0.2},
75+
{"length": 3, "accuracy": 0.3, "loss": 0.3},
76+
],
77+
}
78+
79+
@export_vars
80+
def test_capture_proj(self):
81+
"""Mixed single + lists via '_name' + '_proj'"""
82+
83+
class Result:
84+
def __init__(self, length, accuracy, loss):
85+
self.length = length
86+
self.accuracy = accuracy
87+
self.loss = loss
88+
89+
return {
90+
"_name": "capture_demo",
91+
"_proj": Result(4, 0.1, 0.1),
92+
}
93+
94+
@export_vars
95+
def test_capture_list_proj(self):
96+
"""Mixed single + lists via '_name' + '_proj'"""
97+
98+
class Result:
99+
def __init__(self, length, accuracy, loss):
100+
self.length = length
101+
self.accuracy = accuracy
102+
self.loss = loss
103+
104+
return {
105+
"_name": "capture_demo",
106+
"_proj": [
107+
Result(5, 0.1, 0.1),
108+
Result(5, 0.2, 0.2),
109+
Result(5, 0.3, 0.3),
110+
],
111+
}
56112

57113

58114
# ---------------- Read Config Example ----------------

0 commit comments

Comments
 (0)