generated from coulomb/repo-seed
feat: implement T06, T07 — authorization endpoint, token endpoint
- T06: /authorize with full PKCE validation, Authelia delegation, MFA check - T07: /token with RS256 JWT issuance (stdlib only), PKCE verification, scope-filtered claims 50 OIDC tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
320
src/internal/server/oidc/authorize.go
Normal file
320
src/internal/server/oidc/authorize.go
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keycape/internal/domain"
|
||||||
|
profileerrors "keycape/internal/errors"
|
||||||
|
"keycape/internal/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PendingState holds the authorization request parameters while the user is
|
||||||
|
// being authenticated by the upstream provider (e.g. Authelia). It is keyed
|
||||||
|
// by the opaque state value that is round-tripped through the upstream.
|
||||||
|
type PendingState struct {
|
||||||
|
ClientID string
|
||||||
|
RedirectURI string
|
||||||
|
PKCEChallenge string
|
||||||
|
PKCEChallengeMethod string
|
||||||
|
State string
|
||||||
|
Scopes []string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// pendingStateStore is a thread-safe map of state → PendingState.
|
||||||
|
type pendingStateStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
store map[string]*PendingState
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPendingStateStore() *pendingStateStore {
|
||||||
|
return &pendingStateStore{store: make(map[string]*PendingState)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pendingStateStore) Store(state string, ps *PendingState) {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.store[state] = ps
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pendingStateStore) Load(state string) (*PendingState, bool) {
|
||||||
|
p.mu.Lock()
|
||||||
|
ps, ok := p.store[state]
|
||||||
|
p.mu.Unlock()
|
||||||
|
return ps, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pendingStateStore) Delete(state string) {
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.store, state)
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeHandler implements GET /authorize and GET /authorize/callback.
|
||||||
|
type AuthorizeHandler struct {
|
||||||
|
ClientConfig map[string]*domain.Client
|
||||||
|
Auth domain.AuthProvider
|
||||||
|
MFA domain.MFAProvider
|
||||||
|
Sessions *SessionStore
|
||||||
|
Emitter telemetry.Emitter
|
||||||
|
|
||||||
|
pending *pendingStateStore
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// PendingStates returns the underlying pending-state store so tests can seed it.
|
||||||
|
func (h *AuthorizeHandler) PendingStates() *pendingStateStore {
|
||||||
|
h.init()
|
||||||
|
return h.pending
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthorizeHandler) init() {
|
||||||
|
h.once.Do(func() {
|
||||||
|
if h.pending == nil {
|
||||||
|
h.pending = newPendingStateStore()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP dispatches to the authorize or callback handler based on path.
|
||||||
|
func (h *AuthorizeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.init()
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/callback") {
|
||||||
|
h.ServeHTTPCallback(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.serveAuthorize(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveAuthorize handles the initial GET /authorize request.
|
||||||
|
func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
clientID := q.Get("client_id")
|
||||||
|
redirectURI := q.Get("redirect_uri")
|
||||||
|
responseType := q.Get("response_type")
|
||||||
|
scope := q.Get("scope")
|
||||||
|
state := q.Get("state")
|
||||||
|
codeChallenge := q.Get("code_challenge")
|
||||||
|
codeChallengeMethod := q.Get("code_challenge_method")
|
||||||
|
|
||||||
|
// Emit auth_start telemetry immediately.
|
||||||
|
h.Emitter.Emit(ctx, telemetry.Event{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
EventType: telemetry.EventAuthStart,
|
||||||
|
ClientID: clientID,
|
||||||
|
Endpoint: "/authorize",
|
||||||
|
Result: "pending",
|
||||||
|
})
|
||||||
|
|
||||||
|
// 1. Validate client_id.
|
||||||
|
client, ok := h.ClientConfig[clientID]
|
||||||
|
if !ok {
|
||||||
|
profileerrors.InvalidProfileUsage("unknown client_id", "client_id").
|
||||||
|
Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Validate redirect_uri — check for wildcards first, then exact match.
|
||||||
|
for _, registered := range client.RedirectURIs {
|
||||||
|
if strings.ContainsAny(registered, "*?") {
|
||||||
|
profileerrors.RejectedForSafety(
|
||||||
|
"wildcard redirect URIs are not permitted",
|
||||||
|
"redirect_uri",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !uriRegistered(client.RedirectURIs, redirectURI) {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"redirect_uri does not match any registered URI",
|
||||||
|
"redirect_uri",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Validate response_type.
|
||||||
|
if responseType != "code" {
|
||||||
|
profileerrors.FeatureNotSupported(
|
||||||
|
"only response_type=code is supported",
|
||||||
|
"response_type="+responseType,
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Validate scope contains openid.
|
||||||
|
if !scopeContains(scope, "openid") {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"scope must include openid",
|
||||||
|
"scope",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Validate code_challenge is present.
|
||||||
|
if codeChallenge == "" {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"code_challenge is required (PKCE S256)",
|
||||||
|
"code_challenge",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Validate code_challenge_method.
|
||||||
|
if codeChallengeMethod == "plain" {
|
||||||
|
profileerrors.RejectedForSafety(
|
||||||
|
"code_challenge_method=plain is rejected for security; use S256",
|
||||||
|
"code_challenge_method",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if codeChallengeMethod != "S256" {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"code_challenge_method must be S256",
|
||||||
|
"code_challenge_method",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store pending state so the callback can reconstruct the session.
|
||||||
|
h.pending.Store(state, &PendingState{
|
||||||
|
ClientID: clientID,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
PKCEChallenge: codeChallenge,
|
||||||
|
PKCEChallengeMethod: codeChallengeMethod,
|
||||||
|
State: state,
|
||||||
|
Scopes: strings.Fields(scope),
|
||||||
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Delegate to Auth provider.
|
||||||
|
authURL, err := h.Auth.AuthorizeURL(ctx, domain.AuthRequest{
|
||||||
|
ClientID: clientID,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
State: state,
|
||||||
|
Scopes: strings.Fields(scope),
|
||||||
|
PKCEChallenge: codeChallenge,
|
||||||
|
PKCEChallengeMethod: codeChallengeMethod,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "upstream auth provider error", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(w, r, authURL, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTPCallback handles GET /authorize/callback.
|
||||||
|
func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.init()
|
||||||
|
ctx := r.Context()
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
state := q.Get("state")
|
||||||
|
code := q.Get("code")
|
||||||
|
mfaToken := q.Get("mfa_token")
|
||||||
|
|
||||||
|
// Recover pending state keyed by state param.
|
||||||
|
ps, ok := h.pending.Load(state)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "unknown or expired state", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Now().After(ps.ExpiresAt) {
|
||||||
|
h.pending.Delete(state)
|
||||||
|
http.Error(w, "authorization request expired", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.pending.Delete(state)
|
||||||
|
|
||||||
|
// Handle upstream callback.
|
||||||
|
result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{
|
||||||
|
Code: code,
|
||||||
|
State: state,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
h.Emitter.Emit(ctx, telemetry.Event{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
EventType: telemetry.EventAuthFailure,
|
||||||
|
ClientID: ps.ClientID,
|
||||||
|
Endpoint: "/authorize/callback",
|
||||||
|
Result: "failure",
|
||||||
|
ErrorType: "auth_failed",
|
||||||
|
})
|
||||||
|
http.Error(w, "authentication failed", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check MFA requirement.
|
||||||
|
mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "mfa check error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if mfaRequired {
|
||||||
|
if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil {
|
||||||
|
h.Emitter.Emit(ctx, telemetry.Event{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
EventType: telemetry.EventAuthFailure,
|
||||||
|
ClientID: ps.ClientID,
|
||||||
|
Endpoint: "/authorize/callback",
|
||||||
|
Result: "failure",
|
||||||
|
ErrorType: "mfa_failed",
|
||||||
|
})
|
||||||
|
http.Error(w, "MFA validation failed", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate authorization code and store PKCE session.
|
||||||
|
sess := &PKCESession{
|
||||||
|
ClientID: ps.ClientID,
|
||||||
|
RedirectURI: ps.RedirectURI,
|
||||||
|
PKCEChallenge: ps.PKCEChallenge,
|
||||||
|
PKCEChallengeMethod: ps.PKCEChallengeMethod,
|
||||||
|
State: state,
|
||||||
|
Username: result.Username,
|
||||||
|
Scopes: ps.Scopes,
|
||||||
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
authCode := h.Sessions.Create(sess)
|
||||||
|
|
||||||
|
h.Emitter.Emit(ctx, telemetry.Event{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
EventType: telemetry.EventAuthSuccess,
|
||||||
|
ClientID: ps.ClientID,
|
||||||
|
Endpoint: "/authorize/callback",
|
||||||
|
Result: "success",
|
||||||
|
Scopes: ps.Scopes,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redirect to client with code and state.
|
||||||
|
redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state
|
||||||
|
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func uriRegistered(registered []string, target string) bool {
|
||||||
|
for _, u := range registered {
|
||||||
|
if u == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func scopeContains(scope, want string) bool {
|
||||||
|
for _, s := range strings.Fields(scope) {
|
||||||
|
if s == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
565
src/internal/server/oidc/authorize_test.go
Normal file
565
src/internal/server/oidc/authorize_test.go
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
package oidc_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keycape/internal/domain"
|
||||||
|
profileerrors "keycape/internal/errors"
|
||||||
|
"keycape/internal/server/oidc"
|
||||||
|
"keycape/internal/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock implementations
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// mockAuthProvider implements domain.AuthProvider.
|
||||||
|
type mockAuthProvider struct {
|
||||||
|
authorizeURL string
|
||||||
|
authorizeErr error
|
||||||
|
|
||||||
|
callbackResult *domain.AuthResult
|
||||||
|
callbackErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthProvider) AuthorizeURL(_ context.Context, _ domain.AuthRequest) (string, error) {
|
||||||
|
if m.authorizeErr != nil {
|
||||||
|
return "", m.authorizeErr
|
||||||
|
}
|
||||||
|
return m.authorizeURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthProvider) HandleCallback(_ context.Context, _ domain.CallbackParams) (*domain.AuthResult, error) {
|
||||||
|
return m.callbackResult, m.callbackErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockMFAProvider implements domain.MFAProvider.
|
||||||
|
type mockMFAProvider struct {
|
||||||
|
required bool
|
||||||
|
requiredErr error
|
||||||
|
|
||||||
|
validateErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMFAProvider) CheckMFARequired(_ context.Context, _ string) (bool, error) {
|
||||||
|
return m.required, m.requiredErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMFAProvider) ValidateMFAToken(_ context.Context, _, _ string) error {
|
||||||
|
return m.validateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureEmitter captures the last emitted event.
|
||||||
|
type captureEmitter struct {
|
||||||
|
events []telemetry.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *captureEmitter) Emit(_ context.Context, ev telemetry.Event) {
|
||||||
|
c.events = append(c.events, ev)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *captureEmitter) last() telemetry.Event {
|
||||||
|
if len(c.events) == 0 {
|
||||||
|
return telemetry.Event{}
|
||||||
|
}
|
||||||
|
return c.events[len(c.events)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Test helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func testClient() map[string]*domain.Client {
|
||||||
|
return map[string]*domain.Client{
|
||||||
|
"test-client": {
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||||
|
AllowedScopes: []string{"openid", "profile", "email"},
|
||||||
|
ClientType: "public",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthorizeHandler(auth domain.AuthProvider, mfa domain.MFAProvider, emitter telemetry.Emitter) *oidc.AuthorizeHandler {
|
||||||
|
return &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: oidc.NewSessionStore(),
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validAuthorizeParams() url.Values {
|
||||||
|
return url.Values{
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"scope": []string{"openid profile"},
|
||||||
|
"state": []string{"random-state"},
|
||||||
|
"code_challenge": []string{"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"},
|
||||||
|
"code_challenge_method": []string{"S256"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authorizeRequest(params url.Values) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, "/authorize?"+params.Encode(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeProfileError(t *testing.T, body string) profileerrors.ErrorType {
|
||||||
|
t.Helper()
|
||||||
|
var pe profileerrors.ProfileError
|
||||||
|
if err := json.Unmarshal([]byte(body), &pe); err != nil {
|
||||||
|
t.Fatalf("could not decode ProfileError: %v (body: %q)", err, body)
|
||||||
|
}
|
||||||
|
return pe.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// T06 Authorization Endpoint Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_ValidRequest_RedirectsToAuthelia(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth?state=xyz"}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
req := authorizeRequest(validAuthorizeParams())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusFound {
|
||||||
|
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
loc := w.Header().Get("Location")
|
||||||
|
if loc != "https://authelia.example.com/auth?state=xyz" {
|
||||||
|
t.Errorf("expected redirect to Authelia, got %q", loc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_EmitsAuthStart(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth"}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
req := authorizeRequest(validAuthorizeParams())
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, ev := range emitter.events {
|
||||||
|
if ev.EventType == telemetry.EventAuthStart {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected auth_start telemetry event to be emitted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_MissingCodeChallenge_InvalidProfileUsage(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Del("code_challenge")
|
||||||
|
params.Del("code_challenge_method")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_WildcardRedirectURI_RejectedForSafety(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
clients := map[string]*domain.Client{
|
||||||
|
"wildcard-client": {
|
||||||
|
ClientID: "wildcard-client",
|
||||||
|
RedirectURIs: []string{"https://app.example.com/*"},
|
||||||
|
ClientType: "public",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: clients,
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: oidc.NewSessionStore(),
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("client_id", "wildcard-client")
|
||||||
|
params.Set("redirect_uri", "https://app.example.com/anything")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrRejectedForSafety {
|
||||||
|
t.Errorf("expected rejected_for_profile_safety, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_UnknownClient_InvalidProfileUsage(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("client_id", "no-such-client")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_WrongResponseType_FeatureNotSupported(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("response_type", "token")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrFeatureNotSupported {
|
||||||
|
t.Errorf("expected feature_not_supported_by_profile, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_MissingOpenIDScope_InvalidProfileUsage(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("scope", "profile email")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_PlainCodeChallengeMethod_RejectedForSafety(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("code_challenge_method", "plain")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrRejectedForSafety {
|
||||||
|
t.Errorf("expected rejected_for_profile_safety, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_UnknownRedirectURI_InvalidProfileUsage(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
params := validAuthorizeParams()
|
||||||
|
params.Set("redirect_uri", "https://evil.example.com/callback")
|
||||||
|
|
||||||
|
req := authorizeRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Callback tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuthorizeCallback_Success_RedirectsWithCode(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{
|
||||||
|
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||||
|
}
|
||||||
|
mfa := &mockMFAProvider{required: false}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
h := &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: sessions,
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/authorize/callback?code=authelia-code&state=random-state", nil)
|
||||||
|
// Simulate that there is an ongoing PKCE flow stored in query param forwarding
|
||||||
|
// The callback needs the original client context. We store it via a pre-seeded
|
||||||
|
// pending session keyed by state.
|
||||||
|
// For the callback handler, we expect it to look up the pending state by the
|
||||||
|
// "state" parameter that was originally embedded. We seed the pending state.
|
||||||
|
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||||
|
PKCEChallengeMethod: "S256",
|
||||||
|
State: "random-state",
|
||||||
|
Scopes: []string{"openid", "profile"},
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTPCallback(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusFound {
|
||||||
|
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
loc := w.Header().Get("Location")
|
||||||
|
parsed, err := url.Parse(loc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invalid Location header: %v", err)
|
||||||
|
}
|
||||||
|
if parsed.Query().Get("code") == "" {
|
||||||
|
t.Error("expected code param in redirect, got empty")
|
||||||
|
}
|
||||||
|
if parsed.Query().Get("state") != "random-state" {
|
||||||
|
t.Errorf("expected state=random-state, got %q", parsed.Query().Get("state"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeCallback_MFAFailed_AuthFailure(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{
|
||||||
|
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||||
|
}
|
||||||
|
mfa := &mockMFAProvider{
|
||||||
|
required: true,
|
||||||
|
validateErr: domain.ErrMFAFailed,
|
||||||
|
}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
h := &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: sessions,
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||||
|
PKCEChallengeMethod: "S256",
|
||||||
|
State: "random-state",
|
||||||
|
Scopes: []string{"openid"},
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/authorize/callback?code=authelia-code&state=random-state&mfa_token=wrong", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTPCallback(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", w.Code)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, ev := range emitter.events {
|
||||||
|
if ev.EventType == telemetry.EventAuthFailure {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected auth_failure telemetry event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeCallback_AuthProviderFailed_AuthFailure(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{
|
||||||
|
callbackErr: domain.ErrAuthFailed,
|
||||||
|
}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
h := &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: sessions,
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
State: "random-state",
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/authorize/callback?code=bad&state=random-state", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTPCallback(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("expected 401, got %d", w.Code)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, ev := range emitter.events {
|
||||||
|
if ev.EventType == telemetry.EventAuthFailure {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected auth_failure telemetry event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizeCallback_EmitsAuthSuccess(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{
|
||||||
|
callbackResult: &domain.AuthResult{Username: "bob"},
|
||||||
|
}
|
||||||
|
mfa := &mockMFAProvider{required: false}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
h := &oidc.AuthorizeHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Auth: auth,
|
||||||
|
MFA: mfa,
|
||||||
|
Sessions: sessions,
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.PendingStates().Store("s1", &oidc.PendingState{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
PKCEChallenge: "abc",
|
||||||
|
PKCEChallengeMethod: "S256",
|
||||||
|
State: "s1",
|
||||||
|
Scopes: []string{"openid"},
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/authorize/callback?code=c&state=s1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTPCallback(w, req)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, ev := range emitter.events {
|
||||||
|
if ev.EventType == telemetry.EventAuthSuccess {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected auth_success telemetry event, got events: %v", emitter.events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ServeHTTP dispatch
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAuthorizeHandler_ServeHTTP_DispatchesToCallback(t *testing.T) {
|
||||||
|
auth := &mockAuthProvider{
|
||||||
|
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||||
|
}
|
||||||
|
mfa := &mockMFAProvider{}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
|
||||||
|
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||||
|
|
||||||
|
// A request to /authorize/callback should not be treated as the initial
|
||||||
|
// authorize request and must not require PKCE params.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/authorize/callback?code=x&state=y", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Without a seeded pending state for "y", the callback returns an error.
|
||||||
|
// The important thing is that it is NOT a redirect to Authelia.
|
||||||
|
if w.Code == http.StatusFound {
|
||||||
|
loc := w.Header().Get("Location")
|
||||||
|
if strings.Contains(loc, "authelia") {
|
||||||
|
t.Error("callback path must not redirect to Authelia")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
79
src/internal/server/oidc/session.go
Normal file
79
src/internal/server/oidc/session.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCESession stores the in-flight authorization state server-side.
|
||||||
|
type PKCESession struct {
|
||||||
|
Code string
|
||||||
|
ClientID string
|
||||||
|
RedirectURI string
|
||||||
|
PKCEChallenge string // S256 challenge
|
||||||
|
PKCEChallengeMethod string // always "S256"
|
||||||
|
State string
|
||||||
|
Username string // set after auth
|
||||||
|
Scopes []string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStore is an in-memory PKCE session store.
|
||||||
|
type SessionStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
sessions map[string]*PKCESession // keyed by code
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionStore returns an initialised, empty SessionStore.
|
||||||
|
func NewSessionStore() *SessionStore {
|
||||||
|
return &SessionStore{
|
||||||
|
sessions: make(map[string]*PKCESession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create stores the session and returns the generated authorization code.
|
||||||
|
func (s *SessionStore) Create(sess *PKCESession) string {
|
||||||
|
code := generateCode()
|
||||||
|
sess.Code = code
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.sessions[code] = sess
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a session by code. Returns false if not found or expired.
|
||||||
|
func (s *SessionStore) Get(code string) (*PKCESession, bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
sess, ok := s.sessions[code]
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if time.Now().After(sess.ExpiresAt) {
|
||||||
|
s.Delete(code)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return sess, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a session by code. No-op if the code is not present.
|
||||||
|
func (s *SessionStore) Delete(code string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.sessions, code)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCode returns a cryptographically random, URL-safe string suitable
|
||||||
|
// for use as an authorization code.
|
||||||
|
func generateCode() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic("oidc: failed to generate random code: " + err.Error())
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
219
src/internal/server/oidc/token.go
Normal file
219
src/internal/server/oidc/token.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keycape/internal/domain"
|
||||||
|
profileerrors "keycape/internal/errors"
|
||||||
|
"keycape/internal/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenHandler implements POST /token.
|
||||||
|
type TokenHandler struct {
|
||||||
|
ClientConfig map[string]*domain.Client
|
||||||
|
Sessions *SessionStore
|
||||||
|
Users domain.UserRepository
|
||||||
|
SigningKey *rsa.PrivateKey
|
||||||
|
Issuer string
|
||||||
|
TokenLifetime time.Duration
|
||||||
|
Emitter telemetry.Emitter
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenResponse is the JSON body returned on a successful token exchange.
|
||||||
|
type tokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP handles POST /token.
|
||||||
|
func (h *TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "invalid form body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
grantType := r.FormValue("grant_type")
|
||||||
|
clientID := r.FormValue("client_id")
|
||||||
|
code := r.FormValue("code")
|
||||||
|
codeVerifier := r.FormValue("code_verifier")
|
||||||
|
|
||||||
|
// 1. Validate grant_type.
|
||||||
|
if grantType != "authorization_code" {
|
||||||
|
profileerrors.FeatureNotSupported(
|
||||||
|
"only grant_type=authorization_code is supported",
|
||||||
|
"grant_type="+grantType,
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Validate client exists (basic check; secret auth delegated to future work).
|
||||||
|
if _, ok := h.ClientConfig[clientID]; !ok {
|
||||||
|
profileerrors.InvalidProfileUsage("unknown client_id", "client_id").
|
||||||
|
Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Look up PKCE session.
|
||||||
|
sess, ok := h.Sessions.Get(code)
|
||||||
|
if !ok {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"authorization code not found or expired",
|
||||||
|
"code",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify client_id matches the session.
|
||||||
|
if sess.ClientID != clientID {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"client_id does not match the authorization code",
|
||||||
|
"client_id",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Verify PKCE code_verifier.
|
||||||
|
if !verifyPKCE(codeVerifier, sess.PKCEChallenge) {
|
||||||
|
profileerrors.InvalidProfileUsage(
|
||||||
|
"code_verifier does not match code_challenge",
|
||||||
|
"code_verifier",
|
||||||
|
).Write(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Look up user.
|
||||||
|
user, err := h.Users.LookupUser(ctx, sess.Username)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "user not found", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Build JWT claims.
|
||||||
|
now := time.Now()
|
||||||
|
exp := now.Add(h.TokenLifetime)
|
||||||
|
|
||||||
|
claims := map[string]interface{}{
|
||||||
|
"iss": h.Issuer,
|
||||||
|
"sub": user.ID,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": exp.Unix(),
|
||||||
|
"iat": now.Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
scopeSet := make(map[string]bool)
|
||||||
|
for _, s := range sess.Scopes {
|
||||||
|
scopeSet[s] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if scopeSet["profile"] {
|
||||||
|
claims["preferred_username"] = user.Username
|
||||||
|
}
|
||||||
|
if scopeSet["email"] {
|
||||||
|
claims["email"] = user.Email
|
||||||
|
}
|
||||||
|
if scopeSet["groups"] {
|
||||||
|
claims["groups"] = user.Groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Sign JWT with RSA-SHA256.
|
||||||
|
kid := "key-1" // static kid for v0.1
|
||||||
|
jwtToken, err := buildJWT(claims, kid, h.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to build JWT", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Delete used PKCE session (prevent replay).
|
||||||
|
h.Sessions.Delete(code)
|
||||||
|
|
||||||
|
// 9. Build response.
|
||||||
|
resp := tokenResponse{
|
||||||
|
AccessToken: jwtToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: int(h.TokenLifetime.Seconds()),
|
||||||
|
IDToken: jwtToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 10. Emit token_issued telemetry.
|
||||||
|
h.Emitter.Emit(ctx, telemetry.Event{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
EventType: telemetry.EventTokenIssued,
|
||||||
|
ClientID: clientID,
|
||||||
|
Endpoint: "/token",
|
||||||
|
Result: "success",
|
||||||
|
Scopes: sess.Scopes,
|
||||||
|
GrantType: grantType,
|
||||||
|
})
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// PKCE verification
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// verifyPKCE checks BASE64URL(SHA256(verifier)) == challenge (S256 method).
|
||||||
|
func verifyPKCE(verifier, challenge string) bool {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(verifier))
|
||||||
|
computed := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
return computed == challenge
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// JWT construction (stdlib only — no external JWT library)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type jwtHeader struct {
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
Typ string `json:"typ"`
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildJWT constructs and signs a JWT using RSA-SHA256 with the standard library.
|
||||||
|
// Format: base64url(header) + "." + base64url(payload) + "." + base64url(signature)
|
||||||
|
func buildJWT(claims map[string]interface{}, kid string, key *rsa.PrivateKey) (string, error) {
|
||||||
|
// Header.
|
||||||
|
hdr := jwtHeader{Alg: "RS256", Typ: "JWT", Kid: kid}
|
||||||
|
hdrJSON, err := json.Marshal(hdr)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
|
||||||
|
|
||||||
|
// Payload.
|
||||||
|
payloadJSON, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||||
|
|
||||||
|
// Signing input.
|
||||||
|
signingInput := hdrB64 + "." + payloadB64
|
||||||
|
|
||||||
|
// Digest.
|
||||||
|
digest := sha256.Sum256([]byte(signingInput))
|
||||||
|
|
||||||
|
// Sign with PKCS1v15 / SHA256.
|
||||||
|
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, digest[:])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sigB64 := base64.RawURLEncoding.EncodeToString(sig)
|
||||||
|
|
||||||
|
return strings.Join([]string{hdrB64, payloadB64, sigB64}, "."), nil
|
||||||
|
}
|
||||||
496
src/internal/server/oidc/token_test.go
Normal file
496
src/internal/server/oidc/token_test.go
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
package oidc_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"keycape/internal/domain"
|
||||||
|
profileerrors "keycape/internal/errors"
|
||||||
|
"keycape/internal/server/oidc"
|
||||||
|
"keycape/internal/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock UserRepository
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mockUserRepo struct {
|
||||||
|
users map[string]*domain.User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserRepo) LookupUser(_ context.Context, username string) (*domain.User, error) {
|
||||||
|
u, ok := m.users[username]
|
||||||
|
if !ok {
|
||||||
|
return nil, domain.ErrUserNotFound
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserRepo) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserRepo) ValidatePassword(_ context.Context, _, _ string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// PKCE helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// makeVerifierAndChallenge returns a code_verifier and its S256 code_challenge.
|
||||||
|
func makeVerifierAndChallenge() (verifier, challenge string) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Test helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func newTokenHandler(t *testing.T, sessions *oidc.SessionStore, users domain.UserRepository) (*oidc.TokenHandler, *rsa.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate key: %v", err)
|
||||||
|
}
|
||||||
|
emitter := &captureEmitter{}
|
||||||
|
h := &oidc.TokenHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Sessions: sessions,
|
||||||
|
Users: users,
|
||||||
|
SigningKey: key,
|
||||||
|
Issuer: "https://auth.netkingdom.local",
|
||||||
|
TokenLifetime: 15 * time.Minute,
|
||||||
|
Emitter: emitter,
|
||||||
|
}
|
||||||
|
return h, key
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenRequest(params url.Values) *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/token",
|
||||||
|
strings.NewReader(params.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func seededSession(sessions *oidc.SessionStore, verifier string) (code string) {
|
||||||
|
challenge := s256Challenge(verifier)
|
||||||
|
sess := &oidc.PKCESession{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
PKCEChallenge: challenge,
|
||||||
|
PKCEChallengeMethod: "S256",
|
||||||
|
State: "state1",
|
||||||
|
Username: "alice",
|
||||||
|
Scopes: []string{"openid", "profile", "email", "groups"},
|
||||||
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
return sessions.Create(sess)
|
||||||
|
}
|
||||||
|
|
||||||
|
func s256Challenge(verifier string) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeTokenResponse(t *testing.T, body string) map[string]interface{} {
|
||||||
|
t.Helper()
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &m); err != nil {
|
||||||
|
t.Fatalf("could not decode token response: %v (body: %q)", err, body)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseJWTPayload(t *testing.T, token string) map[string]interface{} {
|
||||||
|
t.Helper()
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("expected 3 JWT parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode JWT payload: %v", err)
|
||||||
|
}
|
||||||
|
var claims map[string]interface{}
|
||||||
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
|
t.Fatalf("unmarshal JWT payload: %v", err)
|
||||||
|
}
|
||||||
|
return claims
|
||||||
|
}
|
||||||
|
|
||||||
|
func aliceUser() *domain.User {
|
||||||
|
return &domain.User{
|
||||||
|
ID: "user-alice",
|
||||||
|
Username: "alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Groups: []string{"admin", "users"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// T07 Token Endpoint Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTokenHandler_ValidExchange_ReturnsJWT(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, verifier)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeTokenResponse(t, w.Body.String())
|
||||||
|
if _, ok := resp["access_token"]; !ok {
|
||||||
|
t.Error("missing access_token")
|
||||||
|
}
|
||||||
|
if _, ok := resp["id_token"]; !ok {
|
||||||
|
t.Error("missing id_token")
|
||||||
|
}
|
||||||
|
if resp["token_type"] != "Bearer" {
|
||||||
|
t.Errorf("expected token_type Bearer, got %v", resp["token_type"])
|
||||||
|
}
|
||||||
|
if _, ok := resp["expires_in"]; !ok {
|
||||||
|
t.Error("missing expires_in")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_WrongGrantType_FeatureNotSupported(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"client_credentials"},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrFeatureNotSupported {
|
||||||
|
t.Errorf("expected feature_not_supported_by_profile, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_PKCEMismatch_InvalidProfileUsage(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
realVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, realVerifier)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{"wrong-verifier-that-does-not-match"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_CodeNotFound_InvalidProfileUsage(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{"no-such-code"},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{"any-verifier"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
errType := decodeProfileError(t, w.Body.String())
|
||||||
|
if errType != profileerrors.ErrInvalidProfileUsage {
|
||||||
|
t.Errorf("expected invalid_profile_usage, got %q", errType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_JWTClaims_CorrectSubAndIssuer(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, verifier)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeTokenResponse(t, w.Body.String())
|
||||||
|
idToken, ok := resp["id_token"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("id_token is not a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := parseJWTPayload(t, idToken)
|
||||||
|
|
||||||
|
if claims["sub"] != "user-alice" {
|
||||||
|
t.Errorf("sub: expected user-alice, got %v", claims["sub"])
|
||||||
|
}
|
||||||
|
if claims["iss"] != "https://auth.netkingdom.local" {
|
||||||
|
t.Errorf("iss: expected https://auth.netkingdom.local, got %v", claims["iss"])
|
||||||
|
}
|
||||||
|
if claims["aud"] != "test-client" {
|
||||||
|
t.Errorf("aud: expected test-client, got %v", claims["aud"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_ScopeFiltering_ProfileScope(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
// Seed session with only openid scope (no email, no groups).
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
challenge := s256Challenge(verifier)
|
||||||
|
sess := &oidc.PKCESession{
|
||||||
|
ClientID: "test-client",
|
||||||
|
RedirectURI: "https://app.example.com/callback",
|
||||||
|
PKCEChallenge: challenge,
|
||||||
|
PKCEChallengeMethod: "S256",
|
||||||
|
Username: "alice",
|
||||||
|
Scopes: []string{"openid"}, // no profile/email/groups
|
||||||
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
code := sessions.Create(sess)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeTokenResponse(t, w.Body.String())
|
||||||
|
idToken := resp["id_token"].(string)
|
||||||
|
claims := parseJWTPayload(t, idToken)
|
||||||
|
|
||||||
|
// Without profile scope, preferred_username must not be present.
|
||||||
|
if _, ok := claims["preferred_username"]; ok {
|
||||||
|
t.Error("preferred_username must be absent when profile scope is not granted")
|
||||||
|
}
|
||||||
|
// Without email scope, email must not be present.
|
||||||
|
if _, ok := claims["email"]; ok {
|
||||||
|
t.Error("email must be absent when email scope is not granted")
|
||||||
|
}
|
||||||
|
// Without groups scope, groups must not be present.
|
||||||
|
if _, ok := claims["groups"]; ok {
|
||||||
|
t.Error("groups must be absent when groups scope is not granted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_ScopeFiltering_AllScopes(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, verifier) // has openid, profile, email, groups
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := decodeTokenResponse(t, w.Body.String())
|
||||||
|
idToken := resp["id_token"].(string)
|
||||||
|
claims := parseJWTPayload(t, idToken)
|
||||||
|
|
||||||
|
if claims["preferred_username"] != "alice" {
|
||||||
|
t.Errorf("preferred_username: expected alice, got %v", claims["preferred_username"])
|
||||||
|
}
|
||||||
|
if claims["email"] != "alice@example.com" {
|
||||||
|
t.Errorf("email: expected alice@example.com, got %v", claims["email"])
|
||||||
|
}
|
||||||
|
if _, ok := claims["groups"]; !ok {
|
||||||
|
t.Error("groups claim must be present when groups scope is granted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_TokenIssuedTelemetry(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate key: %v", err)
|
||||||
|
}
|
||||||
|
capture := &captureEmitter{}
|
||||||
|
h := &oidc.TokenHandler{
|
||||||
|
ClientConfig: testClient(),
|
||||||
|
Sessions: sessions,
|
||||||
|
Users: users,
|
||||||
|
SigningKey: key,
|
||||||
|
Issuer: "https://auth.netkingdom.local",
|
||||||
|
TokenLifetime: 15 * time.Minute,
|
||||||
|
Emitter: capture,
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, verifier)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, ev := range capture.events {
|
||||||
|
if ev.EventType == telemetry.EventTokenIssued {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected token_issued telemetry event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenHandler_CodeDeletedAfterUse(t *testing.T) {
|
||||||
|
sessions := oidc.NewSessionStore()
|
||||||
|
users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}}
|
||||||
|
|
||||||
|
h, _ := newTokenHandler(t, sessions, users)
|
||||||
|
|
||||||
|
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
code := seededSession(sessions, verifier)
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"client_id": []string{"test-client"},
|
||||||
|
"redirect_uri": []string{"https://app.example.com/callback"},
|
||||||
|
"code_verifier": []string{verifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First use — should succeed.
|
||||||
|
req := tokenRequest(params)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("first use: expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second use — code should be gone.
|
||||||
|
req2 := tokenRequest(params)
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w2, req2)
|
||||||
|
if w2.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("second use: expected 400 (code replay), got %d", w2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user