Skip to content

Commit 4dfdcce

Browse files
authored
fix pre-commit (#2)
1 parent e840b6c commit 4dfdcce

30 files changed

+396
-532
lines changed

configs/_base_/datasets/aplaca.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
from datasets import load_dataset
2-
from mmchat.datasets import process_hf_dataset
32
from mmengine.dataset import DefaultSampler
3+
4+
from mmchat.datasets import process_hf_dataset
5+
46
_alpaca = dict(
5-
type = process_hf_dataset,
6-
dataset = dict(
7-
type = load_dataset,
8-
path = 'tatsu-lab/alpaca',
7+
type=process_hf_dataset,
8+
dataset=dict(
9+
type=load_dataset,
10+
path='tatsu-lab/alpaca',
911
),
1012
# map_fn = extract_alpaca_dataset,
11-
prompt_input_format = (
12-
"Below is an instruction that describes a task, paired with an input that provides further context. "
13-
"Write a response that appropriately completes the request.\n\n"
14-
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
15-
),
16-
prompt_no_input_format= (
17-
"Below is an instruction that describes a task. "
18-
"Write a response that appropriately completes the request.\n\n"
19-
"### Instruction:\n{instruction}\n\n### Response: "
20-
),
13+
prompt_input_format=(
14+
'Below is an instruction that describes a task, '
15+
'paired with an input that provides further context. '
16+
'Write a response that appropriately completes the request.\n\n'
17+
'### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n'
18+
'### Response: '),
19+
prompt_no_input_format=(
20+
'Below is an instruction that describes a task. '
21+
'Write a response that appropriately completes the request.\n\n'
22+
'### Instruction:\n{instruction}\n\n### Response: '),
2123
remove_columns=['instruction'],
2224
)
2325

2426
train_dataloader = dict(
2527
batch_size=1,
2628
num_workers=2,
27-
dataset = _alpaca,
28-
sampler=dict(type=DefaultSampler, shuffle=True)
29-
)
29+
dataset=_alpaca,
30+
sampler=dict(type=DefaultSampler, shuffle=True))

configs/_base_/datasets/mmlu_fs.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datasets import load_dataset
2-
from mmchat.datasets import process_hf_dataset
32
from mmengine.dataset import DefaultSampler
43

4+
from mmchat.datasets import process_hf_dataset
55

66
data_root = 'data/mmlu/'
77

@@ -13,19 +13,15 @@
1313
test=data_root + 'five_shot_mmlu_test.json'))
1414

1515
val_mmlu_fs = dict(
16-
type=process_hf_dataset,
17-
dataset=mmlu_fs_dataset,
18-
mode='val')
16+
type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='val')
1917
val_dataloader = dict(
2018
batch_size=1,
2119
num_workers=1,
2220
dataset=val_mmlu_fs,
2321
sampler=dict(type=DefaultSampler, shuffle=False))
2422

2523
test_mmlu_fs = dict(
26-
type=process_hf_dataset,
27-
dataset=mmlu_fs_dataset,
28-
mode='test')
24+
type=process_hf_dataset, dataset=mmlu_fs_dataset, mode='test')
2925
test_dataloader = dict(
3026
batch_size=1,
3127
num_workers=1,

configs/_base_/datasets/mmlu_zs.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datasets import load_dataset
2-
from mmchat.datasets import process_hf_dataset
32
from mmengine.dataset import DefaultSampler
43

4+
from mmchat.datasets import process_hf_dataset
55

66
data_root = 'data/mmlu/'
77

@@ -13,19 +13,15 @@
1313
test=data_root + 'zero_shot_mmlu_test.json'))
1414

1515
val_mmlu_zs = dict(
16-
type=process_hf_dataset,
17-
dataset=mmlu_zs_dataset,
18-
mode='val')
16+
type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='val')
1917
val_dataloader = dict(
2018
batch_size=1,
2119
num_workers=1,
2220
dataset=val_mmlu_zs,
2321
sampler=dict(type=DefaultSampler, shuffle=False))
2422

2523
test_mmlu_zs = dict(
26-
type=process_hf_dataset,
27-
dataset=mmlu_zs_dataset,
28-
mode='test')
24+
type=process_hf_dataset, dataset=mmlu_zs_dataset, mode='test')
2925
test_dataloader = dict(
3026
batch_size=1,
3127
num_workers=1,

configs/_base_/datasets/oasst1.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from datasets import load_dataset
2-
from mmchat.datasets import process_hf_dataset
3-
from mmengine.dataset import DefaultSampler
4-
51
"""
62
------------ Dataset Meta Info (after `load_dataset`) ------------
73
@@ -31,19 +27,22 @@
3127
3228
"""
3329

30+
from datasets import load_dataset
31+
from mmengine.dataset import DefaultSampler
32+
33+
from mmchat.datasets import process_hf_dataset
3434

3535
oasst1 = dict(
36-
type = process_hf_dataset,
37-
dataset = dict(
38-
type = load_dataset,
39-
path = 'timdettmers/openassistant-guanaco',
36+
type=process_hf_dataset,
37+
dataset=dict(
38+
type=load_dataset,
39+
path='timdettmers/openassistant-guanaco',
4040
),
41-
map_fn = "lambda x: {'input': '', 'output': x['text']}",
41+
map_fn="lambda x: {'input': '', 'output': x['text']}",
4242
)
4343

4444
train_dataloader = dict(
4545
batch_size=16,
4646
num_workers=2,
47-
dataset = oasst1,
48-
sampler=dict(type=DefaultSampler, shuffle=True)
49-
)
47+
dataset=oasst1,
48+
sampler=dict(type=DefaultSampler, shuffle=True))

configs/_base_/default_runtime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from mmengine.hooks import (IterTimerHook, LoggerHook, ParamSchedulerHook,
2-
CheckpointHook, DistSamplerSeedHook)
1+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
2+
LoggerHook, ParamSchedulerHook)
33

44
# defaults to use registries in mmpretrain
55
default_scope = 'mmchat'
@@ -47,4 +47,4 @@
4747
resume = False
4848

4949
# Defaults to use random seed and disable `deterministic`
50-
randomness = dict(seed=None, deterministic=False)
50+
randomness = dict(seed=None, deterministic=False)

configs/_base_/schedules/guanaco.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
from mmengine.optim import OptimWrapper
2-
from mmengine.optim import LinearLR, ConstantLR
3-
from mmengine.runner import IterBasedTrainLoop
4-
from torch.optim import AdamW
51
from bitsandbytes.optim import PagedAdamW32bit
2+
from mmengine.optim import ConstantLR, LinearLR, OptimWrapper
3+
64
# optimizer
75
optim_wrapper = dict(
86
type=OptimWrapper,
9-
optimizer=dict(
10-
type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0),
7+
optimizer=dict(type=PagedAdamW32bit, lr=0.0002, weight_decay=0.0),
118
clip_grad=dict(max_norm=0.3, error_if_nonfinite=True),
12-
)
13-
9+
)
1410

1511
# learning policy
1612
param_scheduler = [
@@ -32,11 +28,8 @@
3228
]
3329

3430
# train, val, test setting
35-
train_cfg = dict(
36-
by_epoch=True,
37-
max_epochs = 3, val_interval=1)
38-
31+
train_cfg = dict(by_epoch=True, max_epochs=3, val_interval=1)
3932

4033
# NOTE: `auto_scale_lr` is for automatically scaling LR
4134
# based on the actual training batch size.
42-
auto_scale_lr = dict(base_batch_size=1)
35+
auto_scale_lr = dict(base_batch_size=1)

configs/_base_/schedules/guanaco_deepspeed.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from mmengine.optim import DeepSpeedOptimWrapper
21
from mmengine._strategy import DeepSpeedStrategy
2+
from mmengine.optim import DeepSpeedOptimWrapper
33
from torch.optim import AdamW
4+
45
# optimizer
56
optim_wrapper = dict(
67
type=DeepSpeedOptimWrapper,
7-
optimizer=dict(
8-
type=AdamW, lr=0.0002, weight_decay=0.0))
9-
8+
optimizer=dict(type=AdamW, lr=0.0002, weight_decay=0.0))
109

