diff --git a/.gitignore b/.gitignore index cf2d293..6891a41 100644 --- a/.gitignore +++ b/.gitignore @@ -365,3 +365,4 @@ $RECYCLE.BIN/ # Local configs *.secret.yaml +.aider* diff --git a/cors/cors.go b/cors/cors.go new file mode 100644 index 0000000..41eec20 --- /dev/null +++ b/cors/cors.go @@ -0,0 +1,21 @@ +package cors + +import "net/http" + +// ApplyCORSHeaders adds the standard CORS headers to a response +// For handlers that are not wrapped in middleware +func ApplyCORSHeaders(w http.ResponseWriter, allowedMethods string) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", allowedMethods+", OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") +} + +// HandlePreflight checks if the request is a preflight OPTIONS request and handles it +// Returns true if the request was handled (caller should return immediately) +func HandlePreflight(w http.ResponseWriter, r *http.Request) bool { + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return true + } + return false +} diff --git a/plugins/oauth/README.md b/plugins/oauth/README.md index 25f355b..cfa2dd2 100644 --- a/plugins/oauth/README.md +++ b/plugins/oauth/README.md @@ -202,6 +202,8 @@ oauth: client_id: "xxx" client_secret: "xxx" user_info_url: "https://your-tenant.auth0.com/userinfo" + provider_auth_url: "https://your-tenant.auth0.com/authorize" + provider_token_url: "https://your-tenant.auth0.com/oauth/token" scopes: - "openid" - "profile" diff --git a/plugins/oauth/auth_metadata.go b/plugins/oauth/auth_metadata.go index ba54bfc..b0ca24f 100644 --- a/plugins/oauth/auth_metadata.go +++ b/plugins/oauth/auth_metadata.go @@ -15,19 +15,26 @@ type Metadata struct { } func NewMetadata(issuer url.URL, authorizationEndpoint, tokenEndpoint string, registrationEndpoint string) Metadata { + buildURL := func(endpoint string) string { + if u, err := url.Parse(endpoint); err == nil && u.IsAbs() { + return endpoint + } + return (&url.URL{Scheme: issuer.Scheme, Host: issuer.Host, Path: endpoint}).String() + } + metadata := Metadata{ Issuer: issuer.String(), - AuthorizationEndpoint: (&url.URL{Scheme: issuer.Scheme, Host: issuer.Host, Path: authorizationEndpoint}).String(), + AuthorizationEndpoint: buildURL(authorizationEndpoint), ResponseTypesSupported: []string{"code"}, CodeChallengeMethodsSupported: []string{"S256"}, - TokenEndpoint: (&url.URL{Scheme: issuer.Scheme, Host: issuer.Host, Path: tokenEndpoint}).String(), + TokenEndpoint: buildURL(tokenEndpoint), TokenEndpointAuthMethodsSupported: []string{"client_secret_post"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"}, } // Add registration endpoint if provided if registrationEndpoint != "" { - metadata.RegistrationEndpoint = (&url.URL{Scheme: issuer.Scheme, Host: issuer.Host, Path: registrationEndpoint}).String() + metadata.RegistrationEndpoint = buildURL(registrationEndpoint) } return metadata diff --git a/plugins/oauth/config.go b/plugins/oauth/config.go index 08dda70..f0fd43f 100644 --- a/plugins/oauth/config.go +++ b/plugins/oauth/config.go @@ -147,16 +147,16 @@ func (c *Config) WithDefaults() { c.TokenHeader = "Authorization" } if c.AuthURL == "" { - c.AuthURL = "/oauth/authorize" + c.AuthURL = "/oauth/authorize/" } if c.CallbackURL == "" { - c.CallbackURL = "/oauth/callback" + c.CallbackURL = "/oauth/callback/" } if c.TokenURL == "" { - c.TokenURL = "/oauth/token" + c.TokenURL = "/oauth/token/" } if c.RegisterURL == "" { - c.RegisterURL = "/oauth/register" + c.RegisterURL = "/oauth/register/" } if c.ClientRegistration.ClientSecretExpirySeconds == 0 { c.ClientRegistration.ClientSecretExpirySeconds = 30 * 24 * 60 * 60 // 30 days diff --git a/plugins/oauth/http_helpers.go b/plugins/oauth/http_helpers.go index 79bcee8..7a175ec 100644 --- a/plugins/oauth/http_helpers.go +++ b/plugins/oauth/http_helpers.go @@ -1,42 +1,18 @@ package oauth import ( + "github.com/centralmind/gateway/cors" "net/http" ) // CORSMiddleware applies standard CORS headers to the response func CORSMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Add CORS headers - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - // Handle preflight OPTIONS request - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) + cors.ApplyCORSHeaders(w, "GET, POST") + if cors.HandlePreflight(w, r) { return } - // Call the original handler handler.ServeHTTP(w, r) }) } - -// ApplyCORSHeaders adds the standard CORS headers to a response -// For handlers that are not wrapped in middleware -func ApplyCORSHeaders(w http.ResponseWriter, allowedMethods string) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", allowedMethods+", OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") -} - -// HandlePreflight checks if the request is a preflight OPTIONS request and handles it -// Returns true if the request was handled (caller should return immediately) -func HandlePreflight(w http.ResponseWriter, r *http.Request) bool { - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return true - } - return false -} diff --git a/plugins/oauth/plugin.go b/plugins/oauth/plugin.go index 4479f35..a09da99 100644 --- a/plugins/oauth/plugin.go +++ b/plugins/oauth/plugin.go @@ -148,9 +148,17 @@ func (p *Plugin) RegisterRoutes(mux *http.ServeMux) { if err != nil { return } + + getPath := func(rawURL string) string { + if u, urlParseErr := url.Parse(rawURL); urlParseErr == nil && u.Path != "" { + return u.Path + } + return rawURL + } + // Register HTTP handlers with CORS middleware - mux.Handle(p.config.AuthURL, CORSMiddleware(http.HandlerFunc(p.HandleAuthorize))) - mux.Handle(p.config.CallbackURL, CORSMiddleware(http.HandlerFunc(p.HandleCallback))) + mux.Handle(getPath(p.config.AuthURL), CORSMiddleware(http.HandlerFunc(p.HandleAuthorize))) + mux.Handle(getPath(p.config.CallbackURL), CORSMiddleware(http.HandlerFunc(p.HandleCallback))) // Set up and register the token endpoint tokenPath := p.config.TokenURL // Use the configured token URL @@ -160,7 +168,7 @@ func (p *Plugin) RegisterRoutes(mux *http.ServeMux) { // Register the token handler with CORS middleware tokenHandler := http.HandlerFunc(p.HandleToken) - mux.Handle(tokenPath, CORSMiddleware(tokenHandler)) + mux.Handle(getPath(tokenPath), CORSMiddleware(tokenHandler)) // Register dynamic client registration endpoint if enabled if p.config.ClientRegistration.Enabled { @@ -169,7 +177,7 @@ func (p *Plugin) RegisterRoutes(mux *http.ServeMux) { // Register the handler with CORS middleware registrationHandler := http.HandlerFunc(p.HandleRegister) - mux.Handle(p.config.RegisterURL, CORSMiddleware(registrationHandler)) + mux.Handle(getPath(p.config.RegisterURL), CORSMiddleware(registrationHandler)) } // Register the well-known endpoint with CORS middleware diff --git a/server/sse.go b/server/sse.go index 9df9a66..a030590 100644 --- a/server/sse.go +++ b/server/sse.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/centralmind/gateway/cors" "github.com/centralmind/gateway/xcontext" "net/http" "net/http/httptest" @@ -87,10 +88,8 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) + cors.ApplyCORSHeaders(w, "GET") + if cors.HandlePreflight(w, r) { return }