generated from coulomb/repo-seed
feat: implement T06, T07 — authorization endpoint, token endpoint
- T06: /authorize with full PKCE validation, Authelia delegation, MFA check - T07: /token with RS256 JWT issuance (stdlib only), PKCE verification, scope-filtered claims 50 OIDC tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
320
src/internal/server/oidc/authorize.go
Normal file
320
src/internal/server/oidc/authorize.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
profileerrors "keycape/internal/errors"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// PendingState holds the authorization request parameters while the user is
|
||||
// being authenticated by the upstream provider (e.g. Authelia). It is keyed
|
||||
// by the opaque state value that is round-tripped through the upstream.
|
||||
type PendingState struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
PKCEChallenge string
|
||||
PKCEChallengeMethod string
|
||||
State string
|
||||
Scopes []string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// pendingStateStore is a thread-safe map of state → PendingState.
|
||||
type pendingStateStore struct {
|
||||
mu sync.Mutex
|
||||
store map[string]*PendingState
|
||||
}
|
||||
|
||||
func newPendingStateStore() *pendingStateStore {
|
||||
return &pendingStateStore{store: make(map[string]*PendingState)}
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Store(state string, ps *PendingState) {
|
||||
p.mu.Lock()
|
||||
p.store[state] = ps
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Load(state string) (*PendingState, bool) {
|
||||
p.mu.Lock()
|
||||
ps, ok := p.store[state]
|
||||
p.mu.Unlock()
|
||||
return ps, ok
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Delete(state string) {
|
||||
p.mu.Lock()
|
||||
delete(p.store, state)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// AuthorizeHandler implements GET /authorize and GET /authorize/callback.
|
||||
type AuthorizeHandler struct {
|
||||
ClientConfig map[string]*domain.Client
|
||||
Auth domain.AuthProvider
|
||||
MFA domain.MFAProvider
|
||||
Sessions *SessionStore
|
||||
Emitter telemetry.Emitter
|
||||
|
||||
pending *pendingStateStore
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// PendingStates returns the underlying pending-state store so tests can seed it.
|
||||
func (h *AuthorizeHandler) PendingStates() *pendingStateStore {
|
||||
h.init()
|
||||
return h.pending
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) init() {
|
||||
h.once.Do(func() {
|
||||
if h.pending == nil {
|
||||
h.pending = newPendingStateStore()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ServeHTTP dispatches to the authorize or callback handler based on path.
|
||||
func (h *AuthorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.init()
|
||||
if strings.HasSuffix(r.URL.Path, "/callback") {
|
||||
h.ServeHTTPCallback(w, r)
|
||||
return
|
||||
}
|
||||
h.serveAuthorize(w, r)
|
||||
}
|
||||
|
||||
// serveAuthorize handles the initial GET /authorize request.
|
||||
func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
q := r.URL.Query()
|
||||
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
responseType := q.Get("response_type")
|
||||
scope := q.Get("scope")
|
||||
state := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
|
||||
// Emit auth_start telemetry immediately.
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthStart,
|
||||
ClientID: clientID,
|
||||
Endpoint: "/authorize",
|
||||
Result: "pending",
|
||||
})
|
||||
|
||||
// 1. Validate client_id.
|
||||
client, ok := h.ClientConfig[clientID]
|
||||
if !ok {
|
||||
profileerrors.InvalidProfileUsage("unknown client_id", "client_id").
|
||||
Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Validate redirect_uri — check for wildcards first, then exact match.
|
||||
for _, registered := range client.RedirectURIs {
|
||||
if strings.ContainsAny(registered, "*?") {
|
||||
profileerrors.RejectedForSafety(
|
||||
"wildcard redirect URIs are not permitted",
|
||||
"redirect_uri",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !uriRegistered(client.RedirectURIs, redirectURI) {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"redirect_uri does not match any registered URI",
|
||||
"redirect_uri",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Validate response_type.
|
||||
if responseType != "code" {
|
||||
profileerrors.FeatureNotSupported(
|
||||
"only response_type=code is supported",
|
||||
"response_type="+responseType,
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Validate scope contains openid.
|
||||
if !scopeContains(scope, "openid") {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"scope must include openid",
|
||||
"scope",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Validate code_challenge is present.
|
||||
if codeChallenge == "" {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"code_challenge is required (PKCE S256)",
|
||||
"code_challenge",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Validate code_challenge_method.
|
||||
if codeChallengeMethod == "plain" {
|
||||
profileerrors.RejectedForSafety(
|
||||
"code_challenge_method=plain is rejected for security; use S256",
|
||||
"code_challenge_method",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "S256" {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"code_challenge_method must be S256",
|
||||
"code_challenge_method",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store pending state so the callback can reconstruct the session.
|
||||
h.pending.Store(state, &PendingState{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
PKCEChallenge: codeChallenge,
|
||||
PKCEChallengeMethod: codeChallengeMethod,
|
||||
State: state,
|
||||
Scopes: strings.Fields(scope),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
|
||||
// Delegate to Auth provider.
|
||||
authURL, err := h.Auth.AuthorizeURL(ctx, domain.AuthRequest{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
State: state,
|
||||
Scopes: strings.Fields(scope),
|
||||
PKCEChallenge: codeChallenge,
|
||||
PKCEChallengeMethod: codeChallengeMethod,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "upstream auth provider error", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// ServeHTTPCallback handles GET /authorize/callback.
|
||||
func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) {
|
||||
h.init()
|
||||
ctx := r.Context()
|
||||
q := r.URL.Query()
|
||||
|
||||
state := q.Get("state")
|
||||
code := q.Get("code")
|
||||
mfaToken := q.Get("mfa_token")
|
||||
|
||||
// Recover pending state keyed by state param.
|
||||
ps, ok := h.pending.Load(state)
|
||||
if !ok {
|
||||
http.Error(w, "unknown or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if time.Now().After(ps.ExpiresAt) {
|
||||
h.pending.Delete(state)
|
||||
http.Error(w, "authorization request expired", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h.pending.Delete(state)
|
||||
|
||||
// Handle upstream callback.
|
||||
result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{
|
||||
Code: code,
|
||||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "auth_failed",
|
||||
})
|
||||
http.Error(w, "authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Check MFA requirement.
|
||||
mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username)
|
||||
if err != nil {
|
||||
http.Error(w, "mfa check error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if mfaRequired {
|
||||
if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "mfa_failed",
|
||||
})
|
||||
http.Error(w, "MFA validation failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Generate authorization code and store PKCE session.
|
||||
sess := &PKCESession{
|
||||
ClientID: ps.ClientID,
|
||||
RedirectURI: ps.RedirectURI,
|
||||
PKCEChallenge: ps.PKCEChallenge,
|
||||
PKCEChallengeMethod: ps.PKCEChallengeMethod,
|
||||
State: state,
|
||||
Username: result.Username,
|
||||
Scopes: ps.Scopes,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
authCode := h.Sessions.Create(sess)
|
||||
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthSuccess,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "success",
|
||||
Scopes: ps.Scopes,
|
||||
})
|
||||
|
||||
// Redirect to client with code and state.
|
||||
redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state
|
||||
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func uriRegistered(registered []string, target string) bool {
|
||||
for _, u := range registered {
|
||||
if u == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func scopeContains(scope, want string) bool {
|
||||
for _, s := range strings.Fields(scope) {
|
||||
if s == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user