From e41152f89c5f9968f0602c7aecf09072b7cdfada Mon Sep 17 00:00:00 2001 From: qianyu chen <38046403+qyc-98@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:53:48 +0800 Subject: [PATCH] Update trainer.py --- finetune/trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/finetune/trainer.py b/finetune/trainer.py index 7da95ed8..9d659557 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -7,7 +7,7 @@ from transformers.utils import is_sagemaker_mp_enabled from transformers.trainer import * from transformers.integrations import is_deepspeed_zero3_enabled - +from typing import Dict, List, Optional, Tuple class CPMTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): @@ -170,7 +170,7 @@ def prediction_step( return (loss, logits, labels) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -189,8 +189,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, `torch.Tensor`: The tensor with training loss on this batch. """ model.train() - inputs = self._prepare_inputs(inputs) - + inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device)