2424from pathwaysutils import plugin_executable
2525import uvicorn
2626
27- logging .set_verbosity (logging .INFO )
28-
2927
3028class _ProfileState :
3129 def __init__ (self ):
@@ -102,7 +100,7 @@ def start_server(port: int):
102100 port : The port to start the server on.
103101 """
104102 def server_loop (port : int ):
105- logging .info ("Starting JAX profiler server on port %s" , port )
103+ logging .debug ("Starting JAX profiler server on port %s" , port )
106104 app = fastapi .FastAPI ()
107105
108106 @dataclasses .dataclass
@@ -111,9 +109,9 @@ class ProfilingConfig:
111109 repository_path : str
112110
113111 @app .post ("/profiling" )
114- async def profiling (pc : ProfilingConfig ):
115- logging .info ("Capturing profiling data for %s ms" , pc .duration_ms )
116- logging .info ("Writing profiling data to %s" , pc .repository_path )
112+ async def profiling (pc : ProfilingConfig ): # pylint: disable=unused-variable
113+ logging .debug ("Capturing profiling data for %s ms" , pc .duration_ms )
114+ logging .debug ("Writing profiling data to %s" , pc .repository_path )
117115 jax .profiler .start_trace (pc .repository_path )
118116 time .sleep (pc .duration_ms / 1e3 )
119117 jax .profiler .stop_trace ()
@@ -158,25 +156,27 @@ def start_trace_patch(
158156 create_perfetto_link : bool = False , # pylint: disable=unused-argument
159157 create_perfetto_trace : bool = False , # pylint: disable=unused-argument
160158 ) -> None :
161- logging .info ("jax.profile.start_trace patched with pathways' start_trace" )
159+ logging .debug ("jax.profile.start_trace patched with pathways' start_trace" )
162160 return start_trace (log_dir )
163161
164162 jax .profiler .start_trace = start_trace_patch
165163
166164 def stop_trace_patch () -> None :
167- logging .info ("jax.profile.stop_trace patched with pathways' stop_trace" )
165+ logging .debug ("jax.profile.stop_trace patched with pathways' stop_trace" )
168166 return stop_trace ()
169167
170168 jax .profiler .stop_trace = stop_trace_patch
171169
172170 def start_server_patch (port : int ):
173- logging .info ("jax.profile.start_server patched with pathways' start_server" )
171+ logging .debug (
172+ "jax.profile.start_server patched with pathways' start_server"
173+ )
174174 return start_server (port )
175175
176176 jax .profiler .start_server = start_server_patch
177177
178178 def stop_server_patch ():
179- logging .info ("jax.profile.stop_server patched with pathways' stop_server" )
179+ logging .debug ("jax.profile.stop_server patched with pathways' stop_server" )
180180 return stop_server ()
181181
182182 jax .profiler .stop_server = stop_server_patch
0 commit comments