From 4e50a82f160c006c657bfd132e464b4cbd1b7172 Mon Sep 17 00:00:00 2001 From: magicFeirl <2100709458@qq.com> Date: Tue, 14 Jan 2025 21:34:47 +0800 Subject: [PATCH] feat: Support to customize dataset repeat --- mikazuki/app/api.py | 2 +- mikazuki/schema/shared.ts | 3 ++- mikazuki/utils/train_utils.py | 7 +++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mikazuki/app/api.py b/mikazuki/app/api.py index 68c887e3..bde842db 100644 --- a/mikazuki/app/api.py +++ b/mikazuki/app/api.py @@ -134,7 +134,7 @@ async def create_toml_file(request: Request): trainer_file = trainer_mapping[model_train_type] if model_train_type != "sdxl-finetune": - if not train_utils.validate_data_dir(config["train_data_dir"]): + if not train_utils.validate_data_dir(config["train_data_dir"], config.pop('train_data_dir_repeat', None)): return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。") validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type) diff --git a/mikazuki/schema/shared.ts b/mikazuki/schema/shared.ts index d18e215d..5a147591 100644 --- a/mikazuki/schema/shared.ts +++ b/mikazuki/schema/shared.ts @@ -6,9 +6,10 @@ RAW: { DATASET_SETTINGS: { train_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).default("./train/aki").description("训练数据集路径"), + train_data_dir_repeat: Schema.number().min(1).description("训练数据集重复次数,可选。默认根据图片数量自动选择,如果训练集已指定重复次数则会忽略该参数"), reg_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).description("正则化数据集路径。默认留空,不使用正则化图像"), prior_loss_weight: Schema.number().step(0.1).default(1.0).description("正则化 - 先验损失权重"), - resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"), + resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数"), enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"), min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"), max_bucket_reso: Schema.number().default(1024).description("arb 桶最大分辨率"), diff --git a/mikazuki/utils/train_utils.py b/mikazuki/utils/train_utils.py index 2e27c975..6c35db76 100644 --- a/mikazuki/utils/train_utils.py +++ b/mikazuki/utils/train_utils.py @@ -84,7 +84,7 @@ def match_model_type(sig_content: bytes): return ModelType.UNKNOWN -def validate_data_dir(path): +def validate_data_dir(path, repeat): if not os.path.exists(path): log.error(f"Data dir {path} not exists, check your params") return False @@ -111,7 +111,10 @@ def validate_data_dir(path): captions = glob.glob(path + '/*.txt') log.info(f"{len(imgs)} images found, {len(captions)} captions found") if len(imgs) > 0: - num_repeat = suggest_num_repeat(len(imgs)) + if isinstance(repeat, int) and repeat > 0: + num_repeat = repeat + else: + num_repeat = suggest_num_repeat(len(imgs)) dataset_path = os.path.join(path, f"{num_repeat}_zkz") os.makedirs(dataset_path) for i in imgs: