Skip to content

Conversation

@jalengg
Copy link

@jalengg jalengg commented Jan 15, 2026

No description provided.

Created directory structure following PyHealth conventions:
- pyhealth/models/promptehr/ (subdirectory per MedLink precedent)
  - __init__.py: Exports PromptEHR and reusable components
  - model.py: Main PromptEHR(BaseModel) class
  - conditional_prompt.py: ConditionalPromptEncoder
  - bart_encoder.py: PromptBartEncoder with prompt injection
  - bart_decoder.py: PromptBartDecoder with prompt injection
  - utils.py: VisitStructureSampler and helpers

- pyhealth/datasets/promptehr_dataset.py (flat file per convention)
- pyhealth/tasks/ehr_generation.py (flat file)

Key architectural decisions:
- Subdirectory for models: enables component reusability
- Flat files for datasets/tasks: follows PyHealth conventions
- All files have placeholder implementations with NotImplementedError

Phase 1: Foundation setup complete
Next: Phase 2 - Data infrastructure (tokenizer, dataset)
Implemented PromptEHR tokenization by extending PyHealth's existing
tokenizer infrastructure instead of creating redundant custom files.

Key implementation:
- Created create_promptehr_tokenizer() helper function in promptehr_dataset.py
- Uses PyHealth's Tokenizer class with 7 custom special tokens
- Maintains 1:1 code-to-token mapping (no fragmentation)
- code_offset=7 (medical codes start at ID 7)

Special tokens (IDs 0-6):
- <pad> (0): Padding token
- <s> (1): Start of sequence (BART BOS)
- </s> (2): End of sequence (BART EOS)
- <unk> (3): Unknown token
- <v> (4): Visit start marker
- </v> (5): Visit end marker
- <mask> (6): Masking token for corruption

Validation:
✓ All special tokens have correct IDs (0-6)
✓ code_offset = 7 verified
✓ 1:1 mapping prevents fragmentation (e.g., "401.9" → single token)
✓ Visit structure encoding works with <v>/<\v> markers
✓ Batch encoding with padding works correctly

Advantages over custom tokenizer:
- No redundant files (follows PyHealth conventions)
- Leverages existing batch_encode_2d/3d methods
- Backward compatible with pehr_scratch checkpoints
- Clean, minimal implementation (60 lines)

Next: Port dataset class and data collator from pehr_scratch
Ported complete dataset infrastructure from pehr_scratch to PyHealth:

Components added:
1. PatientRecord class - Container for patient demographics + visit history
2. load_mimic_data() - MIMIC-III CSV loading with age calculation
3. CorruptionFunctions - Three corruption strategies:
   - Mask infilling: Poisson-distributed span masking
   - Token deletion: Binomial deletion (keep ≥1 code/visit)
   - Token replacement: Random code substitution
4. PromptEHRDataset - PyTorch Dataset wrapping patient records
5. EHRDataCollator - Batch collator with corruption augmentation

Key features:
- Integrated with PyHealth's Tokenizer (no custom tokenizer needed)
- Visit structure encoding: <s> <v> codes </v> ... </s>
- Code shuffling within visits (treats codes as unordered sets)
- Corruption probability: 50% (configurable)
- Padding with -100 for labels (PyTorch ignore index)
- Demographics: age (normalized) + gender only (ethnicity removed)

Data flow:
1. load_mimic_data() → PatientRecord list + diagnosis codes
2. create_promptehr_tokenizer(codes) → Tokenizer
3. PromptEHRDataset(records, tokenizer) → PyTorch Dataset
4. EHRDataCollator(tokenizer, max_seq_len) → DataLoader collate_fn

Validation needed:
- Test MIMIC-III loading with real data
- Verify corruption strategies preserve visit structure
- Check batch shapes match model expectations

Files modified: 1 (promptehr_dataset.py: 449 lines)
Files created: 1 (promptehr_collator.py: 209 lines)
Total: 658 lines of ported dataset code

Phase 2 Data Infrastructure: Complete
Next: Phase 3 - Port model architecture (conditional prompt, BART encoder/decoder)
Ported conditional prompt encoder from pehr_scratch with demographic
conditioning through reparameterization.

Components added:
1. NumericalConditionalPrompt - Age embedding with d_hidden=128 bottleneck
2. CategoricalConditionalPrompt - Gender embedding with offset-based indexing
3. ConditionalPromptEncoder - Combined encoder for age + gender

Key features:
- Reparameterization: feature → 128-dim → 768-dim (prevents overfitting)
- Offset-based categorical indexing (prevents category collision)
- Xavier uniform initialization for all parameters
- Supports prompt_length scaling (currently 1)

Demographics used (verified in config.py):
- ✓ n_num_features = 1 (age only)
- ✓ cat_cardinalities = [2] (gender M/F only)
- ✓ d_hidden = 128 (reparameterization bottleneck)
- ✓ Ethnicity removed for medical validity

Architecture:
Age (continuous) → weight * value + bias → proj → 768-dim prompt
Gender (categorical) → embedding + bias → proj → 768-dim prompt
Combined: [batch, 2, 768] (2 prompts: age + gender)

Source: pehr_scratch/conditional_prompt.py (lines 1-219)
No changes to logic - direct port maintaining exact functionality

Phase 3.1 complete: 252 lines
Next: Phase 3.2 - BART encoder with prompt injection
Ported PromptBartEncoder from pehr_scratch/prompt_bart_encoder.py (149 lines).

Key features:
- Extends transformers BartEncoder to accept demographic prompts
- Prepends prompt embeddings to input token embeddings
- Extends attention masks to cover prepended prompts
- Processes through standard BART encoder layers

Implementation:
- PromptBartEncoder: Main encoder class (lines 16-190)
- _expand_mask: Helper for attention mask expansion (lines 193-216)
- Maintains compatibility with transformers BartConfig
- Supports prompt conditioning via inputs_prompt_embeds parameter

Verified in config.py:
- Uses "facebook/bart-base" pre-trained model ✓
- No disabled features (all code actively used) ✓
- dropout=0.3 (increased from BART default) ✓

Source: pehr_scratch/prompt_bart_encoder.py
Phase 3.2 complete: 217 lines
Next: Phase 3.3 - BART decoder with prompt injection
Ported PromptBartDecoder from pehr_scratch/prompt_bart_decoder.py (207 lines).

Key features:
- Extends transformers BartDecoder to accept demographic prompts
- Prepends prompt embeddings to decoder input token embeddings
- Extends attention masks (both causal and padding masks)
- Handles cross-attention to encoder outputs
- Supports key-value caching for efficient generation

Implementation:
- PromptBartDecoder: Main decoder class (lines 16-250)
- _make_causal_mask: Helper for autoregressive masking (lines 253-289)
- _expand_mask: Helper for attention mask expansion (lines 292-315)
- Dual prompt injection: Prompts injected in both encoder AND decoder

Verified in config.py:
- Uses "facebook/bart-base" pre-trained model ✓
- No disabled features (all code actively used) ✓
- dropout=0.3 (increased from BART default) ✓

Critical: Dual prompt injection (encoder + decoder) prevents demographic drift
per implementation decision D008.

Source: pehr_scratch/prompt_bart_decoder.py
Phase 3.3 complete: 316 lines
Next: Phase 3.4 - Main PromptEHR model
…ase 3.4)

Ported PromptBartModel from pehr_scratch/prompt_bart_model.py (lines 16-262).
Created PyHealth wrapper class PromptEHR extending BaseModel.

Core Components:
1. PromptBartModel (lines 21-332):
   - Extends transformers BartForConditionalGeneration
   - Dual prompt encoders (encoder + decoder separate, per D008)
   - Label smoothing = 0.1 (active regularization)
   - Generation methods with demographics passing
   - Prompt position slicing before loss computation

2. PromptEHR (lines 359-475):
   - Extends PyHealth BaseModel for ecosystem integration
   - Wraps PromptBartModel following pyhealth-expert guidance
   - forward() returns {"loss": ...} for training
   - generate() method for inference
   - mode=None to skip discriminative evaluation

3. shift_tokens_right helper (lines 335-356)

Key Decisions:
- Excluded auxiliary losses per D003 (caused mode collapse)
- Excluded code_offset parameter per pehr-scratch-expert (unused)
- Applied dropout=0.3 (increased from BART default 0.1)
- Architecture follows pyhealth-expert Option A (composition not inheritance)

Expert Guidance:
- pehr-scratch-expert: Confirmed what features are actually used vs. disabled
- pyhealth-expert: Provided BaseModel integration pattern for generative models

Verified in config.py:
- Label smoothing: 0.1 ✓
- Dual prompts: both encoder and decoder ✓
- Generation methods: used in generate.py ✓
- Auxiliary losses: DISABLED (weights=0.0) ✓

