Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mikazuki/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def create_toml_file(request: Request):
trainer_file = trainer_mapping[model_train_type]

if model_train_type != "sdxl-finetune":
if not train_utils.validate_data_dir(config["train_data_dir"]):
if not train_utils.validate_data_dir(config["train_data_dir"], config.pop('train_data_dir_repeat', None)):
return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。")

validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type)
Expand Down
3 changes: 2 additions & 1 deletion mikazuki/schema/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
RAW: {
DATASET_SETTINGS: {
train_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).default("./train/aki").description("训练数据集路径"),
train_data_dir_repeat: Schema.number().min(1).description("训练数据集重复次数,可选。默认根据图片数量自动选择,如果训练集已指定重复次数则会忽略该参数"),
reg_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).description("正则化数据集路径。默认留空,不使用正则化图像"),
prior_loss_weight: Schema.number().step(0.1).default(1.0).description("正则化 - 先验损失权重"),
resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数"),
resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数"),
enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"),
min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"),
max_bucket_reso: Schema.number().default(1024).description("arb 桶最大分辨率"),
Expand Down
7 changes: 5 additions & 2 deletions mikazuki/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def validate_model(model_name: str, training_type: str = "sd-lora"):
return False, "model not found"


def validate_data_dir(path):
def validate_data_dir(path, repeat):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
Expand All @@ -187,7 +187,10 @@ def validate_data_dir(path):
captions = glob.glob(path + '/*.txt')
log.info(f"{len(imgs)} images found, {len(captions)} captions found")
if len(imgs) > 0:
num_repeat = suggest_num_repeat(len(imgs))
if isinstance(repeat, int) and repeat > 0:
num_repeat = repeat
else:
num_repeat = suggest_num_repeat(len(imgs))
dataset_path = os.path.join(path, f"{num_repeat}_zkz")
os.makedirs(dataset_path)
for i in imgs:
Expand Down