Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/instructlab/sdg/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Standard
import logging
import uuid

# Third Party
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets.data_files import EmptyDatasetError

# First Party
from instructlab.sdg.utils import pandas

logger = logging.getLogger(__name__)


class Checkpointer:
def __init__(self, checkpoint_dir=None, save_freq=1):
self._checkpoint_dir = checkpoint_dir

self._save_freq = save_freq
self._cache = []

def checkpoint(self, dataset):
self._cache.append(dataset)
if len(self._cache) < self._save_freq:
return
self.save()
self._cache.clear()

def done(self):
if self._cache:
self.save()
self._cache.clear()

def save(self):
if self._checkpoint_dir is None:
return
checkpoint_id = uuid.uuid4().hex
checkpoint_file = (
f"{self._checkpoint_dir}/data_checkpoint_{checkpoint_id}.jsonl"
)
logger.info(f"Saving checkpoint to {checkpoint_file}")
# Saves all the current records to new file in the checkpoint dir
concatenate_datasets(self._cache).to_json(
checkpoint_file, orient="records", lines=True
)

def load(self, dataset: Dataset) -> Dataset:
if self._checkpoint_dir is None:
return dataset, None

try:
pre_generated_data = load_dataset(
"json", data_dir=self._checkpoint_dir, split="train"
)
except EmptyDatasetError:
logger.info(
f"No existing checkpoints found in {self._checkpoint_dir}, generating from scratch"
)
return dataset, None

logger.info(
f"Loading existing checkpoints from {self._checkpoint_dir}, with {pre_generated_data.num_rows} rows"
)
seed_data = self._get_missing_data(dataset, pre_generated_data)
logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")
return seed_data, pre_generated_data

def _get_missing_data(self, seed_data, generated_data):
# Get the common columns between the two datasets
common_columns = list(
set(seed_data.column_names) & set(generated_data.column_names)
)

# Extract the relevant data based on common columns
seed_data_common = seed_data.select_columns(common_columns)
generated_data_common = generated_data.select_columns(common_columns)

# Convert to Pandas DataFrames for easier comparison
seed_df = seed_data_common.to_pandas()
generated_df = generated_data_common.to_pandas()

# Identify missing rows
missing_rows = ~seed_df.apply(tuple, 1).isin(generated_df.apply(tuple, 1))

missing_df = seed_data.to_pandas()[missing_rows]
return pandas.dataset_from_pandas_dataframe(missing_df)
12 changes: 11 additions & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from importlib import resources
from pathlib import Path
from typing import Optional
import dataclasses
import json
import os
import time
Expand Down Expand Up @@ -181,6 +182,8 @@ def _context_init(
model_family: str,
model_id: str,
num_instructions_to_generate: int,
checkpoint_dir: str,
save_freq: int,
batch_num_workers: Optional[int],
batch_size: Optional[int],
):
Expand All @@ -194,6 +197,8 @@ def _context_init(
model_family=model_family,
model_id=model_id,
num_instructions_to_generate=num_instructions_to_generate,
checkpoint_dir=checkpoint_dir,
save_freq=save_freq,
**extra_kwargs,
)

Expand Down Expand Up @@ -284,6 +289,7 @@ def generate_data(
client: Optional[openai.OpenAI] = None,
pipeline: Optional[str] = "simple",
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
) -> None:
"""Generate data for training and testing a model.

Expand Down Expand Up @@ -348,13 +354,17 @@ def generate_data(
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
)

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(ctx, pipeline)

mmlu_bench_pipe = mmlubench_pipe_init(ctx)
# Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline)
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(ctx, output_dir, date_suffix)

Expand Down
19 changes: 17 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import yaml

# First Party
from instructlab.sdg.checkpointing import Checkpointer
from instructlab.sdg.utils import pandas

# Local
Expand Down Expand Up @@ -61,6 +62,8 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
model_id: str
num_instructions_to_generate: int
dataset_num_procs: Optional[int] = DEFAULT_DATASET_NUM_PROCS
checkpoint_dir: Optional[str] = None
save_freq: Optional[int] = 1
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None

Expand Down Expand Up @@ -129,6 +132,12 @@ def generate(self, dataset) -> Dataset:
Generate the dataset by running the pipeline steps.
dataset: the input dataset
"""

# The checkpointer allows us to resume from where we left off
# Saving the output of pipe instances along the way
checkpointer = Checkpointer(self.ctx.checkpoint_dir, self.ctx.save_freq)
dataset, pre_generated_data = checkpointer.load(dataset)

