Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1714,10 +1714,16 @@ def _report_duration(
tbe_id=self.uuid,
)

def _get_tensor_memory(self, tensor_name: str) -> int:
"""Get memory usage of a tensor in bytes."""
if not hasattr(self, tensor_name):
self.log(f"Tensor '{tensor_name}' not found, using 0 bytes")
return 0
tensor = getattr(self, tensor_name)
return tensor.numel() * tensor.element_size()

@torch.jit.ignore
def _report_tbe_mem_usage(
self,
) -> None:
def _report_tbe_mem_usage(self) -> None:
if self.stats_reporter is None:
return

Expand All @@ -1726,18 +1732,17 @@ def _report_tbe_mem_usage(
return

total_mem_usage = sum(
param.numel() * param.element_size() for param in self.parameters()
) + sum(buffer.numel() * buffer.element_size() for buffer in self.buffers())
p.numel() * p.element_size() for p in self.parameters()
) + sum(b.numel() * b.element_size() for b in self.buffers())

if self.use_cpu:
total_hbm_usage = 0
total_uvm_usage = total_mem_usage
else:
# hbm usage is total usage minus uvm usage
total_uvm_usage = sum(
getattr(self, tensor_name).numel()
* getattr(self, tensor_name).element_size()
for tensor_name in self._uvm_tensors_log
if hasattr(self, tensor_name)
self._get_tensor_memory(name)
for name in self._uvm_tensors_log
if hasattr(self, name)
)
total_hbm_usage = total_mem_usage - total_uvm_usage

Expand Down
Loading