1110
# training strategy
1211
strategy = dict(
@@ -65,4 +64,4 @@
6564

6665
# NOTE: `auto_scale_lr` is for automatically scaling LR
6766
# based on the actual training batch size.
68-
auto_scale_lr = dict(base_batch_size=64)
67+
auto_scale_lr = dict(base_batch_size=64)

configs/alpaca/alpaca_standford.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,30 @@
11
from mmengine.config import read_base
22
from transformers import AutoModelForCausalLM, AutoTokenizer
3-
from mmchat.models import SupervisedFinetune, DataProcesorForCausalLM
4-
from transformers import BitsAndBytesConfig
5-
from peft import LoraConfig
6-
from dataclasses import dataclass
7-
import torch
3+
4+
from mmchat.models import DataProcesorForCausalLM, SupervisedFinetune
5+
86
with read_base():
9-
from .._base_.datasets.aplaca import *
10-
from .._base_.schedules.guanaco import *
11-
from .._base_.default_runtime import *
7+
from .._base_.datasets.aplaca import * # noqa: F401,F403
8+
from .._base_.default_runtime import * # noqa: F401,F403
9+
from .._base_.schedules.guanaco import * # noqa: F401,F403
1210

11+
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b'
1312
model = dict(
14-
type = SupervisedFinetune,
15-
data_preprocessor = dict(
13+
type=SupervisedFinetune,
14+
data_preprocessor=dict(
1615
type=DataProcesorForCausalLM,
1716
tokenizer=dict(
1817
type=AutoTokenizer.from_pretrained,
19-
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
20-
use_fast = False,
18+
pretrained_model_name_or_path=pretrained_model_name_or_path,
19+
use_fast=False,
2120
),
22-
source_max_len = 512,
23-
target_max_len = 512,
24-
train_on_source = False,
25-
predict_with_generate = False,
21+
source_max_len=512,
22+
target_max_len=512,
23+
train_on_source=False,
24+
predict_with_generate=False,
2625
),
27-
llm = dict(
26+
llm=dict(
2827
type=AutoModelForCausalLM.from_pretrained,
29-
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
28+
pretrained_model_name_or_path=pretrained_model_name_or_path,
3029
),
31-
3230
)
33-
Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,46 @@
1+
import torch
12
from mmengine.config import read_base
2-
from transformers import AutoModelForCausalLM, AutoTokenizer
3-
from mmchat.models import SupervisedQloraFinetune, DataProcesorForCausalLM
4-
from transformers import BitsAndBytesConfig
53
from peft import LoraConfig
6-
from dataclasses import dataclass
7-
import torch
4+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
5+
BitsAndBytesConfig)
6+
7+
from mmchat.models import DataProcesorForCausalLM, SupervisedQloraFinetune
8+
89
with read_base():
9-
from .._base_.datasets.aplaca import *
10-
from .._base_.schedules.guanaco import *
11-
from .._base_.default_runtime import *
10+
from .._base_.datasets.aplaca import * # noqa: F401,F403
11+
from .._base_.default_runtime import * # noqa: F401,F403
12+
from .._base_.schedules.guanaco import * # noqa: F401,F403
1213

14+
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b'
1315
model = dict(
14-
type = SupervisedQloraFinetune,
15-
data_preprocessor = dict(
16+
type=SupervisedQloraFinetune,
17+
data_preprocessor=dict(
1618
type=DataProcesorForCausalLM,
1719
tokenizer=dict(
1820
type=AutoTokenizer.from_pretrained,
19-
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
20-
use_fast = False,
21+
pretrained_model_name_or_path=pretrained_model_name_or_path,
22+
use_fast=False,
2123
),
22-
source_max_len = 512,
23-
target_max_len = 512,
24-
train_on_source = False,
25-
predict_with_generate = False,
24+
source_max_len=512,
25+
target_max_len=512,
26+
train_on_source=False,
27+
predict_with_generate=False,
2628
),
27-
llm = dict(
29+
llm=dict(
2830
type=AutoModelForCausalLM.from_pretrained,
29-
pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b',
31+
pretrained_model_name_or_path=pretrained_model_name_or_path,
3032
quantization_config=dict(
31-
type = BitsAndBytesConfig,
33+
type=BitsAndBytesConfig,
3234
load_in_4bit=True,
3335
load_in_8bit=False,
3436
llm_int8_has_fp16_weight=False,
3537
bnb_4bit_compute_dtype=torch.float16,
3638
bnb_4bit_use_double_quant=True,
37-
bnb_4bit_quant_type = 'nf4'
38-
)
39-
),
39+
bnb_4bit_quant_type='nf4')),
4040
lora=dict(
4141
type=LoraConfig,
42-
r = 64,
43-
lora_alpha = 16,
44-
lora_dropout = 0.1,
45-
bias = 'none',
46-
task_type = 'CAUSAL_LM'
47-
)
48-
49-
)
50-
42+
r=64,
43+
lora_alpha=16,
44+
lora_dropout=0.1,
45+
bias='none',
46+
task_type='CAUSAL_LM'))

0 commit comments

Comments
 (0)