# If not batching, simply delegate to _generate_single
if not self.ctx.batching_enabled:
logger.info("Running pipeline single-threaded")
Expand All @@ -142,6 +151,7 @@ def generate(self, dataset) -> Dataset:
self.ctx.batch_size,
)
input_splits = self._split_dataset(dataset)
output_splits = []
with ThreadPoolExecutor(max_workers=self.ctx.batch_num_workers) as executor:
futures = [
executor.submit(self._generate_single, input_split)
Expand All @@ -150,8 +160,13 @@ def generate(self, dataset) -> Dataset:

# Collect the results of each batch as they finish. This needs to
# wait for them all, so the order of waiting doesn't matter
output_splits = [future.result() for future in futures]

for future in futures:
ds = future.result()
output_splits.append(ds)
checkpointer.checkpoint(ds)
checkpointer.done()
if pre_generated_data:
output_splits.append(pre_generated_data)
return concatenate_datasets(output_splits)

## Implementation Details ##
Expand Down
117 changes: 117 additions & 0 deletions tests/test_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Standard
import json
import os

# Third Party
from datasets import Dataset
import pytest

# First Party
from instructlab.sdg.checkpointing import Checkpointer


def _add_bar(sample, add_value=100):
sample["bar"] = sample["foo"] + add_value
return sample


def _populate_checkpoints(tmpdir, dataset, checkpoints_count, remove_column):
for i in range(0, checkpoints_count):
checkpoint_dataset = dataset.select(range(i * 10, (i + 1) * 10))
checkpoint_dataset = checkpoint_dataset.map(
lambda x: _add_bar(x, add_value=100)
)
if remove_column:
checkpoint_dataset = checkpoint_dataset.remove_columns("foo")
checkpoint_dataset.to_json(
os.path.join(tmpdir, f"data_checkpoint_abcde{i}.jsonl"),
orient="records",
lines=True,
)


def _validate_checkpoints(tmpdir, expected_files_count, expected_length, remove_column):
saved_files = os.listdir(tmpdir)
assert len(saved_files) == expected_files_count
assert all(f.startswith("data_checkpoint_") for f in saved_files)
assert all(f.endswith(".jsonl") for f in saved_files)

for f in saved_files:
with open(os.path.join(tmpdir, f), "r") as f:
l = list(f)
if isinstance(expected_length, list):
expected_length.remove(len(l))
else:
assert len(l) == expected_length
for s in l:
data = json.loads(s)
if remove_column:
assert "foo" not in data and "bar" in data
else:
assert "foo" in data and "bar" in data


@pytest.mark.parametrize(
"save_freq, remove_column, dataset_size, init_checkpoints, splits, final_checkpoints, checkpoint_length",
[
(1, False, 10, 0, 0, 1, 10),
(1, True, 10, 0, 0, 1, 10),
(1, False, 100, 1, 9, 10, 10),
(1, True, 100, 1, 9, 10, 10),
(1, False, 100, 2, 8, 10, 10),
(3, False, 100, 2, 8, 5, [10, 10, 30, 30, 20]),
],
)
def test_checkpointing(
tmpdir,
save_freq,
remove_column,
dataset_size,
init_checkpoints,
splits,
final_checkpoints,
checkpoint_length,
):
# Our initial dataset
dataset = Dataset.from_list([{"idx": i, "foo": i} for i in range(dataset_size)])

# Generate and save some checkpoints to disk
_populate_checkpoints(tmpdir, dataset, init_checkpoints, remove_column)

# Load checkpoints, giving us the remaining dataset to process and
# the generated data loaded from the checkpoints
checkpointer = Checkpointer(checkpoint_dir=tmpdir, save_freq=save_freq)
dataset, pre_generated_data = checkpointer.load(dataset)

# Should be present, even if removed from the checkpoint (remove_column=True)
assert "foo" in dataset.features

# When testing save_freq, we will have checkpoints of different lengths
if isinstance(checkpoint_length, list):
checkpoints_total = sum(checkpoint_length[:init_checkpoints])
else:
checkpoints_total = checkpoint_length * init_checkpoints

# Validate pre-generated data loaded from the checkpoints
assert len(dataset) == (dataset_size - checkpoints_total)
if init_checkpoints > 0:
assert len(pre_generated_data) == checkpoints_total

# Apply pipeline to the remaining dataset and save checkpoints
if splits:
for i in range(0, splits):
split = dataset.select(range(i * 10, (i + 1) * 10))
split = split.map(lambda x: _add_bar(x, add_value=100))
if remove_column:
split = split.remove_columns("foo")
checkpointer.checkpoint(split)
else:
dataset = dataset.map(lambda x: _add_bar(x, add_value=10))
if remove_column:
dataset = dataset.remove_columns("foo")
checkpointer.checkpoint(dataset)

checkpointer.done()

# Validate that all checkpoints are now saved to disk
_validate_checkpoints(tmpdir, final_checkpoints, checkpoint_length, remove_column)
4 changes: 4 additions & 0 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_context_init_batch_size_optional():
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=None,
batch_num_workers=None,
)
Expand All @@ -32,6 +34,8 @@ def test_context_init_batch_size_optional():
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=20,
batch_num_workers=32,
)
Expand Down