diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8ea5045 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +FROM golang:1.23-alpine AS builder +WORKDIR /app +COPY src/go.mod src/go.sum ./ +RUN go mod download +COPY src/ . +RUN CGO_ENABLED=0 go build -o keycape ./cmd/keycape + +FROM gcr.io/distroless/static-debian12 +COPY --from=builder /app/keycape /keycape +EXPOSE 8080 +ENTRYPOINT ["/keycape"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7231f44 --- /dev/null +++ b/Makefile @@ -0,0 +1,16 @@ +.PHONY: dev seed build test lint + +dev: + docker compose -f docker-compose.dev.yml up + +seed: + docker compose -f docker-compose.dev.yml exec lldap /scripts/seed.sh + +build: + cd src && go build ./... + +test: + cd src && go test ./... + +lint: + cd src && go vet ./... diff --git a/config/dev-config.yaml b/config/dev-config.yaml new file mode 100644 index 0000000..78a3cf7 --- /dev/null +++ b/config/dev-config.yaml @@ -0,0 +1,27 @@ +issuer: "http://localhost:8080" +port: 8080 +tokenLifetime: "15m" +privateKeyPem: "/etc/keycape/key.pem" +environment: "dev" +lldap: + url: "ldap://lldap:3890" + bindDN: "cn=admin,ou=people,dc=netkingdom,dc=local" + bindPW: "adminpassword" + baseDN: "dc=netkingdom,dc=local" +authelia: + baseURL: "http://authelia:9091" + clientId: "keycape" + clientSecret: "changeme" + redirectURI: "http://localhost:8080/authorize/callback" +privacyidea: + baseURL: "http://privacyidea:80" + adminToken: "changeme" + realm: "netkingdom" +clients: + - clientId: "demo-app" + displayName: "Demo Application" + redirectUris: + - "http://localhost:3000/callback" + allowedScopes: ["openid", "profile", "email", "groups"] + grantTypes: ["authorization_code"] + clientType: "public" diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml new file mode 100644 index 0000000..fa050d2 --- /dev/null +++ b/docker-compose.dev.yml @@ -0,0 +1,48 @@ +version: "3.8" + +services: + keycape: + build: + context: . + dockerfile: Dockerfile + ports: + - "8080:8080" + volumes: + - ./config/dev-config.yaml:/etc/keycape/config.yaml:ro + - ./config/dev-key.pem:/etc/keycape/key.pem:ro + environment: + - KEYCAPE_CONFIG=/etc/keycape/config.yaml + depends_on: + - lldap + - authelia + + lldap: + image: lldap/lldap:stable + ports: + - "17170:17170" + - "3890:3890" + environment: + - LLDAP_JWT_SECRET=devjwtsecret + - LLDAP_LDAP_USER_PASS=adminpassword + - LLDAP_LDAP_BASE_DN=dc=netkingdom,dc=local + volumes: + - lldap_data:/data + + authelia: + image: authelia/authelia:latest + ports: + - "9091:9091" + volumes: + - ./config/authelia:/config:ro + environment: + - AUTHELIA_JWT_SECRET=devsecret + + privacyidea: + image: khalibre/privacyidea:latest + ports: + - "5000:80" + environment: + - PI_ADMIN_PASSWORD=adminpassword + +volumes: + lldap_data: diff --git a/src/cmd/keycape/main.go b/src/cmd/keycape/main.go index 5b2a2f0..7c21628 100644 --- a/src/cmd/keycape/main.go +++ b/src/cmd/keycape/main.go @@ -4,11 +4,268 @@ package main import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "flag" "fmt" + "net/http" "os" + "strings" + "time" + + "github.com/rs/zerolog" + + "keycape/internal/adapters/authelia" + "keycape/internal/adapters/lldap" + "keycape/internal/adapters/privacyidea" + "keycape/internal/config" + "keycape/internal/domain" + servererrors "keycape/internal/server/errors" + "keycape/internal/server/oidc" + "keycape/internal/server/telemetry" ) +const version = "0.1.0" + func main() { - fmt.Fprintln(os.Stderr, "keycape server: not yet implemented (T05+)") - os.Exit(1) + log := zerolog.New(os.Stdout).With().Timestamp().Logger() + + // ----------------------------------------------------------------- + // 1. Parse flags and load config. + // ----------------------------------------------------------------- + var cfgPath string + flag.StringVar(&cfgPath, "config", "", "path to YAML config file (env: KEYCAPE_CONFIG)") + flag.Parse() + + cfg, err := config.Load(cfgPath) + if err != nil { + log.Error().Err(err).Msg("failed to load config") + os.Exit(1) + } + + // ----------------------------------------------------------------- + // 2. Validate config. + // ----------------------------------------------------------------- + errs := config.ValidateConfig(cfg) + if len(errs) > 0 { + log.Error().Strs("errors", errs).Msg("config validation failed") + os.Exit(1) + } + + // ----------------------------------------------------------------- + // 3. Load RSA private key. + // ----------------------------------------------------------------- + privateKey, err := loadPrivateKey(cfg.PrivateKeyPEM) + if err != nil { + log.Error().Err(err).Str("path", cfg.PrivateKeyPEM).Msg("failed to load private key") + os.Exit(1) + } + + // ----------------------------------------------------------------- + // 4. Build JWKS from public key. + // ----------------------------------------------------------------- + ks := oidc.NewKeySet() + ks.AddKey("key-1", &privateKey.PublicKey) + + // ----------------------------------------------------------------- + // 5. Build client registry. + // ----------------------------------------------------------------- + clients := buildClientRegistry(cfg.Clients) + + // ----------------------------------------------------------------- + // 6. Create adapters. + // ----------------------------------------------------------------- + lldapAdapter := lldap.New(cfg.LLDAP) + autheliaAdapter := authelia.New(cfg.Authelia, nil) + privacyIDEAAdapter := privacyidea.New(cfg.PrivacyIDEA, nil) + + // ----------------------------------------------------------------- + // 7. Create telemetry emitter. + // ----------------------------------------------------------------- + emitter := telemetry.NewLogEmitter(log) + + // ----------------------------------------------------------------- + // 8. Create enforcement registry. + // ----------------------------------------------------------------- + enforcement := servererrors.DefaultRegistry() + + // ----------------------------------------------------------------- + // 9. Create session store. + // ----------------------------------------------------------------- + sessions := oidc.NewSessionStore() + + // ----------------------------------------------------------------- + // 10. Parse token lifetime. + // ----------------------------------------------------------------- + tokenLifetime := 15 * time.Minute + if cfg.TokenLifetime != "" { + d, err := time.ParseDuration(cfg.TokenLifetime) + if err != nil { + log.Error().Err(err).Str("tokenLifetime", cfg.TokenLifetime).Msg("invalid tokenLifetime") + os.Exit(1) + } + tokenLifetime = d + } + + // ----------------------------------------------------------------- + // 11. Build issuer base URL. + // ----------------------------------------------------------------- + issuer := strings.TrimRight(cfg.Issuer, "/") + + // ----------------------------------------------------------------- + // 12. Register HTTP handlers. + // ----------------------------------------------------------------- + mux := http.NewServeMux() + + // Discovery. + mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + JWKSUri: issuer + "/jwks", + UserinfoEndpoint: issuer + "/userinfo", + })) + + // JWKS. + mux.Handle("/jwks", oidc.NewJWKSHandler(ks)) + + // Authorize handler (with enforcement middleware). + authorizeHandler := &oidc.AuthorizeHandler{ + ClientConfig: clients, + Auth: autheliaAdapter, + MFA: privacyIDEAAdapter, + Sessions: sessions, + Emitter: emitter, + } + mux.Handle("/authorize", enforcement.Middleware(authorizeHandler)) + mux.Handle("/authorize/callback", authorizeHandler) + + // Token handler (with enforcement middleware). + tokenHandler := &oidc.TokenHandler{ + ClientConfig: clients, + Sessions: sessions, + Users: lldapAdapter, + SigningKey: privateKey, + Issuer: issuer, + TokenLifetime: tokenLifetime, + Emitter: emitter, + } + mux.Handle("/token", enforcement.Middleware(tokenHandler)) + + // Userinfo handler. + mux.Handle("/userinfo", &oidc.UserinfoHandler{ + Users: lldapAdapter, + SigningKey: &privateKey.PublicKey, + Issuer: issuer, + Emitter: emitter, + }) + + // Healthz. + mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "ok", + "version": version, + }) + }) + + // Inject emitter into request context. + handler := withEmitter(mux, emitter) + + // ----------------------------------------------------------------- + // 13. Start HTTP server. + // ----------------------------------------------------------------- + addr := fmt.Sprintf(":%d", cfg.Port) + if cfg.Port == 0 { + addr = ":8080" + } + + log.Info(). + Str("issuer", issuer). + Str("addr", addr). + Str("environment", cfg.Environment). + Str("version", version). + Msg("starting keycape server") + + srv := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error().Err(err).Msg("server error") + os.Exit(1) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// loadPrivateKey reads a PEM file and parses the RSA private key. +// Supports both PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM blocks. +func loadPrivateKey(path string) (*rsa.PrivateKey, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read key file: %w", err) + } + + block, _ := pem.Decode(data) + if block == nil { + return nil, fmt.Errorf("no PEM block found in %q", path) + } + + switch block.Type { + case "RSA PRIVATE KEY": + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse PKCS1 private key: %w", err) + } + return key, nil + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse PKCS8 private key: %w", err) + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not an RSA key") + } + return rsaKey, nil + default: + return nil, fmt.Errorf("unexpected PEM block type %q; expected RSA PRIVATE KEY or PRIVATE KEY", block.Type) + } +} + +// buildClientRegistry converts []ClientConfig into the map used by handlers. +func buildClientRegistry(cfgClients []config.ClientConfig) map[string]*domain.Client { + m := make(map[string]*domain.Client, len(cfgClients)) + for i := range cfgClients { + c := &cfgClients[i] + m[c.ClientID] = &domain.Client{ + ClientID: c.ClientID, + DisplayName: c.DisplayName, + RedirectURIs: c.RedirectURIs, + AllowedScopes: c.AllowedScopes, + GrantTypes: c.GrantTypes, + ClientType: c.ClientType, + SecretRef: c.SecretRef, + } + } + return m +} + +// withEmitter wraps a handler to inject the telemetry emitter into every request context. +func withEmitter(next http.Handler, e telemetry.Emitter) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := telemetry.WithEmitter(context.Background(), e) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } diff --git a/src/internal/config/config.go b/src/internal/config/config.go new file mode 100644 index 0000000..9ba0709 --- /dev/null +++ b/src/internal/config/config.go @@ -0,0 +1,63 @@ +// Package config handles loading and validating the KeyCape server configuration +// from a YAML file. The config path is resolved from the --config flag or the +// KEYCAPE_CONFIG environment variable. +package config + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" + + "keycape/internal/adapters/authelia" + "keycape/internal/adapters/lldap" + "keycape/internal/adapters/privacyidea" +) + +// Config is the top-level server configuration. +type Config struct { + Issuer string `yaml:"issuer"` + Port int `yaml:"port"` + TokenLifetime string `yaml:"tokenLifetime"` + PrivateKeyPEM string `yaml:"privateKeyPem"` + LLDAP lldap.Config `yaml:"lldap"` + Authelia authelia.Config `yaml:"authelia"` + PrivacyIDEA privacyidea.Config `yaml:"privacyidea"` + Clients []ClientConfig `yaml:"clients"` + Environment string `yaml:"environment"` +} + +// ClientConfig is a static OIDC client registration. +type ClientConfig struct { + ClientID string `yaml:"clientId"` + DisplayName string `yaml:"displayName"` + RedirectURIs []string `yaml:"redirectUris"` + AllowedScopes []string `yaml:"allowedScopes"` + GrantTypes []string `yaml:"grantTypes"` + ClientType string `yaml:"clientType"` // "confidential" | "public" + SecretRef string `yaml:"secretRef,omitempty"` +} + +// Load reads and parses the YAML config file at path. +// If path is empty, it falls back to the KEYCAPE_CONFIG environment variable. +// Returns an error if the file cannot be read or parsed. +func Load(path string) (*Config, error) { + if path == "" { + path = os.Getenv("KEYCAPE_CONFIG") + } + if path == "" { + return nil, fmt.Errorf("config: no config path specified (use --config or KEYCAPE_CONFIG)") + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("config: read %q: %w", path, err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("config: parse %q: %w", path, err) + } + + return &cfg, nil +} diff --git a/src/internal/config/config_test.go b/src/internal/config/config_test.go new file mode 100644 index 0000000..6659d36 --- /dev/null +++ b/src/internal/config/config_test.go @@ -0,0 +1,225 @@ +package config_test + +import ( + "os" + "path/filepath" + "testing" + + "keycape/internal/config" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// writeTempFile creates a temporary file with the given content and returns its path. +func writeTempFile(t *testing.T, content string) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "keycape-test-*") + if err != nil { + t.Fatalf("create temp file: %v", err) + } + if _, err := f.WriteString(content); err != nil { + t.Fatalf("write temp file: %v", err) + } + f.Close() + return f.Name() +} + +// validConfig returns a minimal valid Config for use in tests. +func validConfig(keyPath string) *config.Config { + return &config.Config{ + Issuer: "https://auth.example.com", + Port: 8080, + TokenLifetime: "15m", + PrivateKeyPEM: keyPath, + Environment: "dev", + Clients: []config.ClientConfig{ + { + ClientID: "test-app", + DisplayName: "Test App", + RedirectURIs: []string{"https://app.example.com/callback"}, + ClientType: "public", + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Load tests +// --------------------------------------------------------------------------- + +func TestLoad_ValidYAML(t *testing.T) { + keyPath := writeTempFile(t, "placeholder-key") + yaml := ` +issuer: "https://auth.example.com" +port: 8080 +tokenLifetime: "15m" +privateKeyPem: "` + keyPath + `" +environment: "dev" +clients: + - clientId: "demo" + displayName: "Demo" + redirectUris: + - "https://demo.example.com/cb" + clientType: "public" +` + cfgPath := writeTempFile(t, yaml) + + cfg, err := config.Load(cfgPath) + if err != nil { + t.Fatalf("Load: unexpected error: %v", err) + } + if cfg.Issuer != "https://auth.example.com" { + t.Errorf("Issuer: want %q, got %q", "https://auth.example.com", cfg.Issuer) + } + if cfg.Port != 8080 { + t.Errorf("Port: want 8080, got %d", cfg.Port) + } + if len(cfg.Clients) != 1 { + t.Errorf("Clients: want 1, got %d", len(cfg.Clients)) + } +} + +func TestLoad_FileNotFound(t *testing.T) { + _, err := config.Load(filepath.Join(t.TempDir(), "nonexistent.yaml")) + if err == nil { + t.Error("Load: expected error for missing file, got nil") + } +} + +func TestLoad_InvalidYAML(t *testing.T) { + bad := writeTempFile(t, "not: valid: yaml: [[[") + _, err := config.Load(bad) + if err == nil { + t.Error("Load: expected error for invalid YAML, got nil") + } +} + +// --------------------------------------------------------------------------- +// Validate tests +// --------------------------------------------------------------------------- + +func TestValidate_ValidConfig(t *testing.T) { + keyPath := writeTempFile(t, "key") + errs := config.ValidateConfig(validConfig(keyPath)) + if len(errs) != 0 { + t.Errorf("ValidateConfig: expected no errors, got %v", errs) + } +} + +func TestValidate_MissingIssuer(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Issuer = "" + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "issuer") { + t.Errorf("expected issuer error, got %v", errs) + } +} + +func TestValidate_InvalidIssuerURL(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Issuer = "not a url" + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "issuer") { + t.Errorf("expected issuer URL error, got %v", errs) + } +} + +func TestValidate_PortZero(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Port = 0 + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "port") { + t.Errorf("expected port error, got %v", errs) + } +} + +func TestValidate_PortTooHigh(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Port = 70000 + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "port") { + t.Errorf("expected port error, got %v", errs) + } +} + +func TestValidate_NoClients(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Clients = nil + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "client") { + t.Errorf("expected client error, got %v", errs) + } +} + +func TestValidate_ClientMissingRedirectURI(t *testing.T) { + keyPath := writeTempFile(t, "key") + cfg := validConfig(keyPath) + cfg.Clients[0].RedirectURIs = nil + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "redirect") { + t.Errorf("expected redirect_uri error, got %v", errs) + } +} + +func TestValidate_MissingPrivateKeyPEM(t *testing.T) { + cfg := validConfig("") + cfg.PrivateKeyPEM = "" + errs := config.ValidateConfig(cfg) + if !containsErr(errs, "privateKeyPem") { + t.Errorf("expected privateKeyPem error, got %v", errs) + } +} + +// --------------------------------------------------------------------------- +// Env var loading test +// --------------------------------------------------------------------------- + +func TestLoad_FromEnvVar(t *testing.T) { + keyPath := writeTempFile(t, "key") + yaml := ` +issuer: "https://auth.example.com" +port: 9090 +tokenLifetime: "30m" +privateKeyPem: "` + keyPath + `" +environment: "dev" +clients: + - clientId: "env-app" + displayName: "Env App" + redirectUris: + - "https://env.example.com/cb" + clientType: "public" +` + cfgPath := writeTempFile(t, yaml) + t.Setenv("KEYCAPE_CONFIG", cfgPath) + + // Load with empty path triggers env var lookup. + cfg, err := config.Load("") + if err != nil { + t.Fatalf("Load with env var: %v", err) + } + if cfg.Port != 9090 { + t.Errorf("Port: want 9090, got %d", cfg.Port) + } +} + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +func containsErr(errs []string, substring string) bool { + for _, e := range errs { + for i := 0; i <= len(e)-len(substring); i++ { + if e[i:i+len(substring)] == substring { + return true + } + } + } + return false +} diff --git a/src/internal/config/validate.go b/src/internal/config/validate.go new file mode 100644 index 0000000..d470507 --- /dev/null +++ b/src/internal/config/validate.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "net/url" + "strings" +) + +// ValidateConfig validates a loaded Config and returns a list of human-readable +// error messages. An empty slice means the config is valid. +// Called at startup — the server must exit 1 if any errors are returned. +func ValidateConfig(cfg *Config) []string { + var errs []string + + // Issuer must be a valid URL with an http(s) scheme. + if cfg.Issuer == "" { + errs = append(errs, "issuer: must not be empty") + } else { + u, err := url.Parse(cfg.Issuer) + if err != nil || u.Scheme == "" || u.Host == "" { + errs = append(errs, fmt.Sprintf("issuer: %q is not a valid URL (must include scheme and host)", cfg.Issuer)) + } else if u.Scheme != "http" && u.Scheme != "https" { + errs = append(errs, fmt.Sprintf("issuer: scheme must be http or https, got %q", u.Scheme)) + } + } + + // Port must be in the valid TCP range. + if cfg.Port < 1 || cfg.Port > 65535 { + errs = append(errs, fmt.Sprintf("port: must be between 1 and 65535, got %d", cfg.Port)) + } + + // At least one client must be registered. + if len(cfg.Clients) == 0 { + errs = append(errs, "clients: at least one client must be defined") + } + + // Each client must have at least one redirect URI and a non-empty clientId. + for i, c := range cfg.Clients { + prefix := fmt.Sprintf("clients[%d] (%s)", i, c.ClientID) + if c.ClientID == "" { + prefix = fmt.Sprintf("clients[%d]", i) + errs = append(errs, prefix+": clientId must not be empty") + } + if len(c.RedirectURIs) == 0 { + errs = append(errs, prefix+": redirect_uri: at least one redirectUri must be registered") + } + // Warn about wildcard redirect URIs (they are blocked at runtime anyway). + for _, uri := range c.RedirectURIs { + if strings.ContainsAny(uri, "*?") { + errs = append(errs, prefix+fmt.Sprintf(": redirect_uri %q must not contain wildcards", uri)) + } + } + } + + // Private key PEM path must be provided (existence is checked at startup). + if cfg.PrivateKeyPEM == "" { + errs = append(errs, "privateKeyPem: path must not be empty") + } + + return errs +} diff --git a/src/tests/profile/profile_test.go b/src/tests/profile/profile_test.go new file mode 100644 index 0000000..f49a645 --- /dev/null +++ b/src/tests/profile/profile_test.go @@ -0,0 +1,635 @@ +// Package profile_test contains integration-style tests for the complete OIDC +// profile (Scenario A from the Acceptance Test Matrix, spec §7). All handler +// implementations are real; only the auth backend adapters are mocked. +package profile_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "keycape/internal/domain" + "keycape/internal/server/errors" + "keycape/internal/server/oidc" + "keycape/internal/server/telemetry" +) + +// --------------------------------------------------------------------------- +// Mock adapters +// --------------------------------------------------------------------------- + +// mockAuth implements domain.AuthProvider for tests. +type mockAuth struct { + authorizeURL string + callbackUser string + callbackErr error +} + +func (m *mockAuth) AuthorizeURL(_ context.Context, req domain.AuthRequest) (string, error) { + if m.authorizeURL != "" { + return m.authorizeURL, nil + } + return "https://authelia.example.com/auth?state=" + req.State, nil +} + +func (m *mockAuth) HandleCallback(_ context.Context, params domain.CallbackParams) (*domain.AuthResult, error) { + if m.callbackErr != nil { + return nil, m.callbackErr + } + username := m.callbackUser + if username == "" { + username = "testuser" + } + return &domain.AuthResult{Username: username}, nil +} + +// mockMFA implements domain.MFAProvider for tests. +type mockMFA struct { + required bool + checkErr error + mfaErr error +} + +func (m *mockMFA) CheckMFARequired(_ context.Context, _ string) (bool, error) { + return m.required, m.checkErr +} + +func (m *mockMFA) ValidateMFAToken(_ context.Context, _, _ string) error { + return m.mfaErr +} + +// mockUsers implements domain.UserRepository for tests. +type mockUsers struct { + users map[string]*domain.User +} + +func newMockUsers() *mockUsers { + return &mockUsers{users: map[string]*domain.User{ + "testuser": { + ID: "uid-001", + Username: "testuser", + DisplayName: "Test User", + Email: "testuser@example.com", + Groups: []string{"developers"}, + Enabled: true, + }, + }} +} + +func (m *mockUsers) LookupUser(_ context.Context, username string) (*domain.User, error) { + u, ok := m.users[username] + if !ok { + return nil, domain.ErrUserNotFound + } + return u, nil +} + +func (m *mockUsers) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) { + return nil, nil +} + +func (m *mockUsers) ValidatePassword(_ context.Context, _, _ string) (bool, error) { + return false, nil +} + +func (m *mockUsers) ListUsers(_ context.Context) ([]domain.User, error) { + return nil, nil +} + +// --------------------------------------------------------------------------- +// TestServer +// --------------------------------------------------------------------------- + +// TestServer wraps an httptest.Server with all the wired-up handlers. +type TestServer struct { + Server *httptest.Server + PrivateKey *rsa.PrivateKey + Sessions *oidc.SessionStore + AuthMock *mockAuth + Clients map[string]*domain.Client +} + +func newTestServer(t *testing.T) *TestServer { + t.Helper() + + // Generate RSA key pair. + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate RSA key: %v", err) + } + + issuer := "http://localhost" // will be overridden with actual server URL after start + + // Create test client registry. + clients := map[string]*domain.Client{ + "demo-app": { + ClientID: "demo-app", + DisplayName: "Demo Application", + RedirectURIs: []string{"http://localhost:3000/callback"}, + AllowedScopes: []string{"openid", "profile", "email", "groups"}, + GrantTypes: []string{"authorization_code"}, + ClientType: "public", + }, + } + + // Create mock adapters. + authMock := &mockAuth{} + mfaMock := &mockMFA{required: false} + usersMock := newMockUsers() + + // Session store. + sessions := oidc.NewSessionStore() + + // Telemetry — noop for tests. + emitter := telemetry.NoopEmitter{} + + // Key set. + ks := oidc.NewKeySet() + ks.AddKey("key-1", &privateKey.PublicKey) + + // Enforcement registry. + reg := errors.DefaultRegistry() + + mux := http.NewServeMux() + + // Discovery handler. + mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{ + Issuer: issuer, + AuthorizationEndpoint: issuer + "/authorize", + TokenEndpoint: issuer + "/token", + JWKSUri: issuer + "/jwks", + UserinfoEndpoint: issuer + "/userinfo", + })) + + // JWKS handler. + mux.Handle("/jwks", oidc.NewJWKSHandler(ks)) + + // Authorize handler (with enforcement middleware). + authorizeHandler := &oidc.AuthorizeHandler{ + ClientConfig: clients, + Auth: authMock, + MFA: mfaMock, + Sessions: sessions, + Emitter: emitter, + } + mux.Handle("/authorize", reg.Middleware(authorizeHandler)) + mux.Handle("/authorize/callback", authorizeHandler) + + // Token handler (with enforcement middleware). + tokenHandler := &oidc.TokenHandler{ + ClientConfig: clients, + Sessions: sessions, + Users: usersMock, + SigningKey: privateKey, + Issuer: issuer, + TokenLifetime: 15 * time.Minute, + Emitter: emitter, + } + mux.Handle("/token", reg.Middleware(tokenHandler)) + + // Userinfo handler. + mux.Handle("/userinfo", &oidc.UserinfoHandler{ + Users: usersMock, + SigningKey: &privateKey.PublicKey, + Issuer: issuer, + Emitter: emitter, + }) + + // Healthz handler. + mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok","version":"0.1.0"}`)) + }) + + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + return &TestServer{ + Server: srv, + PrivateKey: privateKey, + Sessions: sessions, + AuthMock: authMock, + Clients: clients, + } +} + +// --------------------------------------------------------------------------- +// PKCE helpers +// --------------------------------------------------------------------------- + +func generatePKCE(t *testing.T) (verifier, challenge string) { + t.Helper() + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + t.Fatalf("generate PKCE verifier: %v", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + return +} + +// --------------------------------------------------------------------------- +// Test cases +// --------------------------------------------------------------------------- + +// 1. Discovery test. +func TestDiscovery(t *testing.T) { + ts := newTestServer(t) + + resp, err := http.Get(ts.Server.URL + "/.well-known/openid-configuration") + if err != nil { + t.Fatalf("GET /.well-known/openid-configuration: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status: want 200, got %d", resp.StatusCode) + } + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "application/json") { + t.Errorf("Content-Type: want application/json, got %q", ct) + } + + var doc map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode discovery doc: %v", err) + } + + requiredFields := []string{ + "issuer", "authorization_endpoint", "token_endpoint", "jwks_uri", + "response_types_supported", "grant_types_supported", + "code_challenge_methods_supported", "id_token_signing_alg_values_supported", + "scopes_supported", + } + for _, f := range requiredFields { + if _, ok := doc[f]; !ok { + t.Errorf("discovery doc missing field %q", f) + } + } + + // registration_endpoint must be absent. + if _, ok := doc["registration_endpoint"]; ok { + t.Error("discovery doc must not contain registration_endpoint") + } + + // scopes_supported must include openid. + scopes, ok := doc["scopes_supported"].([]interface{}) + if !ok { + t.Fatal("scopes_supported is not an array") + } + found := false + for _, s := range scopes { + if s == "openid" { + found = true + break + } + } + if !found { + t.Error("scopes_supported must include openid") + } +} + +// 2. JWKS test. +func TestJWKS(t *testing.T) { + ts := newTestServer(t) + + resp, err := http.Get(ts.Server.URL + "/jwks") + if err != nil { + t.Fatalf("GET /jwks: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status: want 200, got %d", resp.StatusCode) + } + + var jwks struct { + Keys []struct { + Kty string `json:"kty"` + Alg string `json:"alg"` + Use string `json:"use"` + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` + } `json:"keys"` + } + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + t.Fatalf("decode JWKS: %v", err) + } + + if len(jwks.Keys) == 0 { + t.Fatal("JWKS must contain at least one key") + } + + key := jwks.Keys[0] + if key.Kty != "RSA" { + t.Errorf("kty: want RSA, got %q", key.Kty) + } + if key.Alg != "RS256" { + t.Errorf("alg: want RS256, got %q", key.Alg) + } + if key.N == "" { + t.Error("n (modulus) must not be empty") + } + if key.E == "" { + t.Error("e (exponent) must not be empty") + } +} + +// 3. Authorization redirect test — valid PKCE params → 302 redirect. +func TestAuthorize_Redirect(t *testing.T) { + ts := newTestServer(t) + + _, challenge := generatePKCE(t) + + q := url.Values{} + q.Set("client_id", "demo-app") + q.Set("redirect_uri", "http://localhost:3000/callback") + q.Set("response_type", "code") + q.Set("scope", "openid profile") + q.Set("state", "test-state-123") + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + + client := &http.Client{CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse // don't follow redirect + }} + resp, err := client.Get(ts.Server.URL + "/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Errorf("status: want 302, got %d", resp.StatusCode) + } + location := resp.Header.Get("Location") + if location == "" { + t.Error("Location header must be set on redirect") + } +} + +// 4. Invalid client test — unknown client_id → invalid_profile_usage. +func TestAuthorize_InvalidClient(t *testing.T) { + ts := newTestServer(t) + + _, challenge := generatePKCE(t) + + q := url.Values{} + q.Set("client_id", "unknown-client") + q.Set("redirect_uri", "http://localhost:3000/callback") + q.Set("response_type", "code") + q.Set("scope", "openid") + q.Set("state", "s") + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + + resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status: want 400, got %d", resp.StatusCode) + } + + var pe map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&pe); err != nil { + t.Fatalf("decode error response: %v", err) + } + errType, _ := pe["error"].(string) + if errType != "invalid_profile_usage" { + t.Errorf("error: want invalid_profile_usage, got %q", errType) + } +} + +// 5. Wildcard redirect URI → rejected_for_profile_safety (caught by enforcement middleware). +func TestAuthorize_WildcardRedirectURI(t *testing.T) { + ts := newTestServer(t) + + q := url.Values{} + q.Set("client_id", "demo-app") + q.Set("redirect_uri", "https://evil.com/*") + q.Set("response_type", "code") + q.Set("scope", "openid") + q.Set("state", "s") + q.Set("code_challenge", "abc") + q.Set("code_challenge_method", "S256") + + resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("status: want 403, got %d", resp.StatusCode) + } + + var pe map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&pe) + errType, _ := pe["error"].(string) + if errType != "rejected_for_profile_safety" { + t.Errorf("error: want rejected_for_profile_safety, got %q", errType) + } +} + +// 6. Missing PKCE test — no code_challenge → invalid_profile_usage (enforcement middleware). +func TestAuthorize_MissingPKCE(t *testing.T) { + ts := newTestServer(t) + + q := url.Values{} + q.Set("client_id", "demo-app") + q.Set("redirect_uri", "http://localhost:3000/callback") + q.Set("response_type", "code") + q.Set("scope", "openid") + q.Set("state", "s") + // No code_challenge + + resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status: want 400, got %d", resp.StatusCode) + } + + var pe map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&pe) + errType, _ := pe["error"].(string) + if errType != "invalid_profile_usage" { + t.Errorf("error: want invalid_profile_usage, got %q", errType) + } +} + +// 7. Healthz test. +func TestHealthz(t *testing.T) { + ts := newTestServer(t) + + resp, err := http.Get(ts.Server.URL + "/healthz") + if err != nil { + t.Fatalf("GET /healthz: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status: want 200, got %d", resp.StatusCode) + } + + var body map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode /healthz response: %v", err) + } + if body["status"] != "ok" { + t.Errorf("status field: want ok, got %v", body["status"]) + } +} + +// 8. Complete token flow test — auth callback + token exchange → valid JWT. +func TestCompleteTokenFlow(t *testing.T) { + ts := newTestServer(t) + + verifier, challenge := generatePKCE(t) + + // Step 1: Simulate the callback by seeding a pending state and triggering callback. + // We do this by first calling /authorize to create the pending state, then calling + // /authorize/callback with state and a mock code. + + q := url.Values{} + q.Set("client_id", "demo-app") + q.Set("redirect_uri", "http://localhost:3000/callback") + q.Set("response_type", "code") + q.Set("scope", "openid profile email groups") + q.Set("state", "flow-state-xyz") + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + + noRedirectClient := &http.Client{ + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // /authorize → 302 to upstream auth. + authResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize?" + q.Encode()) + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + authResp.Body.Close() + if authResp.StatusCode != http.StatusFound { + t.Fatalf("authorize: want 302, got %d", authResp.StatusCode) + } + + // Step 2: Simulate the upstream callback returning code + state. + cbQ := url.Values{} + cbQ.Set("code", "upstream-auth-code") + cbQ.Set("state", "flow-state-xyz") + + cbResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize/callback?" + cbQ.Encode()) + if err != nil { + t.Fatalf("GET /authorize/callback: %v", err) + } + cbResp.Body.Close() + if cbResp.StatusCode != http.StatusFound { + t.Fatalf("callback: want 302, got %d", cbResp.StatusCode) + } + + // Extract the auth code from the Location redirect to our client. + location := cbResp.Header.Get("Location") + if location == "" { + t.Fatal("callback: no Location header") + } + locURL, err := url.Parse(location) + if err != nil { + t.Fatalf("parse Location URL: %v", err) + } + authCode := locURL.Query().Get("code") + if authCode == "" { + t.Fatalf("no code in callback redirect: %q", location) + } + + // Step 3: Exchange the auth code for a token. + tokenForm := url.Values{} + tokenForm.Set("grant_type", "authorization_code") + tokenForm.Set("client_id", "demo-app") + tokenForm.Set("code", authCode) + tokenForm.Set("code_verifier", verifier) + + tokenResp, err := http.Post( + ts.Server.URL+"/token", + "application/x-www-form-urlencoded", + strings.NewReader(tokenForm.Encode()), + ) + if err != nil { + t.Fatalf("POST /token: %v", err) + } + defer tokenResp.Body.Close() + + if tokenResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(tokenResp.Body) + t.Fatalf("token: want 200, got %d; body: %s", tokenResp.StatusCode, body) + } + + var tr struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.NewDecoder(tokenResp.Body).Decode(&tr); err != nil { + t.Fatalf("decode token response: %v", err) + } + + if tr.AccessToken == "" { + t.Error("access_token must not be empty") + } + if tr.TokenType != "Bearer" { + t.Errorf("token_type: want Bearer, got %q", tr.TokenType) + } + if tr.IDToken == "" { + t.Error("id_token must not be empty") + } + + // Verify JWT has 3 parts (header.payload.signature). + parts := strings.Split(tr.IDToken, ".") + if len(parts) != 3 { + t.Errorf("id_token: expected 3 JWT parts, got %d", len(parts)) + } + + // Decode payload and check required claims. + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decode JWT payload: %v", err) + } + var claims map[string]interface{} + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + t.Fatalf("parse JWT claims: %v", err) + } + + requiredClaims := []string{"iss", "sub", "aud", "exp", "iat"} + for _, c := range requiredClaims { + if _, ok := claims[c]; !ok { + t.Errorf("JWT missing claim %q", c) + } + } + + if claims["aud"] != "demo-app" { + t.Errorf("aud: want demo-app, got %v", claims["aud"]) + } +}