diff --git a/attributes.go b/attributes.go index a98a3be..8e963c4 100644 --- a/attributes.go +++ b/attributes.go @@ -131,46 +131,52 @@ func (t AttrType) Value() uint16 { return uint16(t) } -func attrNames() map[AttrType]string { - return map[AttrType]string{ - AttrMappedAddress: "MAPPED-ADDRESS", - AttrUsername: "USERNAME", - AttrErrorCode: "ERROR-CODE", - AttrMessageIntegrity: "MESSAGE-INTEGRITY", - AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES", - AttrRealm: "REALM", - AttrNonce: "NONCE", - AttrXORMappedAddress: "XOR-MAPPED-ADDRESS", - AttrSoftware: "SOFTWARE", - AttrAlternateServer: "ALTERNATE-SERVER", - AttrFingerprint: "FINGERPRINT", - AttrPriority: "PRIORITY", - AttrUseCandidate: "USE-CANDIDATE", - AttrICEControlled: "ICE-CONTROLLED", - AttrICEControlling: "ICE-CONTROLLING", - AttrChannelNumber: "CHANNEL-NUMBER", - AttrLifetime: "LIFETIME", - AttrXORPeerAddress: "XOR-PEER-ADDRESS", - AttrData: "DATA", - AttrXORRelayedAddress: "XOR-RELAYED-ADDRESS", - AttrEvenPort: "EVEN-PORT", - AttrRequestedTransport: "REQUESTED-TRANSPORT", - AttrDontFragment: "DONT-FRAGMENT", - AttrReservationToken: "RESERVATION-TOKEN", - AttrConnectionID: "CONNECTION-ID", - AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY", - AttrMessageIntegritySHA256: "MESSAGE-INTEGRITY-SHA256", - AttrPasswordAlgorithm: "PASSWORD-ALGORITHM", - AttrUserhash: "USERHASH", - AttrPasswordAlgorithms: "PASSWORD-ALGORITHMS", - AttrAlternateDomain: "ALTERNATE-DOMAIN", - AttrDtlsInStun: "DTLS-IN-STUN", - AttrDtlsInStunAck: "DTLS-IN-STUN-ACKNOWLEDGEMENT", - } +// attrNames maps each attribute type implemented by this library to its +// human-readable name. It is consulted by AttrType.String() and by +// AttrType.Known() (which determines whether an unknown +// comprehension-required attribute should cause a message to be rejected +// in strict mode), so it MUST be kept in sync with the attributes +// actually implemented in the package. +// +//nolint:gochecknoglobals +var attrNames = map[AttrType]string{ + AttrMappedAddress: "MAPPED-ADDRESS", + AttrUsername: "USERNAME", + AttrErrorCode: "ERROR-CODE", + AttrMessageIntegrity: "MESSAGE-INTEGRITY", + AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES", + AttrRealm: "REALM", + AttrNonce: "NONCE", + AttrXORMappedAddress: "XOR-MAPPED-ADDRESS", + AttrSoftware: "SOFTWARE", + AttrAlternateServer: "ALTERNATE-SERVER", + AttrFingerprint: "FINGERPRINT", + AttrPriority: "PRIORITY", + AttrUseCandidate: "USE-CANDIDATE", + AttrICEControlled: "ICE-CONTROLLED", + AttrICEControlling: "ICE-CONTROLLING", + AttrChannelNumber: "CHANNEL-NUMBER", + AttrLifetime: "LIFETIME", + AttrXORPeerAddress: "XOR-PEER-ADDRESS", + AttrData: "DATA", + AttrXORRelayedAddress: "XOR-RELAYED-ADDRESS", + AttrEvenPort: "EVEN-PORT", + AttrRequestedTransport: "REQUESTED-TRANSPORT", + AttrDontFragment: "DONT-FRAGMENT", + AttrReservationToken: "RESERVATION-TOKEN", + AttrConnectionID: "CONNECTION-ID", + AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY", + AttrMessageIntegritySHA256: "MESSAGE-INTEGRITY-SHA256", + AttrPasswordAlgorithm: "PASSWORD-ALGORITHM", + AttrUserhash: "USERHASH", + AttrPasswordAlgorithms: "PASSWORD-ALGORITHMS", + AttrAlternateDomain: "ALTERNATE-DOMAIN", + AttrDtlsInStun: "DTLS-IN-STUN", + AttrDtlsInStunAck: "DTLS-IN-STUN-ACKNOWLEDGEMENT", } func (t AttrType) String() string { - s, ok := attrNames()[t] + s, ok := attrNames[t] if !ok { // Just return hex representation of unknown attribute type. return fmt.Sprintf("0x%x", uint16(t)) @@ -182,7 +188,7 @@ func (t AttrType) String() string { // Known returns true if AttrType is known and implemented // by this library. func (t AttrType) Known() bool { - _, valid := attrNames()[t] + _, valid := attrNames[t] return valid } diff --git a/attributes_test.go b/attributes_test.go index d910376..0a29f4c 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -108,7 +108,7 @@ func TestAttrTypeRange(t *testing.T) { func TestAttrTypeKnown(t *testing.T) { // All Attributes in attrNames should be known - for attr := range attrNames() { + for attr := range attrNames { assert.True(t, attr.Known()) } diff --git a/iana_test.go b/iana_test.go index abf139c..271173e 100644 --- a/iana_test.go +++ b/iana_test.go @@ -71,7 +71,7 @@ func TestIANA(t *testing.T) { //nolint:cyclop maps.Copy(attrTypes, map[string]AttrType{ "ORIGIN": 0x802F, }) - for val, name := range attrNames() { + for val, name := range attrNames { mapped, ok := attrTypes[name] assert.True(t, ok, "failed to find attribute %s in IANA", name) assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val) diff --git a/message.go b/message.go index b3c4c2d..b1227a9 100644 --- a/message.go +++ b/message.go @@ -407,8 +407,13 @@ var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to rea // ErrInvalidType means that the message type is 0 (reserved). var ErrInvalidType = errors.New("STUN message type 0 is reserved") +// ErrUnknownComprehensionRequired means that the message contains an +// attribute in the comprehension-required range (0x0000-0x7FFF) that +// is not implemented by this library and therefore cannot be processed. +var ErrUnknownComprehensionRequired = errors.New("unknown comprehension-required attribute") + // Decode decodes m.Raw into m. -func (m *Message) Decode() error { //nolint:cyclop +func (m *Message) Decode() error { //nolint:gocognit,cyclop // decoding message header buf := m.Raw if len(buf) < messageHeaderSize { @@ -490,6 +495,18 @@ func (m *Message) Decode() error { //nolint:cyclop if m.strict && afterIntegrity { continue } + if attr.Type.Required() && !attr.Type.Known() { + if m.strict { + if m.logger != nil { + m.logger.Errorf("unknown comprehension-required attribute %s", attr.Type.String()) + } + + return ErrUnknownComprehensionRequired + } + if m.logger != nil { + m.logger.Warnf("unknown comprehension-required attribute %s", attr.Type.String()) + } + } m.Attributes = append(m.Attributes, attr) if isMI { seenMI = true diff --git a/message_test.go b/message_test.go index bda23d5..1d4ce51 100644 --- a/message_test.go +++ b/message_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/pion/logging" "github.com/stretchr/testify/assert" ) @@ -905,3 +906,49 @@ func TestMessageReservedType(t *testing.T) { _, err = mDecodedStrict.ReadFrom(bytes.NewReader(m.Raw)) assert.ErrorIs(t, err, ErrInvalidType) } + +func TestMessageUnknownComprehensionRequired(t *testing.T) { + const unknownAttr AttrType = 0x7FFE // comprehension-required, not implemented + + loggerFactory := logging.NewDefaultLoggerFactory() + loggerFactory.DefaultLogLevel = logging.LogLevelWarn + var logOutput bytes.Buffer + loggerFactory.Writer = &logOutput + logger := loggerFactory.NewLogger("stun") + + m1 := New() + m1.Type = BindingRequest + m1.TransactionID = NewTransactionID() + m1.Add(unknownAttr, []byte{1, 2, 3, 4}) + m1.WriteHeader() + + // Backward-compat behavior: unknown attribute is retained. + mDecoded := NewWithOptions(WithStrict(false), withMessageLogger(logger)) + _, err := mDecoded.ReadFrom(bytes.NewReader(m1.Raw)) + assert.NoError(t, err) + _, found := mDecoded.Attributes.Get(unknownAttr) + assert.True(t, found) + assert.Contains(t, logOutput.String(), "unknown comprehension-required attribute 0x7ffe") + + logOutput.Reset() + + // Strict mode: decoding fails. + mDecodedStrict := NewWithOptions(WithStrict(true), withMessageLogger(logger)) + _, err = mDecodedStrict.ReadFrom(bytes.NewReader(m1.Raw)) + assert.ErrorIs(t, err, ErrUnknownComprehensionRequired) + assert.Contains(t, logOutput.String(), "unknown comprehension-required attribute 0x7ffe") + + logOutput.Reset() + + // Comprehension-optional unknown attribute is always retained. + const unknownOptional AttrType = 0xFFFE + m2 := New() + m2.Type = BindingRequest + m2.TransactionID = NewTransactionID() + m2.Add(unknownOptional, []byte{1, 2, 3, 4}) + m2.WriteHeader() + mDecodedStrict2 := NewWithOptions(WithStrict(true), withMessageLogger(logger)) + _, err = mDecodedStrict2.ReadFrom(bytes.NewReader(m2.Raw)) + assert.NoError(t, err) + assert.Empty(t, logOutput.String()) +}