generated from coulomb/repo-seed
feat: implement T06, T07 — authorization endpoint, token endpoint
- T06: /authorize with full PKCE validation, Authelia delegation, MFA check - T07: /token with RS256 JWT issuance (stdlib only), PKCE verification, scope-filtered claims 50 OIDC tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
320
src/internal/server/oidc/authorize.go
Normal file
320
src/internal/server/oidc/authorize.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
profileerrors "keycape/internal/errors"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// PendingState holds the authorization request parameters while the user is
|
||||
// being authenticated by the upstream provider (e.g. Authelia). It is keyed
|
||||
// by the opaque state value that is round-tripped through the upstream.
|
||||
type PendingState struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
PKCEChallenge string
|
||||
PKCEChallengeMethod string
|
||||
State string
|
||||
Scopes []string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// pendingStateStore is a thread-safe map of state → PendingState.
|
||||
type pendingStateStore struct {
|
||||
mu sync.Mutex
|
||||
store map[string]*PendingState
|
||||
}
|
||||
|
||||
func newPendingStateStore() *pendingStateStore {
|
||||
return &pendingStateStore{store: make(map[string]*PendingState)}
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Store(state string, ps *PendingState) {
|
||||
p.mu.Lock()
|
||||
p.store[state] = ps
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Load(state string) (*PendingState, bool) {
|
||||
p.mu.Lock()
|
||||
ps, ok := p.store[state]
|
||||
p.mu.Unlock()
|
||||
return ps, ok
|
||||
}
|
||||
|
||||
func (p *pendingStateStore) Delete(state string) {
|
||||
p.mu.Lock()
|
||||
delete(p.store, state)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// AuthorizeHandler implements GET /authorize and GET /authorize/callback.
|
||||
type AuthorizeHandler struct {
|
||||
ClientConfig map[string]*domain.Client
|
||||
Auth domain.AuthProvider
|
||||
MFA domain.MFAProvider
|
||||
Sessions *SessionStore
|
||||
Emitter telemetry.Emitter
|
||||
|
||||
pending *pendingStateStore
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// PendingStates returns the underlying pending-state store so tests can seed it.
|
||||
func (h *AuthorizeHandler) PendingStates() *pendingStateStore {
|
||||
h.init()
|
||||
return h.pending
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) init() {
|
||||
h.once.Do(func() {
|
||||
if h.pending == nil {
|
||||
h.pending = newPendingStateStore()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ServeHTTP dispatches to the authorize or callback handler based on path.
|
||||
func (h *AuthorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.init()
|
||||
if strings.HasSuffix(r.URL.Path, "/callback") {
|
||||
h.ServeHTTPCallback(w, r)
|
||||
return
|
||||
}
|
||||
h.serveAuthorize(w, r)
|
||||
}
|
||||
|
||||
// serveAuthorize handles the initial GET /authorize request.
|
||||
func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
q := r.URL.Query()
|
||||
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
responseType := q.Get("response_type")
|
||||
scope := q.Get("scope")
|
||||
state := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
|
||||
// Emit auth_start telemetry immediately.
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthStart,
|
||||
ClientID: clientID,
|
||||
Endpoint: "/authorize",
|
||||
Result: "pending",
|
||||
})
|
||||
|
||||
// 1. Validate client_id.
|
||||
client, ok := h.ClientConfig[clientID]
|
||||
if !ok {
|
||||
profileerrors.InvalidProfileUsage("unknown client_id", "client_id").
|
||||
Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Validate redirect_uri — check for wildcards first, then exact match.
|
||||
for _, registered := range client.RedirectURIs {
|
||||
if strings.ContainsAny(registered, "*?") {
|
||||
profileerrors.RejectedForSafety(
|
||||
"wildcard redirect URIs are not permitted",
|
||||
"redirect_uri",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !uriRegistered(client.RedirectURIs, redirectURI) {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"redirect_uri does not match any registered URI",
|
||||
"redirect_uri",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Validate response_type.
|
||||
if responseType != "code" {
|
||||
profileerrors.FeatureNotSupported(
|
||||
"only response_type=code is supported",
|
||||
"response_type="+responseType,
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Validate scope contains openid.
|
||||
if !scopeContains(scope, "openid") {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"scope must include openid",
|
||||
"scope",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Validate code_challenge is present.
|
||||
if codeChallenge == "" {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"code_challenge is required (PKCE S256)",
|
||||
"code_challenge",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Validate code_challenge_method.
|
||||
if codeChallengeMethod == "plain" {
|
||||
profileerrors.RejectedForSafety(
|
||||
"code_challenge_method=plain is rejected for security; use S256",
|
||||
"code_challenge_method",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "S256" {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"code_challenge_method must be S256",
|
||||
"code_challenge_method",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store pending state so the callback can reconstruct the session.
|
||||
h.pending.Store(state, &PendingState{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
PKCEChallenge: codeChallenge,
|
||||
PKCEChallengeMethod: codeChallengeMethod,
|
||||
State: state,
|
||||
Scopes: strings.Fields(scope),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
|
||||
// Delegate to Auth provider.
|
||||
authURL, err := h.Auth.AuthorizeURL(ctx, domain.AuthRequest{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
State: state,
|
||||
Scopes: strings.Fields(scope),
|
||||
PKCEChallenge: codeChallenge,
|
||||
PKCEChallengeMethod: codeChallengeMethod,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "upstream auth provider error", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// ServeHTTPCallback handles GET /authorize/callback.
|
||||
func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) {
|
||||
h.init()
|
||||
ctx := r.Context()
|
||||
q := r.URL.Query()
|
||||
|
||||
state := q.Get("state")
|
||||
code := q.Get("code")
|
||||
mfaToken := q.Get("mfa_token")
|
||||
|
||||
// Recover pending state keyed by state param.
|
||||
ps, ok := h.pending.Load(state)
|
||||
if !ok {
|
||||
http.Error(w, "unknown or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if time.Now().After(ps.ExpiresAt) {
|
||||
h.pending.Delete(state)
|
||||
http.Error(w, "authorization request expired", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h.pending.Delete(state)
|
||||
|
||||
// Handle upstream callback.
|
||||
result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{
|
||||
Code: code,
|
||||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "auth_failed",
|
||||
})
|
||||
http.Error(w, "authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Check MFA requirement.
|
||||
mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username)
|
||||
if err != nil {
|
||||
http.Error(w, "mfa check error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if mfaRequired {
|
||||
if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "mfa_failed",
|
||||
})
|
||||
http.Error(w, "MFA validation failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Generate authorization code and store PKCE session.
|
||||
sess := &PKCESession{
|
||||
ClientID: ps.ClientID,
|
||||
RedirectURI: ps.RedirectURI,
|
||||
PKCEChallenge: ps.PKCEChallenge,
|
||||
PKCEChallengeMethod: ps.PKCEChallengeMethod,
|
||||
State: state,
|
||||
Username: result.Username,
|
||||
Scopes: ps.Scopes,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
authCode := h.Sessions.Create(sess)
|
||||
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthSuccess,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "success",
|
||||
Scopes: ps.Scopes,
|
||||
})
|
||||
|
||||
// Redirect to client with code and state.
|
||||
redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state
|
||||
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func uriRegistered(registered []string, target string) bool {
|
||||
for _, u := range registered {
|
||||
if u == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func scopeContains(scope, want string) bool {
|
||||
for _, s := range strings.Fields(scope) {
|
||||
if s == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
565
src/internal/server/oidc/authorize_test.go
Normal file
565
src/internal/server/oidc/authorize_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
profileerrors "keycape/internal/errors"
|
||||
"keycape/internal/server/oidc"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock implementations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockAuthProvider implements domain.AuthProvider.
|
||||
type mockAuthProvider struct {
|
||||
authorizeURL string
|
||||
authorizeErr error
|
||||
|
||||
callbackResult *domain.AuthResult
|
||||
callbackErr error
|
||||
}
|
||||
|
||||
func (m *mockAuthProvider) AuthorizeURL(_ context.Context, _ domain.AuthRequest) (string, error) {
|
||||
if m.authorizeErr != nil {
|
||||
return "", m.authorizeErr
|
||||
}
|
||||
return m.authorizeURL, nil
|
||||
}
|
||||
|
||||
func (m *mockAuthProvider) HandleCallback(_ context.Context, _ domain.CallbackParams) (*domain.AuthResult, error) {
|
||||
return m.callbackResult, m.callbackErr
|
||||
}
|
||||
|
||||
// mockMFAProvider implements domain.MFAProvider.
|
||||
type mockMFAProvider struct {
|
||||
required bool
|
||||
requiredErr error
|
||||
|
||||
validateErr error
|
||||
}
|
||||
|
||||
func (m *mockMFAProvider) CheckMFARequired(_ context.Context, _ string) (bool, error) {
|
||||
return m.required, m.requiredErr
|
||||
}
|
||||
|
||||
func (m *mockMFAProvider) ValidateMFAToken(_ context.Context, _, _ string) error {
|
||||
return m.validateErr
|
||||
}
|
||||
|
||||
// captureEmitter captures the last emitted event.
|
||||
type captureEmitter struct {
|
||||
events []telemetry.Event
|
||||
}
|
||||
|
||||
func (c *captureEmitter) Emit(_ context.Context, ev telemetry.Event) {
|
||||
c.events = append(c.events, ev)
|
||||
}
|
||||
|
||||
func (c *captureEmitter) last() telemetry.Event {
|
||||
if len(c.events) == 0 {
|
||||
return telemetry.Event{}
|
||||
}
|
||||
return c.events[len(c.events)-1]
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func testClient() map[string]*domain.Client {
|
||||
return map[string]*domain.Client{
|
||||
"test-client": {
|
||||
ClientID: "test-client",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
AllowedScopes: []string{"openid", "profile", "email"},
|
||||
ClientType: "public",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAuthorizeHandler(auth domain.AuthProvider, mfa domain.MFAProvider, emitter telemetry.Emitter) *oidc.AuthorizeHandler {
|
||||
return &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: oidc.NewSessionStore(),
|
||||
Emitter: emitter,
|
||||
}
|
||||
}
|
||||
|
||||
func validAuthorizeParams() url.Values {
|
||||
return url.Values{
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"openid profile"},
|
||||
"state": []string{"random-state"},
|
||||
"code_challenge": []string{"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"},
|
||||
"code_challenge_method": []string{"S256"},
|
||||
}
|
||||
}
|
||||
|
||||
func authorizeRequest(params url.Values) *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/authorize?"+params.Encode(), nil)
|
||||
}
|
||||
|
||||
func decodeProfileError(t *testing.T, body string) profileerrors.ErrorType {
|
||||
t.Helper()
|
||||
var pe profileerrors.ProfileError
|
||||
if err := json.Unmarshal([]byte(body), &pe); err != nil {
|
||||
t.Fatalf("could not decode ProfileError: %v (body: %q)", err, body)
|
||||
}
|
||||
return pe.Error
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// T06 Authorization Endpoint Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAuthorizeHandler_ValidRequest_RedirectsToAuthelia(t *testing.T) {
|
||||
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth?state=xyz"}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
req := authorizeRequest(validAuthorizeParams())
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "https://authelia.example.com/auth?state=xyz" {
|
||||
t.Errorf("expected redirect to Authelia, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_EmitsAuthStart(t *testing.T) {
|
||||
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth"}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
req := authorizeRequest(validAuthorizeParams())
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
found := false
|
||||
for _, ev := range emitter.events {
|
||||
if ev.EventType == telemetry.EventAuthStart {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected auth_start telemetry event to be emitted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_MissingCodeChallenge_InvalidProfileUsage(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Del("code_challenge")
|
||||
params.Del("code_challenge_method")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_WildcardRedirectURI_RejectedForSafety(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
clients := map[string]*domain.Client{
|
||||
"wildcard-client": {
|
||||
ClientID: "wildcard-client",
|
||||
RedirectURIs: []string{"https://app.example.com/*"},
|
||||
ClientType: "public",
|
||||
},
|
||||
}
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: clients,
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: oidc.NewSessionStore(),
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("client_id", "wildcard-client")
|
||||
params.Set("redirect_uri", "https://app.example.com/anything")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrRejectedForSafety {
|
||||
t.Errorf("expected rejected_for_profile_safety, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_UnknownClient_InvalidProfileUsage(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("client_id", "no-such-client")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_WrongResponseType_FeatureNotSupported(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("response_type", "token")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrFeatureNotSupported {
|
||||
t.Errorf("expected feature_not_supported_by_profile, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_MissingOpenIDScope_InvalidProfileUsage(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("scope", "profile email")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_PlainCodeChallengeMethod_RejectedForSafety(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("code_challenge_method", "plain")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrRejectedForSafety {
|
||||
t.Errorf("expected rejected_for_profile_safety, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_UnknownRedirectURI_InvalidProfileUsage(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
params := validAuthorizeParams()
|
||||
params.Set("redirect_uri", "https://evil.example.com/callback")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Callback tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAuthorizeCallback_Success_RedirectsWithCode(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||
}
|
||||
mfa := &mockMFAProvider{required: false}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/authorize/callback?code=authelia-code&state=random-state", nil)
|
||||
// Simulate that there is an ongoing PKCE flow stored in query param forwarding
|
||||
// The callback needs the original client context. We store it via a pre-seeded
|
||||
// pending session keyed by state.
|
||||
// For the callback handler, we expect it to look up the pending state by the
|
||||
// "state" parameter that was originally embedded. We seed the pending state.
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "random-state",
|
||||
Scopes: []string{"openid", "profile"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
parsed, err := url.Parse(loc)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid Location header: %v", err)
|
||||
}
|
||||
if parsed.Query().Get("code") == "" {
|
||||
t.Error("expected code param in redirect, got empty")
|
||||
}
|
||||
if parsed.Query().Get("state") != "random-state" {
|
||||
t.Errorf("expected state=random-state, got %q", parsed.Query().Get("state"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_MFAFailed_AuthFailure(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||
}
|
||||
mfa := &mockMFAProvider{
|
||||
required: true,
|
||||
validateErr: domain.ErrMFAFailed,
|
||||
}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "random-state",
|
||||
Scopes: []string{"openid"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/authorize/callback?code=authelia-code&state=random-state&mfa_token=wrong", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
found := false
|
||||
for _, ev := range emitter.events {
|
||||
if ev.EventType == telemetry.EventAuthFailure {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected auth_failure telemetry event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_AuthProviderFailed_AuthFailure(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackErr: domain.ErrAuthFailed,
|
||||
}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
State: "random-state",
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/authorize/callback?code=bad&state=random-state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
found := false
|
||||
for _, ev := range emitter.events {
|
||||
if ev.EventType == telemetry.EventAuthFailure {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected auth_failure telemetry event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_EmitsAuthSuccess(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackResult: &domain.AuthResult{Username: "bob"},
|
||||
}
|
||||
mfa := &mockMFAProvider{required: false}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
h.PendingStates().Store("s1", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: "abc",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "s1",
|
||||
Scopes: []string{"openid"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/authorize/callback?code=c&state=s1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
found := false
|
||||
for _, ev := range emitter.events {
|
||||
if ev.EventType == telemetry.EventAuthSuccess {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected auth_success telemetry event, got events: %v", emitter.events)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ServeHTTP dispatch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAuthorizeHandler_ServeHTTP_DispatchesToCallback(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||
}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
|
||||
// A request to /authorize/callback should not be treated as the initial
|
||||
// authorize request and must not require PKCE params.
|
||||
req := httptest.NewRequest(http.MethodGet, "/authorize/callback?code=x&state=y", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
// Without a seeded pending state for "y", the callback returns an error.
|
||||
// The important thing is that it is NOT a redirect to Authelia.
|
||||
if w.Code == http.StatusFound {
|
||||
loc := w.Header().Get("Location")
|
||||
if strings.Contains(loc, "authelia") {
|
||||
t.Error("callback path must not redirect to Authelia")
|
||||
}
|
||||
}
|
||||
}
|
||||
79
src/internal/server/oidc/session.go
Normal file
79
src/internal/server/oidc/session.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PKCESession stores the in-flight authorization state server-side.
|
||||
type PKCESession struct {
|
||||
Code string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
PKCEChallenge string // S256 challenge
|
||||
PKCEChallengeMethod string // always "S256"
|
||||
State string
|
||||
Username string // set after auth
|
||||
Scopes []string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// SessionStore is an in-memory PKCE session store.
|
||||
type SessionStore struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*PKCESession // keyed by code
|
||||
}
|
||||
|
||||
// NewSessionStore returns an initialised, empty SessionStore.
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{
|
||||
sessions: make(map[string]*PKCESession),
|
||||
}
|
||||
}
|
||||
|
||||
// Create stores the session and returns the generated authorization code.
|
||||
func (s *SessionStore) Create(sess *PKCESession) string {
|
||||
code := generateCode()
|
||||
sess.Code = code
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[code] = sess
|
||||
s.mu.Unlock()
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
// Get retrieves a session by code. Returns false if not found or expired.
|
||||
func (s *SessionStore) Get(code string) (*PKCESession, bool) {
|
||||
s.mu.Lock()
|
||||
sess, ok := s.sessions[code]
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(sess.ExpiresAt) {
|
||||
s.Delete(code)
|
||||
return nil, false
|
||||
}
|
||||
return sess, true
|
||||
}
|
||||
|
||||
// Delete removes a session by code. No-op if the code is not present.
|
||||
func (s *SessionStore) Delete(code string) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, code)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// generateCode returns a cryptographically random, URL-safe string suitable
|
||||
// for use as an authorization code.
|
||||
func generateCode() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("oidc: failed to generate random code: " + err.Error())
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
219
src/internal/server/oidc/token.go
Normal file
219
src/internal/server/oidc/token.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
profileerrors "keycape/internal/errors"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// TokenHandler implements POST /token.
|
||||
type TokenHandler struct {
|
||||
ClientConfig map[string]*domain.Client
|
||||
Sessions *SessionStore
|
||||
Users domain.UserRepository
|
||||
SigningKey *rsa.PrivateKey
|
||||
Issuer string
|
||||
TokenLifetime time.Duration
|
||||
Emitter telemetry.Emitter
|
||||
}
|
||||
|
||||
// tokenResponse is the JSON body returned on a successful token exchange.
|
||||
type tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
// ServeHTTP handles POST /token.
|
||||
func (h *TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.FormValue("grant_type")
|
||||
clientID := r.FormValue("client_id")
|
||||
code := r.FormValue("code")
|
||||
codeVerifier := r.FormValue("code_verifier")
|
||||
|
||||
// 1. Validate grant_type.
|
||||
if grantType != "authorization_code" {
|
||||
profileerrors.FeatureNotSupported(
|
||||
"only grant_type=authorization_code is supported",
|
||||
"grant_type="+grantType,
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Validate client exists (basic check; secret auth delegated to future work).
|
||||
if _, ok := h.ClientConfig[clientID]; !ok {
|
||||
profileerrors.InvalidProfileUsage("unknown client_id", "client_id").
|
||||
Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Look up PKCE session.
|
||||
sess, ok := h.Sessions.Get(code)
|
||||
if !ok {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"authorization code not found or expired",
|
||||
"code",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify client_id matches the session.
|
||||
if sess.ClientID != clientID {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"client_id does not match the authorization code",
|
||||
"client_id",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Verify PKCE code_verifier.
|
||||
if !verifyPKCE(codeVerifier, sess.PKCEChallenge) {
|
||||
profileerrors.InvalidProfileUsage(
|
||||
"code_verifier does not match code_challenge",
|
||||
"code_verifier",
|
||||
).Write(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Look up user.
|
||||
user, err := h.Users.LookupUser(ctx, sess.Username)
|
||||
if err != nil {
|
||||
http.Error(w, "user not found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Build JWT claims.
|
||||
now := time.Now()
|
||||
exp := now.Add(h.TokenLifetime)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"iss": h.Issuer,
|
||||
"sub": user.ID,
|
||||
"aud": clientID,
|
||||
"exp": exp.Unix(),
|
||||
"iat": now.Unix(),
|
||||
}
|
||||
|
||||
scopeSet := make(map[string]bool)
|
||||
for _, s := range sess.Scopes {
|
||||
scopeSet[s] = true
|
||||
}
|
||||
|
||||
if scopeSet["profile"] {
|
||||
claims["preferred_username"] = user.Username
|
||||
}
|
||||
if scopeSet["email"] {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
if scopeSet["groups"] {
|
||||
claims["groups"] = user.Groups
|
||||
}
|
||||
|
||||
// 7. Sign JWT with RSA-SHA256.
|
||||
kid := "key-1" // static kid for v0.1
|
||||
jwtToken, err := buildJWT(claims, kid, h.SigningKey)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to build JWT", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Delete used PKCE session (prevent replay).
|
||||
h.Sessions.Delete(code)
|
||||
|
||||
// 9. Build response.
|
||||
resp := tokenResponse{
|
||||
AccessToken: jwtToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(h.TokenLifetime.Seconds()),
|
||||
IDToken: jwtToken,
|
||||
}
|
||||
|
||||
// 10. Emit token_issued telemetry.
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventTokenIssued,
|
||||
ClientID: clientID,
|
||||
Endpoint: "/token",
|
||||
Result: "success",
|
||||
Scopes: sess.Scopes,
|
||||
GrantType: grantType,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PKCE verification
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// verifyPKCE checks BASE64URL(SHA256(verifier)) == challenge (S256 method).
|
||||
func verifyPKCE(verifier, challenge string) bool {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(verifier))
|
||||
computed := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
return computed == challenge
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// JWT construction (stdlib only — no external JWT library)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type jwtHeader struct {
|
||||
Alg string `json:"alg"`
|
||||
Typ string `json:"typ"`
|
||||
Kid string `json:"kid"`
|
||||
}
|
||||
|
||||
// buildJWT constructs and signs a JWT using RSA-SHA256 with the standard library.
|
||||
// Format: base64url(header) + "." + base64url(payload) + "." + base64url(signature)
|
||||
func buildJWT(claims map[string]interface{}, kid string, key *rsa.PrivateKey) (string, error) {
|
||||
// Header.
|
||||
hdr := jwtHeader{Alg: "RS256", Typ: "JWT", Kid: kid}
|
||||
hdrJSON, err := json.Marshal(hdr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
|
||||
|
||||
// Payload.
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
// Signing input.
|
||||
signingInput := hdrB64 + "." + payloadB64
|
||||
|
||||
// Digest.
|
||||
digest := sha256.Sum256([]byte(signingInput))
|
||||
|
||||
// Sign with PKCS1v15 / SHA256.
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, digest[:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sigB64 := base64.RawURLEncoding.EncodeToString(sig)
|
||||
|
||||
return strings.Join([]string{hdrB64, payloadB64, sigB64}, "."), nil
|
||||
}
|
||||
496
src/internal/server/oidc/token_test.go
Normal file
496
src/internal/server/oidc/token_test.go
Normal file
@@ -0,0 +1,496 @@
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
profileerrors "keycape/internal/errors"
|
||||
"keycape/internal/server/oidc"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock UserRepository
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type mockUserRepo struct {
|
||||
users map[string]*domain.User
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) LookupUser(_ context.Context, username string) (*domain.User, error) {
|
||||
u, ok := m.users[username]
|
||||
if !ok {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) ValidatePassword(_ context.Context, _, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PKCE helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// makeVerifierAndChallenge returns a code_verifier and its S256 code_challenge.
|
||||
func makeVerifierAndChallenge() (verifier, challenge string) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
h := sha256.New()
|
||||
h.Write([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTokenHandler(t *testing.T, sessions *oidc.SessionStore, users domain.UserRepository) (*oidc.TokenHandler, *rsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
emitter := &captureEmitter{}
|
||||
h := &oidc.TokenHandler{
|
||||
ClientConfig: testClient(),
|
||||
Sessions: sessions,
|
||||
Users: users,
|
||||
SigningKey: key,
|
||||
Issuer: "https://auth.netkingdom.local",
|
||||
TokenLifetime: 15 * time.Minute,
|
||||
Emitter: emitter,
|
||||
}
|
||||
return h, key
|
||||
}
|
||||
|
||||
func tokenRequest(params url.Values) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/token",
|
||||
strings.NewReader(params.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
}
|
||||
|
||||
func seededSession(sessions *oidc.SessionStore, verifier string) (code string) {
|
||||
challenge := s256Challenge(verifier)
|
||||
sess := &oidc.PKCESession{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: challenge,
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "state1",
|
||||
Username: "alice",
|
||||
Scopes: []string{"openid", "profile", "email", "groups"},
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
return sessions.Create(sess)
|
||||
}
|
||||
|
||||
func s256Challenge(verifier string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func decodeTokenResponse(t *testing.T, body string) map[string]interface{} {
|
||||
t.Helper()
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &m); err != nil {
|
||||
t.Fatalf("could not decode token response: %v (body: %q)", err, body)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func parseJWTPayload(t *testing.T, token string) map[string]interface{} {
|
||||
t.Helper()
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 JWT parts, got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("decode JWT payload: %v", err)
|
||||
}
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
t.Fatalf("unmarshal JWT payload: %v", err)
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
||||
func aliceUser() *domain.User {
|
||||
return &domain.User{
|
||||
ID: "user-alice",
|
||||
Username: "alice",
|
||||
Email: "alice@example.com",
|
||||
Groups: []string{"admin", "users"},
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// T07 Token Endpoint Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestTokenHandler_ValidExchange_ReturnsJWT(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, verifier)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
resp := decodeTokenResponse(t, w.Body.String())
|
||||
if _, ok := resp["access_token"]; !ok {
|
||||
t.Error("missing access_token")
|
||||
}
|
||||
if _, ok := resp["id_token"]; !ok {
|
||||
t.Error("missing id_token")
|
||||
}
|
||||
if resp["token_type"] != "Bearer" {
|
||||
t.Errorf("expected token_type Bearer, got %v", resp["token_type"])
|
||||
}
|
||||
if _, ok := resp["expires_in"]; !ok {
|
||||
t.Error("missing expires_in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_WrongGrantType_FeatureNotSupported(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"client_credentials"},
|
||||
"client_id": []string{"test-client"},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrFeatureNotSupported {
|
||||
t.Errorf("expected feature_not_supported_by_profile, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_PKCEMismatch_InvalidProfileUsage(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
realVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, realVerifier)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{"wrong-verifier-that-does-not-match"},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_CodeNotFound_InvalidProfileUsage(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{"no-such-code"},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{"any-verifier"},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
errType := decodeProfileError(t, w.Body.String())
|
||||
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_JWTClaims_CorrectSubAndIssuer(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, verifier)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
resp := decodeTokenResponse(t, w.Body.String())
|
||||
idToken, ok := resp["id_token"].(string)
|
||||
if !ok {
|
||||
t.Fatal("id_token is not a string")
|
||||
}
|
||||
|
||||
claims := parseJWTPayload(t, idToken)
|
||||
|
||||
if claims["sub"] != "user-alice" {
|
||||
t.Errorf("sub: expected user-alice, got %v", claims["sub"])
|
||||
}
|
||||
if claims["iss"] != "https://auth.netkingdom.local" {
|
||||
t.Errorf("iss: expected https://auth.netkingdom.local, got %v", claims["iss"])
|
||||
}
|
||||
if claims["aud"] != "test-client" {
|
||||
t.Errorf("aud: expected test-client, got %v", claims["aud"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_ScopeFiltering_ProfileScope(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
// Seed session with only openid scope (no email, no groups).
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
challenge := s256Challenge(verifier)
|
||||
sess := &oidc.PKCESession{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: challenge,
|
||||
PKCEChallengeMethod: "S256",
|
||||
Username: "alice",
|
||||
Scopes: []string{"openid"}, // no profile/email/groups
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
code := sessions.Create(sess)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
resp := decodeTokenResponse(t, w.Body.String())
|
||||
idToken := resp["id_token"].(string)
|
||||
claims := parseJWTPayload(t, idToken)
|
||||
|
||||
// Without profile scope, preferred_username must not be present.
|
||||
if _, ok := claims["preferred_username"]; ok {
|
||||
t.Error("preferred_username must be absent when profile scope is not granted")
|
||||
}
|
||||
// Without email scope, email must not be present.
|
||||
if _, ok := claims["email"]; ok {
|
||||
t.Error("email must be absent when email scope is not granted")
|
||||
}
|
||||
// Without groups scope, groups must not be present.
|
||||
if _, ok := claims["groups"]; ok {
|
||||
t.Error("groups must be absent when groups scope is not granted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_ScopeFiltering_AllScopes(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, verifier) // has openid, profile, email, groups
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
resp := decodeTokenResponse(t, w.Body.String())
|
||||
idToken := resp["id_token"].(string)
|
||||
claims := parseJWTPayload(t, idToken)
|
||||
|
||||
if claims["preferred_username"] != "alice" {
|
||||
t.Errorf("preferred_username: expected alice, got %v", claims["preferred_username"])
|
||||
}
|
||||
if claims["email"] != "alice@example.com" {
|
||||
t.Errorf("email: expected alice@example.com, got %v", claims["email"])
|
||||
}
|
||||
if _, ok := claims["groups"]; !ok {
|
||||
t.Error("groups claim must be present when groups scope is granted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_TokenIssuedTelemetry(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
capture := &captureEmitter{}
|
||||
h := &oidc.TokenHandler{
|
||||
ClientConfig: testClient(),
|
||||
Sessions: sessions,
|
||||
Users: users,
|
||||
SigningKey: key,
|
||||
Issuer: "https://auth.netkingdom.local",
|
||||
TokenLifetime: 15 * time.Minute,
|
||||
Emitter: capture,
|
||||
}
|
||||
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, verifier)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, ev := range capture.events {
|
||||
if ev.EventType == telemetry.EventTokenIssued {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected token_issued telemetry event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandler_CodeDeletedAfterUse(t *testing.T) {
|
||||
sessions := oidc.NewSessionStore()
|
||||
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||
|
||||
h, _ := newTokenHandler(t, sessions, users)
|
||||
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code := seededSession(sessions, verifier)
|
||||
|
||||
params := url.Values{
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"code": []string{code},
|
||||
"client_id": []string{"test-client"},
|
||||
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||
"code_verifier": []string{verifier},
|
||||
}
|
||||
|
||||
// First use — should succeed.
|
||||
req := tokenRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("first use: expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Second use — code should be gone.
|
||||
req2 := tokenRequest(params)
|
||||
w2 := httptest.NewRecorder()
|
||||
h.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusBadRequest {
|
||||
t.Errorf("second use: expected 400 (code replay), got %d", w2.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user