diff --git a/src/internal/server/oidc/authorize.go b/src/internal/server/oidc/authorize.go new file mode 100644 index 0000000..87dd58d --- /dev/null +++ b/src/internal/server/oidc/authorize.go @@ -0,0 +1,320 @@ +package oidc + +import ( + "net/http" + "strings" + "sync" + "time" + + "keycape/internal/domain" + profileerrors "keycape/internal/errors" + "keycape/internal/server/telemetry" +) + +// PendingState holds the authorization request parameters while the user is +// being authenticated by the upstream provider (e.g. Authelia). It is keyed +// by the opaque state value that is round-tripped through the upstream. +type PendingState struct { + ClientID string + RedirectURI string + PKCEChallenge string + PKCEChallengeMethod string + State string + Scopes []string + ExpiresAt time.Time +} + +// pendingStateStore is a thread-safe map of state → PendingState. +type pendingStateStore struct { + mu sync.Mutex + store map[string]*PendingState +} + +func newPendingStateStore() *pendingStateStore { + return &pendingStateStore{store: make(map[string]*PendingState)} +} + +func (p *pendingStateStore) Store(state string, ps *PendingState) { + p.mu.Lock() + p.store[state] = ps + p.mu.Unlock() +} + +func (p *pendingStateStore) Load(state string) (*PendingState, bool) { + p.mu.Lock() + ps, ok := p.store[state] + p.mu.Unlock() + return ps, ok +} + +func (p *pendingStateStore) Delete(state string) { + p.mu.Lock() + delete(p.store, state) + p.mu.Unlock() +} + +// AuthorizeHandler implements GET /authorize and GET /authorize/callback. +type AuthorizeHandler struct { + ClientConfig map[string]*domain.Client + Auth domain.AuthProvider + MFA domain.MFAProvider + Sessions *SessionStore + Emitter telemetry.Emitter + + pending *pendingStateStore + once sync.Once +} + +// PendingStates returns the underlying pending-state store so tests can seed it. +func (h *AuthorizeHandler) PendingStates() *pendingStateStore { + h.init() + return h.pending +} + +func (h *AuthorizeHandler) init() { + h.once.Do(func() { + if h.pending == nil { + h.pending = newPendingStateStore() + } + }) +} + +// ServeHTTP dispatches to the authorize or callback handler based on path. +func (h *AuthorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.init() + if strings.HasSuffix(r.URL.Path, "/callback") { + h.ServeHTTPCallback(w, r) + return + } + h.serveAuthorize(w, r) +} + +// serveAuthorize handles the initial GET /authorize request. +func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + q := r.URL.Query() + + clientID := q.Get("client_id") + redirectURI := q.Get("redirect_uri") + responseType := q.Get("response_type") + scope := q.Get("scope") + state := q.Get("state") + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + + // Emit auth_start telemetry immediately. + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventAuthStart, + ClientID: clientID, + Endpoint: "/authorize", + Result: "pending", + }) + + // 1. Validate client_id. + client, ok := h.ClientConfig[clientID] + if !ok { + profileerrors.InvalidProfileUsage("unknown client_id", "client_id"). + Write(w, http.StatusBadRequest) + return + } + + // 2. Validate redirect_uri — check for wildcards first, then exact match. + for _, registered := range client.RedirectURIs { + if strings.ContainsAny(registered, "*?") { + profileerrors.RejectedForSafety( + "wildcard redirect URIs are not permitted", + "redirect_uri", + ).Write(w, http.StatusBadRequest) + return + } + } + if !uriRegistered(client.RedirectURIs, redirectURI) { + profileerrors.InvalidProfileUsage( + "redirect_uri does not match any registered URI", + "redirect_uri", + ).Write(w, http.StatusBadRequest) + return + } + + // 3. Validate response_type. + if responseType != "code" { + profileerrors.FeatureNotSupported( + "only response_type=code is supported", + "response_type="+responseType, + ).Write(w, http.StatusBadRequest) + return + } + + // 4. Validate scope contains openid. + if !scopeContains(scope, "openid") { + profileerrors.InvalidProfileUsage( + "scope must include openid", + "scope", + ).Write(w, http.StatusBadRequest) + return + } + + // 5. Validate code_challenge is present. + if codeChallenge == "" { + profileerrors.InvalidProfileUsage( + "code_challenge is required (PKCE S256)", + "code_challenge", + ).Write(w, http.StatusBadRequest) + return + } + + // 6. Validate code_challenge_method. + if codeChallengeMethod == "plain" { + profileerrors.RejectedForSafety( + "code_challenge_method=plain is rejected for security; use S256", + "code_challenge_method", + ).Write(w, http.StatusBadRequest) + return + } + if codeChallengeMethod != "S256" { + profileerrors.InvalidProfileUsage( + "code_challenge_method must be S256", + "code_challenge_method", + ).Write(w, http.StatusBadRequest) + return + } + + // Store pending state so the callback can reconstruct the session. + h.pending.Store(state, &PendingState{ + ClientID: clientID, + RedirectURI: redirectURI, + PKCEChallenge: codeChallenge, + PKCEChallengeMethod: codeChallengeMethod, + State: state, + Scopes: strings.Fields(scope), + ExpiresAt: time.Now().Add(10 * time.Minute), + }) + + // Delegate to Auth provider. + authURL, err := h.Auth.AuthorizeURL(ctx, domain.AuthRequest{ + ClientID: clientID, + RedirectURI: redirectURI, + State: state, + Scopes: strings.Fields(scope), + PKCEChallenge: codeChallenge, + PKCEChallengeMethod: codeChallengeMethod, + }) + if err != nil { + http.Error(w, "upstream auth provider error", http.StatusBadGateway) + return + } + + http.Redirect(w, r, authURL, http.StatusFound) +} + +// ServeHTTPCallback handles GET /authorize/callback. +func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) { + h.init() + ctx := r.Context() + q := r.URL.Query() + + state := q.Get("state") + code := q.Get("code") + mfaToken := q.Get("mfa_token") + + // Recover pending state keyed by state param. + ps, ok := h.pending.Load(state) + if !ok { + http.Error(w, "unknown or expired state", http.StatusBadRequest) + return + } + if time.Now().After(ps.ExpiresAt) { + h.pending.Delete(state) + http.Error(w, "authorization request expired", http.StatusBadRequest) + return + } + h.pending.Delete(state) + + // Handle upstream callback. + result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{ + Code: code, + State: state, + }) + if err != nil { + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventAuthFailure, + ClientID: ps.ClientID, + Endpoint: "/authorize/callback", + Result: "failure", + ErrorType: "auth_failed", + }) + http.Error(w, "authentication failed", http.StatusUnauthorized) + return + } + + // Check MFA requirement. + mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username) + if err != nil { + http.Error(w, "mfa check error", http.StatusInternalServerError) + return + } + if mfaRequired { + if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil { + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventAuthFailure, + ClientID: ps.ClientID, + Endpoint: "/authorize/callback", + Result: "failure", + ErrorType: "mfa_failed", + }) + http.Error(w, "MFA validation failed", http.StatusUnauthorized) + return + } + } + + // Generate authorization code and store PKCE session. + sess := &PKCESession{ + ClientID: ps.ClientID, + RedirectURI: ps.RedirectURI, + PKCEChallenge: ps.PKCEChallenge, + PKCEChallengeMethod: ps.PKCEChallengeMethod, + State: state, + Username: result.Username, + Scopes: ps.Scopes, + ExpiresAt: time.Now().Add(10 * time.Minute), + } + authCode := h.Sessions.Create(sess) + + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventAuthSuccess, + ClientID: ps.ClientID, + Endpoint: "/authorize/callback", + Result: "success", + Scopes: ps.Scopes, + }) + + // Redirect to client with code and state. + redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state + http.Redirect(w, r, redirectTo, http.StatusFound) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func uriRegistered(registered []string, target string) bool { + for _, u := range registered { + if u == target { + return true + } + } + return false +} + +func scopeContains(scope, want string) bool { + for _, s := range strings.Fields(scope) { + if s == want { + return true + } + } + return false +} diff --git a/src/internal/server/oidc/authorize_test.go b/src/internal/server/oidc/authorize_test.go new file mode 100644 index 0000000..ea1055d --- /dev/null +++ b/src/internal/server/oidc/authorize_test.go @@ -0,0 +1,565 @@ +package oidc_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "keycape/internal/domain" + profileerrors "keycape/internal/errors" + "keycape/internal/server/oidc" + "keycape/internal/server/telemetry" +) + +// --------------------------------------------------------------------------- +// Mock implementations +// --------------------------------------------------------------------------- + +// mockAuthProvider implements domain.AuthProvider. +type mockAuthProvider struct { + authorizeURL string + authorizeErr error + + callbackResult *domain.AuthResult + callbackErr error +} + +func (m *mockAuthProvider) AuthorizeURL(_ context.Context, _ domain.AuthRequest) (string, error) { + if m.authorizeErr != nil { + return "", m.authorizeErr + } + return m.authorizeURL, nil +} + +func (m *mockAuthProvider) HandleCallback(_ context.Context, _ domain.CallbackParams) (*domain.AuthResult, error) { + return m.callbackResult, m.callbackErr +} + +// mockMFAProvider implements domain.MFAProvider. +type mockMFAProvider struct { + required bool + requiredErr error + + validateErr error +} + +func (m *mockMFAProvider) CheckMFARequired(_ context.Context, _ string) (bool, error) { + return m.required, m.requiredErr +} + +func (m *mockMFAProvider) ValidateMFAToken(_ context.Context, _, _ string) error { + return m.validateErr +} + +// captureEmitter captures the last emitted event. +type captureEmitter struct { + events []telemetry.Event +} + +func (c *captureEmitter) Emit(_ context.Context, ev telemetry.Event) { + c.events = append(c.events, ev) +} + +func (c *captureEmitter) last() telemetry.Event { + if len(c.events) == 0 { + return telemetry.Event{} + } + return c.events[len(c.events)-1] +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +func testClient() map[string]*domain.Client { + return map[string]*domain.Client{ + "test-client": { + ClientID: "test-client", + RedirectURIs: []string{"https://app.example.com/callback"}, + AllowedScopes: []string{"openid", "profile", "email"}, + ClientType: "public", + }, + } +} + +func newAuthorizeHandler(auth domain.AuthProvider, mfa domain.MFAProvider, emitter telemetry.Emitter) *oidc.AuthorizeHandler { + return &oidc.AuthorizeHandler{ + ClientConfig: testClient(), + Auth: auth, + MFA: mfa, + Sessions: oidc.NewSessionStore(), + Emitter: emitter, + } +} + +func validAuthorizeParams() url.Values { + return url.Values{ + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "response_type": []string{"code"}, + "scope": []string{"openid profile"}, + "state": []string{"random-state"}, + "code_challenge": []string{"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"}, + "code_challenge_method": []string{"S256"}, + } +} + +func authorizeRequest(params url.Values) *http.Request { + return httptest.NewRequest(http.MethodGet, "/authorize?"+params.Encode(), nil) +} + +func decodeProfileError(t *testing.T, body string) profileerrors.ErrorType { + t.Helper() + var pe profileerrors.ProfileError + if err := json.Unmarshal([]byte(body), &pe); err != nil { + t.Fatalf("could not decode ProfileError: %v (body: %q)", err, body) + } + return pe.Error +} + +// --------------------------------------------------------------------------- +// T06 Authorization Endpoint Tests +// --------------------------------------------------------------------------- + +func TestAuthorizeHandler_ValidRequest_RedirectsToAuthelia(t *testing.T) { + auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth?state=xyz"} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + req := authorizeRequest(validAuthorizeParams()) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusFound { + t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String()) + } + loc := w.Header().Get("Location") + if loc != "https://authelia.example.com/auth?state=xyz" { + t.Errorf("expected redirect to Authelia, got %q", loc) + } +} + +func TestAuthorizeHandler_EmitsAuthStart(t *testing.T) { + auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth"} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + req := authorizeRequest(validAuthorizeParams()) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + found := false + for _, ev := range emitter.events { + if ev.EventType == telemetry.EventAuthStart { + found = true + break + } + } + if !found { + t.Error("expected auth_start telemetry event to be emitted") + } +} + +func TestAuthorizeHandler_MissingCodeChallenge_InvalidProfileUsage(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Del("code_challenge") + params.Del("code_challenge_method") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +func TestAuthorizeHandler_WildcardRedirectURI_RejectedForSafety(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + clients := map[string]*domain.Client{ + "wildcard-client": { + ClientID: "wildcard-client", + RedirectURIs: []string{"https://app.example.com/*"}, + ClientType: "public", + }, + } + h := &oidc.AuthorizeHandler{ + ClientConfig: clients, + Auth: auth, + MFA: mfa, + Sessions: oidc.NewSessionStore(), + Emitter: emitter, + } + + params := validAuthorizeParams() + params.Set("client_id", "wildcard-client") + params.Set("redirect_uri", "https://app.example.com/anything") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrRejectedForSafety { + t.Errorf("expected rejected_for_profile_safety, got %q", errType) + } +} + +func TestAuthorizeHandler_UnknownClient_InvalidProfileUsage(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Set("client_id", "no-such-client") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +func TestAuthorizeHandler_WrongResponseType_FeatureNotSupported(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Set("response_type", "token") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrFeatureNotSupported { + t.Errorf("expected feature_not_supported_by_profile, got %q", errType) + } +} + +func TestAuthorizeHandler_MissingOpenIDScope_InvalidProfileUsage(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Set("scope", "profile email") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +func TestAuthorizeHandler_PlainCodeChallengeMethod_RejectedForSafety(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Set("code_challenge_method", "plain") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrRejectedForSafety { + t.Errorf("expected rejected_for_profile_safety, got %q", errType) + } +} + +func TestAuthorizeHandler_UnknownRedirectURI_InvalidProfileUsage(t *testing.T) { + auth := &mockAuthProvider{} + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + params := validAuthorizeParams() + params.Set("redirect_uri", "https://evil.example.com/callback") + + req := authorizeRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +// --------------------------------------------------------------------------- +// Callback tests +// --------------------------------------------------------------------------- + +func TestAuthorizeCallback_Success_RedirectsWithCode(t *testing.T) { + auth := &mockAuthProvider{ + callbackResult: &domain.AuthResult{Username: "alice"}, + } + mfa := &mockMFAProvider{required: false} + emitter := &captureEmitter{} + + sessions := oidc.NewSessionStore() + h := &oidc.AuthorizeHandler{ + ClientConfig: testClient(), + Auth: auth, + MFA: mfa, + Sessions: sessions, + Emitter: emitter, + } + + req := httptest.NewRequest(http.MethodGet, + "/authorize/callback?code=authelia-code&state=random-state", nil) + // Simulate that there is an ongoing PKCE flow stored in query param forwarding + // The callback needs the original client context. We store it via a pre-seeded + // pending session keyed by state. + // For the callback handler, we expect it to look up the pending state by the + // "state" parameter that was originally embedded. We seed the pending state. + h.PendingStates().Store("random-state", &oidc.PendingState{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + PKCEChallengeMethod: "S256", + State: "random-state", + Scopes: []string{"openid", "profile"}, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + + w := httptest.NewRecorder() + h.ServeHTTPCallback(w, req) + + if w.Code != http.StatusFound { + t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String()) + } + loc := w.Header().Get("Location") + parsed, err := url.Parse(loc) + if err != nil { + t.Fatalf("invalid Location header: %v", err) + } + if parsed.Query().Get("code") == "" { + t.Error("expected code param in redirect, got empty") + } + if parsed.Query().Get("state") != "random-state" { + t.Errorf("expected state=random-state, got %q", parsed.Query().Get("state")) + } +} + +func TestAuthorizeCallback_MFAFailed_AuthFailure(t *testing.T) { + auth := &mockAuthProvider{ + callbackResult: &domain.AuthResult{Username: "alice"}, + } + mfa := &mockMFAProvider{ + required: true, + validateErr: domain.ErrMFAFailed, + } + emitter := &captureEmitter{} + + sessions := oidc.NewSessionStore() + h := &oidc.AuthorizeHandler{ + ClientConfig: testClient(), + Auth: auth, + MFA: mfa, + Sessions: sessions, + Emitter: emitter, + } + + h.PendingStates().Store("random-state", &oidc.PendingState{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + PKCEChallengeMethod: "S256", + State: "random-state", + Scopes: []string{"openid"}, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + + req := httptest.NewRequest(http.MethodGet, + "/authorize/callback?code=authelia-code&state=random-state&mfa_token=wrong", nil) + w := httptest.NewRecorder() + h.ServeHTTPCallback(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } + found := false + for _, ev := range emitter.events { + if ev.EventType == telemetry.EventAuthFailure { + found = true + break + } + } + if !found { + t.Error("expected auth_failure telemetry event") + } +} + +func TestAuthorizeCallback_AuthProviderFailed_AuthFailure(t *testing.T) { + auth := &mockAuthProvider{ + callbackErr: domain.ErrAuthFailed, + } + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + sessions := oidc.NewSessionStore() + h := &oidc.AuthorizeHandler{ + ClientConfig: testClient(), + Auth: auth, + MFA: mfa, + Sessions: sessions, + Emitter: emitter, + } + + h.PendingStates().Store("random-state", &oidc.PendingState{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + State: "random-state", + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + + req := httptest.NewRequest(http.MethodGet, + "/authorize/callback?code=bad&state=random-state", nil) + w := httptest.NewRecorder() + h.ServeHTTPCallback(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } + found := false + for _, ev := range emitter.events { + if ev.EventType == telemetry.EventAuthFailure { + found = true + break + } + } + if !found { + t.Error("expected auth_failure telemetry event") + } +} + +func TestAuthorizeCallback_EmitsAuthSuccess(t *testing.T) { + auth := &mockAuthProvider{ + callbackResult: &domain.AuthResult{Username: "bob"}, + } + mfa := &mockMFAProvider{required: false} + emitter := &captureEmitter{} + + sessions := oidc.NewSessionStore() + h := &oidc.AuthorizeHandler{ + ClientConfig: testClient(), + Auth: auth, + MFA: mfa, + Sessions: sessions, + Emitter: emitter, + } + + h.PendingStates().Store("s1", &oidc.PendingState{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + PKCEChallenge: "abc", + PKCEChallengeMethod: "S256", + State: "s1", + Scopes: []string{"openid"}, + ExpiresAt: time.Now().Add(5 * time.Minute), + }) + + req := httptest.NewRequest(http.MethodGet, + "/authorize/callback?code=c&state=s1", nil) + w := httptest.NewRecorder() + h.ServeHTTPCallback(w, req) + + found := false + for _, ev := range emitter.events { + if ev.EventType == telemetry.EventAuthSuccess { + found = true + break + } + } + if !found { + t.Errorf("expected auth_success telemetry event, got events: %v", emitter.events) + } +} + +// --------------------------------------------------------------------------- +// ServeHTTP dispatch +// --------------------------------------------------------------------------- + +func TestAuthorizeHandler_ServeHTTP_DispatchesToCallback(t *testing.T) { + auth := &mockAuthProvider{ + callbackResult: &domain.AuthResult{Username: "alice"}, + } + mfa := &mockMFAProvider{} + emitter := &captureEmitter{} + + h := newAuthorizeHandler(auth, mfa, emitter) + + // A request to /authorize/callback should not be treated as the initial + // authorize request and must not require PKCE params. + req := httptest.NewRequest(http.MethodGet, "/authorize/callback?code=x&state=y", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + // Without a seeded pending state for "y", the callback returns an error. + // The important thing is that it is NOT a redirect to Authelia. + if w.Code == http.StatusFound { + loc := w.Header().Get("Location") + if strings.Contains(loc, "authelia") { + t.Error("callback path must not redirect to Authelia") + } + } +} diff --git a/src/internal/server/oidc/session.go b/src/internal/server/oidc/session.go new file mode 100644 index 0000000..75ea6f4 --- /dev/null +++ b/src/internal/server/oidc/session.go @@ -0,0 +1,79 @@ +package oidc + +import ( + "crypto/rand" + "encoding/base64" + "sync" + "time" +) + +// PKCESession stores the in-flight authorization state server-side. +type PKCESession struct { + Code string + ClientID string + RedirectURI string + PKCEChallenge string // S256 challenge + PKCEChallengeMethod string // always "S256" + State string + Username string // set after auth + Scopes []string + ExpiresAt time.Time +} + +// SessionStore is an in-memory PKCE session store. +type SessionStore struct { + mu sync.Mutex + sessions map[string]*PKCESession // keyed by code +} + +// NewSessionStore returns an initialised, empty SessionStore. +func NewSessionStore() *SessionStore { + return &SessionStore{ + sessions: make(map[string]*PKCESession), + } +} + +// Create stores the session and returns the generated authorization code. +func (s *SessionStore) Create(sess *PKCESession) string { + code := generateCode() + sess.Code = code + + s.mu.Lock() + s.sessions[code] = sess + s.mu.Unlock() + + return code +} + +// Get retrieves a session by code. Returns false if not found or expired. +func (s *SessionStore) Get(code string) (*PKCESession, bool) { + s.mu.Lock() + sess, ok := s.sessions[code] + s.mu.Unlock() + + if !ok { + return nil, false + } + if time.Now().After(sess.ExpiresAt) { + s.Delete(code) + return nil, false + } + return sess, true +} + +// Delete removes a session by code. No-op if the code is not present. +func (s *SessionStore) Delete(code string) { + s.mu.Lock() + delete(s.sessions, code) + s.mu.Unlock() +} + +// generateCode returns a cryptographically random, URL-safe string suitable +// for use as an authorization code. +func generateCode() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic("oidc: failed to generate random code: " + err.Error()) + } + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/src/internal/server/oidc/token.go b/src/internal/server/oidc/token.go new file mode 100644 index 0000000..ee0848a --- /dev/null +++ b/src/internal/server/oidc/token.go @@ -0,0 +1,219 @@ +package oidc + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "strings" + "time" + + "keycape/internal/domain" + profileerrors "keycape/internal/errors" + "keycape/internal/server/telemetry" +) + +// TokenHandler implements POST /token. +type TokenHandler struct { + ClientConfig map[string]*domain.Client + Sessions *SessionStore + Users domain.UserRepository + SigningKey *rsa.PrivateKey + Issuer string + TokenLifetime time.Duration + Emitter telemetry.Emitter +} + +// tokenResponse is the JSON body returned on a successful token exchange. +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` +} + +// ServeHTTP handles POST /token. +func (h *TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form body", http.StatusBadRequest) + return + } + + grantType := r.FormValue("grant_type") + clientID := r.FormValue("client_id") + code := r.FormValue("code") + codeVerifier := r.FormValue("code_verifier") + + // 1. Validate grant_type. + if grantType != "authorization_code" { + profileerrors.FeatureNotSupported( + "only grant_type=authorization_code is supported", + "grant_type="+grantType, + ).Write(w, http.StatusBadRequest) + return + } + + // 2. Validate client exists (basic check; secret auth delegated to future work). + if _, ok := h.ClientConfig[clientID]; !ok { + profileerrors.InvalidProfileUsage("unknown client_id", "client_id"). + Write(w, http.StatusBadRequest) + return + } + + // 3. Look up PKCE session. + sess, ok := h.Sessions.Get(code) + if !ok { + profileerrors.InvalidProfileUsage( + "authorization code not found or expired", + "code", + ).Write(w, http.StatusBadRequest) + return + } + + // Verify client_id matches the session. + if sess.ClientID != clientID { + profileerrors.InvalidProfileUsage( + "client_id does not match the authorization code", + "client_id", + ).Write(w, http.StatusBadRequest) + return + } + + // 4. Verify PKCE code_verifier. + if !verifyPKCE(codeVerifier, sess.PKCEChallenge) { + profileerrors.InvalidProfileUsage( + "code_verifier does not match code_challenge", + "code_verifier", + ).Write(w, http.StatusBadRequest) + return + } + + // 5. Look up user. + user, err := h.Users.LookupUser(ctx, sess.Username) + if err != nil { + http.Error(w, "user not found", http.StatusInternalServerError) + return + } + + // 6. Build JWT claims. + now := time.Now() + exp := now.Add(h.TokenLifetime) + + claims := map[string]interface{}{ + "iss": h.Issuer, + "sub": user.ID, + "aud": clientID, + "exp": exp.Unix(), + "iat": now.Unix(), + } + + scopeSet := make(map[string]bool) + for _, s := range sess.Scopes { + scopeSet[s] = true + } + + if scopeSet["profile"] { + claims["preferred_username"] = user.Username + } + if scopeSet["email"] { + claims["email"] = user.Email + } + if scopeSet["groups"] { + claims["groups"] = user.Groups + } + + // 7. Sign JWT with RSA-SHA256. + kid := "key-1" // static kid for v0.1 + jwtToken, err := buildJWT(claims, kid, h.SigningKey) + if err != nil { + http.Error(w, "failed to build JWT", http.StatusInternalServerError) + return + } + + // 8. Delete used PKCE session (prevent replay). + h.Sessions.Delete(code) + + // 9. Build response. + resp := tokenResponse{ + AccessToken: jwtToken, + TokenType: "Bearer", + ExpiresIn: int(h.TokenLifetime.Seconds()), + IDToken: jwtToken, + } + + // 10. Emit token_issued telemetry. + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventTokenIssued, + ClientID: clientID, + Endpoint: "/token", + Result: "success", + Scopes: sess.Scopes, + GrantType: grantType, + }) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) +} + +// --------------------------------------------------------------------------- +// PKCE verification +// --------------------------------------------------------------------------- + +// verifyPKCE checks BASE64URL(SHA256(verifier)) == challenge (S256 method). +func verifyPKCE(verifier, challenge string) bool { + h := sha256.New() + h.Write([]byte(verifier)) + computed := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + return computed == challenge +} + +// --------------------------------------------------------------------------- +// JWT construction (stdlib only — no external JWT library) +// --------------------------------------------------------------------------- + +type jwtHeader struct { + Alg string `json:"alg"` + Typ string `json:"typ"` + Kid string `json:"kid"` +} + +// buildJWT constructs and signs a JWT using RSA-SHA256 with the standard library. +// Format: base64url(header) + "." + base64url(payload) + "." + base64url(signature) +func buildJWT(claims map[string]interface{}, kid string, key *rsa.PrivateKey) (string, error) { + // Header. + hdr := jwtHeader{Alg: "RS256", Typ: "JWT", Kid: kid} + hdrJSON, err := json.Marshal(hdr) + if err != nil { + return "", err + } + hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON) + + // Payload. + payloadJSON, err := json.Marshal(claims) + if err != nil { + return "", err + } + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) + + // Signing input. + signingInput := hdrB64 + "." + payloadB64 + + // Digest. + digest := sha256.Sum256([]byte(signingInput)) + + // Sign with PKCS1v15 / SHA256. + sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, digest[:]) + if err != nil { + return "", err + } + sigB64 := base64.RawURLEncoding.EncodeToString(sig) + + return strings.Join([]string{hdrB64, payloadB64, sigB64}, "."), nil +} diff --git a/src/internal/server/oidc/token_test.go b/src/internal/server/oidc/token_test.go new file mode 100644 index 0000000..2d47981 --- /dev/null +++ b/src/internal/server/oidc/token_test.go @@ -0,0 +1,496 @@ +package oidc_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "keycape/internal/domain" + profileerrors "keycape/internal/errors" + "keycape/internal/server/oidc" + "keycape/internal/server/telemetry" +) + +// --------------------------------------------------------------------------- +// Mock UserRepository +// --------------------------------------------------------------------------- + +type mockUserRepo struct { + users map[string]*domain.User +} + +func (m *mockUserRepo) LookupUser(_ context.Context, username string) (*domain.User, error) { + u, ok := m.users[username] + if !ok { + return nil, domain.ErrUserNotFound + } + return u, nil +} + +func (m *mockUserRepo) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) { + return nil, nil +} + +func (m *mockUserRepo) ValidatePassword(_ context.Context, _, _ string) (bool, error) { + return false, nil +} + +// --------------------------------------------------------------------------- +// PKCE helpers +// --------------------------------------------------------------------------- + +// makeVerifierAndChallenge returns a code_verifier and its S256 code_challenge. +func makeVerifierAndChallenge() (verifier, challenge string) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic(err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + h := sha256.New() + h.Write([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + return +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +func newTokenHandler(t *testing.T, sessions *oidc.SessionStore, users domain.UserRepository) (*oidc.TokenHandler, *rsa.PrivateKey) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + emitter := &captureEmitter{} + h := &oidc.TokenHandler{ + ClientConfig: testClient(), + Sessions: sessions, + Users: users, + SigningKey: key, + Issuer: "https://auth.netkingdom.local", + TokenLifetime: 15 * time.Minute, + Emitter: emitter, + } + return h, key +} + +func tokenRequest(params url.Values) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/token", + strings.NewReader(params.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req +} + +func seededSession(sessions *oidc.SessionStore, verifier string) (code string) { + challenge := s256Challenge(verifier) + sess := &oidc.PKCESession{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + PKCEChallenge: challenge, + PKCEChallengeMethod: "S256", + State: "state1", + Username: "alice", + Scopes: []string{"openid", "profile", "email", "groups"}, + ExpiresAt: time.Now().Add(10 * time.Minute), + } + return sessions.Create(sess) +} + +func s256Challenge(verifier string) string { + h := sha256.New() + h.Write([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) +} + +func decodeTokenResponse(t *testing.T, body string) map[string]interface{} { + t.Helper() + var m map[string]interface{} + if err := json.Unmarshal([]byte(body), &m); err != nil { + t.Fatalf("could not decode token response: %v (body: %q)", err, body) + } + return m +} + +func parseJWTPayload(t *testing.T, token string) map[string]interface{} { + t.Helper() + parts := strings.Split(token, ".") + if len(parts) != 3 { + t.Fatalf("expected 3 JWT parts, got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decode JWT payload: %v", err) + } + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + t.Fatalf("unmarshal JWT payload: %v", err) + } + return claims +} + +func aliceUser() *domain.User { + return &domain.User{ + ID: "user-alice", + Username: "alice", + Email: "alice@example.com", + Groups: []string{"admin", "users"}, + Enabled: true, + } +} + +// --------------------------------------------------------------------------- +// T07 Token Endpoint Tests +// --------------------------------------------------------------------------- + +func TestTokenHandler_ValidExchange_ReturnsJWT(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, verifier) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + + resp := decodeTokenResponse(t, w.Body.String()) + if _, ok := resp["access_token"]; !ok { + t.Error("missing access_token") + } + if _, ok := resp["id_token"]; !ok { + t.Error("missing id_token") + } + if resp["token_type"] != "Bearer" { + t.Errorf("expected token_type Bearer, got %v", resp["token_type"]) + } + if _, ok := resp["expires_in"]; !ok { + t.Error("missing expires_in") + } +} + +func TestTokenHandler_WrongGrantType_FeatureNotSupported(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{} + + h, _ := newTokenHandler(t, sessions, users) + + params := url.Values{ + "grant_type": []string{"client_credentials"}, + "client_id": []string{"test-client"}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrFeatureNotSupported { + t.Errorf("expected feature_not_supported_by_profile, got %q", errType) + } +} + +func TestTokenHandler_PKCEMismatch_InvalidProfileUsage(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + realVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, realVerifier) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{"wrong-verifier-that-does-not-match"}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +func TestTokenHandler_CodeNotFound_InvalidProfileUsage(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{} + + h, _ := newTokenHandler(t, sessions, users) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{"no-such-code"}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{"any-verifier"}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } + errType := decodeProfileError(t, w.Body.String()) + if errType != profileerrors.ErrInvalidProfileUsage { + t.Errorf("expected invalid_profile_usage, got %q", errType) + } +} + +func TestTokenHandler_JWTClaims_CorrectSubAndIssuer(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, verifier) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + resp := decodeTokenResponse(t, w.Body.String()) + idToken, ok := resp["id_token"].(string) + if !ok { + t.Fatal("id_token is not a string") + } + + claims := parseJWTPayload(t, idToken) + + if claims["sub"] != "user-alice" { + t.Errorf("sub: expected user-alice, got %v", claims["sub"]) + } + if claims["iss"] != "https://auth.netkingdom.local" { + t.Errorf("iss: expected https://auth.netkingdom.local, got %v", claims["iss"]) + } + if claims["aud"] != "test-client" { + t.Errorf("aud: expected test-client, got %v", claims["aud"]) + } +} + +func TestTokenHandler_ScopeFiltering_ProfileScope(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + // Seed session with only openid scope (no email, no groups). + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + challenge := s256Challenge(verifier) + sess := &oidc.PKCESession{ + ClientID: "test-client", + RedirectURI: "https://app.example.com/callback", + PKCEChallenge: challenge, + PKCEChallengeMethod: "S256", + Username: "alice", + Scopes: []string{"openid"}, // no profile/email/groups + ExpiresAt: time.Now().Add(10 * time.Minute), + } + code := sessions.Create(sess) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + resp := decodeTokenResponse(t, w.Body.String()) + idToken := resp["id_token"].(string) + claims := parseJWTPayload(t, idToken) + + // Without profile scope, preferred_username must not be present. + if _, ok := claims["preferred_username"]; ok { + t.Error("preferred_username must be absent when profile scope is not granted") + } + // Without email scope, email must not be present. + if _, ok := claims["email"]; ok { + t.Error("email must be absent when email scope is not granted") + } + // Without groups scope, groups must not be present. + if _, ok := claims["groups"]; ok { + t.Error("groups must be absent when groups scope is not granted") + } +} + +func TestTokenHandler_ScopeFiltering_AllScopes(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, verifier) // has openid, profile, email, groups + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + resp := decodeTokenResponse(t, w.Body.String()) + idToken := resp["id_token"].(string) + claims := parseJWTPayload(t, idToken) + + if claims["preferred_username"] != "alice" { + t.Errorf("preferred_username: expected alice, got %v", claims["preferred_username"]) + } + if claims["email"] != "alice@example.com" { + t.Errorf("email: expected alice@example.com, got %v", claims["email"]) + } + if _, ok := claims["groups"]; !ok { + t.Error("groups claim must be present when groups scope is granted") + } +} + +func TestTokenHandler_TokenIssuedTelemetry(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + capture := &captureEmitter{} + h := &oidc.TokenHandler{ + ClientConfig: testClient(), + Sessions: sessions, + Users: users, + SigningKey: key, + Issuer: "https://auth.netkingdom.local", + TokenLifetime: 15 * time.Minute, + Emitter: capture, + } + + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, verifier) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + found := false + for _, ev := range capture.events { + if ev.EventType == telemetry.EventTokenIssued { + found = true + break + } + } + if !found { + t.Error("expected token_issued telemetry event") + } +} + +func TestTokenHandler_CodeDeletedAfterUse(t *testing.T) { + sessions := oidc.NewSessionStore() + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + + h, _ := newTokenHandler(t, sessions, users) + + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + code := seededSession(sessions, verifier) + + params := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{"test-client"}, + "redirect_uri": []string{"https://app.example.com/callback"}, + "code_verifier": []string{verifier}, + } + + // First use — should succeed. + req := tokenRequest(params) + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("first use: expected 200, got %d", w.Code) + } + + // Second use — code should be gone. + req2 := tokenRequest(params) + w2 := httptest.NewRecorder() + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusBadRequest { + t.Errorf("second use: expected 400 (code replay), got %d", w2.Code) + } +}