generated from coulomb/repo-seed
bootrapping support
Some checks failed
Build and Publish Container Image / build-and-push (push) Has been cancelled
Some checks failed
Build and Publish Container Image / build-and-push (push) Has been cancelled
This commit is contained in:
@@ -43,7 +43,7 @@ func New(cfg Config, httpClient HTTPClient) *AutheliaAdapter {
|
||||
// values — and requests the full fixed scope set. PKCE is omitted because
|
||||
// the confidential client_secret authenticates the token exchange instead.
|
||||
func (a *AutheliaAdapter) AuthorizeURL(_ context.Context, req domain.AuthRequest) (string, error) {
|
||||
base := strings.TrimRight(a.cfg.BaseURL, "/") + "/api/oidc/authorization"
|
||||
base := strings.TrimRight(a.authorizeBaseURL(), "/") + "/api/oidc/authorization"
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", a.cfg.ClientID)
|
||||
@@ -136,7 +136,7 @@ type tokenResponse struct {
|
||||
// exchangeCode sends a POST to Authelia's token endpoint and returns the
|
||||
// parsed token response. On any HTTP or status error it returns a non-nil error.
|
||||
func (a *AutheliaAdapter) exchangeCode(_ context.Context, code string) (*tokenResponse, error) {
|
||||
tokenURL := strings.TrimRight(a.cfg.BaseURL, "/") + "/api/oidc/token"
|
||||
tokenURL := strings.TrimRight(a.tokenBaseURL(), "/") + "/api/oidc/token"
|
||||
|
||||
body := url.Values{}
|
||||
body.Set("grant_type", "authorization_code")
|
||||
@@ -173,6 +173,20 @@ func (a *AutheliaAdapter) exchangeCode(_ context.Context, code string) (*tokenRe
|
||||
return &tr, nil
|
||||
}
|
||||
|
||||
func (a *AutheliaAdapter) authorizeBaseURL() string {
|
||||
if a.cfg.BrowserBaseURL != "" {
|
||||
return a.cfg.BrowserBaseURL
|
||||
}
|
||||
return a.cfg.BaseURL
|
||||
}
|
||||
|
||||
func (a *AutheliaAdapter) tokenBaseURL() string {
|
||||
if a.cfg.TokenBaseURL != "" {
|
||||
return a.cfg.TokenBaseURL
|
||||
}
|
||||
return a.cfg.BaseURL
|
||||
}
|
||||
|
||||
// parseIDTokenClaims extracts the JWT payload claims without verifying the
|
||||
// signature. This is intentional — the token is received directly from the
|
||||
// upstream OIDC provider over a server-to-server TLS connection.
|
||||
|
||||
@@ -136,6 +136,33 @@ func TestAuthorizeURL_UsesBaseURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeURL_UsesBrowserBaseURLWhenConfigured(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.BaseURL = "http://authelia.sso.svc.cluster.local:9091"
|
||||
cfg.BrowserBaseURL = "https://auth.coulomb.social"
|
||||
|
||||
adapter := authelia.New(cfg, &mockHTTPClient{})
|
||||
req := domain.AuthRequest{
|
||||
ClientID: "app",
|
||||
RedirectURI: "https://app.local/cb",
|
||||
State: "s",
|
||||
PKCEChallenge: "c",
|
||||
PKCEChallengeMethod: "S256",
|
||||
Scopes: []string{"openid"},
|
||||
}
|
||||
|
||||
u, err := adapter.AuthorizeURL(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(u, "https://auth.coulomb.social") {
|
||||
t.Errorf("expected URL to start with BrowserBaseURL, got: %s", u)
|
||||
}
|
||||
if strings.Contains(u, "authelia.sso.svc.cluster.local") {
|
||||
t.Errorf("browser redirect must not use internal service URL, got: %s", u)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleCallback — successful token exchange
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -172,6 +199,32 @@ func TestHandleCallback_Success_PreferredUsername(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCallback_UsesTokenBaseURLWhenConfigured(t *testing.T) {
|
||||
tokenBody := buildTokenResponse(map[string]interface{}{
|
||||
"sub": "user-uuid-123",
|
||||
"preferred_username": "alice",
|
||||
})
|
||||
var tokenURL string
|
||||
client := &mockHTTPClient{
|
||||
doFn: func(req *http.Request) (*http.Response, error) {
|
||||
tokenURL = req.URL.String()
|
||||
return jsonResponse(tokenBody), nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.BaseURL = "https://auth.coulomb.social"
|
||||
cfg.TokenBaseURL = "http://authelia.sso.svc.cluster.local:9091"
|
||||
|
||||
adapter := authelia.New(cfg, client)
|
||||
if _, err := adapter.HandleCallback(context.Background(), domain.CallbackParams{Code: "code"}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(tokenURL, "http://authelia.sso.svc.cluster.local:9091") {
|
||||
t.Errorf("expected token exchange to use TokenBaseURL, got: %s", tokenURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCallback_Success_FallsBackToSub(t *testing.T) {
|
||||
tokenBody := buildTokenResponse(map[string]interface{}{
|
||||
"sub": "user-uuid-456",
|
||||
|
||||
@@ -8,17 +8,25 @@ import "net/http"
|
||||
// Config holds all connection parameters for the Authelia adapter.
|
||||
type Config struct {
|
||||
// BaseURL is the Authelia server base URL, e.g. "https://authelia.local".
|
||||
BaseURL string
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
|
||||
// BrowserBaseURL is the public Authelia URL used for browser redirects.
|
||||
// If empty, BaseURL is used.
|
||||
BrowserBaseURL string `yaml:"browserBaseURL,omitempty"`
|
||||
|
||||
// TokenBaseURL is the server-side Authelia URL used for token exchange.
|
||||
// If empty, BaseURL is used.
|
||||
TokenBaseURL string `yaml:"tokenBaseURL,omitempty"`
|
||||
|
||||
// ClientID is the client ID registered in Authelia for KeyCape.
|
||||
ClientID string
|
||||
ClientID string `yaml:"clientId"`
|
||||
|
||||
// ClientSecret is the client secret for the KeyCape client registration.
|
||||
ClientSecret string
|
||||
ClientSecret string `yaml:"clientSecret"`
|
||||
|
||||
// RedirectURI is the callback URL registered in Authelia that points back
|
||||
// to KeyCape's callback handler.
|
||||
RedirectURI string
|
||||
RedirectURI string `yaml:"redirectURI"`
|
||||
}
|
||||
|
||||
// HTTPClient is a minimal interface over net/http.Client for test injection.
|
||||
|
||||
@@ -81,6 +81,52 @@ clients:
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_AutheliaSplitURLs(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "placeholder-key")
|
||||
yaml := `
|
||||
issuer: "https://kc.example.com"
|
||||
port: 8080
|
||||
tokenLifetime: "15m"
|
||||
privateKeyPem: "` + keyPath + `"
|
||||
environment: "dev"
|
||||
authelia:
|
||||
baseURL: "http://authelia.sso.svc.cluster.local:9091"
|
||||
browserBaseURL: "https://auth.example.com"
|
||||
tokenBaseURL: "http://authelia.sso.svc.cluster.local:9091"
|
||||
clientId: "keycape"
|
||||
clientSecret: "secret"
|
||||
redirectURI: "https://kc.example.com/authorize/callback"
|
||||
clients:
|
||||
- clientId: "netkingdom-bootstrap-console"
|
||||
displayName: "NetKingdom Bootstrap Console"
|
||||
redirectUris:
|
||||
- "http://127.0.0.1:8876/oidc/callback"
|
||||
- "http://localhost:8876/oidc/callback"
|
||||
clientType: "public"
|
||||
`
|
||||
cfgPath := writeTempFile(t, yaml)
|
||||
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Authelia.BaseURL != "http://authelia.sso.svc.cluster.local:9091" {
|
||||
t.Errorf("Authelia.BaseURL: got %q", cfg.Authelia.BaseURL)
|
||||
}
|
||||
if cfg.Authelia.BrowserBaseURL != "https://auth.example.com" {
|
||||
t.Errorf("Authelia.BrowserBaseURL: got %q", cfg.Authelia.BrowserBaseURL)
|
||||
}
|
||||
if cfg.Authelia.TokenBaseURL != "http://authelia.sso.svc.cluster.local:9091" {
|
||||
t.Errorf("Authelia.TokenBaseURL: got %q", cfg.Authelia.TokenBaseURL)
|
||||
}
|
||||
if len(cfg.Clients) != 1 || cfg.Clients[0].ClientID != "netkingdom-bootstrap-console" {
|
||||
t.Fatalf("bootstrap client not loaded: %+v", cfg.Clients)
|
||||
}
|
||||
if got := cfg.Clients[0].RedirectURIs; len(got) != 2 || got[0] != "http://127.0.0.1:8876/oidc/callback" {
|
||||
t.Errorf("bootstrap redirect URIs not loaded: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_FileNotFound(t *testing.T) {
|
||||
_, err := config.Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
|
||||
if err == nil {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -22,6 +25,7 @@ type PendingState struct {
|
||||
State string
|
||||
Scopes []string
|
||||
ExpiresAt time.Time
|
||||
AuthenticatedUser string
|
||||
}
|
||||
|
||||
// pendingStateStore is a thread-safe map of state → PendingState.
|
||||
@@ -212,6 +216,17 @@ func (h *AuthorizeHandler) serveAuthorize(w http.ResponseWriter, r *http.Request
|
||||
func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Request) {
|
||||
h.init()
|
||||
ctx := r.Context()
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
h.serveMFASubmission(w, r)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
w.Header().Set("Allow", "GET, POST")
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
|
||||
state := q.Get("state")
|
||||
@@ -229,7 +244,6 @@ func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Requ
|
||||
http.Error(w, "authorization request expired", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
h.pending.Delete(state)
|
||||
|
||||
// Handle upstream callback.
|
||||
result, err := h.Auth.HandleCallback(ctx, domain.CallbackParams{
|
||||
@@ -248,6 +262,19 @@ func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Requ
|
||||
http.Error(w, "authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if result == nil || result.Username == "" {
|
||||
h.pending.Delete(state)
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "auth_failed",
|
||||
})
|
||||
http.Error(w, "authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Check MFA requirement.
|
||||
mfaRequired, err := h.MFA.CheckMFARequired(ctx, result.Username)
|
||||
@@ -256,34 +283,80 @@ func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
if mfaRequired {
|
||||
if mfaToken == "" {
|
||||
ps.AuthenticatedUser = result.Username
|
||||
h.pending.Store(state, ps)
|
||||
h.renderMFAChallenge(w, ps, "")
|
||||
return
|
||||
}
|
||||
if err := h.MFA.ValidateMFAToken(ctx, result.Username, mfaToken); err != nil {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: ps.ClientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "mfa_failed",
|
||||
})
|
||||
h.pending.Delete(state)
|
||||
h.emitMFAFailure(ctx, ps.ClientID)
|
||||
http.Error(w, "MFA validation failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.pending.Delete(state)
|
||||
h.completeAuthorization(w, r, ps, result.Username)
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) serveMFASubmission(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.Form.Get("state")
|
||||
mfaToken := r.Form.Get("mfa_token")
|
||||
|
||||
ps, ok := h.pending.Load(state)
|
||||
if !ok {
|
||||
http.Error(w, "unknown or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if time.Now().After(ps.ExpiresAt) {
|
||||
h.pending.Delete(state)
|
||||
http.Error(w, "authorization request expired", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if ps.AuthenticatedUser == "" {
|
||||
h.pending.Delete(state)
|
||||
http.Error(w, "mfa challenge not active", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(mfaToken) == "" {
|
||||
h.renderMFAChallenge(w, ps, "Enter the one-time code.")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.MFA.ValidateMFAToken(ctx, ps.AuthenticatedUser, mfaToken); err != nil {
|
||||
h.pending.Delete(state)
|
||||
h.emitMFAFailure(ctx, ps.ClientID)
|
||||
http.Error(w, "MFA validation failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
h.pending.Delete(state)
|
||||
h.completeAuthorization(w, r, ps, ps.AuthenticatedUser)
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) completeAuthorization(w http.ResponseWriter, r *http.Request, ps *PendingState, username string) {
|
||||
// Generate authorization code and store PKCE session.
|
||||
sess := &PKCESession{
|
||||
ClientID: ps.ClientID,
|
||||
RedirectURI: ps.RedirectURI,
|
||||
PKCEChallenge: ps.PKCEChallenge,
|
||||
PKCEChallengeMethod: ps.PKCEChallengeMethod,
|
||||
State: state,
|
||||
Username: result.Username,
|
||||
State: ps.State,
|
||||
Username: username,
|
||||
Scopes: ps.Scopes,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
authCode := h.Sessions.Create(sess)
|
||||
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
h.Emitter.Emit(r.Context(), telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthSuccess,
|
||||
ClientID: ps.ClientID,
|
||||
@@ -293,14 +366,94 @@ func (h *AuthorizeHandler) ServeHTTPCallback(w http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
|
||||
// Redirect to client with code and state.
|
||||
redirectTo := ps.RedirectURI + "?code=" + authCode + "&state=" + state
|
||||
http.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
redirectTo, err := url.Parse(ps.RedirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
q := redirectTo.Query()
|
||||
q.Set("code", authCode)
|
||||
q.Set("state", ps.State)
|
||||
redirectTo.RawQuery = q.Encode()
|
||||
http.Redirect(w, r, redirectTo.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) emitMFAFailure(ctx context.Context, clientID string) {
|
||||
h.Emitter.Emit(ctx, telemetry.Event{
|
||||
Timestamp: time.Now(),
|
||||
EventType: telemetry.EventAuthFailure,
|
||||
ClientID: clientID,
|
||||
Endpoint: "/authorize/callback",
|
||||
Result: "failure",
|
||||
ErrorType: "mfa_failed",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthorizeHandler) renderMFAChallenge(w http.ResponseWriter, ps *PendingState, errorMessage string) {
|
||||
clientName := ps.ClientID
|
||||
if client, ok := h.ClientConfig[ps.ClientID]; ok && client.DisplayName != "" {
|
||||
clientName = client.DisplayName
|
||||
}
|
||||
status := http.StatusOK
|
||||
if errorMessage != "" {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_ = mfaChallengeTemplate.Execute(w, struct {
|
||||
State string
|
||||
Username string
|
||||
ClientName string
|
||||
ErrorMessage string
|
||||
}{
|
||||
State: ps.State,
|
||||
Username: ps.AuthenticatedUser,
|
||||
ClientName: clientName,
|
||||
ErrorMessage: errorMessage,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var mfaChallengeTemplate = template.Must(template.New("mfa-challenge").Parse(`<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>KeyCape MFA</title>
|
||||
<style>
|
||||
:root { color-scheme: light; font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; }
|
||||
body { margin: 0; min-height: 100vh; display: grid; place-items: center; background: #f6f7f9; color: #17202a; }
|
||||
main { width: min(420px, calc(100vw - 32px)); background: #fff; border: 1px solid #dfe4ea; border-radius: 8px; padding: 28px; box-shadow: 0 18px 45px rgba(23, 32, 42, .08); }
|
||||
h1 { margin: 0 0 6px; font-size: 22px; font-weight: 650; letter-spacing: 0; }
|
||||
p { margin: 0 0 20px; color: #52606d; line-height: 1.45; }
|
||||
label { display: block; margin: 0 0 8px; font-size: 13px; font-weight: 650; color: #344054; }
|
||||
input[type="text"] { width: 100%; box-sizing: border-box; height: 44px; border: 1px solid #c9d3df; border-radius: 6px; padding: 0 12px; font: inherit; background: #fff; }
|
||||
input[type="text"]:focus { outline: 2px solid #2f80ed; outline-offset: 2px; border-color: #2f80ed; }
|
||||
button { width: 100%; height: 44px; border: 0; border-radius: 6px; margin-top: 16px; background: #17324d; color: #fff; font: inherit; font-weight: 650; cursor: pointer; }
|
||||
button:focus { outline: 2px solid #2f80ed; outline-offset: 2px; }
|
||||
.meta { font-size: 13px; color: #667085; }
|
||||
.error { margin: 0 0 12px; color: #b42318; font-size: 13px; font-weight: 650; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<h1>Verify sign-in</h1>
|
||||
<p class="meta">{{.Username}} for {{.ClientName}}</p>
|
||||
{{if .ErrorMessage}}<p class="error">{{.ErrorMessage}}</p>{{end}}
|
||||
<form method="post" action="/authorize/callback" autocomplete="off">
|
||||
<input type="hidden" name="state" value="{{.State}}">
|
||||
<label for="mfa_token">One-time code</label>
|
||||
<input id="mfa_token" name="mfa_token" type="text" inputmode="numeric" autocomplete="one-time-code" required autofocus>
|
||||
<button type="submit">Verify</button>
|
||||
</form>
|
||||
</main>
|
||||
</body>
|
||||
</html>`))
|
||||
|
||||
func uriRegistered(registered []string, target string) bool {
|
||||
for _, u := range registered {
|
||||
if u == target {
|
||||
|
||||
@@ -45,14 +45,20 @@ type mockMFAProvider struct {
|
||||
required bool
|
||||
requiredErr error
|
||||
|
||||
validateErr error
|
||||
validateErr error
|
||||
validateCalls int
|
||||
validatedUser string
|
||||
validatedToken string
|
||||
}
|
||||
|
||||
func (m *mockMFAProvider) CheckMFARequired(_ context.Context, _ string) (bool, error) {
|
||||
return m.required, m.requiredErr
|
||||
}
|
||||
|
||||
func (m *mockMFAProvider) ValidateMFAToken(_ context.Context, _, _ string) error {
|
||||
func (m *mockMFAProvider) ValidateMFAToken(_ context.Context, user, token string) error {
|
||||
m.validateCalls++
|
||||
m.validatedUser = user
|
||||
m.validatedToken = token
|
||||
return m.validateErr
|
||||
}
|
||||
|
||||
@@ -80,10 +86,21 @@ func testClient() map[string]*domain.Client {
|
||||
return map[string]*domain.Client{
|
||||
"test-client": {
|
||||
ClientID: "test-client",
|
||||
DisplayName: "Test Client",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
AllowedScopes: []string{"openid", "profile", "email"},
|
||||
ClientType: "public",
|
||||
},
|
||||
"netkingdom-bootstrap-console": {
|
||||
ClientID: "netkingdom-bootstrap-console",
|
||||
DisplayName: "NetKingdom Bootstrap Console",
|
||||
RedirectURIs: []string{
|
||||
"http://127.0.0.1:8876/oidc/callback",
|
||||
"http://localhost:8876/oidc/callback",
|
||||
},
|
||||
AllowedScopes: []string{"openid", "profile", "email", "groups"},
|
||||
ClientType: "public",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +163,28 @@ func TestAuthorizeHandler_ValidRequest_RedirectsToAuthelia(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_BootstrapConsoleRedirectURI_RedirectsToAuthelia(t *testing.T) {
|
||||
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth?state=bootstrap"}
|
||||
mfa := &mockMFAProvider{}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := newAuthorizeHandler(auth, mfa, emitter)
|
||||
params := validAuthorizeParams()
|
||||
params.Set("client_id", "netkingdom-bootstrap-console")
|
||||
params.Set("redirect_uri", "http://127.0.0.1:8876/oidc/callback")
|
||||
|
||||
req := authorizeRequest(params)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
if loc := w.Header().Get("Location"); loc != "https://authelia.example.com/auth?state=bootstrap" {
|
||||
t.Errorf("expected Authelia redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeHandler_EmitsAuthStart(t *testing.T) {
|
||||
auth := &mockAuthProvider{authorizeURL: "https://authelia.example.com/auth"}
|
||||
mfa := &mockMFAProvider{}
|
||||
@@ -449,6 +488,164 @@ func TestAuthorizeCallback_MFAFailed_AuthFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_MFARequired_RendersChallengeWithoutToken(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackResult: &domain.AuthResult{Username: "alice"},
|
||||
}
|
||||
mfa := &mockMFAProvider{required: true}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "random-state",
|
||||
Scopes: []string{"openid"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/authorize/callback?code=authelia-code&state=random-state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 challenge page, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
body := w.Body.String()
|
||||
for _, want := range []string{"Verify sign-in", "alice", "Test Client", `name="mfa_token"`} {
|
||||
if !strings.Contains(body, want) {
|
||||
t.Errorf("challenge page missing %q in body: %s", want, body)
|
||||
}
|
||||
}
|
||||
if mfa.validateCalls != 0 {
|
||||
t.Errorf("MFA token should not be validated until form submission, got %d calls", mfa.validateCalls)
|
||||
}
|
||||
ps, ok := h.PendingStates().Load("random-state")
|
||||
if !ok {
|
||||
t.Fatal("expected pending state to remain for MFA form submission")
|
||||
}
|
||||
if ps.AuthenticatedUser != "alice" {
|
||||
t.Errorf("AuthenticatedUser: want alice, got %q", ps.AuthenticatedUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_MFASubmission_ValidToken_RedirectsWithCode(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{required: true}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
sessions := oidc.NewSessionStore()
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback?from=bootstrap",
|
||||
PKCEChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "random-state",
|
||||
Scopes: []string{"openid"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
AuthenticatedUser: "alice",
|
||||
})
|
||||
|
||||
form := url.Values{"state": {"random-state"}, "mfa_token": {"123456"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/authorize/callback", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("expected 302 redirect, got %d (body: %s)", w.Code, w.Body.String())
|
||||
}
|
||||
if mfa.validatedUser != "alice" || mfa.validatedToken != "123456" {
|
||||
t.Errorf("validated MFA: want alice/123456, got %q/%q", mfa.validatedUser, mfa.validatedToken)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
parsed, err := url.Parse(loc)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid Location header: %v", err)
|
||||
}
|
||||
if parsed.Query().Get("from") != "bootstrap" {
|
||||
t.Errorf("expected original redirect query to be preserved, got %q", loc)
|
||||
}
|
||||
if parsed.Query().Get("code") == "" {
|
||||
t.Error("expected code param in redirect, got empty")
|
||||
}
|
||||
if parsed.Query().Get("state") != "random-state" {
|
||||
t.Errorf("expected state=random-state, got %q", parsed.Query().Get("state"))
|
||||
}
|
||||
if _, ok := h.PendingStates().Load("random-state"); ok {
|
||||
t.Error("expected pending MFA state to be deleted after successful submission")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_MFASubmission_InvalidToken_AuthFailure(t *testing.T) {
|
||||
auth := &mockAuthProvider{}
|
||||
mfa := &mockMFAProvider{
|
||||
required: true,
|
||||
validateErr: domain.ErrMFAFailed,
|
||||
}
|
||||
emitter := &captureEmitter{}
|
||||
|
||||
h := &oidc.AuthorizeHandler{
|
||||
ClientConfig: testClient(),
|
||||
Auth: auth,
|
||||
MFA: mfa,
|
||||
Sessions: oidc.NewSessionStore(),
|
||||
Emitter: emitter,
|
||||
}
|
||||
h.PendingStates().Store("random-state", &oidc.PendingState{
|
||||
ClientID: "test-client",
|
||||
RedirectURI: "https://app.example.com/callback",
|
||||
PKCEChallenge: "abc",
|
||||
PKCEChallengeMethod: "S256",
|
||||
State: "random-state",
|
||||
Scopes: []string{"openid"},
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
AuthenticatedUser: "alice",
|
||||
})
|
||||
|
||||
form := url.Values{"state": {"random-state"}, "mfa_token": {"wrong"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/authorize/callback", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
h.ServeHTTPCallback(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
if _, ok := h.PendingStates().Load("random-state"); ok {
|
||||
t.Error("expected pending MFA state to be deleted after invalid submission")
|
||||
}
|
||||
found := false
|
||||
for _, ev := range emitter.events {
|
||||
if ev.EventType == telemetry.EventAuthFailure && ev.ErrorType == "mfa_failed" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected mfa_failed auth_failure telemetry event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback_AuthProviderFailed_AuthFailure(t *testing.T) {
|
||||
auth := &mockAuthProvider{
|
||||
callbackErr: domain.ErrAuthFailed,
|
||||
|
||||
Reference in New Issue
Block a user