Skip to content

Conversation

@xhr15
Copy link

@xhr15 xhr15 commented Oct 21, 2025

Logits processing is a powerful tool, particularly for using smaller language models for tasks such as named entity recognition. @seanmor5 started work in this area with #354.

Whatever the approach, it will require some kind of state.

This pull request is a proposal to allow logits processors to be stateful.

This would enable the use of deterministic finite automata (DFAs) or pushdown automata (PDAs) for processing constrained grammars in logits processing. bitcrowd#6 shows how this would be used. We will follow up on this PR if this approach is favoured.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @xhr15 and @joelpaulkoch, thanks for the PR!

I dropped a few comments, but the main one is about the API. I know it's a bit more involved, but probably worth it. Let me know what you think, and if you have any concerns!

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_id],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current API, the state is always initialized to %{} and then first invocation of the processor adds a key, here %{next_suppressed_token_id: %Nx.Tensor{...}}.

This can be problematic in defn while loop, which requires the accumulation sate to always have the same shape. In other words, the initial state should already include :next_suppressed_token_id with the default tensor. It is possible that this didn't come up during your tests, because depending on the model/input, we do the first generation step outside of the while loop, and the first call would initialize the state. However, if we are going to support stateful, I would rather do it in a more robust way.

Given the above, a stateless logits processor would involve two steps (functions):

  1. Building an initial state.
  2. Performing logits processing, which receives logits and state, and returns update logits and state.

This way we can call (1) when initializing the generation context, and for the actual processing we call (2).

The behaviour can be similar to Bumblebee.Scheduler. Something like this:

defmodule Bumblebee.LogitsProcessor do
  @moduledoc """
  An interface for configuring and using logits processors.

  Logits processors are used during autoregressive generation to modify
  predicted scores at each generation step. This allows for applying
  certain rules to the model output to control which tokens are picked
  at each generation step, and which are not.

  Every module implementing this behaviour is expected to also define
  a configuration struct.
  """

  @type t :: Bumblebee.Configurable.t()

  @type state :: Nx.Container.t()

  @doc """
  Initializes state for a new logits processor.

  Returns `state`, which is an opaque `Nx.Container`, and it is then
  passed to and returned from `process/2`.

  Oftentimes logits processors are stateless, in which case this
  function can return an empty continer, such as `{}`.
  """
  @callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

  @doc """
  Processes logits, applying specific rules.
  """
  @callback process(
              t(),
              state(),
              logits :: Nx.Tensor.t(),
              context :: context
            ) :: {state :: map(), logits :: Nx.Tensor.t()}
            when context: %{
                   sequence: Nx.Tensor.t(),
                   length: Nx.Tensor.t(),
                   input_length: Nx.Tensor.t()
                 }
end

Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define a bunch of new modules.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko Thank you very much for your comments! I think esp. the two step call makes sense. We'll move in that direction :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko
as an afterthought:

What is the use case for context here:

@callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

Later in the loop, context holds:

context = %{
      sequences: sequences,
      input_length: length,
      length: length,
    }

I am wondering how those would influence the initialisation of the logits processors?

Or are you planning of using additional keys? E.g. from the state as returned by init squence:

%{
      sequences: sequences,
      input_length: length,
      length: length,
      finished_length: finished_length,
      ignored: Nx.broadcast(0, {batch_size})
    }

If that was the case, we should probably rename the parameter to state or initial_state.

Wdyt?

@xhr15
Copy link
Author

xhr15 commented Oct 24, 2025

@jonatanklosko Before we add more test and do further refactorings: Do you think this goes in the right direction? Please let me know if you have concerns or anything could be improved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants