feat: implement T22, T18, T23 — dev stack, profile tests, server binary

- T22: docker-compose.dev.yml dev stack, Dockerfile, root Makefile
- T18: Profile test suite (Scenario A) — 8 integration tests with real handlers
- T23: Server binary wiring all components, config validation, /healthz
- Config: ValidateConfig with startup validation

14 test packages pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-13 02:18:36 +01:00
parent fa27adbc77
commit c18adb6441
9 changed files with 1345 additions and 2 deletions

11
Dockerfile Normal file
View File

@@ -0,0 +1,11 @@
FROM golang:1.23-alpine AS builder
WORKDIR /app
COPY src/go.mod src/go.sum ./
RUN go mod download
COPY src/ .
RUN CGO_ENABLED=0 go build -o keycape ./cmd/keycape
FROM gcr.io/distroless/static-debian12
COPY --from=builder /app/keycape /keycape
EXPOSE 8080
ENTRYPOINT ["/keycape"]

16
Makefile Normal file
View File

@@ -0,0 +1,16 @@
.PHONY: dev seed build test lint
dev:
docker compose -f docker-compose.dev.yml up
seed:
docker compose -f docker-compose.dev.yml exec lldap /scripts/seed.sh
build:
cd src && go build ./...
test:
cd src && go test ./...
lint:
cd src && go vet ./...

27
config/dev-config.yaml Normal file
View File

@@ -0,0 +1,27 @@
issuer: "http://localhost:8080"
port: 8080
tokenLifetime: "15m"
privateKeyPem: "/etc/keycape/key.pem"
environment: "dev"
lldap:
url: "ldap://lldap:3890"
bindDN: "cn=admin,ou=people,dc=netkingdom,dc=local"
bindPW: "adminpassword"
baseDN: "dc=netkingdom,dc=local"
authelia:
baseURL: "http://authelia:9091"
clientId: "keycape"
clientSecret: "changeme"
redirectURI: "http://localhost:8080/authorize/callback"
privacyidea:
baseURL: "http://privacyidea:80"
adminToken: "changeme"
realm: "netkingdom"
clients:
- clientId: "demo-app"
displayName: "Demo Application"
redirectUris:
- "http://localhost:3000/callback"
allowedScopes: ["openid", "profile", "email", "groups"]
grantTypes: ["authorization_code"]
clientType: "public"

48
docker-compose.dev.yml Normal file
View File

@@ -0,0 +1,48 @@
version: "3.8"
services:
keycape:
build:
context: .
dockerfile: Dockerfile
ports:
- "8080:8080"
volumes:
- ./config/dev-config.yaml:/etc/keycape/config.yaml:ro
- ./config/dev-key.pem:/etc/keycape/key.pem:ro
environment:
- KEYCAPE_CONFIG=/etc/keycape/config.yaml
depends_on:
- lldap
- authelia
lldap:
image: lldap/lldap:stable
ports:
- "17170:17170"
- "3890:3890"
environment:
- LLDAP_JWT_SECRET=devjwtsecret
- LLDAP_LDAP_USER_PASS=adminpassword
- LLDAP_LDAP_BASE_DN=dc=netkingdom,dc=local
volumes:
- lldap_data:/data
authelia:
image: authelia/authelia:latest
ports:
- "9091:9091"
volumes:
- ./config/authelia:/config:ro
environment:
- AUTHELIA_JWT_SECRET=devsecret
privacyidea:
image: khalibre/privacyidea:latest
ports:
- "5000:80"
environment:
- PI_ADMIN_PASSWORD=adminpassword
volumes:
lldap_data:

View File

