Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/_base_/datasets/oasst1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)

train_dataloader = dict(
batch_size=16,
batch_size=1,
num_workers=2,
dataset=oasst1,
sampler=dict(type=DefaultSampler, shuffle=True))
8 changes: 5 additions & 3 deletions configs/_base_/schedules/guanaco.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from bitsandbytes.optim import PagedAdamW32bit
from mmengine.optim import ConstantLR, LinearLR, OptimWrapper
from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
type=AmpOptimWrapper,
optimizer=dict(type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0),
clip_grad=dict(max_norm=0.3, error_if_nonfinite=True),
)
accumulative_counts=16,
loss_scale='dynamic',
dtype='float16')

# learning policy
param_scheduler = [
Expand Down
2 changes: 2 additions & 0 deletions configs/guanaco/gunaco_llama_7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch.float16,
device_map='auto',
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
Expand Down
101 changes: 93 additions & 8 deletions mmchat/models/algorithms/sft_qlora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import OrderedDict

import bitsandbytes as bnb
import torch
from peft import get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer
from peft import (PeftType, PromptLearningConfig, get_peft_model,
prepare_model_for_kbit_training)

from mmchat.registry import MODELS
from .sft import SupervisedFinetune
Expand Down Expand Up @@ -35,15 +37,98 @@ def __init__(self, llm, data_preprocessor, lora):
self.llm = get_peft_model(self.llm, lora)

for name, module in self.llm.named_modules():
if isinstance(module, LoraLayer):
module = module.to(torch.float16)
# todo
# if isinstance(module, LoraLayer):
# module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(torch.float16)
# if 'lm_head' in name or 'embed_tokens' in name:
# if hasattr(module, 'weight'):
# if module.weight.dtype == torch.float32:
# module = module.to(torch.float16)
self._is_init = True

def init_weights(self):
pass

def state_dict(self, destination=None, prefix='', keep_vars=False):

def get_peft_model_state_dict(model,
state_dict=None,
adapter_name='default'):
# Modified from `https://github.com/huggingface/peft/blob/main/src
# /peft/utils/save_and_load.py`

config = model.peft_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()
if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
# to_return = lora_state_dict(model,
# bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/
# loralib/utils.py`
# to be used directly with the state dict which is necessary
# when using DeepSpeed or FSDP
bias = config.bias
if bias == 'none':
to_return = {
k: state_dict[k]
for k in state_dict if 'lora_' in k
}
elif bias == 'all':
to_return = {
k: state_dict[k]
for k in state_dict if 'lora_' in k or 'bias' in k
}
elif bias == 'lora_only':
to_return = {}
for k in state_dict:
if 'lora_' in k:
to_return[k] = state_dict[k]
bias_name = k.split('lora_')[0] + 'bias'
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {
k: v
for k, v in to_return.items()
if (('lora_' in k and adapter_name in k) or ('bias' in k))
}
if config.peft_type == PeftType.ADALORA:
rank_pattern = config.rank_pattern
if rank_pattern is not None:
rank_pattern = {
k.replace(f'.{adapter_name}', ''): v
for k, v in rank_pattern.items()
}
config.rank_pattern = rank_pattern

elif config.peft_type == PeftType.ADAPTION_PROMPT:
to_return = {
k: state_dict[k]
for k in state_dict
if k.split('.')[-1].startswith('adaption_')
}
elif isinstance(config, PromptLearningConfig):
to_return = {}
if config.inference_mode:
prompt_embeddings = model.prompt_encoder[
adapter_name].embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save(
adapter_name)
to_return['prompt_embeddings'] = prompt_embeddings
else:
raise NotImplementedError
if model.modules_to_save is not None:
for key, value in state_dict.items():
if any(f'{module_name}.modules_to_save.{adapter_name}' in
key for module_name in model.modules_to_save):
to_return[key.replace('modules_to_save.', '')] = value

return to_return

state_dict = super().state_dict()
to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
return OrderedDict(to_return)
61 changes: 61 additions & 0 deletions tools/model_converters/adapter_pth2hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os

import torch
from mmengine.config import Config, DictAction
from mmengine.utils import mkdir_or_exist

from mmchat.registry import MODELS


def parse_args():
parser = argparse.ArgumentParser(description='MMChat test a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('adapter_checkpoint', help='adapter checkpoint file')
parser.add_argument(
'save_dir', help='the directory to save the checkpoint')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args


def main():
args = parse_args()

# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# load on cpu
if cfg.model.llm.get('device_map'):
cfg.model.llm.device_map = 'cpu'
if cfg.model.llm.get('quantization_config'):
cfg.model.llm.quantization_config.\
llm_int8_enable_fp32_cpu_offload = True

model = MODELS.build(cfg.model)

adapter_checkpoint = torch.load(
args.adapter_checkpoint, map_location='cpu')
model.load_state_dict(adapter_checkpoint['state_dict'], strict=False)
print(f'Load adapter from {args.adapter_checkpoint}')

adapter_path = os.path.join(args.save_dir, 'adapter')
mkdir_or_exist(adapter_path)
model.llm.save_pretrained(adapter_path)
print(f'Save to {adapter_path}')


if __name__ == '__main__':
main()