diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 542ff9c71..059e69858 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -25,7 +25,7 @@ import uuid from kubeflow_trainer_api import models -from kubernetes import client, config, watch +from kubernetes import client, config import kubeflow.common.constants as common_constants from kubeflow.common.types import KubernetesBackendConfig @@ -606,16 +606,21 @@ def _read_pod_logs(self, pod_name: str, container_name: str, follow: bool) -> It """Read logs from a pod container.""" try: if follow: - log_stream = watch.Watch().stream( - self.core_api.read_namespaced_pod_log, + # Stream logs using response.stream() after calling read_namespaced_pod_log + # with _preload_content=False to get a streaming response. + response = self.core_api.read_namespaced_pod_log( name=pod_name, namespace=self.namespace, container=container_name, follow=True, + _preload_content=False, ) - # Stream logs incrementally. - yield from log_stream # type: ignore + # Stream logs incrementally using response.stream(). + for line in response.stream(): + if line: + # Decode bytes to string and yield each line. + yield line.decode("utf-8").rstrip("\n") else: logs = self.core_api.read_namespaced_pod_log( name=pod_name, diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 5fcbeb735..ec4b7f1eb 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -452,6 +452,20 @@ def mock_read_namespaced_pod_log(*args, **kwargs): """Simulate log retrieval from a pod.""" if kwargs.get("namespace") == FAIL_LOGS: raise Exception("Failed to read logs") + + # Handle streaming case: when _preload_content=False and follow=True + if kwargs.get("_preload_content") is False and kwargs.get("follow") is True: + # Return a mock response object with a stream() method + mock_response = Mock() + + def mock_stream(): + """Mock stream generator that yields log lines as bytes.""" + yield b"test log content" + + mock_response.stream = mock_stream + return mock_response + + # Non-streaming case: return plain text logs return "test log content" @@ -1418,6 +1432,12 @@ def test_list_jobs(kubernetes_backend, test_case): config={"name": BASIC_TRAIN_JOB_NAME, "namespace": FAIL_LOGS}, expected_error=RuntimeError, ), + TestCase( + name="valid flow with follow=True for streaming logs", + expected_status=SUCCESS, + config={"name": BASIC_TRAIN_JOB_NAME, "follow": True}, + expected_output=["test log content"], + ), ], ) def test_get_job_logs(kubernetes_backend, test_case): @@ -1425,7 +1445,8 @@ def test_get_job_logs(kubernetes_backend, test_case): print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - logs = kubernetes_backend.get_job_logs(test_case.config.get("name")) + follow = test_case.config.get("follow", False) + logs = kubernetes_backend.get_job_logs(test_case.config.get("name"), follow=follow) # Convert iterator to list for comparison. logs_list = list(logs) assert test_case.expected_status == SUCCESS