From 8fb186f224a4bdf05deec76afee17794438227cd Mon Sep 17 00:00:00 2001 From: Jose Cespedes Date: Sat, 27 Apr 2024 19:37:13 -0600 Subject: [PATCH] feat: now logger can output csv to file --- cogitare/plugins/logger.py | 9 ++++++--- tests/test_plugins/test_logger.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/cogitare/plugins/logger.py b/cogitare/plugins/logger.py index 7c4654f..6e664cf 100644 --- a/cogitare/plugins/logger.py +++ b/cogitare/plugins/logger.py @@ -39,7 +39,8 @@ class Logger(PluginInterface): model.register_plugin(logger2, 'on_end_batch') """ - def __init__(self, title='[Logger]', msg='Loss: {loss_mean:.6f}', show_time=True, output_file=None, freq=1): + def __init__(self, title='[Logger]', msg='Loss: {loss_mean:.6f}', show_time=True, output_file=None, freq=1, + csv=False): super(Logger, self).__init__(freq=freq) self.title = title @@ -48,7 +49,9 @@ def __init__(self, title='[Logger]', msg='Loss: {loss_mean:.6f}', show_time=True self.output_file = output_file self.logger = logging.getLogger(title) coloredlogs.install(level='DEBUG', logger=self.logger) - + self.separator = '| ' + if csv: + self.separator = ', ' if show_time: self._start_time = time.time() @@ -61,7 +64,7 @@ def _time_spent(self): time_str = ' '.join('{} {}'.format(getattr(seconds, k), k) for k in intervals if getattr(seconds, k)) - return '| ' + time_str + return self.separator + time_str def function(self, *args, **kwargs): log = '%s %s %s' % (self.title, self.msg.format(**kwargs), self._time_spent()) diff --git a/tests/test_plugins/test_logger.py b/tests/test_plugins/test_logger.py index 136cf6f..1d354de 100644 --- a/tests/test_plugins/test_logger.py +++ b/tests/test_plugins/test_logger.py @@ -37,3 +37,14 @@ def test_logger_with_tqdm(capsys): l(loss_mean=1.0) out, err = capsys.readouterr() assert '[Logger] Loss: 1.000000 ' in err + + +def test_logger_file_csv_format(capsys): + f = mock.Mock() + f.write = mock.MagicMock(return_value=None) + + l = Logger(output_file=f, csv=True) + l(loss_mean=1.0) + out, err = capsys.readouterr() + assert f.write.called + assert ', ' in err