Skip to content

[Issue]: dumpProxyState cannot be correctly triggered with multiple communication groups in PyTorch #1930

@Rhai2307

Description

@Rhai2307

Question

I’m trying to use dumpProxyState to debug a hang issue, but the behavior doesn’t match my expectations.

nccl/src/proxy.cc

Lines 281 to 312 in 59242d7

ncclResult_t dumpProxyState(struct ncclProxyProgressState* state) {
struct ncclProxyArgs* op = state->active;
int poolIndex, opIndex;
printf("ACTIVE OPS\n");
while (op) {
NCCLCHECK(getOpIndex(op, state, &poolIndex, &opIndex));
if (op->state & OP_SEEN) {
WARN("List loop at element %d-%d", poolIndex, opIndex);
}
NCCLCHECK(printProxyOp(op, poolIndex, opIndex));
op->state |= OP_SEEN;
printf("\n");
struct ncclProxyArgs* nextOp = op->nextPeer;
while (nextOp) {
NCCLCHECK(getOpIndex(nextOp, state, &poolIndex, &opIndex));
if (nextOp->state & OP_SEEN) {
WARN("List loop at element %d-%d", poolIndex, opIndex);
}
printf("| `-> ");
NCCLCHECK(printProxyOp(nextOp, poolIndex, opIndex));
nextOp->state |= OP_SEEN;
printf("\n");
if (nextOp->next) {
WARN("Inactive op has next set!");
}
nextOp = nextOp->nextPeer;
}
if (op->nextPeer == NULL) printf("|\n");
op = op->next;
printf("v\n");
}
printf("[X]\n");

What I did:

  • I set the PROXY_DUMP_SIGNAL environment variable.
  • I used Megatron-LM for training (or just use PyTorch torch.distributed.new_group to create multiple communication groups).
  • I sent the corresponding signal to the process, and confirmed that dumpProxyState is triggered.

However, the output only shows:

ACTIVE OPS
[X]

It does not print the detailed information from printProxyOp, even though there are active proxy operations and I expect all of them to be dumped for debugging purposes.

Analysis

The ncclLastProxyState used by dumpProxyState is stored in a process-global static variable. However, when using Megatron-LM or PyTorch with multiple communication groups, multiple proxy threads can exist within the same process.
As a result, the static variable can only reference the most recently created proxy thread, and proxy operations owned by other proxy threads are not visible when dumpProxyState is triggered.

static ncclProxyProgressState* ncclLastProxyState;

Additional Question

Currently, the dump logic only prints Send / Recv patterns and does not include Collective (Coll) operations. I think supporting collective (Coll) operations in dumpProxyState is necessary, because without it, debugging collective-heavy workloads (e.g., AllReduce, AllGather) is very limited.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions