Skip to content

Commit 7e9e7de

Browse files
committed
Add batch transcode function to convert utils
Signed-off-by: Joaquin Anton Guirao <[email protected]>
1 parent 3b5bd1c commit 7e9e7de

File tree

4 files changed

+775
-204
lines changed

4 files changed

+775
-204
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,46 @@
4040

4141
logger = logging.getLogger(__name__)
4242

43+
# Global singleton instances for nvimgcodec encoder/decoder
44+
# These are initialized lazily on first use to avoid import errors
45+
# when nvimgcodec is not available
46+
_NVIMGCODEC_ENCODER = None
47+
_NVIMGCODEC_DECODER = None
48+
49+
50+
def _get_nvimgcodec_encoder():
51+
"""Get or create the global nvimgcodec encoder singleton."""
52+
global _NVIMGCODEC_ENCODER
53+
if _NVIMGCODEC_ENCODER is None:
54+
try:
55+
from nvidia import nvimgcodec
56+
_NVIMGCODEC_ENCODER = nvimgcodec.Encoder()
57+
logger.debug("Initialized global nvimgcodec.Encoder singleton")
58+
except ImportError:
59+
raise ImportError(
60+
"nvidia-nvimgcodec is required for HTJ2K transcoding. "
61+
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
62+
"(replace {XX} with your CUDA version, e.g., cu13)"
63+
)
64+
return _NVIMGCODEC_ENCODER
65+
66+
67+
def _get_nvimgcodec_decoder():
68+
"""Get or create the global nvimgcodec decoder singleton."""
69+
global _NVIMGCODEC_DECODER
70+
if _NVIMGCODEC_DECODER is None:
71+
try:
72+
from nvidia import nvimgcodec
73+
_NVIMGCODEC_DECODER = nvimgcodec.Decoder()
74+
logger.debug("Initialized global nvimgcodec.Decoder singleton")
75+
except ImportError:
76+
raise ImportError(
77+
"nvidia-nvimgcodec is required for HTJ2K decoding. "
78+
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
79+
"(replace {XX} with your CUDA version, e.g., cu13)"
80+
)
81+
return _NVIMGCODEC_DECODER
82+
4383

4484
class SegmentDescription:
4585
"""Wrapper class for segment description following MONAI Deploy pattern.
@@ -597,3 +637,272 @@ def dicom_seg_to_itk_image(label, output_ext=".seg.nrrd"):
597637

598638
logger.info(f"Result/Output File: {output_file}")
599639
return output_file
640+
641+
642+
def transcode_dicom_to_htj2k(
643+
input_dir: str,
644+
output_dir: str = None,
645+
num_resolutions: int = 6,
646+
code_block_size: tuple = (64, 64),
647+
verify: bool = False,
648+
) -> str:
649+
"""
650+
Transcode DICOM files to HTJ2K (High Throughput JPEG 2000) lossless compression.
651+
652+
HTJ2K is a faster variant of JPEG 2000 that provides better compression performance
653+
for medical imaging applications. This function uses nvidia-nvimgcodec for encoding
654+
with batch processing for improved performance. All transcoding is performed using
655+
lossless compression to preserve image quality.
656+
657+
The function operates in three phases:
658+
1. Load all DICOM files and prepare pixel arrays
659+
2. Batch encode all images to HTJ2K in parallel
660+
3. Save encoded data back to DICOM files
661+
662+
Args:
663+
input_dir: Path to directory containing DICOM files to transcode
664+
output_dir: Path to output directory for transcoded files. If None, creates temp directory
665+
num_resolutions: Number of resolution levels (default: 6)
666+
code_block_size: Code block size as (height, width) tuple (default: (64, 64))
667+
verify: If True, decode output to verify correctness (default: False)
668+
669+
Returns:
670+
Path to output directory containing transcoded DICOM files
671+
672+
Raises:
673+
ImportError: If nvidia-nvimgcodec or pydicom are not available
674+
ValueError: If input directory doesn't exist or contains no DICOM files
675+
676+
Example:
677+
>>> output_dir = transcode_dicom_to_htj2k("/path/to/dicoms")
678+
>>> # Transcoded files are now in output_dir with lossless HTJ2K compression
679+
680+
Note:
681+
Requires nvidia-nvimgcodec to be installed:
682+
pip install nvidia-nvimgcodec-cu{XX}[all]
683+
Replace {XX} with your CUDA version (e.g., cu13 for CUDA 13.x)
684+
"""
685+
import glob
686+
import shutil
687+
from pathlib import Path
688+
689+
# Check for nvidia-nvimgcodec
690+
try:
691+
from nvidia import nvimgcodec
692+
except ImportError:
693+
raise ImportError(
694+
"nvidia-nvimgcodec is required for HTJ2K transcoding. "
695+
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
696+
"(replace {XX} with your CUDA version, e.g., cu13)"
697+
)
698+
699+
# Validate input
700+
if not os.path.exists(input_dir):
701+
raise ValueError(f"Input directory does not exist: {input_dir}")
702+
703+
if not os.path.isdir(input_dir):
704+
raise ValueError(f"Input path is not a directory: {input_dir}")
705+
706+
# Get all DICOM files
707+
dicom_files = []
708+
for pattern in ["*.dcm", "*"]:
709+
dicom_files.extend(glob.glob(os.path.join(input_dir, pattern)))
710+
711+
# Filter to actual DICOM files
712+
valid_dicom_files = []
713+
for file_path in dicom_files:
714+
if os.path.isfile(file_path):
715+
try:
716+
# Quick check if it's a DICOM file
717+
with open(file_path, 'rb') as f:
718+
f.seek(128)
719+
magic = f.read(4)
720+
if magic == b'DICM':
721+
valid_dicom_files.append(file_path)
722+
except Exception:
723+
continue
724+
725+
if not valid_dicom_files:
726+
raise ValueError(f"No valid DICOM files found in {input_dir}")
727+
728+
logger.info(f"Found {len(valid_dicom_files)} DICOM files to transcode")
729+
730+
# Create output directory
731+
if output_dir is None:
732+
output_dir = tempfile.mkdtemp(prefix="htj2k_")
733+
else:
734+
os.makedirs(output_dir, exist_ok=True)
735+
736+
# Create encoder and decoder instances (reused for all files)
737+
encoder = _get_nvimgcodec_encoder()
738+
decoder = _get_nvimgcodec_decoder() if verify else None
739+
740+
# HTJ2K Transfer Syntax UID - Lossless Only
741+
# 1.2.840.10008.1.2.4.201 = HTJ2K Lossless Only
742+
target_transfer_syntax = "1.2.840.10008.1.2.4.201"
743+
quality_type = nvimgcodec.QualityType.LOSSLESS
744+
logger.info("Using lossless HTJ2K compression")
745+
746+
# Configure JPEG2K encoding parameters
747+
jpeg2k_encode_params = nvimgcodec.Jpeg2kEncodeParams()
748+
jpeg2k_encode_params.num_resolutions = num_resolutions
749+
jpeg2k_encode_params.code_block_size = code_block_size
750+
jpeg2k_encode_params.bitstream_type = nvimgcodec.Jpeg2kBitstreamType.JP2
751+
jpeg2k_encode_params.prog_order = nvimgcodec.Jpeg2kProgOrder.LRCP
752+
jpeg2k_encode_params.ht = True # Enable High Throughput mode
753+
754+
encode_params = nvimgcodec.EncodeParams(
755+
quality_type=quality_type,
756+
jpeg2k_encode_params=jpeg2k_encode_params,
757+
)
758+
759+
start_time = time.time()
760+
transcoded_count = 0
761+
skipped_count = 0
762+
failed_count = 0
763+
764+
# Phase 1: Load all DICOM files and prepare pixel arrays for batch encoding
765+
logger.info("Phase 1: Loading DICOM files and preparing pixel arrays...")
766+
dicom_datasets = []
767+
pixel_arrays = []
768+
files_to_encode = []
769+
770+
for i, input_file in enumerate(valid_dicom_files, 1):
771+
try:
772+
# Read DICOM
773+
ds = pydicom.dcmread(input_file)
774+
775+
# Check if already HTJ2K
776+
current_ts = getattr(ds, 'file_meta', {}).get('TransferSyntaxUID', None)
777+
if current_ts and str(current_ts).startswith('1.2.840.10008.1.2.4.20'):
778+
logger.debug(f"[{i}/{len(valid_dicom_files)}] Already HTJ2K: {os.path.basename(input_file)}")
779+
# Just copy the file
780+
output_file = os.path.join(output_dir, os.path.basename(input_file))
781+
shutil.copy2(input_file, output_file)
782+
skipped_count += 1
783+
continue
784+
785+
# Use pydicom's pixel_array to decode the source image
786+
# This handles all transfer syntaxes automatically
787+
source_pixel_array = ds.pixel_array
788+
789+
# Ensure it's a numpy array
790+
if not isinstance(source_pixel_array, np.ndarray):
791+
source_pixel_array = np.array(source_pixel_array)
792+
793+
# Add channel dimension if needed (nvimgcodec expects shape like (H, W, C))
794+
if source_pixel_array.ndim == 2:
795+
source_pixel_array = source_pixel_array[:, :, np.newaxis]
796+
797+
# Store for batch encoding
798+
dicom_datasets.append(ds)
799+
pixel_arrays.append(source_pixel_array)
800+
files_to_encode.append(input_file)
801+
802+
if i % 50 == 0 or i == len(valid_dicom_files):
803+
logger.info(f"Loading progress: {i}/{len(valid_dicom_files)} files loaded")
804+
805+
except Exception as e:
806+
logger.error(f"[{i}/{len(valid_dicom_files)}] Error loading {os.path.basename(input_file)}: {e}")
807+
failed_count += 1
808+
continue
809+
810+
if not pixel_arrays:
811+
logger.warning("No images to encode")
812+
return output_dir
813+
814+
# Phase 2: Batch encode all images to HTJ2K
815+
logger.info(f"Phase 2: Batch encoding {len(pixel_arrays)} images to HTJ2K...")
816+
encode_start = time.time()
817+
818+
try:
819+
encoded_htj2k_images = encoder.encode(
820+
pixel_arrays,
821+
codec="jpeg2k",
822+
params=encode_params,
823+
)
824+
encode_time = time.time() - encode_start
825+
logger.info(f"Batch encoding completed in {encode_time:.2f} seconds ({len(pixel_arrays)/encode_time:.1f} images/sec)")
826+
except Exception as e:
827+
logger.error(f"Batch encoding failed: {e}")
828+
# Fall back to individual encoding
829+
logger.warning("Falling back to individual encoding...")
830+
encoded_htj2k_images = []
831+
for idx, pixel_array in enumerate(pixel_arrays):
832+
try:
833+
encoded_image = encoder.encode(
834+
[pixel_array],
835+
codec="jpeg2k",
836+
params=encode_params,
837+
)
838+
encoded_htj2k_images.extend(encoded_image)
839+
except Exception as e2:
840+
logger.error(f"Failed to encode image {idx}: {e2}")
841+
encoded_htj2k_images.append(None)
842+
843+
# Phase 3: Save encoded data back to DICOM files
844+
logger.info("Phase 3: Saving encoded DICOM files...")
845+
save_start = time.time()
846+
847+
for idx, (ds, encoded_data, input_file) in enumerate(zip(dicom_datasets, encoded_htj2k_images, files_to_encode)):
848+
try:
849+
if encoded_data is None:
850+
logger.error(f"Skipping {os.path.basename(input_file)} - encoding failed")
851+
failed_count += 1
852+
continue
853+
854+
# Encapsulate encoded frames for DICOM
855+
new_encoded_frames = [bytes(encoded_data)]
856+
encapsulated_pixel_data = pydicom.encaps.encapsulate(new_encoded_frames)
857+
ds.PixelData = encapsulated_pixel_data
858+
859+
# Update transfer syntax UID
860+
ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
861+
862+
# Save to output directory
863+
output_file = os.path.join(output_dir, os.path.basename(input_file))
864+
ds.save_as(output_file)
865+
866+
# Verify if requested
867+
if verify:
868+
ds_verify = pydicom.dcmread(output_file)
869+
pixel_data = ds_verify.PixelData
870+
data_sequence = pydicom.encaps.decode_data_sequence(pixel_data)
871+
images_verify = decoder.decode(
872+
data_sequence,
873+
params=nvimgcodec.DecodeParams(
874+
allow_any_depth=True,
875+
color_spec=nvimgcodec.ColorSpec.UNCHANGED
876+
),
877+
)
878+
image_verify = np.array(images_verify[0].cpu()).squeeze()
879+
880+
if not np.allclose(image_verify, ds_verify.pixel_array):
881+
logger.warning(f"Verification failed for {os.path.basename(input_file)}")
882+
failed_count += 1
883+
continue
884+
885+
transcoded_count += 1
886+
887+
if (idx + 1) % 50 == 0 or (idx + 1) == len(dicom_datasets):
888+
logger.info(f"Saving progress: {idx + 1}/{len(dicom_datasets)} files saved")
889+
890+
except Exception as e:
891+
logger.error(f"Error saving {os.path.basename(input_file)}: {e}")
892+
failed_count += 1
893+
continue
894+
895+
save_time = time.time() - save_start
896+
logger.info(f"Saving completed in {save_time:.2f} seconds")
897+
898+
elapsed_time = time.time() - start_time
899+
900+
logger.info(f"Transcoding complete:")
901+
logger.info(f" Total files: {len(valid_dicom_files)}")
902+
logger.info(f" Successfully transcoded: {transcoded_count}")
903+
logger.info(f" Already HTJ2K (copied): {skipped_count}")
904+
logger.info(f" Failed: {failed_count}")
905+
logger.info(f" Time elapsed: {elapsed_time:.2f} seconds")
906+
logger.info(f" Output directory: {output_dir}")
907+
908+
return output_dir

monailabel/transform/reader.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import logging
1515
import os
16+
import threading
1617
import warnings
1718
from collections.abc import Sequence
1819
from typing import TYPE_CHECKING, Any
@@ -45,6 +46,22 @@
4546

4647
__all__ = ["NvDicomReader"]
4748

49+
# Thread-local storage for nvimgcodec decoder
50+
# Each thread gets its own decoder instance for thread safety
51+
_thread_local = threading.local()
52+
53+
54+
def _get_nvimgcodec_decoder():
55+
"""Get or create a thread-local nvimgcodec decoder singleton."""
56+
if not has_nvimgcodec:
57+
raise RuntimeError("nvimgcodec is not available. Cannot create decoder.")
58+
59+
if not hasattr(_thread_local, 'decoder') or _thread_local.decoder is None:
60+
_thread_local.decoder = nvimgcodec.Decoder()
61+
logger.debug(f"Initialized thread-local nvimgcodec.Decoder for thread {threading.current_thread().name}")
62+
63+
return _thread_local.decoder
64+
4865

4966
def _copy_compatible_dict(from_dict: dict, to_dict: dict):
5067
if not isinstance(to_dict, dict):
@@ -173,13 +190,12 @@ def __init__(
173190
self.use_nvimgcodec = use_nvimgcodec
174191
self.prefer_gpu_output = prefer_gpu_output
175192
self.allow_fallback_decode = allow_fallback_decode
176-
# Initialize nvImageCodec decoder if needed
193+
# Initialize decode params for nvImageCodec if needed
177194
if self.use_nvimgcodec:
178195
if not has_nvimgcodec:
179196
warnings.warn("NvDicomReader: nvImageCodec not installed, will use pydicom for decoding.")
180197
self.use_nvimgcodec = False
181198
else:
182-
self._nvimgcodec_decoder = nvimgcodec.Decoder()
183199
self.decode_params = nvimgcodec.DecodeParams(
184200
allow_any_depth=True, color_spec=nvimgcodec.ColorSpec.UNCHANGED
185201
)
@@ -314,7 +330,8 @@ def _nvimgcodec_decode(self, img, filename):
314330
if fragment and fragment != b"\x00\x00\x00\x00"
315331
]
316332
logger.info(f"NvDicomReader: Decoding {len(data_sequence)} fragment(s) with nvImageCodec")
317-
decoded_data = self._nvimgcodec_decoder.decode(data_sequence, params=self.decode_params)
333+
decoder = _get_nvimgcodec_decoder()
334+
decoded_data = decoder.decode(data_sequence, params=self.decode_params)
318335

319336
# Check if decode succeeded (nvImageCodec returns None on failure)
320337
if not decoded_data or decoded_data[0] is None:
@@ -637,7 +654,8 @@ def _process_dicom_series(self, file_paths: list) -> tuple[np.ndarray, dict]:
637654
all_frames.extend(frames)
638655

639656
# Decode all frames at once
640-
decoded_data = self._nvimgcodec_decoder.decode(all_frames, params=self.decode_params)
657+
decoder = _get_nvimgcodec_decoder()
658+
decoded_data = decoder.decode(all_frames, params=self.decode_params)
641659

642660
if not decoded_data or any(d is None for d in decoded_data):
643661
raise ValueError("nvImageCodec batch decode failed")

0 commit comments

Comments
 (0)