Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions configs/_base_/datasets/aplaca.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler

from mmchat.datasets import process_hf_dataset

_alpaca = dict(
type = process_hf_dataset,
dataset = dict(
type = load_dataset,
path = 'tatsu-lab/alpaca',
type=process_hf_dataset,
dataset=dict(
type=load_dataset,
path='tatsu-lab/alpaca',
),
# map_fn = extract_alpaca_dataset,
prompt_input_format = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
),
prompt_no_input_format= (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: "
),
prompt_input_format=(
'Below is an instruction that describes a task, '
'paired with an input that provides further context. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n'
'### Response: '),
prompt_no_input_format=(
'Below is an instruction that describes a task. '
'Write a response that appropriately completes the request.\n\n'
'### Instruction:\n{instruction}\n\n### Response: '),
remove_columns=['instruction'],
)

train_dataloader = dict(
batch_size=1,
num_workers=2,
dataset = _alpaca,
sampler=dict(type=DefaultSampler, shuffle=True)
)
dataset=_alpaca,
sampler=dict(type=DefaultSampler, shuffle=True))
10 changes: 3 additions & 7 deletions configs/_base_/datasets/mmlu_fs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler

from mmchat.datasets import process_hf_dataset

data_root = 'data/mmlu/'

Expand All @@ -13,19 +13,15 @@
test=data_root + 'five_shot_mmlu_test.json'))

val_mmlu_fs = dict(
type=process_hf_dataset,
dataset=mmlu_fs_dataset,
mode='val')
type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='val')
val_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=val_mmlu_fs,
sampler=dict(type=DefaultSampler, shuffle=False))

test_mmlu_fs = dict(
type=process_hf_dataset,
dataset=mmlu_fs_dataset,
mode='test')
type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='test')
test_dataloader = dict(
batch_size=1,
num_workers=1,
Expand Down
10 changes: 3 additions & 7 deletions configs/_base_/datasets/mmlu_zs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler

from mmchat.datasets import process_hf_dataset

data_root = 'data/mmlu/'

Expand All @@ -13,19 +13,15 @@
test=data_root + 'zero_shot_mmlu_test.json'))

val_mmlu_zs = dict(
type=process_hf_dataset,
dataset=mmlu_zs_dataset,
mode='val')
type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='val')
val_dataloader = dict(
batch_size=1,
num_workers=1,
dataset=val_mmlu_zs,
sampler=dict(type=DefaultSampler, shuffle=False))

test_mmlu_zs = dict(
type=process_hf_dataset,
dataset=mmlu_zs_dataset,
mode='test')
type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='test')
test_dataloader = dict(
batch_size=1,
num_workers=1,
Expand Down
25 changes: 12 additions & 13 deletions configs/_base_/datasets/oasst1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from datasets import load_dataset
from mmchat.datasets import process_hf_dataset
from mmengine.dataset import DefaultSampler

"""
------------ Dataset Meta Info (after `load_dataset`) ------------

Expand Down Expand Up @@ -31,19 +27,22 @@

"""

from datasets import load_dataset
from mmengine.dataset import DefaultSampler

from mmchat.datasets import process_hf_dataset

oasst1 = dict(
type = process_hf_dataset,
dataset = dict(
type = load_dataset,
path = 'timdettmers/openassistant-guanaco',
type=process_hf_dataset,
dataset=dict(
type=load_dataset,
path='timdettmers/openassistant-guanaco',
),
map_fn = "lambda x: {'input': '', 'output': x['text']}",
map_fn="lambda x: {'input': '', 'output': x['text']}",
)

train_dataloader = dict(
batch_size=16,
batch_size=1,
num_workers=2,
dataset = oasst1,
sampler=dict(type=DefaultSampler, shuffle=True)
)
dataset=oasst1,
sampler=dict(type=DefaultSampler, shuffle=True))
6 changes: 3 additions & 3 deletions configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mmengine.hooks import (IterTimerHook, LoggerHook, ParamSchedulerHook,
CheckpointHook, DistSamplerSeedHook)
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)

# defaults to use registries in mmpretrain
default_scope = 'mmchat'
Expand Down Expand Up @@ -47,4 +47,4 @@
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
randomness = dict(seed=None, deterministic=False)
23 changes: 9 additions & 14 deletions configs/_base_/schedules/guanaco.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from mmengine.optim import OptimWrapper
from mmengine.optim import LinearLR, ConstantLR
from mmengine.runner import IterBasedTrainLoop
from torch.optim import AdamW
from bitsandbytes.optim import PagedAdamW32bit
from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(
type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0),
type=AmpOptimWrapper,
optimizer=dict(type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0),
clip_grad=dict(max_norm=0.3, error_if_nonfinite=True),
)

accumulative_counts=16,
loss_scale='dynamic',
dtype='float16')

# learning policy
param_scheduler = [
Expand All @@ -32,11 +30,8 @@
]

# train, val, test setting
train_cfg = dict(
by_epoch=True,
max_epochs = 3, val_interval=1)

train_cfg = dict(by_epoch=True, max_epochs=3, val_interval=1)

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1)
auto_scale_lr = dict(base_batch_size=1)
9 changes: 4 additions & 5 deletions configs/_base_/schedules/guanaco_deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from mmengine.optim import DeepSpeedOptimWrapper
from mmengine._strategy import DeepSpeedStrategy
from mmengine.optim import DeepSpeedOptimWrapper
from torch.optim import AdamW

# optimizer
optim_wrapper = dict(
type=DeepSpeedOptimWrapper,
optimizer=dict(
type=AdamW, lr=0.0002, weight_decay=0.0))

optimizer=dict(type=AdamW, lr=0.0002, weight_decay=0.0))

# training strategy
strategy = dict(
Expand Down Expand Up @@ -65,4 +64,4 @@

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=64)
auto_scale_lr = dict(base_batch_size=64)
37 changes: 17 additions & 20 deletions configs/alpaca/alpaca_standford.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
from mmengine.config import read_base
from transformers import AutoModelForCausalLM, AutoTokenizer
from mmchat.models import SupervisedFinetune, DataProcesorForCausalLM
from transformers import BitsAndBytesConfig
from peft import LoraConfig
from dataclasses import dataclass
import torch

from mmchat.models import DataProcesorForCausalLM, SupervisedFinetune

with read_base():
from .._base_.datasets.aplaca import *
from .._base_.schedules.guanaco import *
from .._base_.default_runtime import *
from .._base_.datasets.aplaca import * # noqa: F401,F403
from .._base_.default_runtime import * # noqa: F401,F403
from .._base_.schedules.guanaco import * # noqa: F401,F403

pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b'
model = dict(
type = SupervisedFinetune,
data_preprocessor = dict(
type=SupervisedFinetune,
data_preprocessor=dict(
type=DataProcesorForCausalLM,
tokenizer=dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
use_fast = False,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
),
source_max_len = 512,
target_max_len = 512,
train_on_source = False,
predict_with_generate = False,
source_max_len=512,
target_max_len=512,
train_on_source=False,
predict_with_generate=False,
),
llm = dict(
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
pretrained_model_name_or_path=pretrained_model_name_or_path,
),

)

58 changes: 27 additions & 31 deletions configs/alpaca/alpaca_standford_qlora.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,46 @@
import torch
from mmengine.config import read_base
from transformers import AutoModelForCausalLM, AutoTokenizer
from mmchat.models import SupervisedQloraFinetune, DataProcesorForCausalLM
from transformers import BitsAndBytesConfig
from peft import LoraConfig
from dataclasses import dataclass
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from mmchat.models import DataProcesorForCausalLM, SupervisedQloraFinetune

with read_base():
from .._base_.datasets.aplaca import *
from .._base_.schedules.guanaco import *
from .._base_.default_runtime import *
from .._base_.datasets.aplaca import * # noqa: F401,F403
from .._base_.default_runtime import * # noqa: F401,F403
from .._base_.schedules.guanaco import * # noqa: F401,F403

pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b'
model = dict(
type = SupervisedQloraFinetune,
data_preprocessor = dict(
type=SupervisedQloraFinetune,
data_preprocessor=dict(
type=DataProcesorForCausalLM,
tokenizer=dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
use_fast = False,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
),
source_max_len = 512,
target_max_len = 512,
train_on_source = False,
predict_with_generate = False,
source_max_len=512,
target_max_len=512,
train_on_source=False,
predict_with_generate=False,
),
llm = dict(
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantization_config=dict(
type = BitsAndBytesConfig,
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type = 'nf4'
)
),
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r = 64,
lora_alpha = 16,
lora_dropout = 0.1,
bias = 'none',
task_type = 'CAUSAL_LM'
)

)

r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))
Loading