generated from coulomb/repo-seed
feat: implement T22, T18, T23 — dev stack, profile tests, server binary
- T22: docker-compose.dev.yml dev stack, Dockerfile, root Makefile - T18: Profile test suite (Scenario A) — 8 integration tests with real handlers - T23: Server binary wiring all components, config validation, /healthz - Config: ValidateConfig with startup validation 14 test packages pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
11
Dockerfile
Normal file
11
Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM golang:1.23-alpine AS builder
|
||||
WORKDIR /app
|
||||
COPY src/go.mod src/go.sum ./
|
||||
RUN go mod download
|
||||
COPY src/ .
|
||||
RUN CGO_ENABLED=0 go build -o keycape ./cmd/keycape
|
||||
|
||||
FROM gcr.io/distroless/static-debian12
|
||||
COPY --from=builder /app/keycape /keycape
|
||||
EXPOSE 8080
|
||||
ENTRYPOINT ["/keycape"]
|
||||
16
Makefile
Normal file
16
Makefile
Normal file
@@ -0,0 +1,16 @@
|
||||
.PHONY: dev seed build test lint
|
||||
|
||||
dev:
|
||||
docker compose -f docker-compose.dev.yml up
|
||||
|
||||
seed:
|
||||
docker compose -f docker-compose.dev.yml exec lldap /scripts/seed.sh
|
||||
|
||||
build:
|
||||
cd src && go build ./...
|
||||
|
||||
test:
|
||||
cd src && go test ./...
|
||||
|
||||
lint:
|
||||
cd src && go vet ./...
|
||||
27
config/dev-config.yaml
Normal file
27
config/dev-config.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
issuer: "http://localhost:8080"
|
||||
port: 8080
|
||||
tokenLifetime: "15m"
|
||||
privateKeyPem: "/etc/keycape/key.pem"
|
||||
environment: "dev"
|
||||
lldap:
|
||||
url: "ldap://lldap:3890"
|
||||
bindDN: "cn=admin,ou=people,dc=netkingdom,dc=local"
|
||||
bindPW: "adminpassword"
|
||||
baseDN: "dc=netkingdom,dc=local"
|
||||
authelia:
|
||||
baseURL: "http://authelia:9091"
|
||||
clientId: "keycape"
|
||||
clientSecret: "changeme"
|
||||
redirectURI: "http://localhost:8080/authorize/callback"
|
||||
privacyidea:
|
||||
baseURL: "http://privacyidea:80"
|
||||
adminToken: "changeme"
|
||||
realm: "netkingdom"
|
||||
clients:
|
||||
- clientId: "demo-app"
|
||||
displayName: "Demo Application"
|
||||
redirectUris:
|
||||
- "http://localhost:3000/callback"
|
||||
allowedScopes: ["openid", "profile", "email", "groups"]
|
||||
grantTypes: ["authorization_code"]
|
||||
clientType: "public"
|
||||
48
docker-compose.dev.yml
Normal file
48
docker-compose.dev.yml
Normal file
@@ -0,0 +1,48 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
keycape:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8080:8080"
|
||||
volumes:
|
||||
- ./config/dev-config.yaml:/etc/keycape/config.yaml:ro
|
||||
- ./config/dev-key.pem:/etc/keycape/key.pem:ro
|
||||
environment:
|
||||
- KEYCAPE_CONFIG=/etc/keycape/config.yaml
|
||||
depends_on:
|
||||
- lldap
|
||||
- authelia
|
||||
|
||||
lldap:
|
||||
image: lldap/lldap:stable
|
||||
ports:
|
||||
- "17170:17170"
|
||||
- "3890:3890"
|
||||
environment:
|
||||
- LLDAP_JWT_SECRET=devjwtsecret
|
||||
- LLDAP_LDAP_USER_PASS=adminpassword
|
||||
- LLDAP_LDAP_BASE_DN=dc=netkingdom,dc=local
|
||||
volumes:
|
||||
- lldap_data:/data
|
||||
|
||||
authelia:
|
||||
image: authelia/authelia:latest
|
||||
ports:
|
||||
- "9091:9091"
|
||||
volumes:
|
||||
- ./config/authelia:/config:ro
|
||||
environment:
|
||||
- AUTHELIA_JWT_SECRET=devsecret
|
||||
|
||||
privacyidea:
|
||||
image: khalibre/privacyidea:latest
|
||||
ports:
|
||||
- "5000:80"
|
||||
environment:
|
||||
- PI_ADMIN_PASSWORD=adminpassword
|
||||
|
||||
volumes:
|
||||
lldap_data:
|
||||
@@ -4,11 +4,268 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"keycape/internal/adapters/authelia"
|
||||
"keycape/internal/adapters/lldap"
|
||||
"keycape/internal/adapters/privacyidea"
|
||||
"keycape/internal/config"
|
||||
"keycape/internal/domain"
|
||||
servererrors "keycape/internal/server/errors"
|
||||
"keycape/internal/server/oidc"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
const version = "0.1.0"
|
||||
|
||||
func main() {
|
||||
fmt.Fprintln(os.Stderr, "keycape server: not yet implemented (T05+)")
|
||||
os.Exit(1)
|
||||
log := zerolog.New(os.Stdout).With().Timestamp().Logger()
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 1. Parse flags and load config.
|
||||
// -----------------------------------------------------------------
|
||||
var cfgPath string
|
||||
flag.StringVar(&cfgPath, "config", "", "path to YAML config file (env: KEYCAPE_CONFIG)")
|
||||
flag.Parse()
|
||||
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to load config")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 2. Validate config.
|
||||
// -----------------------------------------------------------------
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if len(errs) > 0 {
|
||||
log.Error().Strs("errors", errs).Msg("config validation failed")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 3. Load RSA private key.
|
||||
// -----------------------------------------------------------------
|
||||
privateKey, err := loadPrivateKey(cfg.PrivateKeyPEM)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("path", cfg.PrivateKeyPEM).Msg("failed to load private key")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 4. Build JWKS from public key.
|
||||
// -----------------------------------------------------------------
|
||||
ks := oidc.NewKeySet()
|
||||
ks.AddKey("key-1", &privateKey.PublicKey)
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 5. Build client registry.
|
||||
// -----------------------------------------------------------------
|
||||
clients := buildClientRegistry(cfg.Clients)
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 6. Create adapters.
|
||||
// -----------------------------------------------------------------
|
||||
lldapAdapter := lldap.New(cfg.LLDAP)
|
||||
autheliaAdapter := authelia.New(cfg.Authelia, nil)
|
||||
privacyIDEAAdapter := privacyidea.New(cfg.PrivacyIDEA, nil)
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 7. Create telemetry emitter.
|
||||
// -----------------------------------------------------------------
|
||||
emitter := telemetry.NewLogEmitter(log)
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 8. Create enforcement registry.
|
||||
// -----------------------------------------------------------------
|
||||
enforcement := servererrors.DefaultRegistry()
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 9. Create session store.
|
||||
// -----------------------------------------------------------------
|
||||
sessions := oidc.NewSessionStore()
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 10. Parse token lifetime.
|
||||
// -----------------------------------------------------------------
|
||||
tokenLifetime := 15 * time.Minute
|
||||
if cfg.TokenLifetime != "" {
|
||||
d, err := time.ParseDuration(cfg.TokenLifetime)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("tokenLifetime", cfg.TokenLifetime).Msg("invalid tokenLifetime")
|
||||
os.Exit(1)
|
||||
}
|
||||
tokenLifetime = d
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 11. Build issuer base URL.
|
||||
// -----------------------------------------------------------------
|
||||
issuer := strings.TrimRight(cfg.Issuer, "/")
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 12. Register HTTP handlers.
|
||||
// -----------------------------------------------------------------
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Discovery.
|
||||
mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
JWKSUri: issuer + "/jwks",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
}))
|
||||
|
||||
// JWKS.
|
||||
mux.Handle("/jwks", oidc.NewJWKSHandler(ks))
|
||||
|
||||
// Authorize handler (with enforcement middleware).
|
||||
authorizeHandler := &oidc.AuthorizeHandler{
|
||||
ClientConfig: clients,
|
||||
Auth: autheliaAdapter,
|
||||
MFA: privacyIDEAAdapter,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
mux.Handle("/authorize", enforcement.Middleware(authorizeHandler))
|
||||
mux.Handle("/authorize/callback", authorizeHandler)
|
||||
|
||||
// Token handler (with enforcement middleware).
|
||||
tokenHandler := &oidc.TokenHandler{
|
||||
ClientConfig: clients,
|
||||
Sessions: sessions,
|
||||
Users: lldapAdapter,
|
||||
SigningKey: privateKey,
|
||||
Issuer: issuer,
|
||||
TokenLifetime: tokenLifetime,
|
||||
Emitter: emitter,
|
||||
}
|
||||
mux.Handle("/token", enforcement.Middleware(tokenHandler))
|
||||
|
||||
// Userinfo handler.
|
||||
mux.Handle("/userinfo", &oidc.UserinfoHandler{
|
||||
Users: lldapAdapter,
|
||||
SigningKey: &privateKey.PublicKey,
|
||||
Issuer: issuer,
|
||||
Emitter: emitter,
|
||||
})
|
||||
|
||||
// Healthz.
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
"version": version,
|
||||
})
|
||||
})
|
||||
|
||||
// Inject emitter into request context.
|
||||
handler := withEmitter(mux, emitter)
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// 13. Start HTTP server.
|
||||
// -----------------------------------------------------------------
|
||||
addr := fmt.Sprintf(":%d", cfg.Port)
|
||||
if cfg.Port == 0 {
|
||||
addr = ":8080"
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("issuer", issuer).
|
||||
Str("addr", addr).
|
||||
Str("environment", cfg.Environment).
|
||||
Str("version", version).
|
||||
Msg("starting keycape server")
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// loadPrivateKey reads a PEM file and parses the RSA private key.
|
||||
// Supports both PKCS#1 ("RSA PRIVATE KEY") and PKCS#8 ("PRIVATE KEY") PEM blocks.
|
||||
func loadPrivateKey(path string) (*rsa.PrivateKey, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read key file: %w", err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(data)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block found in %q", path)
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse PKCS1 private key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse PKCS8 private key: %w", err)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private key is not an RSA key")
|
||||
}
|
||||
return rsaKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected PEM block type %q; expected RSA PRIVATE KEY or PRIVATE KEY", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// buildClientRegistry converts []ClientConfig into the map used by handlers.
|
||||
func buildClientRegistry(cfgClients []config.ClientConfig) map[string]*domain.Client {
|
||||
m := make(map[string]*domain.Client, len(cfgClients))
|
||||
for i := range cfgClients {
|
||||
c := &cfgClients[i]
|
||||
m[c.ClientID] = &domain.Client{
|
||||
ClientID: c.ClientID,
|
||||
DisplayName: c.DisplayName,
|
||||
RedirectURIs: c.RedirectURIs,
|
||||
AllowedScopes: c.AllowedScopes,
|
||||
GrantTypes: c.GrantTypes,
|
||||
ClientType: c.ClientType,
|
||||
SecretRef: c.SecretRef,
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// withEmitter wraps a handler to inject the telemetry emitter into every request context.
|
||||
func withEmitter(next http.Handler, e telemetry.Emitter) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := telemetry.WithEmitter(context.Background(), e)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
63
src/internal/config/config.go
Normal file
63
src/internal/config/config.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Package config handles loading and validating the KeyCape server configuration
|
||||
// from a YAML file. The config path is resolved from the --config flag or the
|
||||
// KEYCAPE_CONFIG environment variable.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"keycape/internal/adapters/authelia"
|
||||
"keycape/internal/adapters/lldap"
|
||||
"keycape/internal/adapters/privacyidea"
|
||||
)
|
||||
|
||||
// Config is the top-level server configuration.
|
||||
type Config struct {
|
||||
Issuer string `yaml:"issuer"`
|
||||
Port int `yaml:"port"`
|
||||
TokenLifetime string `yaml:"tokenLifetime"`
|
||||
PrivateKeyPEM string `yaml:"privateKeyPem"`
|
||||
LLDAP lldap.Config `yaml:"lldap"`
|
||||
Authelia authelia.Config `yaml:"authelia"`
|
||||
PrivacyIDEA privacyidea.Config `yaml:"privacyidea"`
|
||||
Clients []ClientConfig `yaml:"clients"`
|
||||
Environment string `yaml:"environment"`
|
||||
}
|
||||
|
||||
// ClientConfig is a static OIDC client registration.
|
||||
type ClientConfig struct {
|
||||
ClientID string `yaml:"clientId"`
|
||||
DisplayName string `yaml:"displayName"`
|
||||
RedirectURIs []string `yaml:"redirectUris"`
|
||||
AllowedScopes []string `yaml:"allowedScopes"`
|
||||
GrantTypes []string `yaml:"grantTypes"`
|
||||
ClientType string `yaml:"clientType"` // "confidential" | "public"
|
||||
SecretRef string `yaml:"secretRef,omitempty"`
|
||||
}
|
||||
|
||||
// Load reads and parses the YAML config file at path.
|
||||
// If path is empty, it falls back to the KEYCAPE_CONFIG environment variable.
|
||||
// Returns an error if the file cannot be read or parsed.
|
||||
func Load(path string) (*Config, error) {
|
||||
if path == "" {
|
||||
path = os.Getenv("KEYCAPE_CONFIG")
|
||||
}
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("config: no config path specified (use --config or KEYCAPE_CONFIG)")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: read %q: %w", path, err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("config: parse %q: %w", path, err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
225
src/internal/config/config_test.go
Normal file
225
src/internal/config/config_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"keycape/internal/config"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// writeTempFile creates a temporary file with the given content and returns its path.
|
||||
func writeTempFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
f, err := os.CreateTemp(t.TempDir(), "keycape-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp file: %v", err)
|
||||
}
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
t.Fatalf("write temp file: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
// validConfig returns a minimal valid Config for use in tests.
|
||||
func validConfig(keyPath string) *config.Config {
|
||||
return &config.Config{
|
||||
Issuer: "https://auth.example.com",
|
||||
Port: 8080,
|
||||
TokenLifetime: "15m",
|
||||
PrivateKeyPEM: keyPath,
|
||||
Environment: "dev",
|
||||
Clients: []config.ClientConfig{
|
||||
{
|
||||
ClientID: "test-app",
|
||||
DisplayName: "Test App",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
ClientType: "public",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Load tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLoad_ValidYAML(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "placeholder-key")
|
||||
yaml := `
|
||||
issuer: "https://auth.example.com"
|
||||
port: 8080
|
||||
tokenLifetime: "15m"
|
||||
privateKeyPem: "` + keyPath + `"
|
||||
environment: "dev"
|
||||
clients:
|
||||
- clientId: "demo"
|
||||
displayName: "Demo"
|
||||
redirectUris:
|
||||
- "https://demo.example.com/cb"
|
||||
clientType: "public"
|
||||
`
|
||||
cfgPath := writeTempFile(t, yaml)
|
||||
|
||||
cfg, err := config.Load(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Issuer != "https://auth.example.com" {
|
||||
t.Errorf("Issuer: want %q, got %q", "https://auth.example.com", cfg.Issuer)
|
||||
}
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("Port: want 8080, got %d", cfg.Port)
|
||||
}
|
||||
if len(cfg.Clients) != 1 {
|
||||
t.Errorf("Clients: want 1, got %d", len(cfg.Clients))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_FileNotFound(t *testing.T) {
|
||||
_, err := config.Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
|
||||
if err == nil {
|
||||
t.Error("Load: expected error for missing file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML(t *testing.T) {
|
||||
bad := writeTempFile(t, "not: valid: yaml: [[[")
|
||||
_, err := config.Load(bad)
|
||||
if err == nil {
|
||||
t.Error("Load: expected error for invalid YAML, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Validate tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidate_ValidConfig(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
errs := config.ValidateConfig(validConfig(keyPath))
|
||||
if len(errs) != 0 {
|
||||
t.Errorf("ValidateConfig: expected no errors, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MissingIssuer(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Issuer = ""
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "issuer") {
|
||||
t.Errorf("expected issuer error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_InvalidIssuerURL(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Issuer = "not a url"
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "issuer") {
|
||||
t.Errorf("expected issuer URL error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_PortZero(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Port = 0
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "port") {
|
||||
t.Errorf("expected port error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_PortTooHigh(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Port = 70000
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "port") {
|
||||
t.Errorf("expected port error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_NoClients(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Clients = nil
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "client") {
|
||||
t.Errorf("expected client error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_ClientMissingRedirectURI(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
cfg := validConfig(keyPath)
|
||||
cfg.Clients[0].RedirectURIs = nil
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "redirect") {
|
||||
t.Errorf("expected redirect_uri error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MissingPrivateKeyPEM(t *testing.T) {
|
||||
cfg := validConfig("")
|
||||
cfg.PrivateKeyPEM = ""
|
||||
errs := config.ValidateConfig(cfg)
|
||||
if !containsErr(errs, "privateKeyPem") {
|
||||
t.Errorf("expected privateKeyPem error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Env var loading test
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLoad_FromEnvVar(t *testing.T) {
|
||||
keyPath := writeTempFile(t, "key")
|
||||
yaml := `
|
||||
issuer: "https://auth.example.com"
|
||||
port: 9090
|
||||
tokenLifetime: "30m"
|
||||
privateKeyPem: "` + keyPath + `"
|
||||
environment: "dev"
|
||||
clients:
|
||||
- clientId: "env-app"
|
||||
displayName: "Env App"
|
||||
redirectUris:
|
||||
- "https://env.example.com/cb"
|
||||
clientType: "public"
|
||||
`
|
||||
cfgPath := writeTempFile(t, yaml)
|
||||
t.Setenv("KEYCAPE_CONFIG", cfgPath)
|
||||
|
||||
// Load with empty path triggers env var lookup.
|
||||
cfg, err := config.Load("")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with env var: %v", err)
|
||||
}
|
||||
if cfg.Port != 9090 {
|
||||
t.Errorf("Port: want 9090, got %d", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func containsErr(errs []string, substring string) bool {
|
||||
for _, e := range errs {
|
||||
for i := 0; i <= len(e)-len(substring); i++ {
|
||||
if e[i:i+len(substring)] == substring {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
61
src/internal/config/validate.go
Normal file
61
src/internal/config/validate.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateConfig validates a loaded Config and returns a list of human-readable
|
||||
// error messages. An empty slice means the config is valid.
|
||||
// Called at startup — the server must exit 1 if any errors are returned.
|
||||
func ValidateConfig(cfg *Config) []string {
|
||||
var errs []string
|
||||
|
||||
// Issuer must be a valid URL with an http(s) scheme.
|
||||
if cfg.Issuer == "" {
|
||||
errs = append(errs, "issuer: must not be empty")
|
||||
} else {
|
||||
u, err := url.Parse(cfg.Issuer)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
errs = append(errs, fmt.Sprintf("issuer: %q is not a valid URL (must include scheme and host)", cfg.Issuer))
|
||||
} else if u.Scheme != "http" && u.Scheme != "https" {
|
||||
errs = append(errs, fmt.Sprintf("issuer: scheme must be http or https, got %q", u.Scheme))
|
||||
}
|
||||
}
|
||||
|
||||
// Port must be in the valid TCP range.
|
||||
if cfg.Port < 1 || cfg.Port > 65535 {
|
||||
errs = append(errs, fmt.Sprintf("port: must be between 1 and 65535, got %d", cfg.Port))
|
||||
}
|
||||
|
||||
// At least one client must be registered.
|
||||
if len(cfg.Clients) == 0 {
|
||||
errs = append(errs, "clients: at least one client must be defined")
|
||||
}
|
||||
|
||||
// Each client must have at least one redirect URI and a non-empty clientId.
|
||||
for i, c := range cfg.Clients {
|
||||
prefix := fmt.Sprintf("clients[%d] (%s)", i, c.ClientID)
|
||||
if c.ClientID == "" {
|
||||
prefix = fmt.Sprintf("clients[%d]", i)
|
||||
errs = append(errs, prefix+": clientId must not be empty")
|
||||
}
|
||||
if len(c.RedirectURIs) == 0 {
|
||||
errs = append(errs, prefix+": redirect_uri: at least one redirectUri must be registered")
|
||||
}
|
||||
// Warn about wildcard redirect URIs (they are blocked at runtime anyway).
|
||||
for _, uri := range c.RedirectURIs {
|
||||
if strings.ContainsAny(uri, "*?") {
|
||||
errs = append(errs, prefix+fmt.Sprintf(": redirect_uri %q must not contain wildcards", uri))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Private key PEM path must be provided (existence is checked at startup).
|
||||
if cfg.PrivateKeyPEM == "" {
|
||||
errs = append(errs, "privateKeyPem: path must not be empty")
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
635
src/tests/profile/profile_test.go
Normal file
635
src/tests/profile/profile_test.go
Normal file
@@ -0,0 +1,635 @@
|
||||
// Package profile_test contains integration-style tests for the complete OIDC
|
||||
// profile (Scenario A from the Acceptance Test Matrix, spec §7). All handler
|
||||
// implementations are real; only the auth backend adapters are mocked.
|
||||
package profile_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"keycape/internal/domain"
|
||||
"keycape/internal/server/errors"
|
||||
"keycape/internal/server/oidc"
|
||||
"keycape/internal/server/telemetry"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock adapters
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockAuth implements domain.AuthProvider for tests.
|
||||
type mockAuth struct {
|
||||
authorizeURL string
|
||||
callbackUser string
|
||||
callbackErr error
|
||||
}
|
||||
|
||||
func (m *mockAuth) AuthorizeURL(_ context.Context, req domain.AuthRequest) (string, error) {
|
||||
if m.authorizeURL != "" {
|
||||
return m.authorizeURL, nil
|
||||
}
|
||||
return "https://authelia.example.com/auth?state=" + req.State, nil
|
||||
}
|
||||
|
||||
func (m *mockAuth) HandleCallback(_ context.Context, params domain.CallbackParams) (*domain.AuthResult, error) {
|
||||
if m.callbackErr != nil {
|
||||
return nil, m.callbackErr
|
||||
}
|
||||
username := m.callbackUser
|
||||
if username == "" {
|
||||
username = "testuser"
|
||||
}
|
||||
return &domain.AuthResult{Username: username}, nil
|
||||
}
|
||||
|
||||
// mockMFA implements domain.MFAProvider for tests.
|
||||
type mockMFA struct {
|
||||
required bool
|
||||
checkErr error
|
||||
mfaErr error
|
||||
}
|
||||
|
||||
func (m *mockMFA) CheckMFARequired(_ context.Context, _ string) (bool, error) {
|
||||
return m.required, m.checkErr
|
||||
}
|
||||
|
||||
func (m *mockMFA) ValidateMFAToken(_ context.Context, _, _ string) error {
|
||||
return m.mfaErr
|
||||
}
|
||||
|
||||
// mockUsers implements domain.UserRepository for tests.
|
||||
type mockUsers struct {
|
||||
users map[string]*domain.User
|
||||
}
|
||||
|
||||
func newMockUsers() *mockUsers {
|
||||
return &mockUsers{users: map[string]*domain.User{
|
||||
"testuser": {
|
||||
ID: "uid-001",
|
||||
Username: "testuser",
|
||||
DisplayName: "Test User",
|
||||
Email: "testuser@example.com",
|
||||
Groups: []string{"developers"},
|
||||
Enabled: true,
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
func (m *mockUsers) LookupUser(_ context.Context, username string) (*domain.User, error) {
|
||||
u, ok := m.users[username]
|
||||
if !ok {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (m *mockUsers) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockUsers) ValidatePassword(_ context.Context, _, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockUsers) ListUsers(_ context.Context) ([]domain.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestServer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestServer wraps an httptest.Server with all the wired-up handlers.
|
||||
type TestServer struct {
|
||||
Server *httptest.Server
|
||||
PrivateKey *rsa.PrivateKey
|
||||
Sessions *oidc.SessionStore
|
||||
AuthMock *mockAuth
|
||||
Clients map[string]*domain.Client
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *TestServer {
|
||||
t.Helper()
|
||||
|
||||
// Generate RSA key pair.
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
issuer := "http://localhost" // will be overridden with actual server URL after start
|
||||
|
||||
// Create test client registry.
|
||||
clients := map[string]*domain.Client{
|
||||
"demo-app": {
|
||||
ClientID: "demo-app",
|
||||
DisplayName: "Demo Application",
|
||||
RedirectURIs: []string{"http://localhost:3000/callback"},
|
||||
AllowedScopes: []string{"openid", "profile", "email", "groups"},
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
ClientType: "public",
|
||||
},
|
||||
}
|
||||
|
||||
// Create mock adapters.
|
||||
authMock := &mockAuth{}
|
||||
mfaMock := &mockMFA{required: false}
|
||||
usersMock := newMockUsers()
|
||||
|
||||
// Session store.
|
||||
sessions := oidc.NewSessionStore()
|
||||
|
||||
// Telemetry — noop for tests.
|
||||
emitter := telemetry.NoopEmitter{}
|
||||
|
||||
// Key set.
|
||||
ks := oidc.NewKeySet()
|
||||
ks.AddKey("key-1", &privateKey.PublicKey)
|
||||
|
||||
// Enforcement registry.
|
||||
reg := errors.DefaultRegistry()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Discovery handler.
|
||||
mux.Handle("/.well-known/openid-configuration", oidc.NewDiscoveryHandler(oidc.DiscoveryConfig{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
JWKSUri: issuer + "/jwks",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
}))
|
||||
|
||||
// JWKS handler.
|
||||
mux.Handle("/jwks", oidc.NewJWKSHandler(ks))
|
||||
|
||||
// Authorize handler (with enforcement middleware).
|
||||
authorizeHandler := &oidc.AuthorizeHandler{
|
||||
ClientConfig: clients,
|
||||
Auth: authMock,
|
||||
MFA: mfaMock,
|
||||
Sessions: sessions,
|
||||
Emitter: emitter,
|
||||
}
|
||||
mux.Handle("/authorize", reg.Middleware(authorizeHandler))
|
||||
mux.Handle("/authorize/callback", authorizeHandler)
|
||||
|
||||
// Token handler (with enforcement middleware).
|
||||
tokenHandler := &oidc.TokenHandler{
|
||||
ClientConfig: clients,
|
||||
Sessions: sessions,
|
||||
Users: usersMock,
|
||||
SigningKey: privateKey,
|
||||
Issuer: issuer,
|
||||
TokenLifetime: 15 * time.Minute,
|
||||
Emitter: emitter,
|
||||
}
|
||||
mux.Handle("/token", reg.Middleware(tokenHandler))
|
||||
|
||||
// Userinfo handler.
|
||||
mux.Handle("/userinfo", &oidc.UserinfoHandler{
|
||||
Users: usersMock,
|
||||
SigningKey: &privateKey.PublicKey,
|
||||
Issuer: issuer,
|
||||
Emitter: emitter,
|
||||
})
|
||||
|
||||
// Healthz handler.
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok","version":"0.1.0"}`))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
return &TestServer{
|
||||
Server: srv,
|
||||
PrivateKey: privateKey,
|
||||
Sessions: sessions,
|
||||
AuthMock: authMock,
|
||||
Clients: clients,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PKCE helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func generatePKCE(t *testing.T) (verifier, challenge string) {
|
||||
t.Helper()
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
t.Fatalf("generate PKCE verifier: %v", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test cases
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// 1. Discovery test.
|
||||
func TestDiscovery(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/.well-known/openid-configuration")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /.well-known/openid-configuration: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status: want 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.HasPrefix(ct, "application/json") {
|
||||
t.Errorf("Content-Type: want application/json, got %q", ct)
|
||||
}
|
||||
|
||||
var doc map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
||||
t.Fatalf("decode discovery doc: %v", err)
|
||||
}
|
||||
|
||||
requiredFields := []string{
|
||||
"issuer", "authorization_endpoint", "token_endpoint", "jwks_uri",
|
||||
"response_types_supported", "grant_types_supported",
|
||||
"code_challenge_methods_supported", "id_token_signing_alg_values_supported",
|
||||
"scopes_supported",
|
||||
}
|
||||
for _, f := range requiredFields {
|
||||
if _, ok := doc[f]; !ok {
|
||||
t.Errorf("discovery doc missing field %q", f)
|
||||
}
|
||||
}
|
||||
|
||||
// registration_endpoint must be absent.
|
||||
if _, ok := doc["registration_endpoint"]; ok {
|
||||
t.Error("discovery doc must not contain registration_endpoint")
|
||||
}
|
||||
|
||||
// scopes_supported must include openid.
|
||||
scopes, ok := doc["scopes_supported"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatal("scopes_supported is not an array")
|
||||
}
|
||||
found := false
|
||||
for _, s := range scopes {
|
||||
if s == "openid" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("scopes_supported must include openid")
|
||||
}
|
||||
}
|
||||
|
||||
// 2. JWKS test.
|
||||
func TestJWKS(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/jwks")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /jwks: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status: want 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var jwks struct {
|
||||
Keys []struct {
|
||||
Kty string `json:"kty"`
|
||||
Alg string `json:"alg"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
} `json:"keys"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
t.Fatalf("decode JWKS: %v", err)
|
||||
}
|
||||
|
||||
if len(jwks.Keys) == 0 {
|
||||
t.Fatal("JWKS must contain at least one key")
|
||||
}
|
||||
|
||||
key := jwks.Keys[0]
|
||||
if key.Kty != "RSA" {
|
||||
t.Errorf("kty: want RSA, got %q", key.Kty)
|
||||
}
|
||||
if key.Alg != "RS256" {
|
||||
t.Errorf("alg: want RS256, got %q", key.Alg)
|
||||
}
|
||||
if key.N == "" {
|
||||
t.Error("n (modulus) must not be empty")
|
||||
}
|
||||
if key.E == "" {
|
||||
t.Error("e (exponent) must not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Authorization redirect test — valid PKCE params → 302 redirect.
|
||||
func TestAuthorize_Redirect(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
_, challenge := generatePKCE(t)
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", "demo-app")
|
||||
q.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid profile")
|
||||
q.Set("state", "test-state-123")
|
||||
q.Set("code_challenge", challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
client := &http.Client{CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse // don't follow redirect
|
||||
}}
|
||||
resp, err := client.Get(ts.Server.URL + "/authorize?" + q.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Errorf("status: want 302, got %d", resp.StatusCode)
|
||||
}
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Location header must be set on redirect")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Invalid client test — unknown client_id → invalid_profile_usage.
|
||||
func TestAuthorize_InvalidClient(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
_, challenge := generatePKCE(t)
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", "unknown-client")
|
||||
q.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid")
|
||||
q.Set("state", "s")
|
||||
q.Set("code_challenge", challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("status: want 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var pe map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&pe); err != nil {
|
||||
t.Fatalf("decode error response: %v", err)
|
||||
}
|
||||
errType, _ := pe["error"].(string)
|
||||
if errType != "invalid_profile_usage" {
|
||||
t.Errorf("error: want invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Wildcard redirect URI → rejected_for_profile_safety (caught by enforcement middleware).
|
||||
func TestAuthorize_WildcardRedirectURI(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", "demo-app")
|
||||
q.Set("redirect_uri", "https://evil.com/*")
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid")
|
||||
q.Set("state", "s")
|
||||
q.Set("code_challenge", "abc")
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("status: want 403, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var pe map[string]interface{}
|
||||
_ = json.NewDecoder(resp.Body).Decode(&pe)
|
||||
errType, _ := pe["error"].(string)
|
||||
if errType != "rejected_for_profile_safety" {
|
||||
t.Errorf("error: want rejected_for_profile_safety, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Missing PKCE test — no code_challenge → invalid_profile_usage (enforcement middleware).
|
||||
func TestAuthorize_MissingPKCE(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", "demo-app")
|
||||
q.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid")
|
||||
q.Set("state", "s")
|
||||
// No code_challenge
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/authorize?" + q.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("status: want 400, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var pe map[string]interface{}
|
||||
_ = json.NewDecoder(resp.Body).Decode(&pe)
|
||||
errType, _ := pe["error"].(string)
|
||||
if errType != "invalid_profile_usage" {
|
||||
t.Errorf("error: want invalid_profile_usage, got %q", errType)
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Healthz test.
|
||||
func TestHealthz(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
resp, err := http.Get(ts.Server.URL + "/healthz")
|
||||
if err != nil {
|
||||
t.Fatalf("GET /healthz: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status: want 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode /healthz response: %v", err)
|
||||
}
|
||||
if body["status"] != "ok" {
|
||||
t.Errorf("status field: want ok, got %v", body["status"])
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Complete token flow test — auth callback + token exchange → valid JWT.
|
||||
func TestCompleteTokenFlow(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
|
||||
verifier, challenge := generatePKCE(t)
|
||||
|
||||
// Step 1: Simulate the callback by seeding a pending state and triggering callback.
|
||||
// We do this by first calling /authorize to create the pending state, then calling
|
||||
// /authorize/callback with state and a mock code.
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", "demo-app")
|
||||
q.Set("redirect_uri", "http://localhost:3000/callback")
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid profile email groups")
|
||||
q.Set("state", "flow-state-xyz")
|
||||
q.Set("code_challenge", challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
noRedirectClient := &http.Client{
|
||||
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
// /authorize → 302 to upstream auth.
|
||||
authResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize?" + q.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize: %v", err)
|
||||
}
|
||||
authResp.Body.Close()
|
||||
if authResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("authorize: want 302, got %d", authResp.StatusCode)
|
||||
}
|
||||
|
||||
// Step 2: Simulate the upstream callback returning code + state.
|
||||
cbQ := url.Values{}
|
||||
cbQ.Set("code", "upstream-auth-code")
|
||||
cbQ.Set("state", "flow-state-xyz")
|
||||
|
||||
cbResp, err := noRedirectClient.Get(ts.Server.URL + "/authorize/callback?" + cbQ.Encode())
|
||||
if err != nil {
|
||||
t.Fatalf("GET /authorize/callback: %v", err)
|
||||
}
|
||||
cbResp.Body.Close()
|
||||
if cbResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("callback: want 302, got %d", cbResp.StatusCode)
|
||||
}
|
||||
|
||||
// Extract the auth code from the Location redirect to our client.
|
||||
location := cbResp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("callback: no Location header")
|
||||
}
|
||||
locURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("parse Location URL: %v", err)
|
||||
}
|
||||
authCode := locURL.Query().Get("code")
|
||||
if authCode == "" {
|
||||
t.Fatalf("no code in callback redirect: %q", location)
|
||||
}
|
||||
|
||||
// Step 3: Exchange the auth code for a token.
|
||||
tokenForm := url.Values{}
|
||||
tokenForm.Set("grant_type", "authorization_code")
|
||||
tokenForm.Set("client_id", "demo-app")
|
||||
tokenForm.Set("code", authCode)
|
||||
tokenForm.Set("code_verifier", verifier)
|
||||
|
||||
tokenResp, err := http.Post(
|
||||
ts.Server.URL+"/token",
|
||||
"application/x-www-form-urlencoded",
|
||||
strings.NewReader(tokenForm.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("POST /token: %v", err)
|
||||
}
|
||||
defer tokenResp.Body.Close()
|
||||
|
||||
if tokenResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(tokenResp.Body)
|
||||
t.Fatalf("token: want 200, got %d; body: %s", tokenResp.StatusCode, body)
|
||||
}
|
||||
|
||||
var tr struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
if err := json.NewDecoder(tokenResp.Body).Decode(&tr); err != nil {
|
||||
t.Fatalf("decode token response: %v", err)
|
||||
}
|
||||
|
||||
if tr.AccessToken == "" {
|
||||
t.Error("access_token must not be empty")
|
||||
}
|
||||
if tr.TokenType != "Bearer" {
|
||||
t.Errorf("token_type: want Bearer, got %q", tr.TokenType)
|
||||
}
|
||||
if tr.IDToken == "" {
|
||||
t.Error("id_token must not be empty")
|
||||
}
|
||||
|
||||
// Verify JWT has 3 parts (header.payload.signature).
|
||||
parts := strings.Split(tr.IDToken, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("id_token: expected 3 JWT parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Decode payload and check required claims.
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("decode JWT payload: %v", err)
|
||||
}
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payloadJSON, &claims); err != nil {
|
||||
t.Fatalf("parse JWT claims: %v", err)
|
||||
}
|
||||
|
||||
requiredClaims := []string{"iss", "sub", "aud", "exp", "iat"}
|
||||
for _, c := range requiredClaims {
|
||||
if _, ok := claims[c]; !ok {
|
||||
t.Errorf("JWT missing claim %q", c)
|
||||
}
|
||||
}
|
||||
|
||||
if claims["aud"] != "demo-app" {
|
||||
t.Errorf("aud: want demo-app, got %v", claims["aud"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user