Skip to content

Commit 6997c66

Browse files
committed
Add prompt_ignore_length to SamplingConfig to control how many tokens to ignore from the prompt for precense and frequency penalties for trt and torch path.
Signed-off-by: Xuanyu Chen <[email protected]>
1 parent a36b48b commit 6997c66

30 files changed

+534
-196
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class SamplingConfig
7171
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
7272
std::optional<FloatType> const& presencePenalty = std::nullopt,
7373
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
74+
std::optional<SizeType32> const& promptIgnoreLength = std::nullopt,
7475
std::optional<FloatType> const& lengthPenalty = std::nullopt,
7576
std::optional<SizeType32> const& earlyStopping = std::nullopt,
7677
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
@@ -94,6 +95,7 @@ class SamplingConfig
9495
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
9596
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
9697
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
98+
[[nodiscard]] std::optional<SizeType32> getPromptIgnoreLength() const;
9799
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
98100
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
99101
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
@@ -114,6 +116,7 @@ class SamplingConfig
114116
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
115117
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
116118
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
119+
void setPromptIgnoreLength(std::optional<SizeType32> const& promptIgnoreLength);
117120
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
118121
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
119122
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
@@ -133,6 +136,8 @@ class SamplingConfig
133136
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
134137
std::optional<FloatType> const& beamSearchDiversityRate);
135138
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
139+
static std::optional<SizeType32> const& checkPromptIgnoreLength(
140+
std::optional<SizeType32> const& promptIgnoreLength);
136141
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
137142
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
138143
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
@@ -174,6 +179,9 @@ class SamplingConfig
174179
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
175180
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
176181
std::optional<FloatType> mFrequencyPenalty;
182+
/// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have
183+
/// no effect. Values > input (prompt) length will be clamped. Default is 0.
184+
std::optional<SizeType32> mPromptIgnoreLength;
177185
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
178186
std::optional<FloatType> mLengthPenalty;
179187
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with

cpp/include/tensorrt_llm/layers/defaultDecodingParams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ class DefaultDecodingParams
5656
return 1;
5757
}
5858

59+
[[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getPromptIgnoreLength()
60+
{
61+
return 0;
62+
}
63+
5964
[[nodiscard]] __host__ __device__ static constexpr uint64_t getSeed()
6065
{
6166
return 0;

cpp/include/tensorrt_llm/runtime/samplingConfig.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ class SamplingConfig
133133
frequencyPenalty = fuseValues<FloatType>(
134134
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
135135
layers::DefaultDecodingParams::getFrequencyPenalty());
136+
promptIgnoreLength = fuseValues<SizeType32>(
137+
configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; },
138+
layers::DefaultDecodingParams::getPromptIgnoreLength());
136139
noRepeatNgramSize = fuseValues<SizeType32>(
137140
configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; },
138141
layers::DefaultDecodingParams::getNoRepeatNgramSize());
@@ -224,6 +227,7 @@ class SamplingConfig
224227
SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType)
225228
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
226229
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
230+
SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32)
227231
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
228232
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32)
229233
SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
@@ -342,6 +346,7 @@ class SamplingConfig
342346
OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
343347
OptVec<FloatType> presencePenalty; // [1] or [batchSize]
344348
OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
349+
OptVec<SizeType32> promptIgnoreLength; // [1] or [batchSize]
345350
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]
346351

347352
// probs
@@ -377,13 +382,14 @@ class SamplingConfig
377382
&& temperature == other.temperature && originalTemperature == other.originalTemperature
378383
&& minLength == other.minLength && repetitionPenalty == other.repetitionPenalty
379384
&& presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty
380-
&& noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP
381-
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
382-
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
383-
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
384-
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
385-
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
386-
&& cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray;
385+
&& promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize
386+
&& topK == other.topK && topP == other.topP && randomSeed == other.randomSeed
387+
&& topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds
388+
&& beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty
389+
&& earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold
390+
&& topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs
391+
&& outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP
392+
&& beamWidthArray == other.beamWidthArray;
387393
}
388394

389395
SizeType32 getNumReturnBeams() const

