From c6900f4efbd4688ff0503621acaf62693c17720d Mon Sep 17 00:00:00 2001 From: LZHgrla Date: Fri, 14 Jul 2023 17:43:22 +0800 Subject: [PATCH 1/5] fix pre-commit --- configs/_base_/datasets/aplaca.py | 37 ++-- configs/_base_/datasets/mmlu_fs.py | 10 +- configs/_base_/datasets/mmlu_zs.py | 10 +- configs/_base_/datasets/oasst1.py | 23 +- configs/_base_/default_runtime.py | 6 +- configs/_base_/schedules/guanaco.py | 19 +- configs/_base_/schedules/guanaco_deepspeed.py | 9 +- configs/alpaca/alpaca_standford.py | 37 ++-- configs/alpaca/alpaca_standford_qlora.py | 58 +++-- configs/datasets/alpaca.py | 165 -------------- configs/guanaco/gunaco_llama_7B.py | 77 ++++--- mmchat/__init__.py | 6 +- mmchat/datasets/__init__.py | 4 +- mmchat/datasets/huggingface.py | 48 ++--- mmchat/evaluation/__init__.py | 4 +- mmchat/evaluation/metrics/mmlu_metric.py | 201 ++++++++++-------- mmchat/models/__init__.py | 9 +- mmchat/models/algorithms/__init__.py | 6 +- mmchat/models/algorithms/sft.py | 73 ++++--- mmchat/models/algorithms/sft_distill.py | 3 +- mmchat/models/algorithms/sft_lora.py | 3 +- mmchat/models/algorithms/sft_lora_distill.py | 3 +- mmchat/models/algorithms/sft_qlora.py | 24 +-- mmchat/models/algorithms/sft_qlora_distill.py | 3 +- mmchat/models/utils/__init__.py | 4 +- mmchat/models/utils/data_processor.py | 67 ++++-- requirements.txt | 5 +- tools/dist_train.sh | 2 +- tools/test.py | 6 +- tools/train.py | 6 +- 30 files changed, 396 insertions(+), 532 deletions(-) delete mode 100644 configs/datasets/alpaca.py diff --git a/configs/_base_/datasets/aplaca.py b/configs/_base_/datasets/aplaca.py index 489090f27..64d929c06 100644 --- a/configs/_base_/datasets/aplaca.py +++ b/configs/_base_/datasets/aplaca.py @@ -1,29 +1,30 @@ from datasets import load_dataset -from mmchat.datasets import process_hf_dataset from mmengine.dataset import DefaultSampler + +from mmchat.datasets import process_hf_dataset + _alpaca = dict( - type = process_hf_dataset, - dataset = dict( - type = load_dataset, - path = 'tatsu-lab/alpaca', + type=process_hf_dataset, + dataset=dict( + type=load_dataset, + path='tatsu-lab/alpaca', ), # map_fn = extract_alpaca_dataset, - prompt_input_format = ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " - ), - prompt_no_input_format= ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response: " - ), + prompt_input_format=( + 'Below is an instruction that describes a task, ' + 'paired with an input that provides further context. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n' + '### Response: '), + prompt_no_input_format=( + 'Below is an instruction that describes a task. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n### Response: '), remove_columns=['instruction'], ) train_dataloader = dict( batch_size=1, num_workers=2, - dataset = _alpaca, - sampler=dict(type=DefaultSampler, shuffle=True) -) \ No newline at end of file + dataset=_alpaca, + sampler=dict(type=DefaultSampler, shuffle=True)) diff --git a/configs/_base_/datasets/mmlu_fs.py b/configs/_base_/datasets/mmlu_fs.py index 2179deb96..4cffd9163 100644 --- a/configs/_base_/datasets/mmlu_fs.py +++ b/configs/_base_/datasets/mmlu_fs.py @@ -1,7 +1,7 @@ from datasets import load_dataset -from mmchat.datasets import process_hf_dataset from mmengine.dataset import DefaultSampler +from mmchat.datasets import process_hf_dataset data_root = 'data/mmlu/' @@ -13,9 +13,7 @@ test=data_root + 'five_shot_mmlu_test.json')) val_mmlu_fs = dict( - type=process_hf_dataset, - dataset=mmlu_fs_dataset, - mode='val') + type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='val') val_dataloader = dict( batch_size=1, num_workers=1, @@ -23,9 +21,7 @@ sampler=dict(type=DefaultSampler, shuffle=False)) test_mmlu_fs = dict( - type=process_hf_dataset, - dataset=mmlu_fs_dataset, - mode='test') + type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='test') test_dataloader = dict( batch_size=1, num_workers=1, diff --git a/configs/_base_/datasets/mmlu_zs.py b/configs/_base_/datasets/mmlu_zs.py index 5be62de31..29df361c8 100644 --- a/configs/_base_/datasets/mmlu_zs.py +++ b/configs/_base_/datasets/mmlu_zs.py @@ -1,7 +1,7 @@ from datasets import load_dataset -from mmchat.datasets import process_hf_dataset from mmengine.dataset import DefaultSampler +from mmchat.datasets import process_hf_dataset data_root = 'data/mmlu/' @@ -13,9 +13,7 @@ test=data_root + 'zero_shot_mmlu_test.json')) val_mmlu_zs = dict( - type=process_hf_dataset, - dataset=mmlu_zs_dataset, - mode='val') + type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='val') val_dataloader = dict( batch_size=1, num_workers=1, @@ -23,9 +21,7 @@ sampler=dict(type=DefaultSampler, shuffle=False)) test_mmlu_zs = dict( - type=process_hf_dataset, - dataset=mmlu_zs_dataset, - mode='test') + type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='test') test_dataloader = dict( batch_size=1, num_workers=1, diff --git a/configs/_base_/datasets/oasst1.py b/configs/_base_/datasets/oasst1.py index 2784ce26e..7368c2e61 100644 --- a/configs/_base_/datasets/oasst1.py +++ b/configs/_base_/datasets/oasst1.py @@ -1,7 +1,3 @@ -from datasets import load_dataset -from mmchat.datasets import process_hf_dataset -from mmengine.dataset import DefaultSampler - """ ------------ Dataset Meta Info (after `load_dataset`) ------------ @@ -31,19 +27,22 @@ """ +from datasets import load_dataset +from mmengine.dataset import DefaultSampler + +from mmchat.datasets import process_hf_dataset oasst1 = dict( - type = process_hf_dataset, - dataset = dict( - type = load_dataset, - path = 'timdettmers/openassistant-guanaco', + type=process_hf_dataset, + dataset=dict( + type=load_dataset, + path='timdettmers/openassistant-guanaco', ), - map_fn = "lambda x: {'input': '', 'output': x['text']}", + map_fn="lambda x: {'input': '', 'output': x['text']}", ) train_dataloader = dict( batch_size=16, num_workers=2, - dataset = oasst1, - sampler=dict(type=DefaultSampler, shuffle=True) -) \ No newline at end of file + dataset=oasst1, + sampler=dict(type=DefaultSampler, shuffle=True)) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index ae7a42185..444aaae62 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -1,5 +1,5 @@ -from mmengine.hooks import (IterTimerHook, LoggerHook, ParamSchedulerHook, - CheckpointHook, DistSamplerSeedHook) +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) # defaults to use registries in mmpretrain default_scope = 'mmchat' @@ -47,4 +47,4 @@ resume = False # Defaults to use random seed and disable `deterministic` -randomness = dict(seed=None, deterministic=False) \ No newline at end of file +randomness = dict(seed=None, deterministic=False) diff --git a/configs/_base_/schedules/guanaco.py b/configs/_base_/schedules/guanaco.py index e8ae2259f..89909190f 100644 --- a/configs/_base_/schedules/guanaco.py +++ b/configs/_base_/schedules/guanaco.py @@ -1,16 +1,12 @@ -from mmengine.optim import OptimWrapper -from mmengine.optim import LinearLR, ConstantLR -from mmengine.runner import IterBasedTrainLoop -from torch.optim import AdamW from bitsandbytes.optim import PagedAdamW32bit +from mmengine.optim import ConstantLR, LinearLR, OptimWrapper + # optimizer optim_wrapper = dict( type=OptimWrapper, - optimizer=dict( - type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0), + optimizer=dict(type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0), clip_grad=dict(max_norm=0.3, error_if_nonfinite=True), - ) - +) # learning policy param_scheduler = [ @@ -32,11 +28,8 @@ ] # train, val, test setting -train_cfg = dict( - by_epoch=True, - max_epochs = 3, val_interval=1) - +train_cfg = dict(by_epoch=True, max_epochs=3, val_interval=1) # NOTE: `auto_scale_lr` is for automatically scaling LR # based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=1) \ No newline at end of file +auto_scale_lr = dict(base_batch_size=1) diff --git a/configs/_base_/schedules/guanaco_deepspeed.py b/configs/_base_/schedules/guanaco_deepspeed.py index 9af453153..73a5d9a5a 100644 --- a/configs/_base_/schedules/guanaco_deepspeed.py +++ b/configs/_base_/schedules/guanaco_deepspeed.py @@ -1,12 +1,11 @@ -from mmengine.optim import DeepSpeedOptimWrapper from mmengine._strategy import DeepSpeedStrategy +from mmengine.optim import DeepSpeedOptimWrapper from torch.optim import AdamW + # optimizer optim_wrapper = dict( type=DeepSpeedOptimWrapper, - optimizer=dict( - type=AdamW, lr=0.0002, weight_decay=0.0)) - + optimizer=dict(type=AdamW, lr=0.0002, weight_decay=0.0)) # training strategy strategy = dict( @@ -65,4 +64,4 @@ # NOTE: `auto_scale_lr` is for automatically scaling LR # based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=64) \ No newline at end of file +auto_scale_lr = dict(base_batch_size=64) diff --git a/configs/alpaca/alpaca_standford.py b/configs/alpaca/alpaca_standford.py index 1d23e49fe..124915987 100644 --- a/configs/alpaca/alpaca_standford.py +++ b/configs/alpaca/alpaca_standford.py @@ -1,33 +1,30 @@ from mmengine.config import read_base from transformers import AutoModelForCausalLM, AutoTokenizer -from mmchat.models import SupervisedFinetune, DataProcesorForCausalLM -from transformers import BitsAndBytesConfig -from peft import LoraConfig -from dataclasses import dataclass -import torch + +from mmchat.models import DataProcesorForCausalLM, SupervisedFinetune + with read_base(): - from .._base_.datasets.aplaca import * - from .._base_.schedules.guanaco import * - from .._base_.default_runtime import * + from .._base_.datasets.aplaca import * # noqa: F401,F403 + from .._base_.default_runtime import * # noqa: F401,F403 + from .._base_.schedules.guanaco import * # noqa: F401,F403 +pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b' model = dict( - type = SupervisedFinetune, - data_preprocessor = dict( + type=SupervisedFinetune, + data_preprocessor=dict( type=DataProcesorForCausalLM, tokenizer=dict( type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', - use_fast = False, + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, ), - source_max_len = 512, - target_max_len = 512, - train_on_source = False, - predict_with_generate = False, + source_max_len=512, + target_max_len=512, + train_on_source=False, + predict_with_generate=False, ), - llm = dict( + llm=dict( type=AutoModelForCausalLM.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', + pretrained_model_name_or_path=pretrained_model_name_or_path, ), - ) - diff --git a/configs/alpaca/alpaca_standford_qlora.py b/configs/alpaca/alpaca_standford_qlora.py index 2b468915c..2abea8d5d 100644 --- a/configs/alpaca/alpaca_standford_qlora.py +++ b/configs/alpaca/alpaca_standford_qlora.py @@ -1,50 +1,46 @@ +import torch from mmengine.config import read_base -from transformers import AutoModelForCausalLM, AutoTokenizer -from mmchat.models import SupervisedQloraFinetune, DataProcesorForCausalLM -from transformers import BitsAndBytesConfig from peft import LoraConfig -from dataclasses import dataclass -import torch +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from mmchat.models import DataProcesorForCausalLM, SupervisedQloraFinetune + with read_base(): - from .._base_.datasets.aplaca import * - from .._base_.schedules.guanaco import * - from .._base_.default_runtime import * + from .._base_.datasets.aplaca import * # noqa: F401,F403 + from .._base_.default_runtime import * # noqa: F401,F403 + from .._base_.schedules.guanaco import * # noqa: F401,F403 +pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b' model = dict( - type = SupervisedQloraFinetune, - data_preprocessor = dict( + type=SupervisedQloraFinetune, + data_preprocessor=dict( type=DataProcesorForCausalLM, tokenizer=dict( type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', - use_fast = False, + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, ), - source_max_len = 512, - target_max_len = 512, - train_on_source = False, - predict_with_generate = False, + source_max_len=512, + target_max_len=512, + train_on_source=False, + predict_with_generate=False, ), - llm = dict( + llm=dict( type=AutoModelForCausalLM.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', + pretrained_model_name_or_path=pretrained_model_name_or_path, quantization_config=dict( - type = BitsAndBytesConfig, + type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type = 'nf4' - ) - ), + bnb_4bit_quant_type='nf4')), lora=dict( type=LoraConfig, - r = 64, - lora_alpha = 16, - lora_dropout = 0.1, - bias = 'none', - task_type = 'CAUSAL_LM' - ) - -) - + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) diff --git a/configs/datasets/alpaca.py b/configs/datasets/alpaca.py deleted file mode 100644 index eab8beaf7..000000000 --- a/configs/datasets/alpaca.py +++ /dev/null @@ -1,165 +0,0 @@ -from datasets import load_dataset -from mmchat.datasets import process_hf_dataset, DataCollatorForCausalLM -from mmengine.dataset import DefaultSampler -from transformers import AutoModel, AutoTokenizer -from mmchat.models import SupervisedFinetune -from mmchat.models.utils import DataProcesorForCausalLM -from mmchat.visualization import AttentionScoreVisualizer - - - -""" ------------- Dataset Example (after `load_dataset`) ------------ - -DatasetDict({ - train: Dataset({ - features: ['instruction', 'input', 'output', 'text'], - num_rows: 52002 - }) -}) - ------------- Dataset Example (after `process_hf_dataset`) ------------ - -DatasetDict({ - train: Dataset({ - features: ['text', 'input', 'output'], - num_rows: 9846 - }) - test: Dataset({ - features: ['text', 'input', 'output'], - num_rows: 518 - }) -}) - -""" - -alpaca = dict( - type = process_hf_dataset, - dataset = dict( - type = load_dataset, - path = 'tatsu-lab/alpaca', - ), - # map_fn = extract_alpaca_dataset, - prompt_input_format = ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " - ), - prompt_no_input_format= ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response: " - ), - remove_columns=['instruction'], -) - - - -oasst1 = dict( - type = process_hf_dataset, - dataset = dict( - type = load_dataset, - path = 'timdettmers/openassistant-guanaco', - ), - map_fn = lambda x: {'input': '', 'output': x['text']}, -) - - -train_dataloader = dict( - batch_size = 32, - num_workers = 8, - dataset = oasst1, - sampler = dict(type=DefaultSampler, shuffle=True), - persistent_workers = True, -) - -model = dict( - type = SupervisedFinetune, - data_preprocessor = dict( - type=DataProcesorForCausalLM, - tokenizer=dict( - type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', - use_fast = False, - ), - source_max_len = 512, - target_max_len = 512, - train_on_source = False, - predict_with_generate = False, - ), - llm = dict( - type=AutoModel.from_pretrained, - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', - ), - -) - - -# optimizer -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) -# learning policy -param_scheduler = dict( - type='MultiStepLR', by_epoch=True, milestones=[100, 150], gamma=0.1) - -# train, val, test setting -train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) -# val_cfg = dict() -# test_cfg = dict() - -# NOTE: `auto_scale_lr` is for automatically scaling LR -# based on the actual training batch size. -auto_scale_lr = dict(base_batch_size=128) - - -# defaults to use registries in mmpretrain -default_scope = 'mmchat' - -# configure default hooks -default_hooks = dict( - # record the time of every iteration. - timer=dict(type='IterTimerHook'), - - # print log every 100 iterations. - logger=dict(type='LoggerHook', interval=100), - - # enable the parameter scheduler. - param_scheduler=dict(type='ParamSchedulerHook'), - - # save checkpoint per epoch. - checkpoint=dict(type='CheckpointHook', interval=1), - - # set sampler seed in distributed evrionment. - sampler_seed=dict(type='DistSamplerSeedHook'), - - # validation results visualization, set True to enable it. - # visualization=dict(type='VisualizationHook', enable=False), -) - -# configure environment -env_cfg = dict( - # whether to enable cudnn benchmark - cudnn_benchmark=False, - - # set multi process parameters - mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), - - # set distributed parameters - dist_cfg=dict(backend='nccl'), -) - -# # set visualizer -vis_backends = [dict(type='LocalVisBackend')] -visualizer = None - -# set log level -log_level = 'INFO' - -# load from which checkpoint -load_from = None - -# whether to resume training from the loaded checkpoint -resume = False - -# Defaults to use random seed and disable `deterministic` -randomness = dict(seed=None, deterministic=False) \ No newline at end of file diff --git a/configs/guanaco/gunaco_llama_7B.py b/configs/guanaco/gunaco_llama_7B.py index f8d299b29..e21d1af94 100644 --- a/configs/guanaco/gunaco_llama_7B.py +++ b/configs/guanaco/gunaco_llama_7B.py @@ -1,64 +1,59 @@ +import torch from mmengine.config import read_base -from transformers import AutoModelForCausalLM, AutoTokenizer -from mmchat.models import SupervisedQloraFinetune, DataProcesorForCausalLM -from transformers import BitsAndBytesConfig from peft import LoraConfig -from dataclasses import dataclass -import torch +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from mmchat.models import DataProcesorForCausalLM, SupervisedQloraFinetune + with read_base(): - from .._base_.datasets.oasst1 import * - from .._base_.datasets.mmlu_fs import * - from .._base_.schedules.guanaco import * - from .._base_.default_runtime import * + from .._base_.datasets.mmlu_fs import * # noqa: F401,F403 + from .._base_.datasets.oasst1 import * # noqa: F401,F403 + from .._base_.default_runtime import * # noqa: F401,F403 + from .._base_.schedules.guanaco import * # noqa: F401,F403 +pretrained_model_name_or_path = '/nvme/share_data/llama-7b' model = dict( - type = SupervisedQloraFinetune, - data_preprocessor = dict( + type=SupervisedQloraFinetune, + data_preprocessor=dict( type=DataProcesorForCausalLM, tokenizer=dict( type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path = '/nvme/share_data/llama-7b', - use_fast = False, - padding_side="right", - ), - source_max_len = 2048, - target_max_len = 512, - train_on_source = False, - predict_with_generate = False, - ), - llm = dict( + pretrained_model_name_or_path=pretrained_model_name_or_path, + use_fast=False, + padding_side='right'), + source_max_len=2048, + target_max_len=512, + train_on_source=False, + predict_with_generate=False), + llm=dict( type=AutoModelForCausalLM.from_pretrained, - pretrained_model_name_or_path = '/nvme/share_data/llama-7b', - torch_dtype = torch.float16, + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch.float16, quantization_config=dict( - type = BitsAndBytesConfig, + type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type = 'nf4' - ) - ), + bnb_4bit_quant_type='nf4')), lora=dict( type=LoraConfig, - r = 64, - lora_alpha = 16, - lora_dropout = 0.1, - bias = 'none', - task_type = 'CAUSAL_LM' - ) - -) + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) -val_evaluator['tokenizer'] = dict( +val_evaluator['tokenizer'] = dict( # noqa: F405 type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path='/nvme/share_data/llama-7b', + pretrained_model_name_or_path=pretrained_model_name_or_path, use_fast=False, - padding_side="right") + padding_side='right') -test_evaluator['tokenizer'] = dict( +test_evaluator['tokenizer'] = dict( # noqa: F405 type=AutoTokenizer.from_pretrained, - pretrained_model_name_or_path='/nvme/share_data/llama-7b', + pretrained_model_name_or_path=pretrained_model_name_or_path, use_fast=False, - padding_side="right") + padding_side='right') diff --git a/mmchat/__init__.py b/mmchat/__init__.py index 7127db9aa..eeacb447c 100644 --- a/mmchat/__init__.py +++ b/mmchat/__init__.py @@ -1,3 +1,5 @@ -from .datasets import * # noqa: F401,F403 +from mmengine.utils import digit_version -# from .models import * # noqa: F401,F403 +from .version import __version__, version_info + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmchat/datasets/__init__.py b/mmchat/datasets/__init__.py index 4c363cb9f..bc97494b5 100644 --- a/mmchat/datasets/__init__.py +++ b/mmchat/datasets/__init__.py @@ -1 +1,3 @@ -from .huggingface import process_hf_dataset \ No newline at end of file +from .huggingface import process_hf_dataset + +__all__ = ['process_hf_dataset'] diff --git a/mmchat/datasets/huggingface.py b/mmchat/datasets/huggingface.py index b7041b178..8f89c1415 100644 --- a/mmchat/datasets/huggingface.py +++ b/mmchat/datasets/huggingface.py @@ -1,23 +1,21 @@ +from mmchat.registry import DATASETS -from typing import Sequence -from mmchat.registry import DATASETS, TOKENIZER -from torch.utils.data import Dataset, DataLoader, ChainDataset -import transformers -import torch -from dataclasses import dataclass -import copy -from torch.nn.utils.rnn import pad_sequence IGNORE_INDEX = -100 -DEFAULT_PAD_TOKEN = "[PAD]" -def process_hf_dataset(dataset, mode='train', +DEFAULT_PAD_TOKEN = '[PAD]' + + +def process_hf_dataset(dataset, + mode='train', prompt_input_format=None, prompt_no_input_format=None, - map_fn=lambda x:x,remove_columns=[], rename_maps=[]): + map_fn=lambda x: x, + remove_columns=[], + rename_maps=[]): dataset = DATASETS.build(dataset) - + def _prompt_format(example): - if example.get("input", "") != "": + if example.get('input', '') != '': prompt_format = prompt_input_format else: prompt_format = prompt_no_input_format @@ -34,24 +32,8 @@ def _prompt_format(example): # Remove unused columns. if 'train' in dataset.column_names: - dataset = dataset.remove_columns( - [col for col in dataset.column_names['train'] if col not in ['input', 'output']] - ) + dataset = dataset.remove_columns([ + col for col in dataset.column_names['train'] + if col not in ['input', 'output'] + ]) return dataset[mode] - - -class HuggingFaceDataset(Dataset): - def __init__(self, hf_dataset, tokenizer, map_fn, remove_columns=[], rename_maps=[]) -> None: - super().__init__() - - dataset = DATASETS.build(dataset) - dataset = dataset.map(map_fn, remove_columns=remove_columns) - for old, new in rename_maps: - dataset = dataset.rename_column(old, new) - self.dataset = dataset - - self.tokenizer = TOKENIZER.build(tokenizer) - - - - diff --git a/mmchat/evaluation/__init__.py b/mmchat/evaluation/__init__.py index 1761d1aa7..1678b3d53 100644 --- a/mmchat/evaluation/__init__.py +++ b/mmchat/evaluation/__init__.py @@ -1 +1,3 @@ -from .metrics import * +from .metrics import MMLUMetric + +__all__ = ['MMLUMetric'] diff --git a/mmchat/evaluation/metrics/mmlu_metric.py b/mmchat/evaluation/metrics/mmlu_metric.py index 56092e491..a29149922 100644 --- a/mmchat/evaluation/metrics/mmlu_metric.py +++ b/mmchat/evaluation/metrics/mmlu_metric.py @@ -1,11 +1,11 @@ -from typing import Any, List, Optional, Sequence, Union -from rich.console import Console -from rich.table import Table +from typing import Any, Sequence import numpy as np import torch from mmengine.evaluator import BaseMetric from mmengine.logging import MMLogger +from rich.console import Console +from rich.table import Table from mmchat.registry import METRICS, TOKENIZER @@ -14,83 +14,89 @@ class MMLUMetric(BaseMetric): METAINFO = { 'subcategories': { - "abstract_algebra": ["math"], - "anatomy": ["health"], - "astronomy": ["physics"], - "business_ethics": ["business"], - "clinical_knowledge": ["health"], - "college_biology": ["biology"], - "college_chemistry": ["chemistry"], - "college_computer_science": ["computer science"], - "college_mathematics": ["math"], - "college_medicine": ["health"], - "college_physics": ["physics"], - "computer_security": ["computer science"], - "conceptual_physics": ["physics"], - "econometrics": ["economics"], - "electrical_engineering": ["engineering"], - "elementary_mathematics": ["math"], - "formal_logic": ["philosophy"], - "global_facts": ["other"], - "high_school_biology": ["biology"], - "high_school_chemistry": ["chemistry"], - "high_school_computer_science": ["computer science"], - "high_school_european_history": ["history"], - "high_school_geography": ["geography"], - "high_school_government_and_politics": ["politics"], - "high_school_macroeconomics": ["economics"], - "high_school_mathematics": ["math"], - "high_school_microeconomics": ["economics"], - "high_school_physics": ["physics"], - "high_school_psychology": ["psychology"], - "high_school_statistics": ["math"], - "high_school_us_history": ["history"], - "high_school_world_history": ["history"], - "human_aging": ["health"], - "human_sexuality": ["culture"], - "international_law": ["law"], - "jurisprudence": ["law"], - "logical_fallacies": ["philosophy"], - "machine_learning": ["computer science"], - "management": ["business"], - "marketing": ["business"], - "medical_genetics": ["health"], - "miscellaneous": ["other"], - "moral_disputes": ["philosophy"], - "moral_scenarios": ["philosophy"], - "nutrition": ["health"], - "philosophy": ["philosophy"], - "prehistory": ["history"], - "professional_accounting": ["other"], - "professional_law": ["law"], - "professional_medicine": ["health"], - "professional_psychology": ["psychology"], - "public_relations": ["politics"], - "security_studies": ["politics"], - "sociology": ["culture"], - "us_foreign_policy": ["politics"], - "virology": ["health"], - "world_religions": ["philosophy"], + 'abstract_algebra': ['math'], + 'anatomy': ['health'], + 'astronomy': ['physics'], + 'business_ethics': ['business'], + 'clinical_knowledge': ['health'], + 'college_biology': ['biology'], + 'college_chemistry': ['chemistry'], + 'college_computer_science': ['computer science'], + 'college_mathematics': ['math'], + 'college_medicine': ['health'], + 'college_physics': ['physics'], + 'computer_security': ['computer science'], + 'conceptual_physics': ['physics'], + 'econometrics': ['economics'], + 'electrical_engineering': ['engineering'], + 'elementary_mathematics': ['math'], + 'formal_logic': ['philosophy'], + 'global_facts': ['other'], + 'high_school_biology': ['biology'], + 'high_school_chemistry': ['chemistry'], + 'high_school_computer_science': ['computer science'], + 'high_school_european_history': ['history'], + 'high_school_geography': ['geography'], + 'high_school_government_and_politics': ['politics'], + 'high_school_macroeconomics': ['economics'], + 'high_school_mathematics': ['math'], + 'high_school_microeconomics': ['economics'], + 'high_school_physics': ['physics'], + 'high_school_psychology': ['psychology'], + 'high_school_statistics': ['math'], + 'high_school_us_history': ['history'], + 'high_school_world_history': ['history'], + 'human_aging': ['health'], + 'human_sexuality': ['culture'], + 'international_law': ['law'], + 'jurisprudence': ['law'], + 'logical_fallacies': ['philosophy'], + 'machine_learning': ['computer science'], + 'management': ['business'], + 'marketing': ['business'], + 'medical_genetics': ['health'], + 'miscellaneous': ['other'], + 'moral_disputes': ['philosophy'], + 'moral_scenarios': ['philosophy'], + 'nutrition': ['health'], + 'philosophy': ['philosophy'], + 'prehistory': ['history'], + 'professional_accounting': ['other'], + 'professional_law': ['law'], + 'professional_medicine': ['health'], + 'professional_psychology': ['psychology'], + 'public_relations': ['politics'], + 'security_studies': ['politics'], + 'sociology': ['culture'], + 'us_foreign_policy': ['politics'], + 'virology': ['health'], + 'world_religions': ['philosophy'], }, 'categories': { - "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], - "humanities": ["history", "philosophy", "law"], - "social sciences": ["politics", "culture", "economics", "geography", "psychology"], - "other (business, health, misc.)": ["other", "business", "health"], + 'STEM': [ + 'physics', 'chemistry', 'biology', 'computer science', 'math', + 'engineering' + ], + 'humanities': ['history', 'philosophy', 'law'], + 'social sciences': + ['politics', 'culture', 'economics', 'geography', 'psychology'], + 'other (business, health, misc.)': ['other', 'business', 'health'], }, } - METAINFO['subcategories_list'] = list(set([subcat for subcats in METAINFO['subcategories'].values() - for subcat in subcats])) + METAINFO['subcategories_list'] = list({ + subcat + for subcats in METAINFO['subcategories'].values() for subcat in subcats + }) def __init__(self, tokenizer, *args, **kwargs): super().__init__(*args, **kwargs) self.logger: MMLogger = MMLogger.get_current_instance() tokenizer = TOKENIZER.build(tokenizer) self.abcd_idx = [ - tokenizer("A", add_special_tokens=False).input_ids[0], - tokenizer("B", add_special_tokens=False).input_ids[0], - tokenizer("C", add_special_tokens=False).input_ids[0], - tokenizer("D", add_special_tokens=False).input_ids[0], + tokenizer('A', add_special_tokens=False).input_ids[0], + tokenizer('B', add_special_tokens=False).input_ids[0], + tokenizer('C', add_special_tokens=False).input_ids[0], + tokenizer('D', add_special_tokens=False).input_ids[0], ] @staticmethod @@ -99,7 +105,7 @@ def ABCD_to_0123(abcd): @staticmethod def accuracy(preds, gts): - """Computes the accuracy for preds and gts""" + """Computes the accuracy for preds and gts.""" correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)] acc = np.mean(correct) * 100 return acc @@ -121,7 +127,8 @@ def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: pred_logits = sample['logits'] labels = sample['labels'] labels_non_zero_id = (labels != -100).nonzero()[0][0] - pred_logtis_abcd = pred_logits[labels_non_zero_id-1, self.abcd_idx] + pred_logtis_abcd = pred_logits[labels_non_zero_id - 1, + self.abcd_idx] pred = torch.argmax(pred_logtis_abcd).item() preds.append(pred) self.results.append((subject, pred, gt)) @@ -136,9 +143,27 @@ def compute_metrics(self, results: list) -> dict: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ - subjects_results = {subject: {'preds': [], 'gts': []} for subject in self.METAINFO['subcategories'].keys()} - subcats_results = {subcat: {'preds': [], 'gts': []} for subcat in self.METAINFO['subcategories_list']} - cats_results = {cat: {'preds': [], 'gts': []} for cat in self.METAINFO['categories'].keys()} + subjects_results = { + subject: { + 'preds': [], + 'gts': [] + } + for subject in self.METAINFO['subcategories'].keys() + } + subcats_results = { + subcat: { + 'preds': [], + 'gts': [] + } + for subcat in self.METAINFO['subcategories_list'] + } + cats_results = { + cat: { + 'preds': [], + 'gts': [] + } + for cat in self.METAINFO['categories'].keys() + } for subject, pred, gt in results: subjects_results[subject]['preds'].append(pred) subjects_results[subject]['gts'].append(gt) @@ -149,32 +174,40 @@ def compute_metrics(self, results: list) -> dict: for cat, subcats in self.METAINFO['categories'].items(): for subcat in subcats: if subcat in subcats_results: - cats_results[cat]['preds'].extend(subcats_results[subcat]['preds']) - cats_results[cat]['gts'].extend(subcats_results[subcat]['gts']) + cats_results[cat]['preds'].extend( + subcats_results[subcat]['preds']) + cats_results[cat]['gts'].extend( + subcats_results[subcat]['gts']) subjects_metrics = dict() subcats_metrics = dict() cats_metrics = dict() for subject in self.METAINFO['subcategories'].keys(): - assert len(subjects_results[subject]['preds']) == len(subjects_results[subject]['gts']) + assert len(subjects_results[subject]['preds']) == len( + subjects_results[subject]['gts']) if len(subjects_results[subject]['preds']) == 0: self.logger.info(f'Skip subject {subject} for mmlu') else: - score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts']) + score = self.accuracy(subjects_results[subject]['preds'], + subjects_results[subject]['gts']) subjects_metrics[f'{subject}'] = score for subcat in self.METAINFO['subcategories_list']: - assert len(subcats_results[subcat]['preds']) == len(subcats_results[subcat]['gts']) + assert len(subcats_results[subcat]['preds']) == len( + subcats_results[subcat]['gts']) if len(subcats_results[subcat]['preds']) == 0: self.logger.info(f'Skip subcategory {subcat} for mmlu') else: - score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts']) + score = self.accuracy(subcats_results[subcat]['preds'], + subcats_results[subcat]['gts']) subcats_metrics[f'{subcat}'] = score for cat in self.METAINFO['categories'].keys(): - assert len(cats_results[cat]['preds']) == len(cats_results[cat]['gts']) + assert len(cats_results[cat]['preds']) == len( + cats_results[cat]['gts']) if len(cats_results[cat]['preds']) == 0: self.logger.info(f'Skip category {cat} for mmlu') else: - score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts']) + score = self.accuracy(cats_results[cat]['preds'], + cats_results[cat]['gts']) cats_metrics[f'{cat}'] = score metrics = dict() @@ -196,7 +229,7 @@ def _print_results(self, table_metrics: dict) -> None: table.add_column('Categories', justify='left') table.add_column('Accuracy (%)', justify='right') for cat, acc in table_metrics.items(): - table.add_row(cat, '{:.1f}'.format(acc)) + table.add_row(cat, f'{acc:.1f}') with console.capture() as capture: console.print(table, end='') self.logger.info('\n' + capture.get()) diff --git a/mmchat/models/__init__.py b/mmchat/models/__init__.py index c8a78d27b..d21f4c5a4 100644 --- a/mmchat/models/__init__.py +++ b/mmchat/models/__init__.py @@ -1,3 +1,8 @@ -from .algorithms import SupervisedFinetune, SupervisedLoraFinetune, SupervisedQloraFinetune +from .algorithms import (SupervisedFinetune, SupervisedLoraFinetune, + SupervisedQloraFinetune) from .utils import DataProcesorForCausalLM -__all__ = ['SupervisedFinetune', 'SupervisedLoraFinetune', 'SupervisedQloraFinetune' ] \ No newline at end of file + +__all__ = [ + 'SupervisedFinetune', 'SupervisedLoraFinetune', 'SupervisedQloraFinetune', + 'DataProcesorForCausalLM' +] diff --git a/mmchat/models/algorithms/__init__.py b/mmchat/models/algorithms/__init__.py index b4c1b3400..408c84ff3 100644 --- a/mmchat/models/algorithms/__init__.py +++ b/mmchat/models/algorithms/__init__.py @@ -1,3 +1,7 @@ from .sft import SupervisedFinetune from .sft_lora import SupervisedLoraFinetune -from .sft_qlora import SupervisedQloraFinetune \ No newline at end of file +from .sft_qlora import SupervisedQloraFinetune + +__all__ = [ + 'SupervisedFinetune', 'SupervisedLoraFinetune', 'SupervisedQloraFinetune' +] diff --git a/mmchat/models/algorithms/sft.py b/mmchat/models/algorithms/sft.py index 0a2bcf458..ecd8561e7 100644 --- a/mmchat/models/algorithms/sft.py +++ b/mmchat/models/algorithms/sft.py @@ -1,15 +1,16 @@ - +import dataclasses from typing import Dict -from torch import nn -from mmengine.model import BaseModel -from mmengine import Config -from mmchat.registry import MODELS, TOKENIZER, LLM + import torch import transformers -import dataclasses from mmengine import print_log +from mmengine.model import BaseModel +from torch import nn + +from mmchat.registry import LLM + +DEFAULT_PAD_TOKEN = '[PAD]' -DEFAULT_PAD_TOKEN = "[PAD]" def traverse_dict(d): if isinstance(d, dict): @@ -27,6 +28,7 @@ def traverse_dict(d): for element in d: traverse_dict(element) + def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, @@ -34,7 +36,8 @@ def smart_tokenizer_and_embedding_resize( ): """Resize tokenizer and embedding. - Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + Note: This is the unoptimized version that may make your embedding size + not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) @@ -43,8 +46,10 @@ def smart_tokenizer_and_embedding_resize( input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg @@ -64,26 +69,32 @@ def __init__(self, llm, data_preprocessor): model=self.llm, ) from transformers.models.llama import LlamaTokenizer - - if isinstance(self.tokenizer, LlamaTokenizer): - # LLaMA tokenizer may not have correct special tokens set. - # Check and add them if missing to prevent them from being parsed into different tokens. - # Note that these are present in the vocabulary. - # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. + + if isinstance(self.tokenizer, LlamaTokenizer): + # LLaMA tokenizer may not have correct special tokens set. + # Check and add them if missing to prevent them from being + # parsed into different tokens. + # Note that these are present in the vocabulary. + # Note also that `model.config.pad_token_id` is 0 which + # corresponds to `` token. print('Adding special tokens.') self.tokenizer.add_special_tokens({ - "eos_token": self.tokenizer.convert_ids_to_tokens(self.llm.config.eos_token_id), - "bos_token": self.tokenizer.convert_ids_to_tokens(self.llm.config.bos_token_id), - "unk_token": self.tokenizer.convert_ids_to_tokens( - self.llm.config.pad_token_id if self.llm.config.pad_token_id != -1 else self.tokenizer.pad_token_id - ), + 'eos_token': + self.tokenizer.convert_ids_to_tokens( + self.llm.config.eos_token_id), + 'bos_token': + self.tokenizer.convert_ids_to_tokens( + self.llm.config.bos_token_id), + 'unk_token': + self.tokenizer.convert_ids_to_tokens( + self.llm.config.pad_token_id if self.llm.config. + pad_token_id != -1 else self.tokenizer.pad_token_id), }) @property def tokenizer(self): return self.data_preprocessor.tokenizer - def _build_from_cfg_or_module(self, cfg_or_mod, registry): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod @@ -91,10 +102,10 @@ def _build_from_cfg_or_module(self, cfg_or_mod, registry): traverse_dict(cfg_or_mod) return registry.build(cfg_or_mod) else: - raise NotImplemented - + raise NotImplementedError + def forward(self, data, data_samples=None, mode='loss'): - + if mode == 'loss': return self.compute_loss(data, data_samples) elif mode == 'predict': @@ -105,15 +116,17 @@ def forward(self, data, data_samples=None, mode='loss'): raise NotImplementedError def _forward(self, data, data_samples=None): - + outputs = self.llm(**data) - + return outputs def predict(self, data, data_samples=None): outputs = self.llm(**data) - logits_dict = [{'labels': labels, 'logits': logits} \ - for labels, logits in zip(data['labels'], outputs.logits)] + logits_dict = [{ + 'labels': labels, + 'logits': logits + } for labels, logits in zip(data['labels'], outputs.logits)] return logits_dict def compute_loss(self, data, data_samples=None): @@ -121,5 +134,3 @@ def compute_loss(self, data, data_samples=None): # import pdb;pdb.set_trace() loss_dict = {'loss_llm': outputs.loss} return loss_dict - - \ No newline at end of file diff --git a/mmchat/models/algorithms/sft_distill.py b/mmchat/models/algorithms/sft_distill.py index f1112bbb6..185da8ab6 100644 --- a/mmchat/models/algorithms/sft_distill.py +++ b/mmchat/models/algorithms/sft_distill.py @@ -1,6 +1,7 @@ from .sft import SupervisedFinetune + class DistillFinetune(SupervisedFinetune): def __init__(self, llm, tokenizer): - super().__init__(llm, tokenizer) \ No newline at end of file + super().__init__(llm, tokenizer) diff --git a/mmchat/models/algorithms/sft_lora.py b/mmchat/models/algorithms/sft_lora.py index 0b1941fa8..39f40ee16 100644 --- a/mmchat/models/algorithms/sft_lora.py +++ b/mmchat/models/algorithms/sft_lora.py @@ -1,6 +1,7 @@ from .sft import SupervisedFinetune + class SupervisedLoraFinetune(SupervisedFinetune): def __init__(self, llm, tokenizer, lora): - super().__init__(llm, tokenizer) \ No newline at end of file + super().__init__(llm, tokenizer) diff --git a/mmchat/models/algorithms/sft_lora_distill.py b/mmchat/models/algorithms/sft_lora_distill.py index 25b688542..69f91dcf6 100644 --- a/mmchat/models/algorithms/sft_lora_distill.py +++ b/mmchat/models/algorithms/sft_lora_distill.py @@ -1,6 +1,7 @@ from .sft_distill import DistillFinetune + class LoraDistillFinetune(DistillFinetune): def __init__(self, llm, tokenizer): - super().__init__(llm, tokenizer) \ No newline at end of file + super().__init__(llm, tokenizer) diff --git a/mmchat/models/algorithms/sft_qlora.py b/mmchat/models/algorithms/sft_qlora.py index f814679fa..a58891be7 100644 --- a/mmchat/models/algorithms/sft_qlora.py +++ b/mmchat/models/algorithms/sft_qlora.py @@ -1,15 +1,11 @@ +import bitsandbytes as bnb import torch -from .sft import SupervisedFinetune, traverse_dict -from peft import ( - prepare_model_for_kbit_training, - LoraConfig, - get_peft_model, - PeftModel -) +from peft import get_peft_model, prepare_model_for_kbit_training from peft.tuners.lora import LoraLayer -import bitsandbytes as bnb + from mmchat.registry import MODELS -from mmengine import print_log +from .sft import SupervisedFinetune + def find_all_linear_names(model): cls = bnb.nn.Linear4bit @@ -19,11 +15,11 @@ def find_all_linear_names(model): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if 'lm_head' in lora_module_names: # needed for 16-bit + if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') return list(lora_module_names) + class SupervisedQloraFinetune(SupervisedFinetune): def __init__(self, llm, data_preprocessor, lora): @@ -35,7 +31,7 @@ def __init__(self, llm, data_preprocessor, lora): lora = MODELS.build(lora) lora.target_modules = modules - + self.llm = get_peft_model(self.llm, lora) for name, module in self.llm.named_modules(): @@ -47,7 +43,7 @@ def __init__(self, llm, data_preprocessor, lora): if hasattr(module, 'weight'): if module.weight.dtype == torch.float32: module = module.to(torch.float16) - self._is_init=True - + self._is_init = True + def init_weights(self): pass diff --git a/mmchat/models/algorithms/sft_qlora_distill.py b/mmchat/models/algorithms/sft_qlora_distill.py index 192d76907..063c5745e 100644 --- a/mmchat/models/algorithms/sft_qlora_distill.py +++ b/mmchat/models/algorithms/sft_qlora_distill.py @@ -1,6 +1,7 @@ from .sft import SupervisedFinetune + class QloraDistillFinetune(SupervisedFinetune): def __init__(self, llm, tokenizer): - super().__init__(llm, tokenizer) \ No newline at end of file + super().__init__(llm, tokenizer) diff --git a/mmchat/models/utils/__init__.py b/mmchat/models/utils/__init__.py index cda41baf5..c8a0aa422 100644 --- a/mmchat/models/utils/__init__.py +++ b/mmchat/models/utils/__init__.py @@ -1 +1,3 @@ -from .data_processor import DataProcesorForCausalLM \ No newline at end of file +from .data_processor import DataProcesorForCausalLM + +__all__ = ['DataProcesorForCausalLM'] diff --git a/mmchat/models/utils/data_processor.py b/mmchat/models/utils/data_processor.py index 674e6e7cd..5187be23d 100644 --- a/mmchat/models/utils/data_processor.py +++ b/mmchat/models/utils/data_processor.py @@ -1,15 +1,20 @@ -from typing import Optional, Sequence, Dict +import copy +from typing import Dict, Sequence + import torch from mmengine.model import BaseDataPreprocessor -from mmchat.registry import TOKENIZER -import copy from torch.nn.utils.rnn import pad_sequence + +from mmchat.registry import TOKENIZER + IGNORE_INDEX = -100 -DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_PAD_TOKEN = '[PAD]' + + class DataProcesorForCausalLM(BaseDataPreprocessor): - def __init__(self, - tokenizer, + def __init__(self, + tokenizer, source_max_len, target_max_len, train_on_source, @@ -22,11 +27,19 @@ def __init__(self, self.target_max_len = target_max_len self.train_on_source = train_on_source self.predict_with_generate = predict_with_generate - - def forward(self,instances: Sequence[Dict], training=True) -> Dict[str, torch.Tensor]: + + def forward(self, + instances: Sequence[Dict], + training=True) -> Dict[str, torch.Tensor]: # Extract elements - sources = [f"{self.tokenizer.bos_token}{example}" for example in instances['input']] - targets = [f"{example}{self.tokenizer.eos_token}" for example in instances['output']] + sources = [ + f'{self.tokenizer.bos_token}{example}' + for example in instances['input'] + ] + targets = [ + f'{example}{self.tokenizer.eos_token}' + for example in instances['output'] + ] # Tokenize tokenized_sources_with_prompt = self.tokenizer( sources, @@ -44,31 +57,39 @@ def forward(self,instances: Sequence[Dict], training=True) -> Dict[str, torch.Te input_ids = [] labels = [] for tokenized_source, tokenized_target in zip( - tokenized_sources_with_prompt['input_ids'], - tokenized_targets['input_ids'] - ): + tokenized_sources_with_prompt['input_ids'], + tokenized_targets['input_ids']): if not self.predict_with_generate: - input_ids.append(torch.tensor(tokenized_source + tokenized_target)) + input_ids.append( + torch.tensor(tokenized_source + tokenized_target)) if not self.train_on_source: labels.append( - torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) - ) + torch.tensor([ + IGNORE_INDEX for _ in range(len(tokenized_source)) + ] + copy.deepcopy(tokenized_target))) else: - labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) + labels.append( + torch.tensor( + copy.deepcopy(tokenized_source + + tokenized_target))) else: input_ids.append(torch.tensor(tokenized_source)) # import pdb;pdb.set_trace() - + # Apply padding - input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) - labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None + input_ids = pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) if not self.predict_with_generate else None data_dict = { 'input_ids': input_ids, - 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), + 'attention_mask': input_ids.ne(self.tokenizer.pad_token_id), } - + if labels is not None: data_dict['labels'] = labels return self.cast_data({'data': data_dict, 'data_samples': None}) - diff --git a/requirements.txt b/requirements.txt index 534a840f4..7b5b904b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ bitsandbytes==0.39.1 +datasets mmengine==0.8.1 peft@git+https://github.com/huggingface/peft.git -transformers==4.30.2 scipy SentencePiece -datasets - +transformers==4.30.2 diff --git a/tools/dist_train.sh b/tools/dist_train.sh index 1eb32aa9c..3fca7641d 100644 --- a/tools/dist_train.sh +++ b/tools/dist_train.sh @@ -16,4 +16,4 @@ python -m torch.distributed.launch \ --master_port=$PORT \ $(dirname "$0")/train.py \ $CONFIG \ - --launcher pytorch ${@:3} \ No newline at end of file + --launcher pytorch ${@:3} diff --git a/tools/test.py b/tools/test.py index 1a4807f3f..9b783dd43 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,10 +2,7 @@ import argparse import os import os.path as osp -import warnings -from copy import deepcopy -from mmengine import ConfigDict from mmengine.config import Config, DictAction from mmengine.runner import Runner @@ -14,8 +11,7 @@ # TODO: support fuse_conv_bn and format_only def parse_args(): - parser = argparse.ArgumentParser( - description='MMChat test a model') + parser = argparse.ArgumentParser(description='MMChat test a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( diff --git a/tools/train.py b/tools/train.py index 37a75d6fe..e7ffa8798 100644 --- a/tools/train.py +++ b/tools/train.py @@ -49,14 +49,12 @@ def parse_args(): return args - - def main(): args = parse_args() # load config cfg = Config.fromfile(args.config) - + cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) @@ -73,7 +71,7 @@ def main(): # enable automatic-mixed-precision training if args.amp is True: optim_wrapper = cfg.optim_wrapper.type - from mmengine.optim import OptimWrapper, AmpOptimWrapper + from mmengine.optim import AmpOptimWrapper, OptimWrapper if optim_wrapper == AmpOptimWrapper: print_log( 'AMP training is already enabled in your config.', From 226715377e7a7b131ac4abcd7d10416809ae3a08 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 17 Jul 2023 10:49:08 +0800 Subject: [PATCH 2/5] add qlora --- configs/_base_/datasets/oasst1.py | 2 +- configs/_base_/schedules/guanaco.py | 8 +++++--- configs/guanaco/gunaco_llama_7B.py | 8 ++++++-- mmchat/models/algorithms/sft_qlora.py | 14 +++++++------- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/configs/_base_/datasets/oasst1.py b/configs/_base_/datasets/oasst1.py index 7368c2e61..0926572bb 100644 --- a/configs/_base_/datasets/oasst1.py +++ b/configs/_base_/datasets/oasst1.py @@ -42,7 +42,7 @@ ) train_dataloader = dict( - batch_size=16, + batch_size=1, num_workers=2, dataset=oasst1, sampler=dict(type=DefaultSampler, shuffle=True)) diff --git a/configs/_base_/schedules/guanaco.py b/configs/_base_/schedules/guanaco.py index 89909190f..5251391c9 100644 --- a/configs/_base_/schedules/guanaco.py +++ b/configs/_base_/schedules/guanaco.py @@ -1,12 +1,14 @@ from bitsandbytes.optim import PagedAdamW32bit -from mmengine.optim import ConstantLR, LinearLR, OptimWrapper +from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR # optimizer optim_wrapper = dict( - type=OptimWrapper, + type=AmpOptimWrapper, optimizer=dict(type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0), clip_grad=dict(max_norm=0.3, error_if_nonfinite=True), -) + accumulative_counts=16, + loss_scale='dynamic', + dtype='float16') # learning policy param_scheduler = [ diff --git a/configs/guanaco/gunaco_llama_7B.py b/configs/guanaco/gunaco_llama_7B.py index e21d1af94..8a0b34f95 100644 --- a/configs/guanaco/gunaco_llama_7B.py +++ b/configs/guanaco/gunaco_llama_7B.py @@ -21,19 +21,23 @@ type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, use_fast=False, - padding_side='right'), + padding_side='right', + ), source_max_len=2048, target_max_len=512, train_on_source=False, - predict_with_generate=False), + predict_with_generate=False, + ), llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, torch_dtype=torch.float16, + device_map='auto', quantization_config=dict( type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, + llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, diff --git a/mmchat/models/algorithms/sft_qlora.py b/mmchat/models/algorithms/sft_qlora.py index a58891be7..207d19621 100644 --- a/mmchat/models/algorithms/sft_qlora.py +++ b/mmchat/models/algorithms/sft_qlora.py @@ -1,7 +1,6 @@ import bitsandbytes as bnb import torch from peft import get_peft_model, prepare_model_for_kbit_training -from peft.tuners.lora import LoraLayer from mmchat.registry import MODELS from .sft import SupervisedFinetune @@ -35,14 +34,15 @@ def __init__(self, llm, data_preprocessor, lora): self.llm = get_peft_model(self.llm, lora) for name, module in self.llm.named_modules(): - if isinstance(module, LoraLayer): - module = module.to(torch.float16) + # todo + # if isinstance(module, LoraLayer): + # module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) - if 'lm_head' in name or 'embed_tokens' in name: - if hasattr(module, 'weight'): - if module.weight.dtype == torch.float32: - module = module.to(torch.float16) + # if 'lm_head' in name or 'embed_tokens' in name: + # if hasattr(module, 'weight'): + # if module.weight.dtype == torch.float32: + # module = module.to(torch.float16) self._is_init = True def init_weights(self): From 27a041a01243f25fa10efa7e9b3566d04fdd1cae Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 17 Jul 2023 14:01:13 +0800 Subject: [PATCH 3/5] support save lora-related state_dict --- mmchat/models/algorithms/sft_qlora.py | 85 ++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/mmchat/models/algorithms/sft_qlora.py b/mmchat/models/algorithms/sft_qlora.py index 207d19621..a93101109 100644 --- a/mmchat/models/algorithms/sft_qlora.py +++ b/mmchat/models/algorithms/sft_qlora.py @@ -1,6 +1,7 @@ import bitsandbytes as bnb import torch -from peft import get_peft_model, prepare_model_for_kbit_training +from peft import (PeftType, PromptLearningConfig, get_peft_model, + prepare_model_for_kbit_training) from mmchat.registry import MODELS from .sft import SupervisedFinetune @@ -47,3 +48,85 @@ def __init__(self, llm, data_preprocessor, lora): def init_weights(self): pass + + def state_dict(self, destination=None, prefix='', keep_vars=False): + + def get_peft_model_state_dict(model, + state_dict=None, + adapter_name='default'): + # Modified from `https://github.com/huggingface/peft/blob/main/src + # /peft/utils/save_and_load.py` + + config = model.peft_config[adapter_name] + if state_dict is None: + state_dict = model.state_dict() + if config.peft_type in (PeftType.LORA, PeftType.ADALORA): + # to_return = lora_state_dict(model, + # bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/ + # loralib/utils.py` + # to be used directly with the state dict which is necessary + # when using DeepSpeed or FSDP + bias = config.bias + if bias == 'none': + to_return = { + k: state_dict[k] + for k in state_dict if 'lora_' in k + } + elif bias == 'all': + to_return = { + k: state_dict[k] + for k in state_dict if 'lora_' in k or 'bias' in k + } + elif bias == 'lora_only': + to_return = {} + for k in state_dict: + if 'lora_' in k: + to_return[k] = state_dict[k] + bias_name = k.split('lora_')[0] + 'bias' + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] + else: + raise NotImplementedError + to_return = { + k: v + for k, v in to_return.items() + if (('lora_' in k and adapter_name in k) or ('bias' in k)) + } + if config.peft_type == PeftType.ADALORA: + rank_pattern = config.rank_pattern + if rank_pattern is not None: + rank_pattern = { + k.replace(f'.{adapter_name}', ''): v + for k, v in rank_pattern.items() + } + config.rank_pattern = rank_pattern + + elif config.peft_type == PeftType.ADAPTION_PROMPT: + to_return = { + k: state_dict[k] + for k in state_dict + if k.split('.')[-1].startswith('adaption_') + } + elif isinstance(config, PromptLearningConfig): + to_return = {} + if config.inference_mode: + prompt_embeddings = model.prompt_encoder[ + adapter_name].embedding.weight + else: + prompt_embeddings = model.get_prompt_embedding_to_save( + adapter_name) + to_return['prompt_embeddings'] = prompt_embeddings + else: + raise NotImplementedError + if model.modules_to_save is not None: + for key, value in state_dict.items(): + if any(f'{module_name}.modules_to_save.{adapter_name}' in + key for module_name in model.modules_to_save): + to_return[key.replace('modules_to_save.', '')] = value + + return to_return + + state_dict = super().state_dict() + to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict) + return to_return From 19c4c67ad9d911dea15a63f19b2d977bcad194c9 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 17 Jul 2023 16:21:09 +0800 Subject: [PATCH 4/5] replace dict with OrderedDict --- mmchat/models/algorithms/sft_qlora.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmchat/models/algorithms/sft_qlora.py b/mmchat/models/algorithms/sft_qlora.py index a93101109..29515dbdf 100644 --- a/mmchat/models/algorithms/sft_qlora.py +++ b/mmchat/models/algorithms/sft_qlora.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import bitsandbytes as bnb import torch from peft import (PeftType, PromptLearningConfig, get_peft_model, @@ -129,4 +131,4 @@ def get_peft_model_state_dict(model, state_dict = super().state_dict() to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict) - return to_return + return OrderedDict(to_return) From 176ff07aafaaa11491891c7cb95eecc9662c9d24 Mon Sep 17 00:00:00 2001 From: LZHgrla Date: Tue, 18 Jul 2023 12:04:20 +0800 Subject: [PATCH 5/5] add adapter_pth2hf --- tools/model_converters/adapter_pth2hf.py | 61 ++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tools/model_converters/adapter_pth2hf.py diff --git a/tools/model_converters/adapter_pth2hf.py b/tools/model_converters/adapter_pth2hf.py new file mode 100644 index 000000000..f21482b73 --- /dev/null +++ b/tools/model_converters/adapter_pth2hf.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +import torch +from mmengine.config import Config, DictAction +from mmengine.utils import mkdir_or_exist + +from mmchat.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMChat test a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('adapter_checkpoint', help='adapter checkpoint file') + parser.add_argument( + 'save_dir', help='the directory to save the checkpoint') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # load on cpu + if cfg.model.llm.get('device_map'): + cfg.model.llm.device_map = 'cpu' + if cfg.model.llm.get('quantization_config'): + cfg.model.llm.quantization_config.\ + llm_int8_enable_fp32_cpu_offload = True + + model = MODELS.build(cfg.model) + + adapter_checkpoint = torch.load( + args.adapter_checkpoint, map_location='cpu') + model.load_state_dict(adapter_checkpoint['state_dict'], strict=False) + print(f'Load adapter from {args.adapter_checkpoint}') + + adapter_path = os.path.join(args.save_dir, 'adapter') + mkdir_or_exist(adapter_path) + model.llm.save_pretrained(adapter_path) + print(f'Save to {adapter_path}') + + +if __name__ == '__main__': + main()