Skip to content

Commit cb91bd9

Browse files
committed
fix: address data races and validation issues for #2248
This commit fixes 6 issues (4 HIGH, 2 MEDIUM severity) identified in the review of PR #2274. All fixes address the root causes of the reported concurrency and validation bugs. HIGH Severity Fixes: - pkg/audit: Fix data races in getPreviousHash and record by using sync.RWMutex and ensuring atomic updates to the cryptographic chain. - pkg/audit: Add error check for crypto/rand.Read to prevent predictable ID generation. - pkg/tools/mcp: Fix data race in createHTTPClient by protecting reads of managed and oauthConfig fields. MEDIUM Severity Fixes: - pkg/upstream: Add nil check for base transport in NewHeaderTransport to avoid potential nil pointer dereference. - pkg/tools/mcp: Enforce RFC 6749 compliance by validating token_type in OAuth responses. Verification: - All pkg/audit tests pass (13/13) - All pkg/tools/mcp OAuth tests pass (6/6) - Successful build of project and final binary Fixes #2248 Ref: PR #2274
1 parent ba36b3f commit cb91bd9

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

pkg/audit/audit.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ type ErrorAction struct {
153153

154154
// Auditor handles audit trail recording and verification
155155
type Auditor struct {
156-
mu sync.Mutex
156+
mu sync.RWMutex
157157
privateKey ed25519.PrivateKey
158158
publicKey ed25519.PublicKey
159159
records []*AuditRecord
@@ -448,10 +448,11 @@ func (a *Auditor) record(ctx context.Context, sess *session.Session, agentName s
448448
// Get timestamp
449449
timestamp := time.Now().UTC().Format(time.RFC3339Nano)
450450

451-
// Get previous hash for chain integrity
452-
a.mu.Unlock()
453-
previousHash := a.getPreviousHash(sess.ID)
454-
a.mu.Lock()
451+
// Get previous hash for chain integrity (read while holding write lock)
452+
var previousHash string
453+
if hash, ok := a.sessionHash[sess.ID]; ok {
454+
previousHash = hash
455+
}
455456

456457
// Create record
457458
record := &AuditRecord{
@@ -486,15 +487,21 @@ func (a *Auditor) record(ctx context.Context, sess *session.Session, agentName s
486487
// Store record
487488
a.records = append(a.records, record)
488489

489-
// Persist to disk
490-
if err := a.persistRecord(record); err != nil {
490+
// Unlock before persisting to disk (I/O operation)
491+
a.mu.Unlock()
492+
err = a.persistRecord(record)
493+
a.mu.Lock()
494+
495+
if err != nil {
491496
return record, fmt.Errorf("failed to persist record: %w", err)
492497
}
493498

494499
return record, nil
495500
}
496501

497502
func (a *Auditor) getPreviousHash(sessionID string) string {
503+
a.mu.RLock()
504+
defer a.mu.RUnlock()
498505
if hash, ok := a.sessionHash[sessionID]; ok {
499506
return hash
500507
}
@@ -534,7 +541,9 @@ func calculateHash(record *AuditRecord) (string, error) {
534541

535542
func generateID() string {
536543
b := make([]byte, 16)
537-
rand.Read(b)
544+
if _, err := rand.Read(b); err != nil {
545+
panic(fmt.Sprintf("failed to generate random ID: %v", err))
546+
}
538547
return hex.EncodeToString(b)
539548
}
540549

pkg/tools/mcp/oauth.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,11 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe
449449
return errors.New("access_token missing or invalid in client response")
450450
}
451451

452-
if tokenType, ok := tokenData["token_type"].(string); ok {
452+
// token_type is required per OAuth 2.0 RFC 6749
453+
if tokenType, ok := tokenData["token_type"].(string); ok && tokenType != "" {
453454
token.TokenType = tokenType
455+
} else {
456+
return errors.New("token_type missing or invalid in client response")
454457
}
455458

456459
if expiresIn, ok := tokenData["expires_in"].(float64); ok {

pkg/tools/mcp/remote.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,20 @@ func (c *remoteMCPClient) SetManagedOAuth(managed bool) {
106106
func (c *remoteMCPClient) createHTTPClient() *http.Client {
107107
transport := c.headerTransport()
108108

109+
// Read managed and oauthConfig with lock to prevent data race
110+
c.mu.RLock()
111+
managed := c.managed
112+
oauthConfig := c.oauthConfig
113+
c.mu.RUnlock()
114+
109115
// Then wrap with OAuth support
110116
transport = &oauthTransport{
111117
base: transport,
112118
client: c,
113119
tokenStore: c.tokenStore,
114120
baseURL: c.url,
115-
managed: c.managed,
116-
oauthConfig: c.oauthConfig,
121+
managed: managed,
122+
oauthConfig: oauthConfig,
117123
}
118124

119125
return &http.Client{

pkg/upstream/headers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ func Handler(next http.Handler) http.Handler {
3939
// placeholders that are resolved at request time from upstream headers
4040
// stored in the request context.
4141
func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper {
42+
if base == nil {
43+
base = http.DefaultTransport
44+
}
4245
return &headerTransport{base: base, headers: headers}
4346
}
4447

0 commit comments

Comments
 (0)