diff --git a/docs/user-guide/configuration.md b/docs/user-guide/configuration.md index 354ab2e..b10d4a6 100644 --- a/docs/user-guide/configuration.md +++ b/docs/user-guide/configuration.md @@ -18,12 +18,30 @@ These settings must be configured for LDAP Manager to function: ### LDAP Connection Settings -| Setting | Environment Variable | CLI Flag | Description | Example | -| ----------------- | ------------------------ | --------------------- | -------------------------------------- | ----------------------------- | -| LDAP Server | `LDAP_SERVER` | `--ldap-server` | LDAP server URI with protocol and port | `ldaps://dc1.example.com:636` | -| Base DN | `LDAP_BASE_DN` | `--base-dn` | Base Distinguished Name for searches | `DC=example,DC=com` | -| Readonly User | `LDAP_READONLY_USER` | `--readonly-user` | Service account username | `readonly` | -| Readonly Password | `LDAP_READONLY_PASSWORD` | `--readonly-password` | Service account password | `secure_password123` | +| Setting | Environment Variable | CLI Flag | Description | Example | Required | +| ----------------- | ------------------------ | --------------------- | -------------------------------------- | ----------------------------- | -------- | +| LDAP Server | `LDAP_SERVER` | `--ldap-server` | LDAP server URI with protocol and port | `ldaps://dc1.example.com:636` | Yes | +| Base DN | `LDAP_BASE_DN` | `--base-dn` | Base Distinguished Name for searches | `DC=example,DC=com` | Yes | +| Readonly User | `LDAP_READONLY_USER` | `--readonly-user` | Service account username | `readonly` | No | +| Readonly Password | `LDAP_READONLY_PASSWORD` | `--readonly-password` | Service account password | `secure_password123` | No | + +### Operating Modes + +LDAP Manager supports two operating modes based on whether a service account is configured: + +**Service Account Mode** (both `LDAP_READONLY_USER` and `LDAP_READONLY_PASSWORD` set): + +- Background cache refreshes LDAP data every 30 seconds +- Health checks verify LDAP connectivity +- Service account used for initial user lookup during authentication + +**Per-User Credentials Mode** (no service account configured): + +- Each request uses the logged-in user's own LDAP credentials +- No background cache (data fetched fresh per request) +- Health checks report simplified status +- Requires Active Directory with UPN-based authentication (`user@domain`) +- Users must have sufficient LDAP permissions to read directory data ### LDAP Server URI Format @@ -172,6 +190,24 @@ SESSION_PATH=/data/sessions.bbolt SESSION_DURATION=30m ``` +### Per-User Credentials (No Service Account) + +For Active Directory environments where each user authenticates with their own credentials: + +```bash +# .env.local +LDAP_SERVER=ldaps://dc1.ad.example.com:636 +LDAP_BASE_DN=DC=ad,DC=example,DC=com +LDAP_IS_AD=true + +# No LDAP_READONLY_USER or LDAP_READONLY_PASSWORD +# Users authenticate with their own AD credentials via UPN (user@domain) + +LOG_LEVEL=info +PERSIST_SESSIONS=true +SESSION_DURATION=30m +``` + ### Development Environment For local development and testing: diff --git a/internal/ldap_cache/cache.go b/internal/ldap_cache/cache.go index 592135b..03c0490 100644 --- a/internal/ldap_cache/cache.go +++ b/internal/ldap_cache/cache.go @@ -129,13 +129,16 @@ func (c *Cache[T]) update(fn func(*T)) { c.buildIndexes() } -// Get returns a copy of all cached items. -// This operation is read-locked to allow concurrent access from multiple readers. +// Get returns a snapshot copy of all cached items. +// The returned slice is safe to iterate without holding any lock. func (c *Cache[T]) Get() []T { c.m.RLock() defer c.m.RUnlock() - return c.items + result := make([]T, len(c.items)) + copy(result, c.items) + + return result } // Find searches the cache for the first item matching the provided predicate. diff --git a/internal/ldap_cache/manager.go b/internal/ldap_cache/manager.go index bf7a0c2..d288cc7 100644 --- a/internal/ldap_cache/manager.go +++ b/internal/ldap_cache/manager.go @@ -9,6 +9,7 @@ import ( "context" "slices" "sync" + "sync/atomic" "time" ldap "github.com/netresearch/simple-ldap-go" @@ -40,9 +41,9 @@ type Manager struct { client LDAPClient // LDAP client for directory operations metrics *Metrics // Performance metrics and health monitoring - refreshInterval time.Duration // Configurable refresh interval (default 30s) - warmupComplete bool // Tracks if initial cache warming is complete - retryConfig retry.Config // Retry configuration for LDAP operations + refreshInterval time.Duration // Configurable refresh interval (default 30s) + warmupComplete atomic.Bool // Tracks if initial cache warming is complete (concurrent-safe) + retryConfig retry.Config // Retry configuration for LDAP operations Users Cache[ldap.User] // Cached user entries with O(1) indexed lookups Groups Cache[ldap.Group] // Cached group entries with O(1) indexed lookups @@ -95,7 +96,6 @@ func NewWithConfig(client LDAPClient, refreshInterval time.Duration) *Manager { client: client, metrics: metrics, refreshInterval: refreshInterval, - warmupComplete: false, retryConfig: retry.LDAPConfig(), Users: NewCachedWithMetrics[ldap.User](metrics), Groups: NewCachedWithMetrics[ldap.Group](metrics), @@ -209,7 +209,7 @@ func (m *Manager) WarmupCache() { duration := time.Since(startTime) if !hasErrors { - m.warmupComplete = true + m.warmupComplete.Store(true) m.metrics.RecordRefreshComplete(startTime, m.Users.Count(), m.Groups.Count(), m.Computers.Count()) log.Info(). Int("total_entities", totalEntities). @@ -226,7 +226,7 @@ func (m *Manager) WarmupCache() { // IsWarmedUp returns true if the initial cache warming process has completed successfully. // Used to determine if the cache is ready to serve requests optimally. func (m *Manager) IsWarmedUp() bool { - return m.warmupComplete + return m.warmupComplete.Load() } // RefreshUsers fetches all users from LDAP and updates the user cache. @@ -414,22 +414,7 @@ func (m *Manager) FindComputerBySAMAccountName(samAccountName string) (*ldap.Com // which works correctly even when OpenLDAP's memberOf overlay is not enabled. // Returns a complete user object with expanded group information. func (m *Manager) PopulateGroupsForUser(user *ldap.User) *FullLDAPUser { - full := &FullLDAPUser{ - User: *user, - Groups: make([]ldap.Group, 0), - } - - userDN := user.DN() - - // Iterate through all groups and check if user is a member - // This approach works regardless of whether memberOf overlay is enabled - for _, group := range m.Groups.Get() { - if slices.Contains(group.Members, userDN) { - full.Groups = append(full.Groups, group) - } - } - - return full + return PopulateGroupsForUserFromData(user, m.Groups.Get()) } // PopulateUsersForGroup creates a FullLDAPGroup with populated member list. @@ -438,32 +423,7 @@ func (m *Manager) PopulateGroupsForUser(user *ldap.User) *FullLDAPUser { // When showDisabled is false, filters out disabled users from membership. // Returns a complete group object with expanded member and parent group information. func (m *Manager) PopulateUsersForGroup(group *ldap.Group, showDisabled bool) *FullLDAPGroup { - full := &FullLDAPGroup{ - Group: *group, - Members: make([]ldap.User, 0), - ParentGroups: make([]ldap.Group, 0), - } - - for _, userDN := range group.Members { - user, err := m.FindUserByDN(userDN) - if err == nil { - if !showDisabled && !user.Enabled { - continue - } - - full.Members = append(full.Members, *user) - } - } - - // Resolve parent groups from MemberOf - for _, parentDN := range group.MemberOf { - parentGroup, err := m.FindGroupByDN(parentDN) - if err == nil { - full.ParentGroups = append(full.ParentGroups, *parentGroup) - } - } - - return full + return PopulateUsersForGroupFromData(group, m.Users.Get(), m.Groups.Get(), showDisabled) } // PopulateGroupsForComputer creates a FullLDAPComputer with populated group memberships. @@ -472,22 +432,7 @@ func (m *Manager) PopulateUsersForGroup(group *ldap.Group, showDisabled bool) *F // which works correctly even when OpenLDAP's memberOf overlay is not enabled. // Returns a complete computer object with expanded group information. func (m *Manager) PopulateGroupsForComputer(computer *ldap.Computer) *FullLDAPComputer { - full := &FullLDAPComputer{ - Computer: *computer, - Groups: make([]ldap.Group, 0), - } - - computerDN := computer.DN() - - // Iterate through all groups and check if computer is a member - // This approach works regardless of whether memberOf overlay is enabled - for _, group := range m.Groups.Get() { - if slices.Contains(group.Members, computerDN) { - full.Groups = append(full.Groups, group) - } - } - - return full + return PopulateGroupsForComputerFromData(computer, m.Groups.Get()) } // OnAddUserToGroup updates cache when a user is added to a group. @@ -540,6 +485,89 @@ func (m *Manager) OnRemoveUserFromGroup(userDN, groupDN string) { }) } +// PopulateGroupsForUserFromData creates a FullLDAPUser with populated group memberships +// using provided data instead of cache. Works identically to PopulateGroupsForUser +// but operates on explicit slices rather than the cache. +func PopulateGroupsForUserFromData(user *ldap.User, allGroups []ldap.Group) *FullLDAPUser { + full := &FullLDAPUser{ + User: *user, + Groups: make([]ldap.Group, 0), + } + + userDN := user.DN() + + for _, group := range allGroups { + if slices.Contains(group.Members, userDN) { + full.Groups = append(full.Groups, group) + } + } + + return full +} + +// PopulateUsersForGroupFromData creates a FullLDAPGroup with populated member list +// using provided data instead of cache. Works identically to PopulateUsersForGroup +// but operates on explicit slices rather than the cache. +func PopulateUsersForGroupFromData( + group *ldap.Group, allUsers []ldap.User, allGroups []ldap.Group, showDisabled bool, +) *FullLDAPGroup { + full := &FullLDAPGroup{ + Group: *group, + Members: make([]ldap.User, 0), + ParentGroups: make([]ldap.Group, 0), + } + + // Build a map for O(1) user lookups by DN + usersByDN := make(map[string]*ldap.User, len(allUsers)) + for i := range allUsers { + usersByDN[allUsers[i].DN()] = &allUsers[i] + } + + for _, memberDN := range group.Members { + if user, ok := usersByDN[memberDN]; ok { + if !showDisabled && !user.Enabled { + continue + } + + full.Members = append(full.Members, *user) + } + } + + // Build a map for O(1) group lookups by DN + groupsByDN := make(map[string]*ldap.Group, len(allGroups)) + for i := range allGroups { + groupsByDN[allGroups[i].DN()] = &allGroups[i] + } + + for _, parentDN := range group.MemberOf { + if parentGroup, ok := groupsByDN[parentDN]; ok { + full.ParentGroups = append(full.ParentGroups, *parentGroup) + } + } + + return full +} + +// PopulateGroupsForComputerFromData creates a FullLDAPComputer with populated group memberships +// using provided data instead of cache. Works identically to PopulateGroupsForComputer +// but operates on explicit slices rather than the cache. +func PopulateGroupsForComputerFromData(computer *ldap.Computer, allGroups []ldap.Group) *FullLDAPComputer { + full := &FullLDAPComputer{ + Computer: *computer, + Groups: make([]ldap.Group, 0), + } + + computerDN := computer.DN() + + for _, group := range allGroups { + if slices.Contains(group.Members, computerDN) { + full.Groups = append(full.Groups, group) + } + } + + return full +} + // GetMetrics returns the current cache metrics for monitoring and observability. // Provides comprehensive statistics about cache performance, health, and operations. func (m *Manager) GetMetrics() *Metrics { diff --git a/internal/options/app.go b/internal/options/app.go index ea5bf4d..41dcad7 100644 --- a/internal/options/app.go +++ b/internal/options/app.go @@ -264,12 +264,9 @@ func Parse() (*Opts, error) { if err := validateRequired("base-dn", fBaseDN); err != nil { return nil, err } - if err := validateRequired("readonly-user", fReadonlyUser); err != nil { - return nil, err - } - if err := validateRequired("readonly-password", fReadonlyPassword); err != nil { - return nil, err - } + // readonly-user and readonly-password are optional. + // When not configured, the app uses per-user LDAP credentials + // and the background cache is disabled. if *fPersistSessions { if err := validateRequired("session-path", fSessionPath); err != nil { diff --git a/internal/options/parse_test.go b/internal/options/parse_test.go index 31fa704..861c3ee 100644 --- a/internal/options/parse_test.go +++ b/internal/options/parse_test.go @@ -107,8 +107,7 @@ func TestParse_MissingRequiredFields(t *testing.T) { }{ {"MissingLDAPServer", "LDAP_SERVER", "ldap-server"}, {"MissingBaseDN", "LDAP_BASE_DN", "base-dn"}, - {"MissingReadonlyUser", "LDAP_READONLY_USER", "readonly-user"}, - {"MissingReadonlyPassword", "LDAP_READONLY_PASSWORD", "readonly-password"}, + // readonly-user and readonly-password are now optional } for _, tt := range tests { diff --git a/internal/web/auth.go b/internal/web/auth.go index 5356048..1818250 100644 --- a/internal/web/auth.go +++ b/internal/web/auth.go @@ -3,7 +3,11 @@ package web // HTTP handlers and middleware for authentication and session management. import ( + "fmt" + "strings" + "github.com/gofiber/fiber/v2" + ldap "github.com/netresearch/simple-ldap-go" "github.com/rs/zerolog/log" "github.com/netresearch/ldap-manager/internal/version" @@ -33,15 +37,15 @@ func (a *App) loginHandler(c *fiber.Ctx) error { password := c.FormValue("password") if username != "" && password != "" { - user, err := a.ldapReadonly.CheckPasswordForSAMAccountName(username, password) - if err != nil { + user, authErr := a.authenticateUser(username, password) + if authErr != nil { // Record failed attempt for rate limiting ip := c.IP() blocked := a.rateLimiter.RecordAttempt(ip) // Log username for security audit trail - intentional per OWASP logging guidelines log.Warn(). - Err(err). + Err(authErr). Str("username", username). Str("ip", ip). Int("remaining_attempts", a.rateLimiter.GetRemainingAttempts(ip)). @@ -49,8 +53,14 @@ func (a *App) loginHandler(c *fiber.Ctx) error { // If blocked after this attempt, return rate limit error if blocked { - return c.Status(fiber.StatusTooManyRequests). - SendString("Too many failed login attempts. Please try again later.") + c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) + + return templates.LoginWithStyles( + templates.Flashes(templates.ErrorFlash("Too many failed login attempts. Please try again later.")), + "", + a.GetCSRFToken(c), + a.GetStylesPath(), + ).Render(c.UserContext(), c.Response().BodyWriter()) } c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) @@ -66,7 +76,17 @@ func (a *App) loginHandler(c *fiber.Ctx) error { // Successful login - reset rate limit counter a.rateLimiter.ResetAttempts(c.IP()) + // Regenerate session ID to prevent session fixation attacks + if err := sess.Regenerate(); err != nil { + return handle500(c, err) + } + sess.Set("dn", user.DN()) + // Password stored in session for per-user LDAP binding. + // Mitigated by: session-only cookies (HttpOnly, SameSite=Strict), + // configurable session TTL, and server-side session storage. + sess.Set("password", password) + sess.Set("username", username) if err := sess.Save(); err != nil { return handle500(c, err) } @@ -88,3 +108,64 @@ func (a *App) loginHandler(c *fiber.Ctx) error { a.GetStylesPath(), ).Render(c.UserContext(), c.Response().BodyWriter()) } + +// authenticateUser verifies credentials using either the service account +// (when configured) or direct UPN bind (for AD without service account). +func (a *App) authenticateUser(username, password string) (*ldap.User, error) { + if a.ldapReadonly != nil { + // Service account available: use it to look up user and verify password + return a.ldapReadonly.CheckPasswordForSAMAccountName(username, password) + } + + // No service account: authenticate via UPN bind (Active Directory) + return a.authenticateViaUPNBind(username, password) +} + +// authenticateViaUPNBind authenticates by binding as user@domain directly. +// Used when no service account is configured. +func (a *App) authenticateViaUPNBind(username, password string) (*ldap.User, error) { + // Validate username to prevent LDAP injection in UPN construction + if strings.ContainsAny(username, `\@,=+"<>#;`) { + return nil, fmt.Errorf("invalid characters in username") + } + + domain := domainFromBaseDN(a.ldapConfig.BaseDN) + if domain == "" { + return nil, fmt.Errorf("cannot derive domain from BaseDN %q: no DC components found", a.ldapConfig.BaseDN) + } + + upn := username + "@" + domain + + // Bind as the user via UPN + userClient, err := ldap.New(a.ldapConfig, upn, password, a.ldapOpts...) + if err != nil { + return nil, fmt.Errorf("UPN bind failed: %w", err) + } + defer func() { _ = userClient.Close() }() + + // Look up user details using the user's own connection + user, err := userClient.FindUserBySAMAccountName(username) + if err != nil { + return nil, fmt.Errorf("user lookup after UPN bind: %w", err) + } + + return user, nil +} + +// domainFromBaseDN derives a DNS domain from an LDAP BaseDN. +// Example: "DC=example,DC=com" → "example.com" +func domainFromBaseDN(baseDN string) string { + parts := strings.Split(baseDN, ",") + domains := make([]string, 0, len(parts)) + + for _, part := range parts { + part = strings.TrimSpace(part) + upper := strings.ToUpper(part) + + if strings.HasPrefix(upper, "DC=") { + domains = append(domains, part[3:]) + } + } + + return strings.Join(domains, ".") +} diff --git a/internal/web/computers.go b/internal/web/computers.go index 0c78c53..54c4564 100644 --- a/internal/web/computers.go +++ b/internal/web/computers.go @@ -7,24 +7,36 @@ import ( "sort" "github.com/gofiber/fiber/v2" + ldap "github.com/netresearch/simple-ldap-go" + "github.com/netresearch/ldap-manager/internal/ldap_cache" "github.com/netresearch/ldap-manager/internal/web/templates" ) -// computersHandler handles GET /computers requests to list all computer accounts in the LDAP directory. -// Supports optional show-disabled query parameter to include disabled computer accounts. -// Computers are sorted alphabetically by CN (Common Name) and returned as HTML using template caching. -// -// Query Parameters: -// - show-disabled: Set to "1" to include disabled computers in the listing -// -// Returns: -// - 200: HTML page with computer listing -// - 500: Internal server error if LDAP query fails func (a *App) computersHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session showDisabled := c.Query("show-disabled", "0") == "1" - computers := a.ldapCache.FindComputers(showDisabled) + + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + allComputers, err := userLDAP.FindComputers() + if err != nil { + return handle500(c, err) + } + + computers := allComputers + if !showDisabled { + computers = nil + for _, comp := range allComputers { + if comp.Enabled { + computers = append(computers, comp) + } + } + } + sort.SliceStable(computers, func(i, j int) bool { return computers[i].CN() < computers[j].CN() }) @@ -33,42 +45,55 @@ func (a *App) computersHandler(c *fiber.Ctx) error { return a.templateCache.RenderWithCache(c, templates.Computers(computers)) } -// computerHandler handles GET /computers/:computerDN requests to display detailed information for a specific computer. -// The computerDN path parameter must be URL-encoded Distinguished Name of the computer account. -// Returns computer details including attributes, group memberships, and system information. -// -// Path Parameters: -// - computerDN: URL-encoded Distinguished Name of the computer -// (e.g. "CN=WORKSTATION01,OU=Computers,DC=example,DC=com") -// -// Returns: -// - 200: HTML page with computer details and group memberships -// - 500: Internal server error if computer not found or LDAP query fails -// -// Example: -// -// GET /computers/CN%3DWORKSTATION01%2COU%3DComputers%2CDC%3Dexample%2CDC%3Dcom func (a *App) computerHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session computerDN, err := url.PathUnescape(c.Params("*")) if err != nil { return handle500(c, err) } - thinComputer, err := a.ldapCache.FindComputerByDN(computerDN) + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + computers, err := userLDAP.FindComputers() if err != nil { return handle500(c, err) } - computer := a.ldapCache.PopulateGroupsForComputer(thinComputer) - sort.SliceStable(computer.Groups, func(i, j int) bool { - return computer.Groups[i].CN() < computer.Groups[j].CN() + computer := findComputerByDN(computers, computerDN) + if computer == nil { + c.Status(fiber.StatusNotFound) + + return a.fourOhFourHandler(c) + } + + groups, err := userLDAP.FindGroups() + if err != nil { + return handle500(c, err) + } + + fullComputer := ldap_cache.PopulateGroupsForComputerFromData(computer, groups) + sort.SliceStable(fullComputer.Groups, func(i, j int) bool { + return fullComputer.Groups[i].CN() < fullComputer.Groups[j].CN() }) // Use template caching with computer DN as additional cache data return a.templateCache.RenderWithCache( c, - templates.Computer(computer), + templates.Computer(fullComputer), "computerDN:"+computerDN, ) } + +// findComputerByDN searches for a computer by DN in a slice. +func findComputerByDN(computers []ldap.Computer, dn string) *ldap.Computer { + for i := range computers { + if computers[i].DN() == dn { + return &computers[i] + } + } + + return nil +} diff --git a/internal/web/groups.go b/internal/web/groups.go index 34139de..458cfb3 100644 --- a/internal/web/groups.go +++ b/internal/web/groups.go @@ -1,19 +1,30 @@ package web import ( + "errors" "net/url" "sort" "github.com/gofiber/fiber/v2" ldap "github.com/netresearch/simple-ldap-go" + "github.com/rs/zerolog/log" "github.com/netresearch/ldap-manager/internal/ldap_cache" "github.com/netresearch/ldap-manager/internal/web/templates" ) func (a *App) groupsHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session - groups := a.ldapCache.FindGroups() + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + groups, err := userLDAP.FindGroups() + if err != nil { + return handle500(c, err) + } + sort.SliceStable(groups, func(i, j int) bool { return groups[i].CN() < groups[j].CN() }) @@ -23,16 +34,29 @@ func (a *App) groupsHandler(c *fiber.Ctx) error { } func (a *App) groupHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session groupDN, err := url.PathUnescape(c.Params("*")) if err != nil { return handle500(c, err) } - group, unassignedUsers, err := a.loadGroupData(c, groupDN) + userLDAP, err := a.getUserLDAP(c) if err != nil { return handle500(c, err) } + defer func() { _ = userLDAP.Close() }() + + showDisabled := c.Query("show-disabled", "0") == "1" + + group, unassignedUsers, err := a.loadGroupDataFromLDAP(userLDAP, groupDN, showDisabled) + if err != nil { + if errors.Is(err, ldap.ErrGroupNotFound) { + c.Status(fiber.StatusNotFound) + + return a.fourOhFourHandler(c) + } + + return handle500(c, err) + } // Use template caching with group DN as additional cache data return a.templateCache.RenderWithCache( @@ -49,7 +73,6 @@ type groupModifyForm struct { // nolint:dupl // Similar to userModifyHandler but operates on different entities with different forms func (a *App) groupModifyHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session groupDN, err := url.PathUnescape(c.Params("*")) if err != nil { return handle500(c, err) @@ -61,86 +84,110 @@ func (a *App) groupModifyHandler(c *fiber.Ctx) error { } if form.RemoveUser == nil && form.AddUser == nil { - return c.Redirect("/groups/" + groupDN) + return c.Redirect("/groups/" + url.PathEscape(groupDN)) } - // Perform the group modification using the readonly LDAP client - // User is already authenticated via session middleware - if err := a.performGroupModification(a.ldapReadonly, &form, groupDN); err != nil { - return a.renderGroupWithError(c, groupDN, "Failed to modify: "+err.Error()) + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + // Perform the group modification using the logged-in user's LDAP connection + if err := a.performGroupModification(userLDAP, &form, groupDN); err != nil { + log.Warn().Err(err).Str("groupDN", groupDN).Msg("failed to modify group") + + return a.renderGroupWithFlash(c, userLDAP, groupDN, templates.ErrorFlash("Failed to modify group membership")) } // Invalidate template cache after successful modification - a.invalidateTemplateCacheOnGroupModification(groupDN) + a.invalidateTemplateCacheOnModification() // Render success response - return a.renderGroupWithSuccess(c, groupDN, "Successfully modified group") + return a.renderGroupWithFlash(c, userLDAP, groupDN, templates.SuccessFlash("Successfully modified group")) } -func (a *App) findUnassignedUsers(group *ldap_cache.FullLDAPGroup) []ldap.User { - return a.ldapCache.Users.Filter(func(u ldap.User) bool { - // Check if user is already a member of this group - for _, member := range group.Members { - if member.DN() == u.DN() { - return false - } - } +// loadGroupDataFromLDAP loads group data directly from an LDAP client connection. +func (a *App) loadGroupDataFromLDAP( + userLDAP *ldap.LDAP, groupDN string, showDisabledUsers bool, +) (*ldap_cache.FullLDAPGroup, []ldap.User, error) { + groups, err := userLDAP.FindGroups() + if err != nil { + return nil, nil, err + } - return true - }) -} + group := findGroupByDN(groups, groupDN) + if group == nil { + return nil, nil, ldap.ErrGroupNotFound + } -// loadGroupData loads and prepares group data with proper sorting -func (a *App) loadGroupData(c *fiber.Ctx, groupDN string) (*ldap_cache.FullLDAPGroup, []ldap.User, error) { - thinGroup, err := a.ldapCache.FindGroupByDN(groupDN) + users, err := userLDAP.FindUsers() if err != nil { return nil, nil, err } - showDisabledUsers := c.Query("show-disabled", "0") == "1" - group := a.ldapCache.PopulateUsersForGroup(thinGroup, showDisabledUsers) - sort.SliceStable(group.Members, func(i, j int) bool { - return group.Members[i].CN() < group.Members[j].CN() + fullGroup := ldap_cache.PopulateUsersForGroupFromData(group, users, groups, showDisabledUsers) + + sort.SliceStable(fullGroup.Members, func(i, j int) bool { + return fullGroup.Members[i].CN() < fullGroup.Members[j].CN() }) - unassignedUsers := a.findUnassignedUsers(group) + + unassignedUsers := filterUnassignedUsers(users, fullGroup) sort.SliceStable(unassignedUsers, func(i, j int) bool { return unassignedUsers[i].CN() < unassignedUsers[j].CN() }) - return group, unassignedUsers, nil + return fullGroup, unassignedUsers, nil } -// renderGroupWithError renders the group page with an error message -func (a *App) renderGroupWithError(c *fiber.Ctx, groupDN, errorMsg string) error { +// renderGroupWithFlash renders the group page with a flash message using a user LDAP connection. +func (a *App) renderGroupWithFlash(c *fiber.Ctx, userLDAP *ldap.LDAP, groupDN string, flash templates.Flash) error { c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) - group, unassignedUsers, err := a.loadGroupData(c, groupDN) + + showDisabled := c.Query("show-disabled", "0") == "1" + + group, unassignedUsers, err := a.loadGroupDataFromLDAP(userLDAP, groupDN, showDisabled) if err != nil { return handle500(c, err) } return templates.Group( group, unassignedUsers, - templates.Flashes(templates.ErrorFlash(errorMsg)), + templates.Flashes(flash), a.GetCSRFToken(c), ).Render(c.UserContext(), c.Response().BodyWriter()) } -// renderGroupWithSuccess renders the group page with a success message -func (a *App) renderGroupWithSuccess(c *fiber.Ctx, groupDN, successMsg string) error { - c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) - group, unassignedUsers, err := a.loadGroupData(c, groupDN) - if err != nil { - return handle500(c, err) +// filterUnassignedUsers returns users not in the given group. +func filterUnassignedUsers(allUsers []ldap.User, group *ldap_cache.FullLDAPGroup) []ldap.User { + memberDNS := make(map[string]struct{}, len(group.Members)) + for _, member := range group.Members { + memberDNS[member.DN()] = struct{}{} } - return templates.Group( - group, unassignedUsers, - templates.Flashes(templates.SuccessFlash(successMsg)), - a.GetCSRFToken(c), - ).Render(c.UserContext(), c.Response().BodyWriter()) + result := make([]ldap.User, 0) + + for _, u := range allUsers { + if _, isMember := memberDNS[u.DN()]; !isMember { + result = append(result, u) + } + } + + return result } -// performGroupModification handles the actual LDAP group modification operation +// findGroupByDN searches for a group by DN in a slice. +func findGroupByDN(groups []ldap.Group, dn string) *ldap.Group { + for i := range groups { + if groups[i].DN() == dn { + return &groups[i] + } + } + + return nil +} + +// performGroupModification handles the actual LDAP group modification operation. func (a *App) performGroupModification( ldapClient *ldap.LDAP, form *groupModifyForm, groupDN string, ) error { @@ -148,29 +195,19 @@ func (a *App) performGroupModification( if err := ldapClient.AddUserToGroup(*form.AddUser, groupDN); err != nil { return err } - a.ldapCache.OnAddUserToGroup(*form.AddUser, groupDN) + + if a.ldapCache != nil { + a.ldapCache.OnAddUserToGroup(*form.AddUser, groupDN) + } } else if form.RemoveUser != nil { if err := ldapClient.RemoveUserFromGroup(*form.RemoveUser, groupDN); err != nil { return err } - a.ldapCache.OnRemoveUserFromGroup(*form.RemoveUser, groupDN) + + if a.ldapCache != nil { + a.ldapCache.OnRemoveUserFromGroup(*form.RemoveUser, groupDN) + } } return nil } - -// invalidateTemplateCacheOnGroupModification invalidates relevant cache entries after group modification -func (a *App) invalidateTemplateCacheOnGroupModification(groupDN string) { - // Invalidate the specific group page - a.invalidateTemplateCache("/groups/" + groupDN) - - // Invalidate groups list page (counts may have changed) - a.invalidateTemplateCache("/groups") - - // Invalidate users pages (user membership may have changed) - a.invalidateTemplateCache("/users") - - // Clear all cache entries for safety (this could be optimized further) - // In a high-traffic environment, you might want to be more selective - a.templateCache.Clear() -} diff --git a/internal/web/handlers_test.go b/internal/web/handlers_test.go index 76eb920..6eb9df1 100644 --- a/internal/web/handlers_test.go +++ b/internal/web/handlers_test.go @@ -29,7 +29,7 @@ func assertHTTPRedirect(t *testing.T, resp *http.Response) { } } -func assertHTTPStatus(t *testing.T, resp *http.Response, expectedStatus int) { +func assertHTTPStatus(t *testing.T, resp *http.Response, expectedStatus int) { //nolint:unparam // utility function t.Helper() if resp.StatusCode != expectedStatus { t.Errorf("Expected status %d, got %d", expectedStatus, resp.StatusCode) @@ -115,6 +115,7 @@ func setupTestApp() (*App, *testLDAPClient) { testClient, _ := ldap.New(testConfig, "cn=admin", "password") //nolint:errcheck app := &App{ + ldapConfig: testConfig, ldapReadonly: testClient, // Test client for testing ldapCache: ldap_cache.New(mockClient), sessionStore: sessionStore, @@ -430,34 +431,40 @@ func TestRequireAuthMiddleware(t *testing.T) { }) } -// Test the cache helper functions -func TestFindUnassignedGroupsFunction(t *testing.T) { - app, _ := setupTestApp() +// Test the standalone filter functions +func TestFilterUnassignedGroups(t *testing.T) { + user := &ldap_cache.FullLDAPUser{ + Groups: []ldap.Group{ + {Members: []string{"cn=user1"}}, + }, + } - users := app.ldapCache.FindUsers(true) - if len(users) > 0 { - user := app.ldapCache.PopulateGroupsForUser(&users[0]) - unassignedGroups := app.findUnassignedGroups(user) + allGroups := []ldap.Group{ + {Members: []string{"cn=user1"}}, + {Members: []string{"cn=user2"}}, + } - // Basic sanity check - should return slice (possibly empty) - if unassignedGroups == nil { - t.Error("findUnassignedGroups should return a slice, not nil") - } + unassigned := filterUnassignedGroups(allGroups, user) + if unassigned == nil { + t.Error("filterUnassignedGroups should return a slice, not nil") } } -func TestFindUnassignedUsersFunction(t *testing.T) { - app, _ := setupTestApp() +func TestFilterUnassignedUsers(t *testing.T) { + group := &ldap_cache.FullLDAPGroup{ + Members: []ldap.User{ + {SAMAccountName: "user1"}, + }, + } - groups := app.ldapCache.FindGroups() - if len(groups) > 0 { - group := app.ldapCache.PopulateUsersForGroup(&groups[0], true) - unassignedUsers := app.findUnassignedUsers(group) + allUsers := []ldap.User{ + {SAMAccountName: "user1"}, + {SAMAccountName: "user2"}, + } - // Basic sanity check - should return slice (possibly empty) - if unassignedUsers == nil { - t.Error("findUnassignedUsers should return a slice, not nil") - } + unassigned := filterUnassignedUsers(allUsers, group) + if unassigned == nil { + t.Error("filterUnassignedUsers should return a slice, not nil") } } @@ -657,20 +664,37 @@ func TestHandle500(t *testing.T) { t.Skip("Error handler testing requires complex template mocking") } +// Test domainFromBaseDN helper +func TestDomainFromBaseDN(t *testing.T) { + tests := []struct { + name string + baseDN string + expected string + }{ + {"simple domain", "DC=example,DC=com", "example.com"}, + {"subdomain", "DC=sub,DC=example,DC=com", "sub.example.com"}, + {"with OU prefix", "OU=Users,DC=example,DC=com", "example.com"}, + {"mixed case", "dc=Example,DC=COM", "Example.COM"}, + {"with spaces", " DC=example , DC=com ", "example.com"}, + {"empty", "", ""}, + {"no DC components", "OU=Users,CN=Admin", ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := domainFromBaseDN(tc.baseDN) + if result != tc.expected { + t.Errorf("domainFromBaseDN(%q) = %q, want %q", tc.baseDN, result, tc.expected) + } + }) + } +} + // ============================================================================= // Regression tests for LDAP escape sequence handling in URLs -// Bug: DNs with LDAP escape sequences like \0A (newline) were not properly -// URL-encoded, causing browsers to convert backslashes to forward slashes, -// resulting in "not found" errors. // ============================================================================= -// TestDNWithLDAPEscapeSequences tests that DNs containing LDAP escape sequences -// are properly handled when URL-encoded. This is a regression test for a bug -// where computers with special characters in their CN (like newlines represented -// as \0A in LDAP) could not be accessed. func TestDNWithLDAPEscapeSequences(t *testing.T) { - // These are examples of DNs with LDAP escape sequences (RFC 4514) - // \0A = newline (0x0A), \2C = comma, \5C = backslash testCases := []struct { name string rawDN string @@ -695,15 +719,12 @@ func TestDNWithLDAPEscapeSequences(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // URL-encode the DN as the template functions do encoded := url.PathEscape(tc.rawDN) - // Verify the backslash is properly encoded as %5C if strings.Contains(tc.rawDN, `\`) && !strings.Contains(encoded, "%5C") { t.Errorf("Backslash in DN should be encoded as %%5C\nRaw: %s\nEncoded: %s", tc.rawDN, encoded) } - // Verify round-trip: encoding then decoding should return original decoded, err := url.PathUnescape(encoded) if err != nil { t.Fatalf("Failed to decode URL: %v", err) @@ -716,19 +737,11 @@ func TestDNWithLDAPEscapeSequences(t *testing.T) { } } -// TestURLEncodingPreservesLDAPBackslash specifically tests that the backslash -// character used in LDAP escape sequences is preserved through URL encoding. -// This was the root cause of the bug: browsers convert unencoded backslashes -// to forward slashes in URLs. func TestURLEncodingPreservesLDAPBackslash(t *testing.T) { - // This DN contains \0A which represents a newline in LDAP - // Without proper URL encoding, browsers convert \ to / dnWithNewline := `CN=wd-ex\0ACNF:0a3049e5-44d2-4a9e-930a-ae355eda25f5,CN=Computers,DC=netresearch,DC=nr` - // URL-encode it encoded := url.PathEscape(dnWithNewline) - // The encoded URL must contain %5C (encoded backslash), not a literal backslash if strings.Contains(encoded, `\`) { t.Errorf("Encoded URL should not contain literal backslash (browsers convert \\ to /)\nEncoded: %s", encoded) } @@ -737,7 +750,6 @@ func TestURLEncodingPreservesLDAPBackslash(t *testing.T) { t.Errorf("Encoded URL should contain %%5C (URL-encoded backslash)\nEncoded: %s", encoded) } - // Verify the server-side decoding recovers the original DN decoded, err := url.PathUnescape(encoded) if err != nil { t.Fatalf("Failed to decode: %v", err) @@ -748,16 +760,11 @@ func TestURLEncodingPreservesLDAPBackslash(t *testing.T) { } } -// TestWildcardRouteWithSpecialCharacters tests that the wildcard route pattern -// correctly captures DNs containing characters that would break named parameters. func TestWildcardRouteWithSpecialCharacters(t *testing.T) { app, _ := setupTestApp() - // DNs with characters that would break :paramName routing problematicDNS := []string{ - // Forward slash - would be interpreted as path separator with :param `CN=test/computer,CN=Computers,DC=example,DC=com`, - // Encoded LDAP escape sequence that browsers might mangle url.PathEscape(`CN=test\0Acomputer,CN=Computers,DC=example,DC=com`), } @@ -771,11 +778,63 @@ func TestWildcardRouteWithSpecialCharacters(t *testing.T) { } defer closeHTTPResponse(t, resp) - // Should NOT be 404 - the route should match - // (will be 302 redirect to login since not authenticated) if resp.StatusCode == 404 { t.Errorf("Route should match DN with special characters, got 404 for path: %s", path) } }) } } + +// Test findUserByDN helper +func TestFindByDN(t *testing.T) { + users := []ldap.User{ + {SAMAccountName: "user1"}, + {SAMAccountName: "user2"}, + } + + t.Run("returns nil error for empty DN (matches default)", func(t *testing.T) { + // Users with empty DN will match empty string search + user, err := findUserByDN(users, users[0].DN()) + if err != nil { + // If DN() returns empty for test users, this is expected + if user == nil { + t.Log("Test users have empty DNs - this is expected in unit tests") + } + } + }) + + t.Run("returns error for non-existent DN", func(t *testing.T) { + _, err := findUserByDN(users, "cn=nonexistent,dc=test,dc=com") + if !errors.Is(err, ldap.ErrUserNotFound) { + t.Errorf("Expected ErrUserNotFound, got %v", err) + } + }) +} + +// Test findGroupByDN helper +func TestFindGroupByDN(t *testing.T) { + groups := []ldap.Group{ + {Members: []string{"cn=user1"}}, + } + + t.Run("returns nil for non-existent DN", func(t *testing.T) { + result := findGroupByDN(groups, "cn=nonexistent,dc=test,dc=com") + if result != nil { + t.Error("Expected nil for non-existent group DN") + } + }) +} + +// Test findComputerByDN helper +func TestFindComputerByDN(t *testing.T) { + computers := []ldap.Computer{ + {SAMAccountName: "pc1$"}, + } + + t.Run("returns nil for non-existent DN", func(t *testing.T) { + result := findComputerByDN(computers, "cn=nonexistent,dc=test,dc=com") + if result != nil { + t.Error("Expected nil for non-existent computer DN") + } + }) +} diff --git a/internal/web/health.go b/internal/web/health.go index eb01936..8e8531d 100644 --- a/internal/web/health.go +++ b/internal/web/health.go @@ -6,14 +6,24 @@ import ( // healthHandler provides a comprehensive health check endpoint. // Returns cache metrics, connection pool health, system health status, and operational statistics. +// When no service account is configured, reports a simplified status. func (a *App) healthHandler(c *fiber.Ctx) error { + if a.ldapCache == nil || a.ldapReadonly == nil { + return c.JSON(fiber.Map{ + "overall_healthy": true, + "mode": "per-user credentials", + "cache": "disabled (no service account)", + "connection_pool": "disabled (no service account)", + }) + } + cacheHealthStats := a.ldapCache.GetHealthCheck() poolStats := a.ldapReadonly.GetPoolStats() // Determine pool health poolHealthy := poolStats.TotalConnections > 0 - overallHealthy := cacheHealthStats.HealthStatus == "healthy" && poolHealthy + overallHealthy := cacheHealthStats.HealthStatus == statusHealthy && poolHealthy // Determine status code based on health state statusCode := a.getHealthStatusCode(overallHealthy, cacheHealthStats.HealthStatus, poolHealthy) @@ -34,7 +44,7 @@ func (a *App) getHealthStatusCode(overallHealthy bool, cacheStatus string, poolH if overallHealthy { return fiber.StatusOK } - if cacheStatus == "degraded" || (cacheStatus == "healthy" && !poolHealthy) { + if cacheStatus == "degraded" || (cacheStatus == statusHealthy && !poolHealthy) { return fiber.StatusOK // Still functional but degraded } @@ -42,9 +52,16 @@ func (a *App) getHealthStatusCode(overallHealthy bool, cacheStatus string, poolH } // readinessHandler provides a simple readiness check. -// Returns 200 OK if the cache system and connection pool are operational and ready to serve requests. -// Includes cache warming status and connection pool health to indicate if system is ready. +// Returns 200 OK if the system is operational and ready to serve requests. +// When no service account is configured, always reports ready (auth happens per-request). func (a *App) readinessHandler(c *fiber.Ctx) error { + if a.ldapCache == nil || a.ldapReadonly == nil { + return c.JSON(fiber.Map{ + "status": "ready", + "mode": "per-user credentials", + }) + } + isCacheHealthy := a.ldapCache.IsHealthy() isWarmedUp := a.ldapCache.IsWarmedUp() poolStats := a.ldapReadonly.GetPoolStats() @@ -54,9 +71,9 @@ func (a *App) readinessHandler(c *fiber.Ctx) error { if isCacheHealthy && isWarmedUp && isPoolHealthy { return c.JSON(fiber.Map{ "status": "ready", - "cache": "healthy", + "cache": statusHealthy, "warmed_up": true, - "connection_pool": "healthy", + "connection_pool": statusHealthy, }) } @@ -64,17 +81,24 @@ func (a *App) readinessHandler(c *fiber.Ctx) error { status, reason := a.getReadinessStatus(isCacheHealthy, isWarmedUp, isPoolHealthy) c.Status(fiber.StatusServiceUnavailable) + poolStatus := statusUnhealthy + if isPoolHealthy { + poolStatus = statusHealthy + } + return c.JSON(fiber.Map{ "status": status, "cache": reason, "warmed_up": isWarmedUp, - "connection_pool": "unhealthy", + "connection_pool": poolStatus, }) } const ( statusNotReady = "not ready" statusWarmingUp = "warming up" + statusHealthy = "healthy" + statusUnhealthy = "unhealthy" ) // getReadinessStatus determines status and reason based on readiness conditions @@ -108,8 +132,13 @@ func (a *App) getReadinessStatus(cacheHealthy, warmedUp, poolHealthy bool) (stat // livenessHandler provides a simple liveness check. // Returns 200 OK if the application is running and responsive. func (a *App) livenessHandler(c *fiber.Ctx) error { - return c.JSON(fiber.Map{ + response := fiber.Map{ "status": "alive", - "uptime": a.ldapCache.GetMetrics().GetUptime().String(), - }) + } + + if a.ldapCache != nil { + response["uptime"] = a.ldapCache.GetMetrics().GetUptime().String() + } + + return c.JSON(response) } diff --git a/internal/web/health_test.go b/internal/web/health_test.go index e544729..3e23d81 100644 --- a/internal/web/health_test.go +++ b/internal/web/health_test.go @@ -15,7 +15,7 @@ import ( "github.com/netresearch/ldap-manager/internal/ldap_cache" ) -// setupHealthTestApp creates a test application for health endpoint testing +// setupHealthTestApp creates a test application for health endpoint testing (with service account) func setupHealthTestApp() *App { mockClient := &testLDAPClient{ users: []ldap.User{ @@ -60,6 +60,30 @@ func setupHealthTestApp() *App { return app } +// setupHealthTestAppNoServiceAccount creates a test application without service account +func setupHealthTestAppNoServiceAccount() *App { + sessionStore := session.New(session.Config{ + Storage: memory.New(), + }) + + f := fiber.New(fiber.Config{ + ErrorHandler: handle500, + }) + + app := &App{ + ldapReadonly: nil, + ldapCache: nil, + sessionStore: sessionStore, + fiber: f, + } + + f.Get("/health", app.healthHandler) + f.Get("/ready", app.readinessHandler) + f.Get("/live", app.livenessHandler) + + return app +} + func TestHealthHandler(t *testing.T) { app := setupHealthTestApp() @@ -114,6 +138,47 @@ func TestHealthHandler(t *testing.T) { }) } +func TestHealthHandlerNoServiceAccount(t *testing.T) { + app := setupHealthTestAppNoServiceAccount() + + t.Run("returns healthy status without service account", func(t *testing.T) { + req := httptest.NewRequest("GET", "/health", http.NoBody) + resp, err := app.fiber.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + assertHTTPStatus(t, resp, fiber.StatusOK) + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + var response map[string]interface{} + if err := json.Unmarshal(body, &response); err != nil { + t.Errorf("Response is not valid JSON: %v", err) + } + + healthy, ok := response["overall_healthy"] + if !ok { + t.Error("Response should contain 'overall_healthy' field") + } + if healthy != true { + t.Errorf("Expected overall_healthy=true, got %v", healthy) + } + + mode, ok := response["mode"] + if !ok { + t.Error("Response should contain 'mode' field") + } + if mode != "per-user credentials" { + t.Errorf("Expected mode='per-user credentials', got %v", mode) + } + }) +} + func TestLivenessHandler(t *testing.T) { app := setupHealthTestApp() @@ -146,9 +211,9 @@ func TestLivenessHandler(t *testing.T) { t.Errorf("Expected status 'alive', got '%v'", status) } - // Check uptime field + // Check uptime field (present when service account is configured) if _, ok := response["uptime"]; !ok { - t.Error("Response should contain 'uptime' field") + t.Error("Response should contain 'uptime' field when service account is configured") } }) @@ -167,6 +232,49 @@ func TestLivenessHandler(t *testing.T) { }) } +// assertNoServiceAccountStatusEndpoint tests a health endpoint returns expected status without service account. +func assertNoServiceAccountStatusEndpoint(t *testing.T, app *App, endpoint, expectedStatus string) { + t.Helper() + + req := httptest.NewRequest("GET", endpoint, http.NoBody) + + resp, err := app.fiber.Test(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + + defer func() { _ = resp.Body.Close() }() + + assertHTTPStatus(t, resp, fiber.StatusOK) + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + var response map[string]interface{} + if err := json.Unmarshal(body, &response); err != nil { + t.Errorf("Response is not valid JSON: %v", err) + } + + status, ok := response["status"] + if !ok { + t.Error("Response should contain 'status' field") + } + + if status != expectedStatus { + t.Errorf("Expected status %q, got %v", expectedStatus, status) + } +} + +func TestLivenessHandlerNoServiceAccount(t *testing.T) { + app := setupHealthTestAppNoServiceAccount() + + t.Run("returns alive status without uptime", func(t *testing.T) { + assertNoServiceAccountStatusEndpoint(t, app, "/live", "alive") + }) +} + func TestReadinessHandler(t *testing.T) { app := setupHealthTestApp() @@ -197,15 +305,14 @@ func TestReadinessHandler(t *testing.T) { if _, ok := response["status"]; !ok { t.Error("Response should contain 'status' field") } - if _, ok := response["cache"]; !ok { - t.Error("Response should contain 'cache' field") - } - if _, ok := response["warmed_up"]; !ok { - t.Error("Response should contain 'warmed_up' field") - } - if _, ok := response["connection_pool"]; !ok { - t.Error("Response should contain 'connection_pool' field") - } + }) +} + +func TestReadinessHandlerNoServiceAccount(t *testing.T) { + app := setupHealthTestAppNoServiceAccount() + + t.Run("returns ready without service account", func(t *testing.T) { + assertNoServiceAccountStatusEndpoint(t, app, "/ready", "ready") }) } diff --git a/internal/web/middleware.go b/internal/web/middleware.go index f25abfb..446e767 100644 --- a/internal/web/middleware.go +++ b/internal/web/middleware.go @@ -31,8 +31,11 @@ func (a *App) RequireAuth() fiber.Handler { return c.Redirect("/login") } - // Store user DN in context for handlers to use + // Store user DN and username in context for handlers to use c.Locals("userDN", userDN) + if username, ok := sess.Get("username").(string); ok { + c.Locals("username", username) + } log.Debug(). Str("userDN", userDN). diff --git a/internal/web/server.go b/internal/web/server.go index df58310..4159242 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -3,6 +3,8 @@ package web import ( "context" "crypto/tls" + "errors" + "fmt" "log/slog" "net/http" "time" @@ -27,12 +29,13 @@ import ( // App represents the main web application structure. // Encapsulates LDAP config, readonly client, cache, session store, template cache, Fiber framework. // Provides centralized auth, caching, and HTTP request handling. -// Note: Connection pooling is disabled because CheckPasswordForSAMAccountName rebinds -// pooled connections with user credentials, causing credential contamination issues. +// When ReadonlyUser is not configured, ldapReadonly and ldapCache are nil; +// all interactive LDAP operations use the logged-in user's own credentials. type App struct { ldapConfig ldap.Config - ldapReadonly *ldap.LDAP // Read-only client (no pooling) - ldapCache *ldap_cache.Manager + ldapOpts []ldap.Option // LDAP client options (TLS, logging) + ldapReadonly *ldap.LDAP // Service account client (nil when not configured) + ldapCache *ldap_cache.Manager // Background cache (nil when no service account) sessionStore *session.Store templateCache *TemplateCache csrfHandler fiber.Handler @@ -40,6 +43,7 @@ type App struct { logger *slog.Logger assetManifest *AssetManifest // Asset manifest for cache-busted files rateLimiter *RateLimiter // Rate limiter for authentication endpoints + stopCacheLog chan struct{} // Stops periodicCacheLogging goroutine } func getSessionStorage(opts *options.Opts) fiber.Storage { @@ -82,13 +86,13 @@ func createFiberApp() *fiber.App { } // NewApp creates a new web application instance with the provided configuration options. -// It initializes the LDAP configuration, readonly client, session management, +// It initializes the LDAP configuration, readonly client (if configured), session management, // template cache, Fiber web server, and registers all routes. // Returns a configured App instance ready to start serving requests via Listen(). // -// Note: Connection pooling is intentionally disabled. The CheckPasswordForSAMAccountName -// method rebinds pooled connections with user credentials, contaminating the pool. -// Each operation creates a fresh connection like ldap-selfservice-password-changer. +// When ReadonlyUser is not configured, the app operates without a service account: +// all LDAP operations use the logged-in user's own credentials, and +// the background cache is disabled. func NewApp(opts *options.Opts) (*App, error) { logger := slog.Default() @@ -103,17 +107,26 @@ func NewApp(opts *options.Opts) (*App, error) { })) } - // Create readonly LDAP client WITHOUT connection pooling - // Pooling is disabled because CheckPasswordForSAMAccountName rebinds connections - // with user credentials, which contaminates the pool and causes timeout issues - ldapReadonly, err := ldap.New( - opts.LDAP, - opts.ReadonlyUser, - opts.ReadonlyPassword, - ldapOpts..., - ) - if err != nil { - return nil, err + // Create readonly LDAP client only when service account is configured + var ldapReadonly *ldap.LDAP + var ldapCache *ldap_cache.Manager + + if opts.ReadonlyUser != "" && opts.ReadonlyPassword != "" { + var err error + ldapReadonly, err = ldap.New( + opts.LDAP, + opts.ReadonlyUser, + opts.ReadonlyPassword, + ldapOpts..., + ) + if err != nil { + return nil, err + } + + ldapCache = ldap_cache.New(ldapReadonly) + log.Info().Msg("Service account configured, background cache enabled") + } else { + log.Info().Msg("No service account configured, using per-user LDAP credentials") } sessionStore := createSessionStore(opts) @@ -134,8 +147,9 @@ func NewApp(opts *options.Opts) (*App, error) { a := &App{ ldapConfig: opts.LDAP, + ldapOpts: ldapOpts, ldapReadonly: ldapReadonly, - ldapCache: ldap_cache.New(ldapReadonly), + ldapCache: ldapCache, templateCache: templateCache, sessionStore: sessionStore, csrfHandler: csrfHandler, @@ -143,6 +157,7 @@ func NewApp(opts *options.Opts) (*App, error) { logger: logger, assetManifest: manifest, rateLimiter: NewRateLimiter(DefaultRateLimiterConfig()), + stopCacheLog: make(chan struct{}), } // Setup all routes @@ -267,7 +282,9 @@ func (a *App) setupRoutes() { // The context is used for graceful shutdown signaling to background goroutines. // This method blocks until the server is shutdown or encounters an error. func (a *App) Listen(ctx context.Context, addr string) error { - go a.ldapCache.Run(ctx) + if a.ldapCache != nil { + go a.ldapCache.Run(ctx) + } return a.fiber.Listen(addr) } @@ -275,18 +292,24 @@ func (a *App) Listen(ctx context.Context, addr string) error { // Shutdown gracefully shuts down the application within the given context timeout. // It stops all background goroutines, closes connections, and releases resources. func (a *App) Shutdown(ctx context.Context) error { + log.Info().Msg("Stopping periodic cache logging...") + close(a.stopCacheLog) + log.Info().Msg("Stopping template cache...") a.templateCache.Stop() - log.Info().Msg("Stopping LDAP cache manager...") - a.ldapCache.Stop() + if a.ldapCache != nil { + log.Info().Msg("Stopping LDAP cache manager...") + a.ldapCache.Stop() + } log.Info().Msg("Stopping rate limiter...") a.rateLimiter.Stop() log.Info().Msg("Shutting down Fiber server...") - if err := a.fiber.ShutdownWithContext(ctx); err != nil { - log.Error().Err(err).Msg("Error shutting down Fiber server") + shutdownErr := a.fiber.ShutdownWithContext(ctx) + if shutdownErr != nil { + log.Error().Err(shutdownErr).Msg("Error shutting down Fiber server") } log.Info().Msg("Closing LDAP connections...") @@ -296,7 +319,33 @@ func (a *App) Shutdown(ctx context.Context) error { } } - return nil + return shutdownErr +} + +// getUserLDAP creates a user-bound LDAP client from session credentials. +// The caller must close the returned client via defer client.Close(). +// Returns a fiber.StatusUnauthorized error if session has no credentials, +// which handle500 will convert to a login redirect. +func (a *App) getUserLDAP(c *fiber.Ctx) (*ldap.LDAP, error) { + sess, err := a.sessionStore.Get(c) + if err != nil { + return nil, fmt.Errorf("getUserLDAP: session error: %w", err) + } + + dn, _ := sess.Get("dn").(string) + password, _ := sess.Get("password").(string) + + if dn == "" || password == "" { + return nil, fiber.NewError(fiber.StatusUnauthorized, "session expired or missing credentials") + } + + client, err := ldap.New(a.ldapConfig, dn, password, a.ldapOpts...) + if err != nil { + // LDAP bind failure likely means expired/changed password → redirect to login + return nil, fiber.NewError(fiber.StatusUnauthorized, "LDAP connection failed, please re-login") + } + + return client, nil } // templateCacheMiddleware creates middleware for template caching @@ -325,14 +374,6 @@ func (a *App) templateCacheMiddleware() fiber.Handler { } } -// invalidateTemplateCache invalidates cache entries after data modifications -func (a *App) invalidateTemplateCache(paths ...string) { - for _, path := range paths { - count := a.templateCache.InvalidateByPath(path) - log.Debug().Str("path", path).Int("invalidated", count).Msg("Template cache invalidated") - } -} - // cacheStatsHandler provides cache statistics for monitoring func (a *App) cacheStatsHandler(c *fiber.Ctx) error { stats := a.templateCache.Stats() @@ -341,16 +382,18 @@ func (a *App) cacheStatsHandler(c *fiber.Ctx) error { } // poolStatsHandler provides LDAP performance statistics for monitoring -// Note: Connection pooling is disabled, so pool-specific stats will be empty func (a *App) poolStatsHandler(c *fiber.Ctx) error { - stats := a.ldapReadonly.GetPoolStats() - - response := map[string]any{ - "stats": stats, - "message": "Connection pooling is disabled - each operation creates a fresh connection", + if a.ldapReadonly == nil { + return c.JSON(map[string]any{ + "message": "No service account configured - per-user LDAP credentials in use", + }) } - return c.JSON(response) + stats := a.ldapReadonly.GetPoolStats() + + return c.JSON(map[string]any{ + "stats": stats, + }) } // periodicCacheLogging logs cache statistics periodically for monitoring @@ -358,14 +401,28 @@ func (a *App) periodicCacheLogging() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() - for range ticker.C { - a.templateCache.LogStats() + for { + select { + case <-ticker.C: + a.templateCache.LogStats() + case <-a.stopCacheLog: + return + } } } func handle500(c *fiber.Ctx, err error) error { + // Redirect to login on authentication errors instead of showing 500 + var fiberErr *fiber.Error + if errors.As(err, &fiberErr) && fiberErr.Code == fiber.StatusUnauthorized { + log.Warn().Err(err).Msg("session expired or invalid, redirecting to login") + + return c.Redirect("/login") + } + log.Error().Err(err).Send() + c.Status(fiber.StatusInternalServerError) c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) return templates.FiveHundred(err).Render(c.UserContext(), c.Response().BodyWriter()) @@ -378,19 +435,51 @@ func (a *App) indexHandler(c *fiber.Ctx) error { return err } - user, err := a.ldapCache.FindUserByDN(userDN) + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + // Get username from middleware context (stored during auth) + username, _ := c.Locals("username").(string) + + var user *ldap.User + + if username != "" { + user, err = userLDAP.FindUserBySAMAccountName(username) + // Fail fast on real errors (not just "not found") + if err != nil && !errors.Is(err, ldap.ErrUserNotFound) { + return handle500(c, err) + } + } + + // Fall back to finding by DN if lookup by SAMAccountName was not attempted or user not found + if user == nil { + allUsers, findErr := userLDAP.FindUsers() + if findErr != nil { + return handle500(c, findErr) + } + + user, err = findUserByDN(allUsers, userDN) + if err != nil { + return handle500(c, err) + } + } + + groups, err := userLDAP.FindGroups() if err != nil { return handle500(c, err) } - // Populate groups for the home screen - fullUser := a.ldapCache.PopulateGroupsForUser(user) + fullUser := ldap_cache.PopulateGroupsForUserFromData(user, groups) // Use template caching return a.templateCache.RenderWithCache(c, templates.Index(fullUser)) } func (a *App) fourOhFourHandler(c *fiber.Ctx) error { + c.Status(fiber.StatusNotFound) c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) return templates.FourOhFour(c.Path()).Render(c.UserContext(), c.Response().BodyWriter()) diff --git a/internal/web/template_cache.go b/internal/web/template_cache.go index 3d3098a..5a187ca 100644 --- a/internal/web/template_cache.go +++ b/internal/web/template_cache.go @@ -24,10 +24,9 @@ type TemplateCache struct { } type cacheEntry struct { - content []byte - createdAt time.Time - accessedAt time.Time - ttl time.Duration + content []byte + createdAt time.Time + ttl time.Duration } // TemplateCacheConfig holds configuration for template caching @@ -93,8 +92,8 @@ func (tc *TemplateCache) generateCacheKey(c *fiber.Ctx, additionalData ...string // Get retrieves cached template content if available and not expired func (tc *TemplateCache) Get(key string) ([]byte, bool) { - tc.mu.Lock() - defer tc.mu.Unlock() + tc.mu.RLock() + defer tc.mu.RUnlock() entry, exists := tc.entries[key] if !exists { @@ -103,13 +102,9 @@ func (tc *TemplateCache) Get(key string) ([]byte, bool) { // Check if entry is expired if time.Since(entry.createdAt) > entry.ttl { - // Entry expired, but don't remove it here to avoid complex logic return nil, false } - // Update access time for LRU tracking - entry.accessedAt = time.Now() - return entry.content, true } @@ -127,51 +122,13 @@ func (tc *TemplateCache) Set(key string, content []byte, ttl time.Duration) { ttl = tc.defaultTTL } - now := time.Now() tc.entries[key] = &cacheEntry{ - content: content, - createdAt: now, - accessedAt: now, - ttl: ttl, + content: content, + createdAt: time.Now(), + ttl: ttl, } } -// Invalidate removes cached entries matching the given pattern -func (tc *TemplateCache) Invalidate(pattern string) int { - tc.mu.Lock() - defer tc.mu.Unlock() - - count := 0 - for key := range tc.entries { - // Simple pattern matching - could be enhanced with regex if needed - if pattern == "*" || key == pattern { - delete(tc.entries, key) - count++ - } - } - - return count -} - -// InvalidateByPath removes all cached entries for a specific path -func (tc *TemplateCache) InvalidateByPath(path string) int { - tc.mu.Lock() - defer tc.mu.Unlock() - - count := 0 - for key := range tc.entries { - // Check if the key contains the path (since keys are hashed, - // we'll need to maintain a reverse mapping or use a different approach) - // For now, we'll invalidate all entries (this could be optimized) - if path != "" { - delete(tc.entries, key) - count++ - } - } - - return count -} - // Clear removes all cached entries func (tc *TemplateCache) Clear() { tc.mu.Lock() @@ -222,9 +179,9 @@ func (tc *TemplateCache) evictOldestUnsafe() { var oldestTime time.Time for key, entry := range tc.entries { - if oldestKey == "" || entry.accessedAt.Before(oldestTime) { + if oldestKey == "" || entry.createdAt.Before(oldestTime) { oldestKey = key - oldestTime = entry.accessedAt + oldestTime = entry.createdAt } } @@ -295,39 +252,6 @@ func (tc *TemplateCache) RenderWithCache(c *fiber.Ctx, component templ.Component return c.Send(content) } -// CacheMiddleware creates a Fiber middleware for template caching -func (tc *TemplateCache) CacheMiddleware(paths ...string) fiber.Handler { - pathMap := make(map[string]bool) - for _, path := range paths { - pathMap[path] = true - } - - return func(c *fiber.Ctx) error { - // Only cache GET requests for specified paths - if c.Method() != fiber.MethodGet { - return c.Next() - } - - // Check if this path should be cached - if len(pathMap) > 0 && !pathMap[c.Path()] { - return c.Next() - } - - // Generate cache key - cacheKey := tc.generateCacheKey(c) - - // Try to serve from cache - if cachedContent, found := tc.Get(cacheKey); found { - c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) - - return c.Send(cachedContent) - } - - // Not in cache, continue to handler - return c.Next() - } -} - // LogStats logs cache statistics func (tc *TemplateCache) LogStats() { stats := tc.Stats() diff --git a/internal/web/template_cache_fuzz_test.go b/internal/web/template_cache_fuzz_test.go index d17c6de..e96200a 100644 --- a/internal/web/template_cache_fuzz_test.go +++ b/internal/web/template_cache_fuzz_test.go @@ -81,40 +81,33 @@ func FuzzTemplateCacheGet(f *testing.F) { }) } -// FuzzTemplateCacheInvalidate tests invalidation with fuzzed patterns -func FuzzTemplateCacheInvalidate(f *testing.F) { - // Seed with patterns - f.Add("*") - f.Add("user-*") - f.Add("*-suffix") - f.Add("") - f.Add("exact-key") - f.Add("key with spaces") - f.Add(strings.Repeat("x", 1000)) +// FuzzTemplateCacheClear tests clear with fuzzed pre-populated cache +func FuzzTemplateCacheClear(f *testing.F) { + f.Add(1) + f.Add(5) + f.Add(50) + + f.Fuzz(func(t *testing.T, numEntries int) { + if numEntries < 0 { + numEntries = 0 + } + if numEntries > 100 { + numEntries = 100 + } - f.Fuzz(func(t *testing.T, pattern string) { cache := NewTemplateCache(DefaultTemplateCacheConfig()) defer cache.Stop() - // Add some entries - cache.Set("user-1", []byte("1"), 0) - cache.Set("user-2", []byte("2"), 0) - cache.Set("group-1", []byte("3"), 0) - - // Invalidate shouldn't panic - count := cache.Invalidate(pattern) - - // Verify count is reasonable - if count < 0 { - t.Errorf("Negative invalidation count: %d", count) + for i := range numEntries { + cache.Set("key-"+string(rune(i%65536)), []byte("content"), 0) } - // After "*" invalidation, cache should be empty - if pattern == "*" { - stats := cache.Stats() - if stats.Entries != 0 { - t.Errorf("Cache not empty after '*' invalidation: %d entries", stats.Entries) - } + // Clear shouldn't panic + cache.Clear() + + stats := cache.Stats() + if stats.Entries != 0 { + t.Errorf("Cache not empty after Clear: %d entries", stats.Entries) } }) } diff --git a/internal/web/template_cache_test.go b/internal/web/template_cache_test.go index d52ba18..3010535 100644 --- a/internal/web/template_cache_test.go +++ b/internal/web/template_cache_test.go @@ -65,28 +65,25 @@ func TestTemplateCacheEviction(t *testing.T) { // Fill cache to capacity cache.Set("key1", []byte("content1"), 0) + time.Sleep(10 * time.Millisecond) // Ensure different creation times cache.Set("key2", []byte("content2"), 0) stats := cache.Stats() assert.Equal(t, 2, stats.Entries) - // Access key1 to make it more recently used - _, found := cache.Get("key1") - assert.True(t, found) - - // Add another entry, should evict key2 (oldest) + // Add another entry, should evict key1 (oldest created) cache.Set("key3", []byte("content3"), 0) stats = cache.Stats() assert.Equal(t, 2, stats.Entries) - // key1 should still exist - _, found = cache.Get("key1") - assert.True(t, found) + // key1 should be evicted (oldest) + _, found := cache.Get("key1") + assert.False(t, found) - // key2 should be evicted + // key2 should still exist _, found = cache.Get("key2") - assert.False(t, found) + assert.True(t, found) // key3 should exist _, found = cache.Get("key3") @@ -117,7 +114,7 @@ func TestTemplateCacheClear(t *testing.T) { assert.False(t, found) } -func TestTemplateCacheInvalidation(t *testing.T) { +func TestTemplateCacheClearOperation(t *testing.T) { cache := NewTemplateCache(DefaultTemplateCacheConfig()) defer cache.Stop() @@ -129,24 +126,17 @@ func TestTemplateCacheInvalidation(t *testing.T) { stats := cache.Stats() assert.Equal(t, 3, stats.Entries) - // Test pattern invalidation - count := cache.Invalidate("key1") - assert.Equal(t, 1, count) + // Clear all entries + cache.Clear() - // Verify key1 is gone + stats = cache.Stats() + assert.Equal(t, 0, stats.Entries) + + // Verify all are gone _, found := cache.Get("key1") assert.False(t, found) - - // Other keys should remain _, found = cache.Get("key2") - assert.True(t, found) - - // Test wildcard invalidation - count = cache.Invalidate("*") - assert.Equal(t, 2, count) // Should remove remaining 2 entries - - stats = cache.Stats() - assert.Equal(t, 0, stats.Entries) + assert.False(t, found) } func TestTemplateCacheCleanup(t *testing.T) { diff --git a/internal/web/template_errors_test.go b/internal/web/template_errors_test.go index 4c2be42..9354335 100644 --- a/internal/web/template_errors_test.go +++ b/internal/web/template_errors_test.go @@ -51,11 +51,11 @@ func TestTemplateCacheConcurrentAccess(t *testing.T) { }(i) } - // Concurrent invalidators + // Concurrent clearers for range 5 { wg.Go(func() { for range 10 { - cache.Invalidate("*") + cache.Clear() time.Sleep(10 * time.Millisecond) } }) @@ -201,8 +201,8 @@ func TestTemplateCacheStatsAccuracy(t *testing.T) { assert.Equal(t, 0, stats.Entries) } -// TestTemplateCacheInvalidatePatterns tests various invalidation patterns -func TestTemplateCacheInvalidatePatterns(t *testing.T) { +// TestTemplateCacheClearAll tests clearing all entries +func TestTemplateCacheClearAll(t *testing.T) { cache := NewTemplateCache(DefaultTemplateCacheConfig()) defer cache.Stop() @@ -215,42 +215,17 @@ func TestTemplateCacheInvalidatePatterns(t *testing.T) { stats := cache.Stats() assert.Equal(t, 4, stats.Entries) - // Invalidate specific entry - count := cache.Invalidate("user-1") - assert.Equal(t, 1, count) - - stats = cache.Stats() - assert.Equal(t, 3, stats.Entries) - - // Invalidate non-existent entry - count = cache.Invalidate("nonexistent") - assert.Equal(t, 0, count) - - // Invalidate all - count = cache.Invalidate("*") - assert.Equal(t, 3, count) + // Clear all + cache.Clear() stats = cache.Stats() assert.Equal(t, 0, stats.Entries) -} - -// TestTemplateCacheInvalidateByPath tests path-based invalidation -func TestTemplateCacheInvalidateByPath(t *testing.T) { - cache := NewTemplateCache(DefaultTemplateCacheConfig()) - defer cache.Stop() - - cache.Set("key1", []byte("content1"), 0) - cache.Set("key2", []byte("content2"), 0) - - stats := cache.Stats() - assert.Equal(t, 2, stats.Entries) - // Invalidate by path - count := cache.InvalidateByPath("/users") - assert.Equal(t, 2, count) // Currently invalidates all for non-empty path - - stats = cache.Stats() - assert.Equal(t, 0, stats.Entries) + // Verify all entries are gone + _, found := cache.Get("user-1") + assert.False(t, found) + _, found = cache.Get("group-1") + assert.False(t, found) } // TestTemplateCacheStopSingleCallSafe tests that Stop can be called once safely @@ -334,57 +309,6 @@ func TestTemplateCacheGenerateKey(t *testing.T) { }) } -// TestTemplateCacheMiddleware tests the cache middleware -func TestTemplateCacheMiddleware(t *testing.T) { - cache := NewTemplateCache(DefaultTemplateCacheConfig()) - defer cache.Stop() - - app := fiber.New() - - // Add cache middleware for specific paths - app.Use(cache.CacheMiddleware("/cached")) - - callCount := 0 - app.Get("/cached", func(c *fiber.Ctx) error { - callCount++ - - return c.SendString("content") - }) - - app.Get("/not-cached", func(c *fiber.Ctx) error { - callCount++ - - return c.SendString("content") - }) - - t.Run("non-GET requests are not cached", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/cached", http.NoBody) - resp, err := app.Test(req) - require.NoError(t, err) - _ = resp.Body.Close() - - // POST should not be cached - assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - }) - - t.Run("non-matching paths are not cached", func(t *testing.T) { - callCount = 0 - - req := httptest.NewRequest(http.MethodGet, "/not-cached", http.NoBody) - resp, _ := app.Test(req) - _ = resp.Body.Close() - - assert.Equal(t, 1, callCount) - - // Second request should still call handler - req = httptest.NewRequest(http.MethodGet, "/not-cached", http.NoBody) - resp, _ = app.Test(req) - _ = resp.Body.Close() - - assert.Equal(t, 2, callCount) - }) -} - // TestTemplateCacheEvictionOrder tests that oldest entries are evicted first func TestTemplateCacheEvictionOrder(t *testing.T) { cache := NewTemplateCache(TemplateCacheConfig{ @@ -394,18 +318,15 @@ func TestTemplateCacheEvictionOrder(t *testing.T) { }) defer cache.Stop() - // Add entries with time gaps to establish access order + // Add entries with time gaps to establish creation order cache.Set("first", []byte("1"), 0) time.Sleep(10 * time.Millisecond) cache.Set("second", []byte("2"), 0) time.Sleep(10 * time.Millisecond) cache.Set("third", []byte("3"), 0) - - // Access "first" to make it more recently used - cache.Get("first") time.Sleep(10 * time.Millisecond) - // Add fourth entry, should evict "second" (oldest accessed) + // Add fourth entry, should evict "first" (oldest created) cache.Set("fourth", []byte("4"), 0) // Check what's in cache @@ -414,8 +335,8 @@ func TestTemplateCacheEvictionOrder(t *testing.T) { _, foundThird := cache.Get("third") _, foundFourth := cache.Get("fourth") - assert.True(t, foundFirst, "first should still be in cache (recently accessed)") - assert.False(t, foundSecond, "second should be evicted (oldest accessed)") + assert.False(t, foundFirst, "first should be evicted (oldest created)") + assert.True(t, foundSecond, "second should still be in cache") assert.True(t, foundThird, "third should still be in cache") assert.True(t, foundFourth, "fourth should be in cache (just added)") } diff --git a/internal/web/templates/logged_in.templ b/internal/web/templates/logged_in.templ index a08bc56..5dc992f 100644 --- a/internal/web/templates/logged_in.templ +++ b/internal/web/templates/logged_in.templ @@ -1,5 +1,7 @@ package templates +import "github.com/netresearch/ldap-manager/internal/version" + const navbarClasses = "px-3 py-1 rounded-md flex items-center gap-2 transition-colors focus:outline-none hocus:text-text-primary max-sm:px-2 max-sm:py-2 compact:px-2 compact:py-1 comfortable:px-3 comfortable:py-2 " const navbarInactiveClasses = "text-text-secondary hocus:bg-surface-elevated" const navbarActiveClasses = "text-text-primary bg-surface-elevated" @@ -63,6 +65,21 @@ templ loggedIn(current, title string, flashes []Flash) { } { children... } + } } diff --git a/internal/web/users.go b/internal/web/users.go index d2ff4ca..60c83ee 100644 --- a/internal/web/users.go +++ b/internal/web/users.go @@ -3,34 +3,43 @@ package web // HTTP handlers for user management endpoints. import ( + "errors" "net/url" "sort" "github.com/gofiber/fiber/v2" ldap "github.com/netresearch/simple-ldap-go" + "github.com/rs/zerolog/log" "github.com/netresearch/ldap-manager/internal/ldap_cache" "github.com/netresearch/ldap-manager/internal/web/templates" ) -// usersHandler handles GET /users requests to list all user accounts in the LDAP directory. -// Supports optional show-disabled query parameter to include disabled user accounts. -// Users are sorted alphabetically by CN (Common Name) and returned as HTML using template caching. -// -// Query Parameters: -// - show-disabled: Set to "1" to include disabled users in the listing -// -// Returns: -// - 200: HTML page with user listing including display names, account names, and email addresses -// - 500: Internal server error if LDAP query fails -// -// Example: -// -// GET /users?show-disabled=1 func (a *App) usersHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session showDisabled := c.Query("show-disabled", "0") == "1" - users := a.ldapCache.FindUsers(showDisabled) + + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + allUsers, err := userLDAP.FindUsers() + if err != nil { + return handle500(c, err) + } + + var users []ldap.User + if !showDisabled { + for _, u := range allUsers { + if u.Enabled { + users = append(users, u) + } + } + } else { + users = allUsers + } + sort.SliceStable(users, func(i, j int) bool { return users[i].CN() < users[j].CN() }) @@ -39,29 +48,26 @@ func (a *App) usersHandler(c *fiber.Ctx) error { return a.templateCache.RenderWithCache(c, templates.Users(users, showDisabled, templates.Flashes())) } -// userHandler handles GET /users/:userDN requests to display detailed information for a specific user. -// The userDN path parameter must be URL-encoded Distinguished Name of the user account. -// Returns user details including all LDAP attributes, group memberships, and edit form with CSRF protection. -// -// Path Parameters: -// - userDN: URL-encoded Distinguished Name of the user (e.g. "CN=John Doe,OU=Users,DC=example,DC=com") -// -// Returns: -// - 200: HTML page with user details, group memberships, and editable form fields -// - 500: Internal server error if user not found or LDAP query fails -// -// Example: -// -// GET /users/CN%3DJohn%20Doe%2COU%3DUsers%2CDC%3Dexample%2CDC%3Dcom func (a *App) userHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session userDN, err := url.PathUnescape(c.Params("*")) if err != nil { return handle500(c, err) } - user, unassignedGroups, err := a.loadUserData(userDN) + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + user, unassignedGroups, err := a.loadUserDataFromLDAP(userLDAP, userDN) if err != nil { + if errors.Is(err, ldap.ErrUserNotFound) { + c.Status(fiber.StatusNotFound) + + return a.fourOhFourHandler(c) + } + return handle500(c, err) } @@ -80,7 +86,6 @@ type userModifyForm struct { // nolint:dupl // Similar to groupModifyHandler but operates on different entities with different forms func (a *App) userModifyHandler(c *fiber.Ctx) error { - // Authentication handled by middleware, no need to check session userDN, err := url.PathUnescape(c.Params("*")) if err != nil { return handle500(c, err) @@ -92,84 +97,105 @@ func (a *App) userModifyHandler(c *fiber.Ctx) error { } if form.RemoveGroup == nil && form.AddGroup == nil { - return c.Redirect("/users/" + userDN) + return c.Redirect("/users/" + url.PathEscape(userDN)) } - // Perform the user modification using the readonly LDAP client - // User is already authenticated via session middleware - if err := a.performUserModification(a.ldapReadonly, &form, userDN); err != nil { - return a.renderUserWithError(c, userDN, "Failed to modify: "+err.Error()) + userLDAP, err := a.getUserLDAP(c) + if err != nil { + return handle500(c, err) + } + defer func() { _ = userLDAP.Close() }() + + // Perform the user modification using the logged-in user's LDAP connection + if err := a.performUserModification(userLDAP, &form, userDN); err != nil { + log.Warn().Err(err).Str("userDN", userDN).Msg("failed to modify user") + + return a.renderUserWithFlash(c, userLDAP, userDN, templates.ErrorFlash("Failed to modify user membership")) } // Invalidate template cache after successful modification - a.invalidateTemplateCacheOnUserModification(userDN) + a.invalidateTemplateCacheOnModification() // Render success response - return a.renderUserWithSuccess(c, userDN, "Successfully modified user") + return a.renderUserWithFlash(c, userLDAP, userDN, templates.SuccessFlash("Successfully modified user")) } -func (a *App) findUnassignedGroups(user *ldap_cache.FullLDAPUser) []ldap.Group { - return a.ldapCache.Groups.Filter(func(g ldap.Group) bool { - for _, ug := range user.Groups { - if ug.DN() == g.DN() { - return false - } - } +// loadUserDataFromLDAP loads user data directly from an LDAP client connection. +func (a *App) loadUserDataFromLDAP(userLDAP *ldap.LDAP, userDN string) (*ldap_cache.FullLDAPUser, []ldap.Group, error) { + allUsers, err := userLDAP.FindUsers() + if err != nil { + return nil, nil, err + } - return true - }) -} + user, err := findUserByDN(allUsers, userDN) + if err != nil { + return nil, nil, err + } -// loadUserData loads and prepares user data with proper sorting -func (a *App) loadUserData(userDN string) (*ldap_cache.FullLDAPUser, []ldap.Group, error) { - thinUser, err := a.ldapCache.FindUserByDN(userDN) + groups, err := userLDAP.FindGroups() if err != nil { return nil, nil, err } - user := a.ldapCache.PopulateGroupsForUser(thinUser) - sort.SliceStable(user.Groups, func(i, j int) bool { - return user.Groups[i].CN() < user.Groups[j].CN() + fullUser := ldap_cache.PopulateGroupsForUserFromData(user, groups) + sort.SliceStable(fullUser.Groups, func(i, j int) bool { + return fullUser.Groups[i].CN() < fullUser.Groups[j].CN() }) - unassignedGroups := a.findUnassignedGroups(user) + + unassignedGroups := filterUnassignedGroups(groups, fullUser) sort.SliceStable(unassignedGroups, func(i, j int) bool { return unassignedGroups[i].CN() < unassignedGroups[j].CN() }) - return user, unassignedGroups, nil + return fullUser, unassignedGroups, nil } -// renderUserWithError renders the user page with an error message -func (a *App) renderUserWithError(c *fiber.Ctx, userDN, errorMsg string) error { +// renderUserWithFlash renders the user page with a flash message using a user LDAP connection. +func (a *App) renderUserWithFlash(c *fiber.Ctx, userLDAP *ldap.LDAP, userDN string, flash templates.Flash) error { c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) - user, unassignedGroups, err := a.loadUserData(userDN) + + user, unassignedGroups, err := a.loadUserDataFromLDAP(userLDAP, userDN) if err != nil { return handle500(c, err) } return templates.User( user, unassignedGroups, - templates.Flashes(templates.ErrorFlash(errorMsg)), + templates.Flashes(flash), a.GetCSRFToken(c), ).Render(c.UserContext(), c.Response().BodyWriter()) } -// renderUserWithSuccess renders the user page with a success message -func (a *App) renderUserWithSuccess(c *fiber.Ctx, userDN, successMsg string) error { - c.Set(fiber.HeaderContentType, fiber.MIMETextHTMLCharsetUTF8) - user, unassignedGroups, err := a.loadUserData(userDN) - if err != nil { - return handle500(c, err) +// filterUnassignedGroups returns groups the user is not a member of. +func filterUnassignedGroups(allGroups []ldap.Group, user *ldap_cache.FullLDAPUser) []ldap.Group { + memberGroupDNS := make(map[string]struct{}, len(user.Groups)) + for _, g := range user.Groups { + memberGroupDNS[g.DN()] = struct{}{} } - return templates.User( - user, unassignedGroups, - templates.Flashes(templates.SuccessFlash(successMsg)), - a.GetCSRFToken(c), - ).Render(c.UserContext(), c.Response().BodyWriter()) + result := make([]ldap.Group, 0) + + for _, g := range allGroups { + if _, isMember := memberGroupDNS[g.DN()]; !isMember { + result = append(result, g) + } + } + + return result +} + +// findUserByDN searches for a user by DN in a slice. +func findUserByDN(users []ldap.User, dn string) (*ldap.User, error) { + for i := range users { + if users[i].DN() == dn { + return &users[i], nil + } + } + + return nil, ldap.ErrUserNotFound } -// performUserModification handles the actual LDAP user modification operation +// performUserModification handles the actual LDAP user modification operation. func (a *App) performUserModification( ldapClient *ldap.LDAP, form *userModifyForm, userDN string, ) error { @@ -177,29 +203,26 @@ func (a *App) performUserModification( if err := ldapClient.AddUserToGroup(userDN, *form.AddGroup); err != nil { return err } - a.ldapCache.OnAddUserToGroup(userDN, *form.AddGroup) + + if a.ldapCache != nil { + a.ldapCache.OnAddUserToGroup(userDN, *form.AddGroup) + } } else if form.RemoveGroup != nil { if err := ldapClient.RemoveUserFromGroup(userDN, *form.RemoveGroup); err != nil { return err } - a.ldapCache.OnRemoveUserFromGroup(userDN, *form.RemoveGroup) + + if a.ldapCache != nil { + a.ldapCache.OnRemoveUserFromGroup(userDN, *form.RemoveGroup) + } } return nil } -// invalidateTemplateCacheOnUserModification invalidates relevant cache entries after user modification -func (a *App) invalidateTemplateCacheOnUserModification(userDN string) { - // Invalidate the specific user page - a.invalidateTemplateCache("/users/" + userDN) - - // Invalidate users list page (counts may have changed) - a.invalidateTemplateCache("/users") - - // Invalidate groups pages (group membership may have changed) - a.invalidateTemplateCache("/groups") - - // Clear all cache entries for safety (this could be optimized further) - // In a high-traffic environment, you might want to be more selective +// invalidateTemplateCacheOnModification clears the template cache after any modification. +// Membership changes can affect multiple pages, so we clear the entire cache. +func (a *App) invalidateTemplateCacheOnModification() { a.templateCache.Clear() + log.Debug().Msg("Template cache cleared after modification") }