@@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
34
34
OptFloat const & topPMin, std::optional<TokenIdType> const & topPResetIds, OptFloat const & topPDecay,
35
35
std::optional<RandomSeedType> const & seed, OptFloat const & temperature, OptSize32 const & minTokens,
36
36
OptFloat const & beamSearchDiversityRate, OptFloat const & repetitionPenalty, OptFloat const & presencePenalty,
37
- OptFloat const & frequencyPenalty, OptFloat const & lengthPenalty, OptSize32 const & earlyStopping ,
38
- OptSize32 const & noRepeatNgramSize , OptSize32 const & numReturnSequences, OptFloat const & minP ,
39
- OptVec<SizeType32> const & beamWidthArray)
37
+ OptFloat const & frequencyPenalty, OptSize32 const & promptIgnoreLength, OptFloat const & lengthPenalty ,
38
+ OptSize32 const & earlyStopping , OptSize32 const & noRepeatNgramSize, OptSize32 const & numReturnSequences ,
39
+ OptFloat const & minP, OptVec<SizeType32> const & beamWidthArray)
40
40
: mBeamWidth (checkBeamWidth(beamWidth))
41
41
, mTopK (checkTopK(topK))
42
42
, mTopP (checkTopP(topP))
@@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
50
50
, mRepetitionPenalty (checkRepetitionPenalty(repetitionPenalty))
51
51
, mPresencePenalty (presencePenalty)
52
52
, mFrequencyPenalty (frequencyPenalty)
53
+ , mPromptIgnoreLength (checkPromptIgnoreLength(promptIgnoreLength))
53
54
, mLengthPenalty (checkLengthPenalty(lengthPenalty))
54
55
, mEarlyStopping (checkEarlyStopping(earlyStopping))
55
56
, mNoRepeatNgramSize (checkNoRepeatNgramSize(noRepeatNgramSize))
@@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
67
68
&& mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
68
69
&& mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
69
70
&& mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
70
- && mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
71
- && mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
72
- && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray ;
71
+ && mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
72
+ && mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
73
+ && mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
74
+ && mBeamWidthArray == other.mBeamWidthArray ;
73
75
}
74
76
75
77
// Getters
@@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
143
145
return mFrequencyPenalty ;
144
146
}
145
147
148
+ OptSize32 SamplingConfig::getPromptIgnoreLength () const
149
+ {
150
+ return mPromptIgnoreLength ;
151
+ }
152
+
146
153
OptFloat SamplingConfig::getLengthPenalty () const
147
154
{
148
155
return mLengthPenalty ;
@@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
240
247
mFrequencyPenalty = frequencyPenalty;
241
248
}
242
249
250
+ void SamplingConfig::setPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
251
+ {
252
+ mPromptIgnoreLength = checkPromptIgnoreLength (promptIgnoreLength);
253
+ }
254
+
243
255
void SamplingConfig::setLengthPenalty (OptFloat const & lengthPenalty)
244
256
{
245
257
mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
@@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
362
374
return repetitionpenalty;
363
375
}
364
376
377
+ OptSize32 const & SamplingConfig::checkPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
378
+ {
379
+ if (promptIgnoreLength.has_value ())
380
+ {
381
+ TLLM_CHECK (promptIgnoreLength.value () >= 0 );
382
+ }
383
+ return promptIgnoreLength;
384
+ }
385
+
365
386
OptFloat const & SamplingConfig::checkLengthPenalty (OptFloat const & lengthPenalty)
366
387
{
367
388
if (lengthPenalty.has_value ())
0 commit comments