cpp/tensorrt_llm/executor/samplingConfig.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
3434
OptFloat const& topPMin, std::optional<TokenIdType> const& topPResetIds, OptFloat const& topPDecay,
3535
std::optional<RandomSeedType> const& seed, OptFloat const& temperature, OptSize32 const& minTokens,
3636
OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty,
37-
OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping,
38-
OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP,
39-
OptVec<SizeType32> const& beamWidthArray)
37+
OptFloat const& frequencyPenalty, OptSize32 const& promptIgnoreLength, OptFloat const& lengthPenalty,
38+
OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences,
39+
OptFloat const& minP, OptVec<SizeType32> const& beamWidthArray)
4040
: mBeamWidth(checkBeamWidth(beamWidth))
4141
, mTopK(checkTopK(topK))
4242
, mTopP(checkTopP(topP))
@@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
5050
, mRepetitionPenalty(checkRepetitionPenalty(repetitionPenalty))
5151
, mPresencePenalty(presencePenalty)
5252
, mFrequencyPenalty(frequencyPenalty)
53+
, mPromptIgnoreLength(checkPromptIgnoreLength(promptIgnoreLength))
5354
, mLengthPenalty(checkLengthPenalty(lengthPenalty))
5455
, mEarlyStopping(checkEarlyStopping(earlyStopping))
5556
, mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize))
@@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
6768
&& mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
6869
&& mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
6970
&& mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
70-
&& mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
71-
&& mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
72-
&& mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray;
71+
&& mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
72+
&& mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
73+
&& mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
74+
&& mBeamWidthArray == other.mBeamWidthArray;
7375
}
7476

7577
// Getters
@@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
143145
return mFrequencyPenalty;
144146
}
145147

148+
OptSize32 SamplingConfig::getPromptIgnoreLength() const
149+
{
150+
return mPromptIgnoreLength;
151+
}
152+
146153
OptFloat SamplingConfig::getLengthPenalty() const
147154
{
148155
return mLengthPenalty;
@@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
240247
mFrequencyPenalty = frequencyPenalty;
241248
}
242249

250+
void SamplingConfig::setPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
251+
{
252+
mPromptIgnoreLength = checkPromptIgnoreLength(promptIgnoreLength);
253+
}
254+
243255
void SamplingConfig::setLengthPenalty(OptFloat const& lengthPenalty)
244256
{
245257
mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
@@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
362374
return repetitionpenalty;
363375
}
364376

377+
OptSize32 const& SamplingConfig::checkPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
378+
{
379+
if (promptIgnoreLength.has_value())
380+
{
381+
TLLM_CHECK(promptIgnoreLength.value() >= 0);
382+
}
383+
return promptIgnoreLength;
384+
}
385+
365386
OptFloat const& SamplingConfig::checkLengthPenalty(OptFloat const& lengthPenalty)
366387
{
367388
if (lengthPenalty.has_value())

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
158158
auto repetitionPenalty = su::deserialize<std::optional<FloatType>>(is);
159159
auto presencePenalty = su::deserialize<std::optional<FloatType>>(is);
160160
auto frequencyPenalty = su::deserialize<std::optional<FloatType>>(is);
161+
auto promptIgnoreLength = su::deserialize<std::optional<SizeType32>>(is);
161162
auto lengthPenalty = su::deserialize<std::optional<FloatType>>(is);
162163
auto earlyStopping = su::deserialize<std::optional<SizeType32>>(is);
163164
auto noRepeatNgramSize = su::deserialize<std::optional<SizeType32>>(is);
@@ -166,8 +167,8 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
166167
auto beamWidthArray = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
167168

168169
return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength,
169-
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping,
170-
noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
170+
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength,
171+
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
171172
}
172173

173174
void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
@@ -185,6 +186,7 @@ void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
185186
su::serialize(config.mRepetitionPenalty, os);
186187
su::serialize(config.mPresencePenalty, os);
187188
su::serialize(config.mFrequencyPenalty, os);
189+
su::serialize(config.mPromptIgnoreLength, os);
188190
su::serialize(config.mLengthPenalty, os);
189191
su::serialize(config.mEarlyStopping, os);
190192
su::serialize(config.mNoRepeatNgramSize, os);
@@ -209,6 +211,7 @@ size_t Serialization::serializedSize(SamplingConfig const& config)
209211
totalSize += su::serializedSize(config.mRepetitionPenalty);
210212
totalSize += su::serializedSize(config.mPresencePenalty);
211213
totalSize += su::serializedSize(config.mFrequencyPenalty);
214+
totalSize += su::serializedSize(config.mPromptIgnoreLength);
212215
totalSize += su::serializedSize(config.mLengthPenalty);
213216
totalSize += su::serializedSize(config.mEarlyStopping);
214217
totalSize += su::serializedSize(config.mNoRepeatNgramSize);

0 commit comments

Comments
 (0)