Skip to content

feat: add column_map support to collate functions#561

Open
zamal-db wants to merge 1 commit intoPrunaAI:mainfrom
zamal-db:feat/collate-column-map
Open

feat: add column_map support to collate functions#561
zamal-db wants to merge 1 commit intoPrunaAI:mainfrom
zamal-db:feat/collate-column-map

Conversation

@zamal-db
Copy link

@zamal-db zamal-db commented Mar 1, 2026

Description

After working through the evaluation metric issues in #560, I moved on to benchmarking against a couple of HF image-preference datasets. Tried loading data-is-better-together/open-image-preferences-v1-binarized via PrunaDataModule.from_datasets and immediately got a KeyError: 'image' because that dataset uses chosen / prompt instead of image / text.

The workaround is calling dataset.rename_column() on every split before passing them in, but that gets old fast when you're iterating over several datasets. Noticed #297 describes the same problem and has been open for a while, so I went ahead and fixed it.

What changed

Added an optional column_map parameter to all seven collate functions in collate.py. It accepts a dict mapping canonical names to actual dataset column names:

dm = PrunaDataModule.from_datasets(
    (train_ds, val_ds, test_ds),
    collate_fn="image_generation_collate",
    collate_fn_args={
        "img_size": 512,
        "column_map": {"image": "chosen", "text": "prompt"},
    },
)

No changes to PrunaDataModule itself were needed since collate_fn_args already flows through to functools.partial.

  • Default behavior is unchanged (column_map=None)
  • One small helper _resolve_column() keeps things DRY
  • All seven collate functions updated: image_generation_collate, prompt_collate, prompt_with_auxiliaries_collate, audio_collate, image_classification_collate, text_generation_collate, question_answering_collate

Related Issue

Closes #297

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

New tests/data/test_collate.py with 11 CPU tests covering every collate function with remapped columns, plus two end-to-end tests through the full PrunaDataModule.from_datasets pipeline (one with column_map, one without to verify backward compatibility). All pass with pytest -m cpu.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

None

Add an optional column_map parameter to all seven collate
functions so users can pass custom HF dataset column names
without renaming them first.

Closes PrunaAI#297
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Enhance PrunaDataModule to accept any Column Name from HF datasets

1 participant