Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions plugins/huggingface/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Hugging Face Datasets Plugin

Native support for HuggingFace Datasets in Flyte: prefetch datasets from the Hub to remote storage and pass `datasets.Dataset` between tasks with automatic Parquet serialization.

## Installation

```bash
pip install flyteplugins-huggingface
```

## Prefetch from HuggingFace Hub

Stream a dataset from the Hub directly to Flyte's remote storage:

```python
import flyte
from flyteplugins.huggingface import hf_dataset

flyte.init(endpoint="my-flyte-endpoint")

run = hf_dataset(repo="stanfordnlp/imdb", split="train")
run.wait()
data_dir = run.outputs()[0] # flyte.io.Dir with parquet files
```

## Type transformer

Pass `datasets.Dataset` between tasks with automatic serialization:

```python
import flyte
import datasets

env = flyte.TaskEnvironment(
name="hf-example",
image=flyte.Image.from_debian_base().with_pip_packages(
"flyteplugins-huggingface",
),
)


@env.task
async def create_dataset() -> datasets.Dataset:
return datasets.Dataset.from_dict({
"text": ["hello", "world", "foo"],
"label": [0, 1, 0],
})


@env.task
async def filter_positive(ds: datasets.Dataset) -> datasets.Dataset:
return ds.filter(lambda x: x["label"] == 1)
```

## Column filtering

Use type annotations to load only specific columns:

```python
from typing import Annotated
from collections import OrderedDict

@env.task
async def load_text_only(
ds: Annotated[datasets.Dataset, OrderedDict(text=str)],
) -> list:
return ds["text"]
```
76 changes: 76 additions & 0 deletions plugins/huggingface/examples/hf_dataset_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Example: HuggingFace Datasets with Flyte.

This example demonstrates:
- Prefetching a dataset from HuggingFace Hub to remote storage
- Loading a prefetched Dir into a datasets.Dataset inside a task
- Passing datasets.Dataset between tasks via the type transformer
- Creating and returning new datasets from tasks
"""

import datasets
import flyte
import pyarrow.parquet as pq

from flyteplugins.huggingface import hf_dataset

env = flyte.TaskEnvironment(
name="hf-dataset-example",
image=flyte.Image.from_debian_base(name="hf-dataset-example").with_pip_packages(
"flyteplugins-huggingface",
),
)


@env.task
async def load_from_dir(data_dir: flyte.io.Dir) -> datasets.Dataset:
"""Load parquet files from a prefetched Dir into a datasets.Dataset."""
tables = []
async for file in data_dir.walk():
if file.path.endswith(".parquet"):
local = await file.download()
tables.append(pq.read_table(local))
import pyarrow as pa

return datasets.Dataset(pa.concat_tables(tables))


@env.task
async def tokenize(ds: datasets.Dataset) -> datasets.Dataset:
"""Simple tokenization: add word count column."""
word_counts = [len(text.split()) for text in ds["text"]]
return ds.add_column("word_count", word_counts)


@env.task
async def filter_long(ds: datasets.Dataset) -> datasets.Dataset:
"""Keep only rows with more than 100 words."""
return ds.filter(lambda row: row["word_count"] > 100)


@env.task
async def summary(ds: datasets.Dataset) -> str:
return f"{len(ds)} rows, columns: {ds.column_names}"


if __name__ == "__main__":
flyte.init()

# 1. Prefetch dataset from HuggingFace Hub to remote storage
run = hf_dataset(repo="stanfordnlp/imdb", split="train")
run.wait()
data_dir = run.outputs()[0]

# 2. Load into datasets.Dataset inside a task
run = flyte.with_runcontext("local").run(load_from_dir, data_dir)
ds = run.outputs()[0]

# 3. Pass datasets.Dataset between tasks via the type transformer
run = flyte.with_runcontext("local").run(tokenize, ds)
tokenized = run.outputs()[0]

run = flyte.with_runcontext("local").run(filter_long, tokenized)
filtered = run.outputs()[0]

run = flyte.with_runcontext("local").run(summary, filtered)
print(run.outputs()[0])
85 changes: 85 additions & 0 deletions plugins/huggingface/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
[project]
name = "flyteplugins-huggingface"
dynamic = ["version"]
description = "Hugging Face Datasets plugin for flyte"
readme = "README.md"
authors = [{ name = "Flyte Contributors", email = "admin@flyte.org" }]
requires-python = ">=3.10"
dependencies = [
"datasets>=2.14.5",
"huggingface-hub>=0.27.0",
"pyarrow",
"flyte"
]

[project.entry-points."flyte.plugins.types"]
huggingface = "flyteplugins.huggingface.df_transformer:register_huggingface_df_transformers"

[build-system]
requires = ["setuptools", "setuptools_scm"]
build-backend = "setuptools.build_meta"

[dependency-groups]
dev = [
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"pandas",
]

[tool.setuptools]
include-package-data = true
license-files = ["licenses/*.txt", "LICENSE"]

[tool.setuptools.packages.find]
where = ["src"]
include = ["flyteplugins*"]

[tool.setuptools_scm]
root = "../../"

[tool.pytest.ini_options]
norecursedirs = []
log_cli = true
log_cli_level = 20
markers = []
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
branch = true

[tool.ruff]
line-length = 120

[tool.ruff.lint]
select = [
"E",
"W",
"F",
"I",
"PLW",
"YTT",
"ASYNC",
"C4",
"T10",
"EXE",
"ISC",
"LOG",
"PIE",
"Q",
"RSE",
"FLY",
"PGH",
"PLC",
"PLE",
"PLW",
"FURB",
"RUF",
]
ignore = ["PGH003", "PLC0415", "ASYNC240"]

[tool.ruff.lint.per-file-ignores]
"examples/*" = ["E402"]

[tool.uv.sources]
flyte = { path = "../../", editable = true }
6 changes: 6 additions & 0 deletions plugins/huggingface/src/flyteplugins/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._prefetch import HuggingFaceDatasetInfo, hf_dataset

__all__ = [
"HuggingFaceDatasetInfo",
"hf_dataset",
]
Loading
Loading