Source: pehr_scratch/prompt_bart_model.py (lines 16-262 + 264-276)
Phase 3 complete: 476 lines
Next: Phase 4 - Visit structure constraints (or testing/integration)
Fixed two bugs found during testing (both also present in pehr_scratch):

1. **Encoder: None attention_mask handling** (bart_encoder.py)
   - Bug: Accessed attention_mask.dtype before checking if None
   - Fix: Check None first, then access dtype/device
   - Impact: Makes code more robust for direct encoder usage
   - Confirmed by pehr-scratch-expert: Same bug in pehr_scratch

2. **Decoder: Defensive cache handling** (bart_decoder.py)
   - Bug: IndexError on empty cache structures during generation
   - Fix: Try-except with safe fallback (past_key_values_length=0)
   - Impact: Handles transformers 4.53.3 Cache API differences
   - Expert confirmed: Safe fallback, slight quality degradation but no crash

Testing:
- All 9 tests pass (test_promptehr_basic.py)
- Generation works correctly with demographics
- Compatible with transformers 4.53.3

Per pehr-scratch-expert: Our fixes make PyHealth MORE defensive than pehr_scratch.
…ions (Phase 4)

Add VisitStructureSampler and 5 core generation functions ported from pehr_scratch:
- sample_demographics: Realistic demographic sampling
- decode_patient_demographics: Demographics formatting
- parse_sequence_to_visits: Token sequence to visit structure conversion
- generate_patient_sequence_conditional: Reconstruction from partial prompts
- generate_patient_with_structure_constraints: PRIMARY generation method

Key improvements over pehr_scratch:
- Flexible patient record handling (dict or object)
- Flexible demographic extraction (gender or sex attribute)
- No external config dependencies (fully self-contained)
- Python 3.9 type hint compatibility

Testing: 12/12 sanity tests passing
Code reduction: 894→465 lines (kept production essentials only)
Files: pyhealth/models/promptehr/{visit_sampler.py, generation.py}
…ding (Phase 5)

Add checkpoint loading utility and validate PyHealth Trainer compatibility:
- load_from_checkpoint() classmethod with auto vocab size detection
- Supports pehr_scratch checkpoint format (model_state_dict extraction)
- Auto-detects custom vocabularies (6992 vs 50265 tokens)
- Test 13: Trainer integration validated (mode=None generative evaluation)
- Test 14: Checkpoint loading validated (loads best_model.pt successfully)

Key findings:
- PyHealth Trainer works out-of-the-box with mode=None (no custom Trainer needed)
- Demographics (x_num, x_cat) auto-forward through Trainer
- Generative evaluation returns only loss (no accuracy/f1 metrics)

Testing: 14/14 sanity tests passing
Files: pyhealth/models/promptehr/model.py, test_promptehr_basic.py
Validation: Trainer integration + checkpoint loading from pehr_scratch
Add example scripts for training and generating synthetic patients:
- examples/promptehr_mimic3.py: Complete training and generation pipeline
- examples/promptehr_train.slurm: SLURM batch script for GPU training

Training example includes:
- MIMIC-III data loading (PATIENTS.csv, ADMISSIONS.csv, DIAGNOSES_ICD.csv)
- PromptEHR model training with PyHealth Trainer
- Checkpoint loading/saving
- 4 generation methods (structure-constrained, conditional, demographic, warm-start)
- Synthetic dataset export (CSV/JSON)

SLURM script includes:
- GPU resource allocation (1x GPU, 64GB RAM, 8 CPUs)
- Environment validation
- Automatic directory creation
- Full logging and monitoring

Usage:
  # Interactive training
  python examples/promptehr_mimic3.py --mimic3_root /path/to/data --num_epochs 20

  # SLURM submission
  sbatch examples/promptehr_train.slurm

Files:
- examples/promptehr_mimic3.py (15KB)
- examples/promptehr_train.slurm (4.1KB)
- pyhealth/models/__init__.py (added PromptEHR export)
- AdamW moved from transformers to torch.optim in newer versions
- Fix virtual environment path in SLURM script to use pehr_scratch venv
- Import fix: torch.optim.AdamW instead of transformers.AdamW
- Fix model initialization to use correct parameters (dataset, bart_config_name, _custom_vocab_size)
- Remove invalid dropout parameters (handled internally by PromptEHR)
- Fix checkpoint saving to save bart_model.state_dict() instead of model.state_dict()
- Fix checkpoint loading to load into bart_model.state_dict()
- Update checkpoint config to match PromptEHR.__init__ signature

Resolves TypeError: BaseModel.__init__() got unexpected keyword argument 'vocab_size'
This commit fixes the RuntimeError: "Expected all tensors to be on the same
device, but found at least two devices, cuda:0 and cpu!" that occurred during
PromptEHR training.

Root Cause:
- PyHealth Trainer moves model to CUDA but NOT data (trainer.py line 206)
- EHRDataCollator returns CPU tensors
- Forward pass fails when model parameters (cuda:0) interact with data (cpu)
- Error occurred in conditional_prompt.py line 64: self.weight[None] * x_num[..., None]

Solution:
- Added DeviceAwareCollatorWrapper class to wrap EHRDataCollator
- Wrapper intercepts collator output and moves all tensors to target device
- Ensures model and data are on same device before forward pass
- Modified training pipeline to use wrapped collator

Changes:
- examples/promptehr_mimic3.py:
  - Added DeviceAwareCollatorWrapper class (lines 38-73)
  - Modified train_promptehr() to wrap base collator (lines 182-186)
  - All tensors (age, sex, input_ids, decoder_input_ids, etc.) now on correct device

This addresses PyHealth framework limitation where Trainer assumes data is
already on correct device.
Issue: Job 6147766 failed during post-training generation phase
Error: ImportError on line 297 of examples/promptehr_mimic3.py
Root Cause: VisitStructureSampler is in pyhealth.models.promptehr, not pyhealth.datasets.promptehr_dataset

Changes:
- Consolidated imports on line 297 to use correct module path
- Both VisitStructureSampler and generate_patient_with_structure_constraints now imported from pyhealth.models.promptehr

Note: Training phase completed successfully (19/19 epochs, loss=1.3366)
This fix enables synthetic patient generation using the saved checkpoint.
PyTorch 2.6+ changed default from weights_only=False to True.
Checkpoints containing custom objects (tokenizer) require explicit weights_only=False.
New Features:
1. Created examples/promptehr_generate_local.py - standalone script for quick CPU-based generation
   - Generates 10 synthetic patients in ~15 seconds on CPU
   - No SLURM or GPU required
   - Human-readable console output

2. Extended tokenizer API for pehr_scratch compatibility
   - Added convenience properties: bos_token_id, pad_token_id, eos_token_id, code_offset
   - Added method alias: convert_tokens_to_ids() wrapping convert_tokens_to_indices()
   - Added vocab object with idx2code and code2idx mappings

Changes to pyhealth/datasets/promptehr_dataset.py:
- create_promptehr_tokenizer() now adds 4 properties and 1 method alias
- Maintains backward compatibility with checkpoints

Changes to examples/promptehr_generate_local.py:
- New file: 141 lines
- Adds compatibility shims for old checkpoints without new properties
- Loads checkpoint, generates patients, displays results

Tested: Successfully generates 10 patients with realistic demographics and ICD-9 codes
Core Changes:
- Fix tokenizer pickling: Move VocabCompat to module level for serialization
- Fix generation API: Update to use PyHealth convert_tokens_to_indices()
- Fix token bug: Correct visit end token from <\v> to </v>

Scripts Added:
- split_mimic_train_holdout.py: Split MIMIC-III into train/holdout sets
- promptehr_train_holdout.slurm: Train on 45,520 patients (1k held out)
- promptehr_generate_10k.slurm: Generate 10,000 synthetic patients

Training Results:
- Trained on 45,520 patients (1,000 holdout for evaluation)
- 20 epochs, loss converged to 1.346
- Successfully generated 10,000 synthetic patients
- Dataset: 12,905 visits, 535,420 diagnosis codes, 517 unique codes
- Average: 1.29 visits/patient, 53.5 codes/patient

Technical Details:
- Fixed local function pickling issues in tokenizer creation
- Added VocabCompat wrapper class for pehr_scratch API compatibility
- All comprehensive tests passing (tokenizer save/load validated)
- Generation with proper error handling (set -e, set -o pipefail)

Documentation:
- Added IMPLEMENTATION_REPORT.md with full technical details
Generated Outputs:
- synthetic_patients_10000.csv: 10,000 synthetic patients (930KB)
  - 12,905 visits total
  - 535,420 diagnosis codes
  - 517 unique ICD-9 codes
  - 1.29 avg visits/patient, 53.5 avg codes/patient

- synthetic_patients_200.csv: 200 validation patients (22KB)
  - Generated during training completion

Generated from model trained on:
- 45,520 training patients (1,000 held out)
- 20 epochs, converged loss 1.346
- Job 6623691 completed successfully Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant