Skip to content

Commit dfffa22

Browse files
committed
Enable fine tuning on HPU
Signed-off-by: Sergey Plotnikov <[email protected]>
1 parent a479f0b commit dfffa22

File tree

10 files changed

+243
-19
lines changed

10 files changed

+243
-19
lines changed

docs/hpu.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# InstructLab Training on HPU
2+
3+
## HPU specific changes
4+
Next changes are required to enable training on HPU:
5+
6+
|GPU|HPU|
7+
|---|---|
8+
|`from accelerate import Accelerator` | `from optimum.habana.accelerate import GaudiAccelerator`|
9+
|`from accelerate.utils import FullyShardedDataParallelPlugin` | `from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin` |
10+
11+
It is also recommended to use HPU optimized versions of transformers:
12+
13+
```Python
14+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
15+
adapt_transformers_to_gaudi()
16+
```
17+
18+
## Bucketing
19+
Multipack sampler implementation produces wide range of batches with different sample lengths and number of samples. Each of these combinations leads to graph recompilation and this recompilation takes time and slows down training. To reduce number of recompilations HPU implementation uses bucketing approach, when maximum sample length in batch is aligned to some predefined value. It is similar to padding but all samples in the batch are padded not to the longest sample but to the some slightly bigger value.
20+
21+
![bucketing vs. padding](./hpu_pic/bucketing_vs_padding.png)
22+
23+
24+
To compute bucked size, we use next algorithm:
25+
- Firstly, we find MSB of the longest sample in the batch, let's call it S.
26+
- Then we slice the range [2 ** S, 2 ** (S+1)] into 16 buckets of the same size.
27+
- Then we use top boundary of the smallest suitable bucked as padding value.
28+
29+
This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations.
30+
31+
## How to run
32+
To run training build docker using next dockerfile:
33+
```Dockerfile
34+
FROM vault.habana.ai/gaudi-docker/1.21.0/rhel9.4/habanalabs/pytorch-installer-2.6.0:1.21.0-555
35+
36+
ARG CMAKE_ARGS="-DGGML_NATIVE=off"
37+
38+
WORKDIR /app
39+
RUN pip install git+https://github.com/instructlab/[email protected]
40+
41+
WORKDIR /app
42+
RUN pip install git+https://github.com/huggingface/[email protected]
43+
```
44+
45+
Then make next changes to config file:
46+
```YAML
47+
train:
48+
device: hpu
49+
distributed_backend: fsdp
50+
fsdp_cpu_offload_optimizer: false
51+
is_padding_free: true
52+
pipeline: accelerated
53+
disable_flash_attn: true
54+
```
55+
56+
And finally run this command line:
57+
```BASH
58+
ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl
59+
```
60+
61+
30.9 KB
Loading

src/instructlab/training/accelerator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Callable, Optional
44

55
# Third Party
6-
from accelerate import Accelerator as TransformersAccel
76
from torch.utils.data import DataLoader
87
from transformers import get_scheduler
98
import torch
@@ -32,6 +31,7 @@ def __init__(
3231
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
3332
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
3433
fsdp_cpu_offload_params: Optional[bool] = False,
34+
device: Optional[str] = None,
3535
):
3636
self.samples_per_gpu = samples_per_gpu
3737
self.save_samples = save_samples
@@ -48,6 +48,7 @@ def __init__(
4848
deepspeed_cpu_offload_optimizer_ratio
4949
)
5050
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
51+
self.device_str = device
5152

5253
if self.distributed_framework == DistributedBackend.DEEPSPEED:
5354
# Standard
@@ -69,6 +70,12 @@ def __init__(
6970
"fsdp_plugin": self.get_fsdp_config(),
7071
"mixed_precision": "bf16",
7172
}
73+
74+
if device == "hpu":
75+
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
76+
else:
77+
from accelerate import Accelerator as TransformersAccel
78+
7279
self.accelerator = TransformersAccel(
7380
**accel_args,
7481
)
@@ -160,6 +167,10 @@ def get_fsdp_config(self):
160167
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
161168
)
162169

170+
if self.device_str == "hpu":
171+
fsdp_plugin.use_orig_params=True
172+
fsdp_plugin.sync_module_states=True
173+
163174
# `use_orig_params` must be disabled when using LoRA and FSDP together
164175
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
165176
if self.model.lora_config is not None:

src/instructlab/training/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,5 @@ class TrainingArgs(BaseModel):
246246
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
247247
default="INFO"
248248
)
249+
250+
device: Optional[str] = None
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from functools import lru_cache
3+
4+
5+
@lru_cache(maxsize=None)
6+
def is_torch_hpu_available() -> bool:
7+
try:
8+
import habana_frameworks.torch.core # noqa: F401
9+
except ImportError:
10+
return False
11+
return True
12+
13+
14+
def simple_bucket(length):
15+
"""
16+
This bucket algorithm merely relies on the given number instead of based on
17+
slicing the known (min, max) range for several reasons:
18+
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
19+
(min, max) sequence length of each rank will be much smaller than the
20+
(min, max) sequence length of the dataset. Bucketing on the
21+
(min, max) sequence length of the dataset is not practical
22+
2) The (min, max) sequence length of a given rank is unknown until
23+
finishing 1 epoch since the packing is done on the fly
24+
3) Due to the shuffling, the (min, max) sequence length of a given rank
25+
may vary between ranks. Once the (min, max) sequence length of a
26+
given rank changes, the bucketing also needs adjustment
27+
28+
This bucket algorithm is based on the most significant set bit of the input number.
29+
It first check what’s the most significant set bit, assuming it's bit "S",
30+
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
31+
By default the range is divided into 16 buckets, so the bucket size will be
32+
2 ** (S - 4)
33+
For example, 0b10001 will be padded to 0b10010.
34+
This approach can limit the overhead of bucketing (at most 1/16 of the input
35+
number) and also prevent recompilation due to a too small bucket size.
36+
"""
37+
l = length
38+
msb = 0
39+
while l > 0:
40+
msb += 1
41+
l = l // 2
42+
43+
align = (1 << (msb - 4)) if msb >= 4 else 1
44+
45+
return (length + align - 1) // align * align
46+
47+
48+
def bucket(length):
49+
return simple_bucket(length)

src/instructlab/training/main_ds.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
UserWarning,
3434
)
3535

36+
from instructlab.training.hpu_utils import is_torch_hpu_available
37+
38+
if is_torch_hpu_available():
39+
import habana_frameworks.torch.core as htcore
40+
import habana_frameworks.torch.distributed.hccl
41+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
42+
adapt_transformers_to_gaudi()
43+
3644
# Third Party
3745
from tqdm import tqdm
3846
from transformers import AutoConfig
@@ -122,7 +130,7 @@ def train(
122130
if local_rank == 0:
123131
inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}")
124132

125-
# blast through the batches in the train loader up to the last step within the epoch.
133+
# blast through the batches in the train loader up to the last step within the epoch.
126134
for batch in accelerator.train_loader:
127135
if global_step <= args.last_step:
128136
# in the case of resuming, last_step > 0
@@ -137,10 +145,19 @@ def train(
137145
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
138146
total_length = float(torch.tensor([batch.pop("total_length")]))
139147
for k in batch:
140-
batch[k] = batch[k].to(local_rank)
148+
batch[k] = batch[k].to('hpu' if args.device == "hpu" else local_rank)
149+
150+
hpu_args = {}
151+
if args.device == "hpu":
152+
hpu_args = {
153+
"use_flash_attention":True,
154+
"lazy_mode":False,
155+
}
156+
141157
output = model(
142158
**batch,
143159
use_cache=False,
160+
**hpu_args,
144161
)
145162
loss = output.loss
146163
log_loss = loss.detach().item()
@@ -177,8 +194,14 @@ def train(
177194
elapsed_time = time.time() - start
178195
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
179196
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
180-
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
181-
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
197+
198+
if args.device == "hpu":
199+
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
200+
malloc_retries = 0
201+
else:
202+
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
203+
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
204+
182205
global_grad_norm = (
183206
model.get_global_grad_norm()
184207
if hasattr(model, "get_global_grad_norm")
@@ -200,8 +223,8 @@ def train(
200223
"rank": torch.distributed.get_rank(),
201224
"overall_throughput": overall_throughput,
202225
"lr": current_lr,
203-
"cuda_mem_allocated": cuda_mem_allocated,
204-
"cuda_malloc_retries": cuda_malloc_retries,
226+
("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated,
227+
("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries,
205228
"num_loss_counted_tokens": int(num_loss_counted_tokens),
206229
"num_tokens_rank0": int(total_length),
207230
"batch_size": int(micro_batch_size),
@@ -234,7 +257,10 @@ def train(
234257
global_step += 1
235258
if local_rank == 0:
236259
inner_pb.update(1)
237-
torch.cuda.empty_cache()
260+
261+
if args.device != "hpu":
262+
torch.cuda.empty_cache()
263+
238264
if args.checkpoint_at_epoch:
239265
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
240266
save_checkpoint(
@@ -312,17 +338,24 @@ def main(args):
312338
args.model_type = model_conf.model_type
313339

314340
#### distributed init #####
315-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
341+
if args.device == "hpu":
342+
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
343+
else:
344+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
345+
316346
args.local_rank = int(os.environ["LOCAL_RANK"])
317347

318348
timeout = _get_collective_timeout()
319-
if timeout is not None:
320-
torch.distributed.init_process_group(timeout=timeout)
321-
else:
322-
torch.distributed.init_process_group()
349+
backend = "hccl" if args.device == "hpu" else None
350+
torch.distributed.init_process_group(backend=backend, timeout=timeout)
323351

324352
args.global_rank = torch.distributed.get_rank()
325-
tensor = torch.ByteTensor([False]).cuda()
353+
354+
if args.device == "hpu":
355+
tensor = torch.ByteTensor([False]).to('hpu')
356+
else:
357+
tensor = torch.ByteTensor([False]).cuda()
358+
326359
torch.distributed.all_reduce(tensor)
327360
torch.distributed.barrier()
328361

@@ -369,6 +402,7 @@ def main(args):
369402
flash_enabled=flash_enabled,
370403
noise_alpha=args.NEFTune_alpha,
371404
lora_quant_bits=args.lora_quant_bits,
405+
device=args.device,
372406
)
373407

374408
args.base_model_args = m.base_model_args
@@ -407,6 +441,7 @@ def main(args):
407441
samples_per_gpu=args.samples_per_gpu,
408442
sampler=args.sampler,
409443
seed=args.seed,
444+
device=args.device,
410445
)
411446
if len(train_loader) == 0:
412447
# this happens sometimes when we have more GPUs than data to process. In this case
@@ -426,6 +461,7 @@ def main(args):
426461
samples_per_gpu=args.samples_per_gpu,
427462
sampler=args.sampler,
428463
seed=args.seed,
464+
device=args.device,
429465
)
430466

431467
if args.local_rank == 0:
@@ -457,6 +493,7 @@ def main(args):
457493
deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
458494
fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
459495
save_samples=args.save_samples,
496+
device=args.device,
460497
)
461498
# optimizer needs model that has been prepared by accelerator
462499
# and then accelerator needs to be prepared AGAIN once optimizer is initialized
@@ -636,6 +673,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
636673
if train_args.keep_last_checkpoint_only:
637674
command.append("--keep_last_checkpoint_only")
638675

676+
command.append(
677+
f"--device={train_args.device}"
678+
)
679+
639680
logger.info("Running training command as subprocess: %s", " ".join(command))
640681
process = None
641682
interrupt: KeyboardInterrupt | Exception | None = None
@@ -837,6 +878,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
837878
action="store_true",
838879
help="Use Liger kernels for training.",
839880
)
881+
882+
parser.add_argument(
883+
"--device",
884+
type=str,
885+
default=None,
886+
help="PyTorch device to use.",
887+
)
888+
840889
args = parser.parse_args()
841890
set_random_seed(args.seed)
842891
main(args)

0 commit comments

Comments
 (0)