Skip to content

Commit c17aafc

Browse files
committed
Update with code from sensAI 1.4.0
1 parent 835ca7b commit c17aafc

File tree

6 files changed

+181
-5
lines changed

6 files changed

+181
-5
lines changed

src/sensai/util/deprecation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ def deprecated(message):
1515
def deprecated_decorator(func):
1616
@wraps(func)
1717
def deprecated_func(*args, **kwargs):
18-
msg = "{} is a deprecated function. {}".format(func.__name__, message)
18+
func_name = func.__name__
19+
if func_name == "__init__":
20+
class_name = func.__qualname__.split('.')[0]
21+
msg = "{} is a deprecated class. {}".format(class_name, message)
22+
else:
23+
msg = "{} is a deprecated function. {}".format(func_name, message)
1924
if logging.Logger.root.hasHandlers():
2025
log.warning(msg)
2126
else:

src/sensai/util/git.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
from dataclasses import dataclass
3+
import subprocess
4+
from typing import Optional
5+
6+
from .string import ToStringMixin
7+
8+
log = logging.getLogger(__name__)
9+
10+
11+
@dataclass
12+
class GitStatus(ToStringMixin):
13+
commit: str
14+
has_unstaged_changes: bool
15+
has_staged_uncommitted_changes: bool
16+
has_untracked_files: bool
17+
18+
@property
19+
def is_clean(self) -> bool:
20+
return not (self.has_unstaged_changes or
21+
self.has_staged_uncommitted_changes or
22+
self.has_untracked_files)
23+
24+
25+
def git_status() -> Optional[GitStatus]:
26+
try:
27+
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
28+
unstaged = bool(subprocess.check_output(['git', 'diff', '--name-only']).decode('ascii').strip())
29+
staged = bool(subprocess.check_output(['git', 'diff', '--staged', '--name-only']).decode('ascii').strip())
30+
untracked = bool(subprocess.check_output(['git', 'ls-files', '--others', '--exclude-standard']).decode('ascii').strip())
31+
return GitStatus(
32+
commit=commit_hash,
33+
has_unstaged_changes=unstaged,
34+
has_staged_uncommitted_changes=staged,
35+
has_untracked_files=untracked
36+
)
37+
except Exception as e:
38+
log.error("Error determining Git status", exc_info=e)
39+
return None

src/sensai/util/helper.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This module contains various helper functions.
33
"""
44
import math
5-
from typing import Any, Sequence, Union, TypeVar, List, Optional, Dict, Container, Iterable
5+
from typing import Any, Sequence, Union, TypeVar, List, Optional, Dict, Container, Iterable, Tuple
66

77
T = TypeVar("T")
88

@@ -110,3 +110,24 @@ def kwarg_if_not_none(arg_name: str, arg_value: Any) -> Dict[str, Any]:
110110
return {}
111111
else:
112112
return {arg_name: arg_value}
113+
114+
115+
def flatten_dict(d: Dict[str, Any], sep: str = '.') -> Dict[str, Any]:
116+
"""
117+
Flatten a nested dictionary by concatenating nested keys with a separator.
118+
119+
:param d: the dictionary to flatten
120+
:param sep: the separator to use in order to join the keys of nested dictionaries
121+
:return: a flattened dictionary
122+
"""
123+
def _flatten(d: Dict[str, Any], parent_key: str = '') -> List[Tuple[str, Any]]:
124+
items: List[Tuple[str, Any]] = []
125+
for k, v in d.items():
126+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
127+
if isinstance(v, dict):
128+
items.extend(_flatten(v, new_key))
129+
else:
130+
items.append((new_key, v))
131+
return items
132+
133+
return dict(_flatten(d))

src/sensai/util/plot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def save(self, path):
181181
log.info(f"Saving figure in {path}")
182182
self.fig.savefig(path)
183183

184+
def show(self):
185+
self.fig.show()
186+
184187
def xtick(self: TPlot, major=None, minor=None) -> TPlot:
185188
"""
186189
Sets a tick on every integer multiple of the given base values.

src/sensai/util/string.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def to_string(x, converter: StringConverter = None, apply_converter_to_non_compl
8484
return list_string(x, brackets="()", converter=converter)
8585
elif type(x) == dict:
8686
return dict_string(x, brackets="{}", converter=converter)
87+
elif type(x) == str:
88+
return repr(x)
8789
elif type(x) == types.MethodType:
8890
# could be bound method of a ToStringMixin instance (which would print the repr of the instance, which can potentially cause
8991
# an infinite recursion)
@@ -426,7 +428,7 @@ def take(cnt=1):
426428
def find_matching(j):
427429
start = j
428430
op = s[j]
429-
cl = {"[": "]", "(": ")", "'": "'"}[s[j]]
431+
cl = {"[": "]", "(": ")", "'": "'", "{": "}"}[s[j]]
430432
is_bracket = cl != s[j]
431433
stack = 0
432434
while j < len(s):
@@ -439,7 +441,7 @@ def find_matching(j):
439441
j += 1
440442
return None
441443

442-
brackets = "[("
444+
brackets = "[({"
443445
quotes = "'"
444446
while i < len(s):
445447
is_bracket = s[i] in brackets
@@ -457,7 +459,7 @@ def find_matching(j):
457459
take(1)
458460
indent += 1
459461
nl()
460-
elif s[i] in "])":
462+
elif s[i] in "])}":
461463
take(1)
462464
indent -= 1
463465
elif s[i:i+2] == ", ":

src/sensai/util/tensorboard.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
import pandas as pd
3+
from tensorboard.backend.event_processing import event_accumulator
4+
5+
from matplotlib import pyplot as plt
6+
7+
from .pandas import SeriesInterpolationLinearIndex
8+
9+
10+
class TensorboardData:
11+
def __init__(self, events: event_accumulator.EventAccumulator):
12+
self.events = events
13+
self.events.Reload()
14+
15+
def get_series(self, tag: str, smoothing_factor: float = 0.0) -> pd.Series:
16+
"""
17+
Gets the (smoothed) pandas Series for a specific tensorboard tag.
18+
19+
:param tag: the tensorboard tag
20+
:param smoothing_factor: the smoothing factor between 0 and 1 which determines the relative importance of past values.
21+
0: no smoothing
22+
1: maximum smoothing (all values will be equal to the first value)
23+
:return: the pandas series with the step as the index
24+
"""
25+
if not 0 <= smoothing_factor <= 1:
26+
raise ValueError("Smoothing factor must be between 0 and 1")
27+
28+
try:
29+
scalar_events = self.events.Scalars(tag)
30+
except KeyError:
31+
raise KeyError(f"Tag '{tag}' not found in tensorboard events")
32+
33+
steps = [event.step for event in scalar_events]
34+
values = [event.value for event in scalar_events]
35+
36+
if smoothing_factor > 0:
37+
smoothed_values = []
38+
last = values[0]
39+
for value in values:
40+
last = smoothing_factor * last + (1 - smoothing_factor) * value
41+
smoothed_values.append(last)
42+
values = smoothed_values
43+
44+
return pd.Series(values, index=steps, name=tag)
45+
46+
def get_tags(self) -> list[str]:
47+
"""
48+
Get list of available scalar tags in the events.
49+
50+
:return: list of tag names
51+
"""
52+
return self.events.Tags()['scalars']
53+
54+
def get_data_frame(self, tags: list[str] | None = None, smoothing_factor: float = 0.0) -> pd.DataFrame:
55+
"""
56+
Gets multiple series as a DataFrame.
57+
58+
:param tags: the list of tensorboard tags to consider; if None, use all
59+
:param smoothing_factor: smoothing factor to apply to all series
60+
:return: DataFrame with steps as index and tags as columns
61+
"""
62+
if tags is None:
63+
tags = self.get_tags()
64+
series_dict = {}
65+
for tag in tags:
66+
series = self.get_series(tag, smoothing_factor)
67+
series_dict[series.name] = series
68+
69+
return pd.DataFrame(series_dict)
70+
71+
72+
class TensorboardSeriesComparison:
73+
def __init__(self, tb_reference: TensorboardData, tb_current: TensorboardData,
74+
tag: str, index_start: int, index_end: int):
75+
s_ref = tb_reference.get_series(tag)
76+
s_cur = tb_current.get_series(tag)
77+
78+
interp = SeriesInterpolationLinearIndex(ffill=True, bfill=True)
79+
s_ref, s_cur = interp.interpolate_all_with_combined_index([s_ref, s_cur])
80+
81+
self.s_ref = s_ref.loc[index_start:index_end]
82+
self.s_cur = s_cur.loc[index_start:index_end]
83+
84+
def mean_relative_difference(self):
85+
"""
86+
Computes the difference between the current series and the reference series, relative to the reference,
87+
e.g. if the current series is on average 105% of the reference series (5% relative difference), then
88+
the value will be 0.05.
89+
Since we divide by the absolute value of the reference, this also works for negative cases, i.e.
90+
if the reference series value is -0.10 and the current series value is -0.08, then the relative
91+
difference is 0.2 (20%).
92+
93+
:return: the mean relative difference
94+
"""
95+
diff = self.s_cur - self.s_ref
96+
diff_rel = diff / abs(self.s_ref)
97+
return np.mean(diff_rel)
98+
99+
def plot_series(self, show=False) -> plt.Figure:
100+
fig = plt.figure()
101+
self.s_ref.plot()
102+
self.s_cur.plot()
103+
plt.title(self.s_ref.name)
104+
if show:
105+
plt.show()
106+
return fig

0 commit comments

Comments
 (0)