diff --git a/transformerlab/plugins/mlx_lora_trainer/main.py b/transformerlab/plugins/mlx_lora_trainer/main.py index d34cd4c21..abfd6c385 100644 --- a/transformerlab/plugins/mlx_lora_trainer/main.py +++ b/transformerlab/plugins/mlx_lora_trainer/main.py @@ -46,6 +46,10 @@ def train_mlx_lora(): chat_column = tlab_trainer.params.get("chatml_formatted_column", "messages") formatting_template = tlab_trainer.params.get("formatting_template", None) + # Check if we are starting from a checkpoint: + restart_from_checkpoint = tlab_trainer.params.get("restart_from_checkpoint", None) + print("Restarting from checkpoint:", restart_from_checkpoint) + if num_train_epochs is not None and num_train_epochs != "" and int(num_train_epochs) >= 0: if num_train_epochs == 0: print( @@ -109,6 +113,11 @@ def train_mlx_lora(): if not os.path.exists(adaptor_output_dir): os.makedirs(adaptor_output_dir) + # If checkpointing is enabled, then set the path to be the checkpoint name in the adaptor output directory + if restart_from_checkpoint: + checkpoint_path = os.path.join(adaptor_output_dir, restart_from_checkpoint) + print("Using checkpoint output directory:", checkpoint_path) + # Get Python executable (from venv if available) python_executable = get_python_executable(plugin_dir) env = os.environ.copy() @@ -142,6 +151,9 @@ def train_mlx_lora(): str(save_every), ] + if restart_from_checkpoint: + popen_command.extend(["--resume-adapter-file", checkpoint_path]) + # If a config file has been created then include it if config_file: popen_command.extend(["--config", config_file]) diff --git a/transformerlab/routers/tasks.py b/transformerlab/routers/tasks.py index 333529a4f..c8f69eacb 100644 --- a/transformerlab/routers/tasks.py +++ b/transformerlab/routers/tasks.py @@ -223,6 +223,8 @@ async def queue_task(task_id: int, input_override: str = "{}", output_override: if not isinstance(output_override, dict): output_override = json.loads(output_override) + print(f"Input override: {input_override}") + if not isinstance(task_to_queue["config"], dict): task_to_queue["config"] = json.loads(task_to_queue["config"])