package oidc import ( "net/http" "strings" "sync" "time" "keycape/internal/domain" profileerrors "keycape/internal/errors" "keycape/internal/server/telemetry" ) // PendingState holds the authorization request parameters while the user is // being authenticated by the upstream provider (e.g. Authelia). It is keyed // by the opaque state value that is round-tripped through the upstream. type PendingState struct { ClientID string RedirectURI string PKCEChallenge string PKCEChallengeMethod string State string Scopes []string ExpiresAt time.Time } // pendingStateStore is a thread-safe map of state → PendingState. type pendingStateStore struct { mu sync.Mutex store map[string]*PendingState } func newPendingStateStore() *pendingStateStore { return &pendingStateStore{store: make(map[string]*PendingState)} } func (p *pendingStateStore) Store(state string, ps *PendingState) { p.mu.Lock() p.store[state] = ps p.mu.Unlock() } func (p *pendingStateStore) Load(state string) (*PendingState, bool) { p.mu.Lock() ps, ok := p.store[state] p.mu.Unlock() return ps, ok } func (p *pendingStateStore) Delete(state string) { p.mu.Lock() delete(p.store, state) p.mu.Unlock() } // AuthorizeHandler implements GET /authorize and GET /authorize/callback. type AuthorizeHandler struct { ClientConfig map[string]*domain.Client Auth domain.AuthProvider MFA domain.MFAProvider Sessions *SessionStore Emitter telemetry.Emitter pending *pendingStateStore once sync.Once } // PendingStates returns the underlying pending-state store so tests can seed it. func (h *AuthorizeHandler) PendingStates() *pendingStateStore { h.init() return h.pending } func (h *AuthorizeHandler) init() { h.once.Do(func() { if h.pending == nil { h.pending = newPendingStateStore() } }) } // ServeHTTP dispatches to the authorize or callback handler based on path. func (h *AuthorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.init() if strings.HasSuffix(r.URL.Path, "/callback") { h.ServeHTTPCallback(w, r) return } h.serveAuthorize(w, r) } // serveAuthorize handles the initial GET /authorize request. func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request) { ctx := r.Context() q := r.URL.Query() clientID := q.Get("client_id") redirectURI := q.Get("redirect_uri") responseType := q.Get("response_type") scope := q.Get("scope") state := q.Get("state") codeChallenge := q.Get("code_challenge") codeChallengeMethod := q.Get("code_challenge_method") // Emit auth_start telemetry immediately. h.Emitter.Emit(ctx, telemetry.Event{ Timestamp: time.Now(), EventType: telemetry.EventAuthStart, ClientID: clientID, Endpoint: "/authorize", Result: "pending", }) // 1. Validate client_id. client, ok := h.ClientConfig[clientID] if !ok { profileerrors.InvalidProfileUsage("unknown client_id", "client_id"). Write(w, http.StatusBadRequest) return } // 2. Validate redirect_uri — check for wildcards first, then exact match. for _, registered := range client.RedirectURIs { if strings.ContainsAny(registered, "*?") { profileerrors.RejectedForSafety( "wildcard redirect URIs are not permitted", "redirect_uri", ).Write(w, http.StatusBadRequest) return } } if !uriRegistered(client.RedirectURIs, redirectURI) { profileerrors.InvalidProfileUsage( "redirect_uri does not match any registered URI", "redirect_uri", ).Write(w, http.StatusBadRequest) return } // 3. Validate response_type. if responseType != "code" { profileerrors.FeatureNotSupported( "only response_type=code is supported", "response_type="+responseType, ).Write(w, http.StatusBadRequest) return } // 4. Validate scope contains openid. if !scopeContains(scope, "openid") { profileerrors.InvalidProfileUsage( "scope must include openid", "scope", ).Write(w, http.StatusBadRequest) return } // 5. Validate code_challenge is present. if codeChallenge == "" { profileerrors.InvalidProfileUsage( "code_challenge is required (PKCE S256)", "code_challenge", ).Write(w, http.StatusBadRequest) return } // 6. Validate code_challenge_method. if codeChallengeMethod == "plain" { profileerrors.RejectedForSafety( "code_challenge_method=plain is rejected for security; use S256", "code_challenge_method", ).Write(w, http.StatusBadRequest) return } if codeChallengeMethod != "S256" { profileerrors.InvalidProfileUsage( "code_challenge_method must be S256", "code_challenge_method", ).Write(w, http.StatusBadRequest) return } // Store pending state so the callback can reconstruct the session. h.pending.Store(state, &PendingState{ ClientID: clientID, RedirectURI: redirectURI, PKCEChallenge: codeChallenge, PKCEChallengeMethod: codeChallengeMethod, State: state, Scopes: strings.Fields(scope), ExpiresAt: time.Now().Add(10 * time.Minute), }) // Delegate to Auth provider. authURL, err := h.Auth.AuthorizeURL(ctx, domain.AuthRequest{ ClientID: clientID, RedirectURI: redirectURI, State: state, Scopes: strings.Fields(scope), PKCEChallenge: codeChallenge, PKCEChallengeMethod: codeChallengeMethod, }) if err != nil { http.Error(w, "upstream auth provider error", http.StatusBadGateway) return } http.Redirect(w, r, authURL, http.StatusFound) } // ServeHTTPCallback handles GET /authorize/callback. func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) { h.init() ctx := r.Context() q := r.URL.Query() state := q.Get("state") code := q.Get("code") mfaToken := q.Get("mfa_token") // Recover pending state keyed by state param. ps, ok := h.pending.Load(state) if !ok { http.Error(w, "unknown or expired state", http.StatusBadRequest) return } if time.Now().After(ps.ExpiresAt) { h.pending.Delete(state) http.Error(w, "authorization request expired", http.StatusBadRequest) return } h.pending.Delete(state) // Handle upstream callback. result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{ Code: code, State: state, }) if err != nil { h.Emitter.Emit(ctx, telemetry.Event{ Timestamp: time.Now(), EventType: telemetry.EventAuthFailure, ClientID: ps.ClientID, Endpoint: "/authorize/callback", Result: "failure", ErrorType: "auth_failed", }) http.Error(w, "authentication failed", http.StatusUnauthorized) return } // Check MFA requirement. mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username) if err != nil { http.Error(w, "mfa check error", http.StatusInternalServerError) return } if mfaRequired { if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil { h.Emitter.Emit(ctx, telemetry.Event{ Timestamp: time.Now(), EventType: telemetry.EventAuthFailure, ClientID: ps.ClientID, Endpoint: "/authorize/callback", Result: "failure", ErrorType: "mfa_failed", }) http.Error(w, "MFA validation failed", http.StatusUnauthorized) return } } // Generate authorization code and store PKCE session. sess := &PKCESession{ ClientID: ps.ClientID, RedirectURI: ps.RedirectURI, PKCEChallenge: ps.PKCEChallenge, PKCEChallengeMethod: ps.PKCEChallengeMethod, State: state, Username: result.Username, Scopes: ps.Scopes, ExpiresAt: time.Now().Add(10 * time.Minute), } authCode := h.Sessions.Create(sess) h.Emitter.Emit(ctx, telemetry.Event{ Timestamp: time.Now(), EventType: telemetry.EventAuthSuccess, ClientID: ps.ClientID, Endpoint: "/authorize/callback", Result: "success", Scopes: ps.Scopes, }) // Redirect to client with code and state. redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state http.Redirect(w, r, redirectTo, http.StatusFound) } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- func uriRegistered(registered []string, target string) bool { for _, u := range registered { if u == target { return true } } return false } func scopeContains(scope, want string) bool { for _, s := range strings.Fields(scope) { if s == want { return true } } return false }