Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions composer/loggers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pathlib import Path
from typing import Any, Optional, Sequence, Union
from urllib.parse import urlparse

import numpy as np
import torch
Expand Down Expand Up @@ -108,9 +109,16 @@ def _initialize_summary_writer(self):

assert self.run_name is not None
assert self.log_dir is not None
# We name the child directory after the run_name to ensure the run_name shows up
# in the Tensorboard GUI.
summary_writer_log_dir = Path(self.log_dir) / self.run_name

parsed = urlparse(self.log_dir)
# TODO: Handle other remote storage schemes
if parsed.scheme == 's3':
scheme, bucket, prefix, _, _, _ = parsed
summary_writer_log_dir = f"{scheme}://{bucket}/{prefix.strip('/')}/{self.run_name}"
else:
# We name the child directory after the run_name to ensure the run_name shows up
# in the Tensorboard GUI.
summary_writer_log_dir = str(Path(self.log_dir) / self.run_name)

# Disable SummaryWriter's internal flushing to avoid file corruption while
# file staged for upload to an ObjectStore.
Expand Down
19 changes: 19 additions & 0 deletions tests/loggers/test_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import Sequence

import boto3
import moto
import pytest
import torch

Expand Down Expand Up @@ -52,3 +54,20 @@ def test_tensorboard_log_image(test_tensorboard_logger, dummy_state):
logger = Logger(dummy_state, [])
test_tensorboard_logger.close(dummy_state, logger)
# Tensorboard images are stored inline, so we can't check them automatically.


@moto.mock_aws
def test_tensorboard_logger_s3_log_dir(dummy_state):
bucket_name = 'test-tensorboard-bucket'
s3 = boto3.client('s3')
s3.create_bucket(Bucket=bucket_name)

test_s3_log_dir = f's3://{bucket_name}/log_prefix'

dummy_state.run_name = 'tensorboard-test-log-s3'
logger = Logger(dummy_state, [])
tensorboard_logger = TensorboardLogger(log_dir=test_s3_log_dir)
tensorboard_logger.init(dummy_state, logger)
assert tensorboard_logger.writer is not None
expected_log_dir = f'{test_s3_log_dir}/{dummy_state.run_name}'
assert tensorboard_logger.writer.log_dir == expected_log_dir