Skip to content

Commit 1f2b640

Browse files
committed
Logging updates
Updated logging levels to make Pathways Utils less noisy Using default verbosity. Lets usersset the verbosity as with `absl.logging.set_verbosity(absl.logging.INFO)`. PiperOrigin-RevId: 683380846
1 parent 359776d commit 1f2b640

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

pathwaysutils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _is_persistence_enabled():
4545

4646

4747
if _is_pathways_used():
48-
logging.warning(
48+
logging.debug(
4949
"pathwaysutils: Detected Pathways-on-Cloud backend. Applying changes."
5050
)
5151
proxy_backend.register_backend_factory()
@@ -58,9 +58,9 @@ def _is_persistence_enabled():
5858
try:
5959
cloud_logging.setup()
6060
except OSError as e:
61-
logging.warning("pathwaysutils: Failed to set up cloud logging.")
61+
logging.debug("pathwaysutils: Failed to set up cloud logging.")
6262
else:
63-
logging.warning(
63+
logging.debug(
6464
"pathwaysutils: Did not detect Pathways-on-Cloud backend. No changes"
6565
" applied."
6666
)

pathwaysutils/persistence/pathways_orbax_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def register_pathways_handlers(
180180
read_timeout: Optional[datetime.timedelta] = None,
181181
):
182182
"""Function that must be called before saving or restoring with Pathways."""
183-
logging.warning(
183+
logging.debug(
184184
'Registering CloudPathwaysArrayHandler (Pathways Persistence API).'
185185
)
186186
type_handlers.register_type_handler(

pathwaysutils/profiling.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from pathwaysutils import plugin_executable
2525
import uvicorn
2626

27-
logging.set_verbosity(logging.INFO)
28-
2927

3028
class _ProfileState:
3129
def __init__(self):
@@ -102,7 +100,7 @@ def start_server(port: int):
102100
port : The port to start the server on.
103101
"""
104102
def server_loop(port: int):
105-
logging.info("Starting JAX profiler server on port %s", port)
103+
logging.debug("Starting JAX profiler server on port %s", port)
106104
app = fastapi.FastAPI()
107105

108106
@dataclasses.dataclass
@@ -111,9 +109,9 @@ class ProfilingConfig:
111109
repository_path: str
112110

113111
@app.post("/profiling")
114-
async def profiling(pc: ProfilingConfig):
115-
logging.info("Capturing profiling data for %s ms", pc.duration_ms)
116-
logging.info("Writing profiling data to %s", pc.repository_path)
112+
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
113+
logging.debug("Capturing profiling data for %s ms", pc.duration_ms)
114+
logging.debug("Writing profiling data to %s", pc.repository_path)
117115
jax.profiler.start_trace(pc.repository_path)
118116
time.sleep(pc.duration_ms / 1e3)
119117
jax.profiler.stop_trace()
@@ -158,25 +156,27 @@ def start_trace_patch(
158156
create_perfetto_link: bool = False, # pylint: disable=unused-argument
159157
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
160158
) -> None:
161-
logging.info("jax.profile.start_trace patched with pathways' start_trace")
159+
logging.debug("jax.profile.start_trace patched with pathways' start_trace")
162160
return start_trace(log_dir)
163161

164162
jax.profiler.start_trace = start_trace_patch
165163

166164
def stop_trace_patch() -> None:
167-
logging.info("jax.profile.stop_trace patched with pathways' stop_trace")
165+
logging.debug("jax.profile.stop_trace patched with pathways' stop_trace")
168166
return stop_trace()
169167

170168
jax.profiler.stop_trace = stop_trace_patch
171169

172170
def start_server_patch(port: int):
173-
logging.info("jax.profile.start_server patched with pathways' start_server")
171+
logging.debug(
172+
"jax.profile.start_server patched with pathways' start_server"
173+
)
174174
return start_server(port)
175175

176176
jax.profiler.start_server = start_server_patch
177177

178178
def stop_server_patch():
179-
logging.info("jax.profile.stop_server patched with pathways' stop_server")
179+
logging.debug("jax.profile.stop_server patched with pathways' stop_server")
180180
return stop_server()
181181

182182
jax.profiler.stop_server = stop_server_patch

0 commit comments

Comments
 (0)