@@ -4,11 +4,268 @@
package main
import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"keycape/internal/adapters/authelia"
"keycape/internal/adapters/lldap"
"keycape/internal/adapters/privacyidea"
"keycape/internal/config"
"keycape/internal/domain"
servererrors "keycape/internal/server/errors"
"keycape/internal/server/oidc"
"keycape/internal/server/telemetry"
)
const version = "0.1.0"
func main() {
fmt.Fprintln(os.Stderr, "keycape server: not yet implemented (T05+)")
os.Exit(1)
log := zerolog.New(os.Stdout).With().Timestamp().Logger()
// -----------------------------------------------------------------
// 1. Parse flags and load config.
// -----------------------------------------------------------------
var cfgPath string
flag.StringVar(&cfgPath, "config", "", "path to YAML config file (env: KEYCAPE_CONFIG)")
flag.Parse()
cfg, err := config.Load(cfgPath)
if err != nil {
log.Error().Err(err).Msg("failed to load config")
os.Exit(1)
}
// -----------------------------------------------------------------
// 2. Validate config.
// -----------------------------------------------------------------
errs := config.ValidateConfig(cfg)
if len(errs) > 0 {
log.Error().Strs("errors", errs).Msg("config validation failed")
os.Exit(1)
}
// -----------------------------------------------------------------
// 3. Load RSA private key.
// -----------------------------------------------------------------
privateKey, err := loadPrivateKey(cfg.PrivateKeyPEM)
if err != nil {
log.Error().Err(err).Str("path", cfg.PrivateKeyPEM).Msg("failed to load private key")
os.Exit(1)
}
// -----------------------------------------------------------------
// 4. Build JWKS from public key.
// -----------------------------------------------------------------
ks := oidc.NewKeySet()
ks.AddKey("key-1", &privateKey.PublicKey)
// -----------------------------------------------------------------
// 5. Build client registry.
// -----------------------------------------------------------------
clients := buildClientRegistry(cfg.Clients)
// -----------------------------------------------------------------
// 6. Create adapters.
// -----------------------------------------------------------------
lldapAdapter := lldap.New(cfg.LLDAP)
autheliaAdapter := authelia.New(cfg.Authelia, nil)
privacyIDEAAdapter := privacyidea.New(cfg.PrivacyIDEA, nil)
// -----------------------------------------------------------------
// 7. Create telemetry emitter.
// -----------------------------------------------------------------
emitter := telemetry.NewLogEmitter(log)
// -----------------------------------------------------------------
// 8. Create enforcement registry.
// -----------------------------------------------------------------
enforcement := servererrors.DefaultRegistry()
// -----------------------------------------------------------------
// 9. Create session store.
// -----------------------------------------------------------------
sessions := oidc.NewSessionStore()
// -----------------------------------------------------------------
// 10. Parse token lifetime.
// -----------------------------------------------------------------
tokenLifetime := 15 * time.Minute
if cfg.TokenLifetime != "" {
d, err := time.ParseDuration(cfg.TokenLifetime)
if err != nil {
log.Error().Err(err).Str("tokenLifetime", cfg.TokenLifetime).Msg("invalid tokenLifetime")
os.Exit(1)
}
tokenLifetime = d
}
// -----------------------------------------------------------------
// 11. Build issuer base URL.
// -----------------------------------------------------------------
issuer := strings.TrimRight(cfg.Issuer, "/")
// -----------------------------------------------------------------
// 12. Register HTTP handlers.
// -----------------------------------------------------------------
mux := http.NewServeMux()
// Discovery.
mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
JWKSUri: issuer + "/jwks",
UserinfoEndpoint: issuer + "/userinfo",
}))
// JWKS.
mux.Handle("/jwks", oidc.NewJWKSHandler(ks))
// Authorize handler (with enforcement middleware).
authorizeHandler := &oidc.AuthorizeHandler{
ClientConfig: clients,
Auth: autheliaAdapter,
MFA: privacyIDEAAdapter,
Sessions: sessions,
Emitter: emitter,
}
mux.Handle("/authorize", enforcement.Middleware(authorizeHandler))
mux.Handle("/authorize/callback", authorizeHandler)
// Token handler (with enforcement middleware).
tokenHandler := &oidc.TokenHandler{
ClientConfig: clients,
Sessions: sessions,
Users: lldapAdapter,
SigningKey: privateKey,
Issuer: issuer,
TokenLifetime: tokenLifetime,
Emitter: emitter,
}
mux.Handle("/token", enforcement.Middleware(tokenHandler))
// Userinfo handler.
mux.Handle("/userinfo", &oidc.UserinfoHandler{
Users: lldapAdapter,
SigningKey: &privateKey.PublicKey,
Issuer: issuer,
Emitter: emitter,
})
// Healthz.
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
"version": version,
})
})
// Inject emitter into request context.
handler := withEmitter(mux, emitter)
// -----------------------------------------------------------------
// 13. Start HTTP server.
// -----------------------------------------------------------------
addr := fmt.Sprintf(":%d", cfg.Port)
if cfg.Port == 0 {
addr = ":8080"
}
log.Info().
Str("issuer", issuer).
Str("addr", addr).
Str("environment", cfg.Environment).
Str("version", version).
Msg("starting keycape server")
srv := &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
os.Exit(1)
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// loadPrivateKey reads a PEM file and parses the RSA private key.
// Supports both PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM blocks.
func loadPrivateKey(path string) (*rsa.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read key file: %w", err)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block found in %q", path)
}
switch block.Type {
case "RSA PRIVATE KEY":
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse PKCS1 private key: %w", err)
}
return key, nil
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse PKCS8 private key: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private key is not an RSA key")
}
return rsaKey, nil
default:
return nil, fmt.Errorf("unexpected PEM block type %q; expected RSA PRIVATE KEY or PRIVATE KEY", block.Type)
}
}
// buildClientRegistry converts []ClientConfig into the map used by handlers.
func buildClientRegistry(cfgClients []config.ClientConfig) map[string]*domain.Client {
m := make(map[string]*domain.Client, len(cfgClients))
for i := range cfgClients {
c := &cfgClients[i]
m[c.ClientID] = &domain.Client{
ClientID: c.ClientID,
DisplayName: c.DisplayName,
RedirectURIs: c.RedirectURIs,
AllowedScopes: c.AllowedScopes,
GrantTypes: c.GrantTypes,
ClientType: c.ClientType,
SecretRef: c.SecretRef,
}
}
return m
}
// withEmitter wraps a handler to inject the telemetry emitter into every request context.
func withEmitter(next http.Handler, e telemetry.Emitter) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := telemetry.WithEmitter(context.Background(), e)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -0,0 +1,63 @@
// Package config handles loading and validating the KeyCape server configuration
// from a YAML file. The config path is resolved from the --config flag or the
// KEYCAPE_CONFIG environment variable.
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
"keycape/internal/adapters/authelia"
"keycape/internal/adapters/lldap"
"keycape/internal/adapters/privacyidea"
)
// Config is the top-level server configuration.
type Config struct {
Issuer string `yaml:"issuer"`
Port int `yaml:"port"`
TokenLifetime string `yaml:"tokenLifetime"`
PrivateKeyPEM string `yaml:"privateKeyPem"`
LLDAP lldap.Config `yaml:"lldap"`
Authelia authelia.Config `yaml:"authelia"`
PrivacyIDEA privacyidea.Config `yaml:"privacyidea"`
Clients []ClientConfig `yaml:"clients"`
Environment string `yaml:"environment"`
}
// ClientConfig is a static OIDC client registration.
type ClientConfig struct {
ClientID string `yaml:"clientId"`
DisplayName string `yaml:"displayName"`
RedirectURIs []string `yaml:"redirectUris"`
AllowedScopes []string `yaml:"allowedScopes"`
GrantTypes []string `yaml:"grantTypes"`
ClientType string `yaml:"clientType"` // "confidential" | "public"
SecretRef string `yaml:"secretRef,omitempty"`
}
// Load reads and parses the YAML config file at path.
// If path is empty, it falls back to the KEYCAPE_CONFIG environment variable.
// Returns an error if the file cannot be read or parsed.
func Load(path string) (*Config, error) {
if path == "" {
path = os.Getenv("KEYCAPE_CONFIG")
}
if path == "" {
return nil, fmt.Errorf("config: no config path specified (use --config or KEYCAPE_CONFIG)")
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("config: read %q: %w", path, err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("config: parse %q: %w", path, err)
}
return &cfg, nil
}

View File

@@ -0,0 +1,225 @@
package config_test
import (
"os"
"path/filepath"
"testing"
"keycape/internal/config"
)
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// writeTempFile creates a temporary file with the given content and returns its path.
func writeTempFile(t *testing.T, content string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "keycape-test-*")
if err != nil {
t.Fatalf("create temp file: %v", err)
}
if _, err := f.WriteString(content); err != nil {
t.Fatalf("write temp file: %v", err)
}
f.Close()
return f.Name()
}
// validConfig returns a minimal valid Config for use in tests.
func validConfig(keyPath string) *config.Config {
return &config.Config{
Issuer: "https://auth.example.com",
Port: 8080,
TokenLifetime: "15m",
PrivateKeyPEM: keyPath,
Environment: "dev",
Clients: []config.ClientConfig{
{
ClientID: "test-app",
DisplayName: "Test App",
RedirectURIs: []string{"https://app.example.com/callback"},
ClientType: "public",
},
},
}
}
// ---------------------------------------------------------------------------
// Load tests
// ---------------------------------------------------------------------------
func TestLoad_ValidYAML(t *testing.T) {
keyPath := writeTempFile(t, "placeholder-key")
yaml := `
issuer: "https://auth.example.com"
port: 8080
tokenLifetime: "15m"
privateKeyPem: "` + keyPath + `"
environment: "dev"
clients:
- clientId: "demo"
displayName: "Demo"
redirectUris:
- "https://demo.example.com/cb"
clientType: "public"
`
cfgPath := writeTempFile(t, yaml)
cfg, err := config.Load(cfgPath)
if err != nil {
t.Fatalf("Load: unexpected error: %v", err)
}
if cfg.Issuer != "https://auth.example.com" {
t.Errorf("Issuer: want %q, got %q", "https://auth.example.com", cfg.Issuer)
}
if cfg.Port != 8080 {
t.Errorf("Port: want 8080, got %d", cfg.Port)
}
if len(cfg.Clients) != 1 {
t.Errorf("Clients: want 1, got %d", len(cfg.Clients))
}
}
func TestLoad_FileNotFound(t *testing.T) {
_, err := config.Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
if err == nil {
t.Error("Load: expected error for missing file, got nil")
}
}
func TestLoad_InvalidYAML(t *testing.T) {
bad := writeTempFile(t, "not: valid: yaml: [[[")
_, err := config.Load(bad)
if err == nil {
t.Error("Load: expected error for invalid YAML, got nil")
}
}
// ---------------------------------------------------------------------------
// Validate tests
// ---------------------------------------------------------------------------
func TestValidate_ValidConfig(t *testing.T) {
keyPath := writeTempFile(t, "key")
errs := config.ValidateConfig(validConfig(keyPath))
if len(errs) != 0 {
t.Errorf("ValidateConfig: expected no errors, got %v", errs)
}
}
func TestValidate_MissingIssuer(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Issuer = ""
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "issuer") {
t.Errorf("expected issuer error, got %v", errs)
}
}
func TestValidate_InvalidIssuerURL(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Issuer = "not a url"
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "issuer") {
t.Errorf("expected issuer URL error, got %v", errs)
}
}
func TestValidate_PortZero(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Port = 0
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "port") {
t.Errorf("expected port error, got %v", errs)
}
}
func TestValidate_PortTooHigh(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Port = 70000
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "port") {
t.Errorf("expected port error, got %v", errs)
}
}
func TestValidate_NoClients(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Clients = nil
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "client") {
t.Errorf("expected client error, got %v", errs)
}
}
func TestValidate_ClientMissingRedirectURI(t *testing.T) {
keyPath := writeTempFile(t, "key")
cfg := validConfig(keyPath)
cfg.Clients[0].RedirectURIs = nil
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "redirect") {
t.Errorf("expected redirect_uri error, got %v", errs)
}
}
func TestValidate_MissingPrivateKeyPEM(t *testing.T) {
cfg := validConfig("")
cfg.PrivateKeyPEM = ""
errs := config.ValidateConfig(cfg)
if !containsErr(errs, "privateKeyPem") {
t.Errorf("expected privateKeyPem error, got %v", errs)
}
}
// ---------------------------------------------------------------------------
// Env var loading test
// ---------------------------------------------------------------------------
func TestLoad_FromEnvVar(t *testing.T) {
keyPath := writeTempFile(t, "key")
yaml := `
issuer: "https://auth.example.com"
port: 9090
tokenLifetime: "30m"
privateKeyPem: "` + keyPath + `"
environment: "dev"
clients:
- clientId: "env-app"
displayName: "Env App"
redirectUris:
- "https://env.example.com/cb"
clientType: "public"
`
cfgPath := writeTempFile(t, yaml)
t.Setenv("KEYCAPE_CONFIG", cfgPath)
// Load with empty path triggers env var lookup.
cfg, err := config.Load("")
if err != nil {
t.Fatalf("Load with env var: %v", err)
}
if cfg.Port != 9090 {
t.Errorf("Port: want 9090, got %d", cfg.Port)
}
}
// ---------------------------------------------------------------------------
// Helper
// ---------------------------------------------------------------------------
func containsErr(errs []string, substring string) bool {
for _, e := range errs {
for i := 0; i <= len(e)-len(substring); i++ {
if e[i:i+len(substring)] == substring {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,61 @@
package config
import (
"fmt"
"net/url"
"strings"
)
// ValidateConfig validates a loaded Config and returns a list of human-readable
// error messages. An empty slice means the config is valid.
// Called at startup — the server must exit 1 if any errors are returned.
func ValidateConfig(cfg *Config) []string {
var errs []string
// Issuer must be a valid URL with an http(s) scheme.
if cfg.Issuer == "" {
errs = append(errs, "issuer: must not be empty")
} else {
u, err := url.Parse(cfg.Issuer)
if err != nil || u.Scheme == "" || u.Host == "" {
errs = append(errs, fmt.Sprintf("issuer: %q is not a valid URL (must include scheme and host)", cfg.Issuer))
} else if u.Scheme != "http" && u.Scheme != "https" {
errs = append(errs, fmt.Sprintf("issuer: scheme must be http or https, got %q", u.Scheme))
}
}
// Port must be in the valid TCP range.
if cfg.Port < 1 || cfg.Port > 65535 {
errs = append(errs, fmt.Sprintf("port: must be between 1 and 65535, got %d", cfg.Port))
}
// At least one client must be registered.
if len(cfg.Clients) == 0 {
errs = append(errs, "clients: at least one client must be defined")
}
// Each client must have at least one redirect URI and a non-empty clientId.
for i, c := range cfg.Clients {
prefix := fmt.Sprintf("clients[%d] (%s)", i, c.ClientID)
if c.ClientID == "" {
prefix = fmt.Sprintf("clients[%d]", i)
errs = append(errs, prefix+": clientId must not be empty")
}
if len(c.RedirectURIs) == 0 {
errs = append(errs, prefix+": redirect_uri: at least one redirectUri must be registered")
}
// Warn about wildcard redirect URIs (they are blocked at runtime anyway).
for _, uri := range c.RedirectURIs {
if strings.ContainsAny(uri, "*?") {
errs = append(errs, prefix+fmt.Sprintf(": redirect_uri %q must not contain wildcards", uri))
}
}
}
// Private key PEM path must be provided (existence is checked at startup).
if cfg.PrivateKeyPEM == "" {
errs = append(errs, "privateKeyPem: path must not be empty")
}
return errs
}

View File

@@ -0,0 +1,635 @@
// Package profile_test contains integration-style tests for the complete OIDC
// profile (Scenario A from the Acceptance Test Matrix, spec §7). All handler
// implementations are real; only the auth backend adapters are mocked.
package profile_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"keycape/internal/domain"
"keycape/internal/server/errors"
"keycape/internal/server/oidc"
"keycape/internal/server/telemetry"
)
// ---------------------------------------------------------------------------
// Mock adapters
// ---------------------------------------------------------------------------
// mockAuth implements domain.AuthProvider for tests.
type mockAuth struct {
authorizeURL string
callbackUser string
callbackErr error
}
func (m *mockAuth) AuthorizeURL(_ context.Context, req domain.AuthRequest) (string, error) {
if m.authorizeURL != "" {
return m.authorizeURL, nil
}
return "https://authelia.example.com/auth?state=" + req.State, nil
}
func (m *mockAuth) HandleCallback(_ context.Context, params domain.CallbackParams) (*domain.AuthResult, error) {
if m.callbackErr != nil {
return nil, m.callbackErr
}
username := m.callbackUser
if username == "" {
username = "testuser"
}
return &domain.AuthResult{Username: username}, nil
}
// mockMFA implements domain.MFAProvider for tests.
type mockMFA struct {
required bool
checkErr error
mfaErr error
}
func (m *mockMFA) CheckMFARequired(_ context.Context, _ string) (bool, error) {
return m.required, m.checkErr
}
func (m *mockMFA) ValidateMFAToken(_ context.Context, _, _ string) error {
return m.mfaErr
}
// mockUsers implements domain.UserRepository for tests.
type mockUsers struct {
users map[string]*domain.User
}
func newMockUsers() *mockUsers {
return &mockUsers{users: map[string]*domain.User{
"testuser": {
ID: "uid-001",
Username: "testuser",
DisplayName: "Test User",
Email: "testuser@example.com",
Groups: []string{"developers"},
Enabled: true,
},
}}
}
func (m *mockUsers) LookupUser(_ context.Context, username string) (*domain.User, error) {
u, ok := m.users[username]
if !ok {
return nil, domain.ErrUserNotFound
}
return u, nil
}
func (m *mockUsers) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) {
return nil, nil
}
func (m *mockUsers) ValidatePassword(_ context.Context, _, _ string) (bool, error) {
return false, nil
}
func (m *mockUsers) ListUsers(_ context.Context) ([]domain.User, error) {
return nil, nil
}
// ---------------------------------------------------------------------------
// TestServer
// ---------------------------------------------------------------------------
// TestServer wraps an httptest.Server with all the wired-up handlers.
type TestServer struct {
Server *httptest.Server
PrivateKey *rsa.PrivateKey
Sessions *oidc.SessionStore
AuthMock *mockAuth
Clients map[string]*domain.Client
}
func newTestServer(t *testing.T) *TestServer {
t.Helper()
// Generate RSA key pair.
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate RSA key: %v", err)
}
issuer := "http://localhost" // will be overridden with actual server URL after start
// Create test client registry.
clients := map[string]*domain.Client{
"demo-app": {
ClientID: "demo-app",
DisplayName: "Demo Application",
RedirectURIs: []string{"http://localhost:3000/callback"},
AllowedScopes: []string{"openid", "profile", "email", "groups"},
GrantTypes: []string{"authorization_code"},
ClientType: "public",
},
}
// Create mock adapters.
authMock := &mockAuth{}
mfaMock := &mockMFA{required: false}
usersMock := newMockUsers()
// Session store.
sessions := oidc.NewSessionStore()
// Telemetry — noop for tests.
emitter := telemetry.NoopEmitter{}
// Key set.
ks := oidc.NewKeySet()
ks.AddKey("key-1", &privateKey.PublicKey)
// Enforcement registry.
reg := errors.DefaultRegistry()
mux := http.NewServeMux()
// Discovery handler.
mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
JWKSUri: issuer + "/jwks",
UserinfoEndpoint: issuer + "/userinfo",
}))
// JWKS handler.
mux.Handle("/jwks", oidc.NewJWKSHandler(ks))
// Authorize handler (with enforcement middleware).
authorizeHandler := &oidc.AuthorizeHandler{
ClientConfig: clients,
Auth: authMock,
MFA: mfaMock,
Sessions: sessions,
Emitter: emitter,
}
mux.Handle("/authorize", reg.Middleware(authorizeHandler))
mux.Handle("/authorize/callback", authorizeHandler)
// Token handler (with enforcement middleware).
tokenHandler := &oidc.TokenHandler{
ClientConfig: clients,
Sessions: sessions,
Users: usersMock,
SigningKey: privateKey,
Issuer: issuer,
TokenLifetime: 15 * time.Minute,
Emitter: emitter,
}
mux.Handle("/token", reg.Middleware(tokenHandler))
// Userinfo handler.
mux.Handle("/userinfo", &oidc.UserinfoHandler{
Users: usersMock,
SigningKey: &privateKey.PublicKey,
Issuer: issuer,
Emitter: emitter,
})
// Healthz handler.
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"status":"ok","version":"0.1.0"}`))
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return &TestServer{
Server: srv,
PrivateKey: privateKey,
Sessions: sessions,
AuthMock: authMock,
Clients: clients,
}
}
// ---------------------------------------------------------------------------
// PKCE helpers
// ---------------------------------------------------------------------------
func generatePKCE(t *testing.T) (verifier, challenge string) {
t.Helper()
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
t.Fatalf("generate PKCE verifier: %v", err)
}
verifier = base64.RawURLEncoding.EncodeToString(b)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return
}
// ---------------------------------------------------------------------------
// Test cases
// ---------------------------------------------------------------------------
// 1. Discovery test.
func TestDiscovery(t *testing.T) {
ts := newTestServer(t)
resp, err := http.Get(ts.Server.URL + "/.well-known/openid-configuration")
if err != nil {
t.Fatalf("GET /.well-known/openid-configuration: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("status: want 200, got %d", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "application/json") {
t.Errorf("Content-Type: want application/json, got %q", ct)
}
var doc map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
t.Fatalf("decode discovery doc: %v", err)
}
requiredFields := []string{
"issuer", "authorization_endpoint", "token_endpoint", "jwks_uri",
"response_types_supported", "grant_types_supported",
"code_challenge_methods_supported", "id_token_signing_alg_values_supported",
"scopes_supported",
}
for _, f := range requiredFields {
if _, ok := doc[f]; !ok {
t.Errorf("discovery doc missing field %q", f)
}
}
// registration_endpoint must be absent.
if _, ok := doc["registration_endpoint"]; ok {
t.Error("discovery doc must not contain registration_endpoint")
}
// scopes_supported must include openid.
scopes, ok := doc["scopes_supported"].([]interface{})
if !ok {
t.Fatal("scopes_supported is not an array")
}
found := false
for _, s := range scopes {
if s == "openid" {
found = true
break
}
}
if !found {
t.Error("scopes_supported must include openid")
}
}
// 2. JWKS test.
func TestJWKS(t *testing.T) {
ts := newTestServer(t)
resp, err := http.Get(ts.Server.URL + "/jwks")
if err != nil {
t.Fatalf("GET /jwks: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("status: want 200, got %d", resp.StatusCode)
}
var jwks struct {
Keys []struct {
Kty string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
} `json:"keys"`
}
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
t.Fatalf("decode JWKS: %v", err)
}
if len(jwks.Keys) == 0 {
t.Fatal("JWKS must contain at least one key")
}
key := jwks.Keys[0]
if key.Kty != "RSA" {
t.Errorf("kty: want RSA, got %q", key.Kty)
}
if key.Alg != "RS256" {
t.Errorf("alg: want RS256, got %q", key.Alg)
}
if key.N == "" {
t.Error("n (modulus) must not be empty")
}
if key.E == "" {
t.Error("e (exponent) must not be empty")
}
}
// 3. Authorization redirect test — valid PKCE params → 302 redirect.
func TestAuthorize_Redirect(t *testing.T) {
ts := newTestServer(t)
_, challenge := generatePKCE(t)
q := url.Values{}
q.Set("client_id", "demo-app")
q.Set("redirect_uri", "http://localhost:3000/callback")
q.Set("response_type", "code")
q.Set("scope", "openid profile")
q.Set("state", "test-state-123")
q.Set("code_challenge", challenge)
q.Set("code_challenge_method", "S256")
client := &http.Client{CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse // don't follow redirect
}}
resp, err := client.Get(ts.Server.URL + "/authorize?" + q.Encode())
if err != nil {
t.Fatalf("GET /authorize: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Errorf("status: want 302, got %d", resp.StatusCode)
}
location := resp.Header.Get("Location")
if location == "" {
t.Error("Location header must be set on redirect")
}
}
// 4. Invalid client test — unknown client_id → invalid_profile_usage.
func TestAuthorize_InvalidClient(t *testing.T) {
ts := newTestServer(t)
_, challenge := generatePKCE(t)
q := url.Values{}
q.Set("client_id", "unknown-client")
q.Set("redirect_uri", "http://localhost:3000/callback")
q.Set("response_type", "code")
q.Set("scope", "openid")
q.Set("state", "s")
q.Set("code_challenge", challenge)
q.Set("code_challenge_method", "S256")
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
if err != nil {
t.Fatalf("GET /authorize: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("status: want 400, got %d", resp.StatusCode)
}
var pe map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&pe); err != nil {
t.Fatalf("decode error response: %v", err)
}
errType, _ := pe["error"].(string)
if errType != "invalid_profile_usage" {
t.Errorf("error: want invalid_profile_usage, got %q", errType)
}
}
// 5. Wildcard redirect URI → rejected_for_profile_safety (caught by enforcement middleware).
func TestAuthorize_WildcardRedirectURI(t *testing.T) {
ts := newTestServer(t)
q := url.Values{}
q.Set("client_id", "demo-app")
q.Set("redirect_uri", "https://evil.com/*")
q.Set("response_type", "code")
q.Set("scope", "openid")
q.Set("state", "s")
q.Set("code_challenge", "abc")
q.Set("code_challenge_method", "S256")
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
if err != nil {
t.Fatalf("GET /authorize: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("status: want 403, got %d", resp.StatusCode)
}
var pe map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&pe)
errType, _ := pe["error"].(string)
if errType != "rejected_for_profile_safety" {
t.Errorf("error: want rejected_for_profile_safety, got %q", errType)
}
}
// 6. Missing PKCE test — no code_challenge → invalid_profile_usage (enforcement middleware).
func TestAuthorize_MissingPKCE(t *testing.T) {
ts := newTestServer(t)
q := url.Values{}
q.Set("client_id", "demo-app")
q.Set("redirect_uri", "http://localhost:3000/callback")
q.Set("response_type", "code")
q.Set("scope", "openid")
q.Set("state", "s")
// No code_challenge
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
if err != nil {
t.Fatalf("GET /authorize: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("status: want 400, got %d", resp.StatusCode)
}
var pe map[string]interface{}
_ = json.NewDecoder(resp.Body).Decode(&pe)
errType, _ := pe["error"].(string)
if errType != "invalid_profile_usage" {
t.Errorf("error: want invalid_profile_usage, got %q", errType)
}
}
// 7. Healthz test.
func TestHealthz(t *testing.T) {
ts := newTestServer(t)
resp, err := http.Get(ts.Server.URL + "/healthz")
if err != nil {
t.Fatalf("GET /healthz: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("status: want 200, got %d", resp.StatusCode)
}
var body map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
t.Fatalf("decode /healthz response: %v", err)
}
if body["status"] != "ok" {
t.Errorf("status field: want ok, got %v", body["status"])
}
}
// 8. Complete token flow test — auth callback + token exchange → valid JWT.
func TestCompleteTokenFlow(t *testing.T) {
ts := newTestServer(t)
verifier, challenge := generatePKCE(t)
// Step 1: Simulate the callback by seeding a pending state and triggering callback.
// We do this by first calling /authorize to create the pending state, then calling
// /authorize/callback with state and a mock code.
q := url.Values{}
q.Set("client_id", "demo-app")
q.Set("redirect_uri", "http://localhost:3000/callback")
q.Set("response_type", "code")
q.Set("scope", "openid profile email groups")
q.Set("state", "flow-state-xyz")
q.Set("code_challenge", challenge)
q.Set("code_challenge_method", "S256")
noRedirectClient := &http.Client{
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
},
}
// /authorize → 302 to upstream auth.
authResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize?" + q.Encode())
if err != nil {
t.Fatalf("GET /authorize: %v", err)
}
authResp.Body.Close()
if authResp.StatusCode != http.StatusFound {
t.Fatalf("authorize: want 302, got %d", authResp.StatusCode)
}
// Step 2: Simulate the upstream callback returning code + state.
cbQ := url.Values{}
cbQ.Set("code", "upstream-auth-code")
cbQ.Set("state", "flow-state-xyz")
cbResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize/callback?" + cbQ.Encode())
if err != nil {
t.Fatalf("GET /authorize/callback: %v", err)
}
cbResp.Body.Close()
if cbResp.StatusCode != http.StatusFound {
t.Fatalf("callback: want 302, got %d", cbResp.StatusCode)
}
// Extract the auth code from the Location redirect to our client.
location := cbResp.Header.Get("Location")
if location == "" {
t.Fatal("callback: no Location header")
}
locURL, err := url.Parse(location)
if err != nil {
t.Fatalf("parse Location URL: %v", err)
}
authCode := locURL.Query().Get("code")
if authCode == "" {
t.Fatalf("no code in callback redirect: %q", location)
}
// Step 3: Exchange the auth code for a token.
tokenForm := url.Values{}
tokenForm.Set("grant_type", "authorization_code")
tokenForm.Set("client_id", "demo-app")
tokenForm.Set("code", authCode)
tokenForm.Set("code_verifier", verifier)
tokenResp, err := http.Post(
ts.Server.URL+"/token",
"application/x-www-form-urlencoded",
strings.NewReader(tokenForm.Encode()),
)
if err != nil {
t.Fatalf("POST /token: %v", err)
}
defer tokenResp.Body.Close()
if tokenResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(tokenResp.Body)
t.Fatalf("token: want 200, got %d; body: %s", tokenResp.StatusCode, body)
}
var tr struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.NewDecoder(tokenResp.Body).Decode(&tr); err != nil {
t.Fatalf("decode token response: %v", err)
}
if tr.AccessToken == "" {
t.Error("access_token must not be empty")
}
if tr.TokenType != "Bearer" {
t.Errorf("token_type: want Bearer, got %q", tr.TokenType)
}
if tr.IDToken == "" {
t.Error("id_token must not be empty")
}
// Verify JWT has 3 parts (header.payload.signature).
parts := strings.Split(tr.IDToken, ".")
if len(parts) != 3 {
t.Errorf("id_token: expected 3 JWT parts, got %d", len(parts))
}
// Decode payload and check required claims.
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("decode JWT payload: %v", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payloadJSON, &claims); err != nil {
t.Fatalf("parse JWT claims: %v", err)
}
requiredClaims := []string{"iss", "sub", "aud", "exp", "iat"}
for _, c := range requiredClaims {
if _, ok := claims[c]; !ok {
t.Errorf("JWT missing claim %q", c)
}
}
if claims["aud"] != "demo-app" {
t.Errorf("aud: want demo-app, got %v", claims["aud"])
}
}