|
| 1 | +import torch |
1 | 2 | from mmengine.config import read_base |
2 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
3 | | -from mmchat.models import SupervisedQloraFinetune, DataProcesorForCausalLM |
4 | | -from transformers import BitsAndBytesConfig |
5 | 3 | from peft import LoraConfig |
6 | | -from dataclasses import dataclass |
7 | | -import torch |
| 4 | +from transformers import (AutoModelForCausalLM, AutoTokenizer, |
| 5 | + BitsAndBytesConfig) |
| 6 | + |
| 7 | +from mmchat.models import DataProcesorForCausalLM, SupervisedQloraFinetune |
| 8 | + |
8 | 9 | with read_base(): |
9 | | - from .._base_.datasets.aplaca import * |
10 | | - from .._base_.schedules.guanaco import * |
11 | | - from .._base_.default_runtime import * |
| 10 | + from .._base_.datasets.aplaca import * # noqa: F401,F403 |
| 11 | + from .._base_.default_runtime import * # noqa: F401,F403 |
| 12 | + from .._base_.schedules.guanaco import * # noqa: F401,F403 |
12 | 13 |
|
| 14 | +pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b' |
13 | 15 | model = dict( |
14 | | - type = SupervisedQloraFinetune, |
15 | | - data_preprocessor = dict( |
| 16 | + type=SupervisedQloraFinetune, |
| 17 | + data_preprocessor=dict( |
16 | 18 | type=DataProcesorForCausalLM, |
17 | 19 | tokenizer=dict( |
18 | 20 | type=AutoTokenizer.from_pretrained, |
19 | | - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', |
20 | | - use_fast = False, |
| 21 | + pretrained_model_name_or_path=pretrained_model_name_or_path, |
| 22 | + use_fast=False, |
21 | 23 | ), |
22 | | - source_max_len = 512, |
23 | | - target_max_len = 512, |
24 | | - train_on_source = False, |
25 | | - predict_with_generate = False, |
| 24 | + source_max_len=512, |
| 25 | + target_max_len=512, |
| 26 | + train_on_source=False, |
| 27 | + predict_with_generate=False, |
26 | 28 | ), |
27 | | - llm = dict( |
| 29 | + llm=dict( |
28 | 30 | type=AutoModelForCausalLM.from_pretrained, |
29 | | - pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b', |
| 31 | + pretrained_model_name_or_path=pretrained_model_name_or_path, |
30 | 32 | quantization_config=dict( |
31 | | - type = BitsAndBytesConfig, |
| 33 | + type=BitsAndBytesConfig, |
32 | 34 | load_in_4bit=True, |
33 | 35 | load_in_8bit=False, |
34 | 36 | llm_int8_has_fp16_weight=False, |
35 | 37 | bnb_4bit_compute_dtype=torch.float16, |
36 | 38 | bnb_4bit_use_double_quant=True, |
37 | | - bnb_4bit_quant_type = 'nf4' |
38 | | - ) |
39 | | - ), |
| 39 | + bnb_4bit_quant_type='nf4')), |
40 | 40 | lora=dict( |
41 | 41 | type=LoraConfig, |
42 | | - r = 64, |
43 | | - lora_alpha = 16, |
44 | | - lora_dropout = 0.1, |
45 | | - bias = 'none', |
46 | | - task_type = 'CAUSAL_LM' |
47 | | - ) |
48 | | - |
49 | | -) |
50 | | - |
| 42 | + r=64, |
| 43 | + lora_alpha=16, |
| 44 | + lora_dropout=0.1, |
| 45 | + bias='none', |
| 46 | + task_type='CAUSAL_LM')) |
0 commit comments