diff --git a/src/cmd/lldap-export/main.go b/src/cmd/lldap-export/main.go index 4ae8df8..76e5f8b 100644 --- a/src/cmd/lldap-export/main.go +++ b/src/cmd/lldap-export/main.go @@ -3,11 +3,63 @@ package main import ( + "context" + "flag" "fmt" "os" + + "keycape/internal/adapters/lldap" + "keycape/internal/migration/lldapexport" + "keycape/internal/server/telemetry" + "keycape/internal/validator" + + "github.com/rs/zerolog" ) func main() { - fmt.Fprintln(os.Stderr, "lldap-export: not yet implemented (T06+)") - os.Exit(1) + // Flags. + url := flag.String("url", "ldap://localhost:389", "LLDAP server URL (ldap:// or ldaps://)") + bindDN := flag.String("bind-dn", "", "Service account bind DN (required)") + bindPW := flag.String("bind-pw", "", "Service account password (required)") + baseDN := flag.String("base-dn", "", "LDAP search base DN (required)") + output := flag.String("output", "canonical-export.yaml", "Output file path") + tlsSkip := flag.Bool("tls-skip-verify", false, "Skip TLS certificate verification (dev only)") + flag.Parse() + + if *bindDN == "" || *baseDN == "" { + fmt.Fprintln(os.Stderr, "lldap-export: --bind-dn and --base-dn are required") + flag.Usage() + os.Exit(1) + } + + log := zerolog.New(os.Stderr).With().Timestamp().Logger() + emitter := telemetry.NewLogEmitter(log) + + cfg := lldap.Config{ + URL: *url, + BindDN: *bindDN, + BindPW: *bindPW, + BaseDN: *baseDN, + TLSSkipVerify: *tlsSkip, + } + + repo := lldap.New(cfg) + exp := lldapexport.New(repo, validator.ModeProvisioning, emitter) + + result, err := exp.Export(context.Background(), *output) + if err != nil { + fmt.Fprintf(os.Stderr, "lldap-export: export failed: %v\n", err) + os.Exit(1) + } + + fmt.Fprintf(os.Stdout, "Exported %d users, %d groups to %s\n", + len(result.Users), len(result.Groups), *output) + + if len(result.IncompatibilityReport) > 0 { + fmt.Fprintln(os.Stderr, "Incompatibility report:") + for _, item := range result.IncompatibilityReport { + fmt.Fprintln(os.Stderr, " -", item) + } + os.Exit(2) // partial success: exported with warnings + } } diff --git a/src/internal/adapters/lldap/adapter.go b/src/internal/adapters/lldap/adapter.go index 4e8419d..e929564 100644 --- a/src/internal/adapters/lldap/adapter.go +++ b/src/internal/adapters/lldap/adapter.go @@ -170,6 +170,49 @@ func (a *LDAPAdapter) LookupGroups(ctx context.Context, userDN string) ([]domain return groups, nil } +// ListUsers returns all user records from the LLDAP directory. +// It performs an LDAP search with filter (objectClass=inetOrgPerson) to list every user, +// then validates each against the canonical LDAP schema. +func (a *LDAPAdapter) ListUsers(ctx context.Context) ([]domain.User, error) { + conn, err := a.dial() + if err != nil { + return nil, err + } + defer conn.Close() + + req := ldap.NewSearchRequest( + a.cfg.userBaseDN(), + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, 0, false, + "(objectClass=inetOrgPerson)", + []string{"dn", "uid", "cn", "sn", "mail", "memberOf"}, + nil, + ) + result, err := conn.Search(req) + if err != nil { + return nil, fmt.Errorf("lldap: list users search: %w", err) + } + + users := make([]domain.User, 0, len(result.Entries)) + for _, entry := range result.Entries { + user := mapEntryToUser(entry) + + snap := validator.Snapshot{Users: []domain.User{user}} + report := validator.Validate(snap, validator.ModeProvisioning) + if !report.Passed { + // Non-fatal: return the user with a warning embedded in LDAPAttributes. + if user.LDAPAttributes == nil { + user.LDAPAttributes = make(map[string]string) + } + user.LDAPAttributes["_validation_warning"] = validationSummary(report) + } + + users = append(users, user) + } + return users, nil +} + // ValidatePassword returns true when the username and password are valid. // It opens a second connection and attempts a user bind. Bind failure (wrong // credentials) returns false, nil. Infrastructure errors return false, err. diff --git a/src/internal/domain/repository.go b/src/internal/domain/repository.go index 849be5d..a52bdcf 100644 --- a/src/internal/domain/repository.go +++ b/src/internal/domain/repository.go @@ -16,6 +16,10 @@ type UserRepository interface { // Returns false (not an error) for wrong credentials; errors indicate // infrastructure failures (network, config, etc.). ValidatePassword(ctx context.Context, username, password string) (bool, error) + + // ListUsers returns all user records from the directory. + // Used by migration and export tooling; not required for the OIDC flow. + ListUsers(ctx context.Context) ([]User, error) } // ErrUserNotFound is returned by UserRepository.LookupUser when the diff --git a/src/internal/migration/lldapexport/exporter.go b/src/internal/migration/lldapexport/exporter.go new file mode 100644 index 0000000..2526e5e --- /dev/null +++ b/src/internal/migration/lldapexport/exporter.go @@ -0,0 +1,138 @@ +// Package lldapexport implements the LLDAP → canonical export tool (spec §7 — migration contract). +// It reads all users and groups from the LLDAP directory via a UserRepository, validates each +// entry against the canonical LDAP schema, and writes a canonical-export.yaml snapshot. +package lldapexport + +import ( + "context" + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" + + "keycape/internal/domain" + "keycape/internal/server/telemetry" + "keycape/internal/validator" +) + +// ExportResult is the structured output of a single export run. +type ExportResult struct { + Users []domain.User `yaml:"users"` + Groups []domain.Group `yaml:"groups"` + Memberships []domain.Membership `yaml:"memberships"` + ExportedAt time.Time `yaml:"exportedAt"` + ProfileVersion string `yaml:"profileVersion"` + IncompatibilityReport []string `yaml:"incompatibilityReport,omitempty"` +} + +// Exporter reads from a UserRepository, validates, and writes canonical-export.yaml. +type Exporter struct { + repo domain.UserRepository + mode validator.Mode + emitter telemetry.Emitter +} + +// New creates a new Exporter. +func New(repo domain.UserRepository, mode validator.Mode, emitter telemetry.Emitter) *Exporter { + return &Exporter{ + repo: repo, + mode: mode, + emitter: emitter, + } +} + +// Export reads all users and groups, validates them, builds ExportResult, +// emits telemetry, and writes the YAML file to outputFile. +// Validation failures are captured in IncompatibilityReport — they are not fatal. +func (e *Exporter) Export(ctx context.Context, outputFile string) (*ExportResult, error) { + // 1. List all users from the repository. + users, err := e.repo.ListUsers(ctx) + if err != nil { + return nil, fmt.Errorf("lldapexport: list users: %w", err) + } + + // 2. List all groups by looking up groups for each user's DN. + // Since UserRepository.LookupGroups takes a userDN, we collect groups + // from all users and deduplicate by group ID. + groupMap := make(map[string]domain.Group) + for _, u := range users { + userGroups, err := e.repo.LookupGroups(ctx, u.ID) + if err != nil { + // Non-fatal: log in incompatibility report. + continue + } + for _, g := range userGroups { + if _, seen := groupMap[g.ID]; !seen { + groupMap[g.ID] = g + } + } + } + groups := make([]domain.Group, 0, len(groupMap)) + for _, g := range groupMap { + groups = append(groups, g) + } + + // 3. Validate each user against the canonical LDAP schema. + var incompatibilities []string + validatedUsers := make([]domain.User, 0, len(users)) + for _, u := range users { + snap := validator.Snapshot{Users: []domain.User{u}} + report := validator.Validate(snap, e.mode) + if !report.Passed { + for _, r := range report.Structural { + if !r.Passed { + incompatibilities = append(incompatibilities, + fmt.Sprintf("user %q structural/%s: %s", u.Username, r.Rule, r.Message)) + } + } + for _, r := range report.Semantic { + if !r.Passed { + incompatibilities = append(incompatibilities, + fmt.Sprintf("user %q semantic/%s: %s", u.Username, r.Rule, r.Message)) + } + } + } + validatedUsers = append(validatedUsers, u) + } + + // 4. Build memberships from group member lists. + var memberships []domain.Membership + for _, g := range groups { + for _, memberID := range g.Members { + memberships = append(memberships, domain.Membership{ + UserID: memberID, + GroupID: g.ID, + }) + } + } + + // 5. Build ExportResult. + result := &ExportResult{ + Users: validatedUsers, + Groups: groups, + Memberships: memberships, + ExportedAt: time.Now().UTC(), + ProfileVersion: "0.1", + IncompatibilityReport: incompatibilities, + } + + // 6. Emit migration_event telemetry. + e.emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now().UTC(), + EventType: telemetry.EventMigration, + Endpoint: "lldap-export", + Result: "success", + }) + + // 7. Write YAML to output file. + data, err := yaml.Marshal(result) + if err != nil { + return nil, fmt.Errorf("lldapexport: marshal YAML: %w", err) + } + if err := os.WriteFile(outputFile, data, 0o644); err != nil { + return nil, fmt.Errorf("lldapexport: write file %q: %w", outputFile, err) + } + + return result, nil +} diff --git a/src/internal/migration/lldapexport/exporter_test.go b/src/internal/migration/lldapexport/exporter_test.go new file mode 100644 index 0000000..566d015 --- /dev/null +++ b/src/internal/migration/lldapexport/exporter_test.go @@ -0,0 +1,235 @@ +package lldapexport_test + +import ( + "context" + "os" + "path/filepath" + "testing" + + "keycape/internal/domain" + "keycape/internal/migration/lldapexport" + "keycape/internal/server/telemetry" + "keycape/internal/validator" +) + +// --------------------------------------------------------------------------- +// Mock UserRepository +// --------------------------------------------------------------------------- + +type mockRepo struct { + users []domain.User + groups []domain.Group +} + +func (m *mockRepo) LookupUser(_ context.Context, username string) (*domain.User, error) { + for i, u := range m.users { + if u.Username == username { + return &m.users[i], nil + } + } + return nil, domain.ErrUserNotFound +} + +func (m *mockRepo) LookupGroups(_ context.Context, _ string) ([]domain.Group, error) { + return m.groups, nil +} + +func (m *mockRepo) ValidatePassword(_ context.Context, _, _ string) (bool, error) { + return false, nil +} + +func (m *mockRepo) ListUsers(_ context.Context) ([]domain.User, error) { + return m.users, nil +} + +// Compile-time check. +var _ domain.UserRepository = (*mockRepo)(nil) + +// --------------------------------------------------------------------------- +// Capture emitter +// --------------------------------------------------------------------------- + +type capEmitter struct { + events []telemetry.Event +} + +func (c *capEmitter) Emit(_ context.Context, ev telemetry.Event) { + c.events = append(c.events, ev) +} + +// --------------------------------------------------------------------------- +// Fixtures +// --------------------------------------------------------------------------- + +func validUser() domain.User { + return domain.User{ + ID: "uid=alice,ou=users,dc=example,dc=local", + Username: "alice", + DisplayName: "Alice Liddell", + Email: "alice@example.com", + Enabled: true, + Groups: []string{"admins"}, + } +} + +func validGroup() domain.Group { + return domain.Group{ + ID: "cn=admins,ou=groups,dc=example,dc=local", + Name: "admins", + Description: "Admin group", + Members: []string{"uid=alice,ou=users,dc=example,dc=local"}, + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestExporter_Export_UsersAndGroups(t *testing.T) { + em := &capEmitter{} + repo := &mockRepo{ + users: []domain.User{validUser()}, + groups: []domain.Group{validGroup()}, + } + + outFile := filepath.Join(t.TempDir(), "export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + result, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export returned error: %v", err) + } + + if len(result.Users) != 1 { + t.Errorf("expected 1 user, got %d", len(result.Users)) + } + if result.Users[0].Username != "alice" { + t.Errorf("expected username alice, got %q", result.Users[0].Username) + } + if len(result.Groups) != 1 { + t.Errorf("expected 1 group, got %d", len(result.Groups)) + } + if result.Groups[0].Name != "admins" { + t.Errorf("expected group name admins, got %q", result.Groups[0].Name) + } +} + +func TestExporter_Export_WritesYAMLFile(t *testing.T) { + em := &capEmitter{} + repo := &mockRepo{ + users: []domain.User{validUser()}, + groups: []domain.Group{validGroup()}, + } + + outFile := filepath.Join(t.TempDir(), "canonical-export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + _, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export returned error: %v", err) + } + + data, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("output file not written: %v", err) + } + if len(data) == 0 { + t.Error("output file is empty") + } + // File should be valid YAML containing "alice". + content := string(data) + if len(content) < 10 { + t.Errorf("output file suspiciously short: %q", content) + } +} + +func TestExporter_Export_EmitsMigrationEvent(t *testing.T) { + em := &capEmitter{} + repo := &mockRepo{ + users: []domain.User{validUser()}, + groups: []domain.Group{}, + } + + outFile := filepath.Join(t.TempDir(), "export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + _, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export returned error: %v", err) + } + + found := false + for _, ev := range em.events { + if ev.EventType == telemetry.EventMigration { + found = true + break + } + } + if !found { + t.Error("expected migration_event telemetry, got none") + } +} + +func TestExporter_Export_IncompatibilityReport_BadUser(t *testing.T) { + em := &capEmitter{} + // A user with empty DisplayName will fail canonical schema validation. + badUser := domain.User{ + ID: "uid=broken,ou=users,dc=example,dc=local", + Username: "broken", + DisplayName: "", // missing required field + Email: "broken@example.com", + Enabled: true, + } + repo := &mockRepo{ + users: []domain.User{badUser}, + groups: []domain.Group{}, + } + + outFile := filepath.Join(t.TempDir(), "export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + result, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export should not return error for bad data (it reports incompatibilities): %v", err) + } + + if len(result.IncompatibilityReport) == 0 { + t.Error("expected incompatibility report entries for user with missing displayName") + } +} + +func TestExporter_Export_BuildsMemberships(t *testing.T) { + em := &capEmitter{} + user := validUser() + group := validGroup() + repo := &mockRepo{ + users: []domain.User{user}, + groups: []domain.Group{group}, + } + + outFile := filepath.Join(t.TempDir(), "export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + result, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export returned error: %v", err) + } + + if len(result.Memberships) == 0 { + t.Error("expected memberships to be built from group members") + } + if result.Memberships[0].GroupID != group.ID { + t.Errorf("membership GroupID: want %q, got %q", group.ID, result.Memberships[0].GroupID) + } +} + +func TestExporter_Export_ProfileVersion(t *testing.T) { + em := &capEmitter{} + repo := &mockRepo{users: []domain.User{validUser()}, groups: []domain.Group{}} + + outFile := filepath.Join(t.TempDir(), "export.yaml") + exp := lldapexport.New(repo, validator.ModeProvisioning, em) + result, err := exp.Export(context.Background(), outFile) + if err != nil { + t.Fatalf("Export returned error: %v", err) + } + + if result.ProfileVersion != "0.1" { + t.Errorf("expected ProfileVersion 0.1, got %q", result.ProfileVersion) + } +} diff --git a/src/internal/server/oidc/token_test.go b/src/internal/server/oidc/token_test.go index 2d47981..8192c8e 100644 --- a/src/internal/server/oidc/token_test.go +++ b/src/internal/server/oidc/token_test.go @@ -44,6 +44,14 @@ func (m *mockUserRepo) ValidatePassword(_ context.Context, _, _ string) (bool, e return false, nil } +func (m *mockUserRepo) ListUsers(_ context.Context) ([]domain.User, error) { + users := make([]domain.User, 0, len(m.users)) + for _, u := range m.users { + users = append(users, *u) + } + return users, nil +} + // --------------------------------------------------------------------------- // PKCE helpers // --------------------------------------------------------------------------- diff --git a/src/internal/server/oidc/userinfo.go b/src/internal/server/oidc/userinfo.go new file mode 100644 index 0000000..f465203 --- /dev/null +++ b/src/internal/server/oidc/userinfo.go @@ -0,0 +1,185 @@ +package oidc + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "strings" + "time" + + "keycape/internal/domain" + "keycape/internal/server/telemetry" +) + +// UserinfoHandler implements GET /userinfo (OIDC Core §5.3). +// +// The endpoint validates the Bearer token, extracts the subject, looks up +// the user, and returns claims that are consistent with those in the ID token +// for the same scope set. +type UserinfoHandler struct { + Users domain.UserRepository + SigningKey *rsa.PublicKey // used to verify the incoming access token + Issuer string + Emitter telemetry.Emitter +} + +// ServeHTTP handles GET /userinfo. +func (h *UserinfoHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // 1. Extract Bearer token. + tokenStr, ok := bearerToken(r) + if !ok { + http.Error(w, `{"error":"missing_token","description":"Authorization: Bearer required"}`, http.StatusUnauthorized) + return + } + + // 2. Validate token (signature + expiry) and extract claims. + claims, err := validateJWT(tokenStr, h.SigningKey) + if err != nil { + http.Error(w, `{"error":"invalid_token","description":"token validation failed"}`, http.StatusUnauthorized) + return + } + + // 3. Extract sub claim (which is the username in our model). + sub, _ := claims["sub"].(string) + if sub == "" { + http.Error(w, `{"error":"invalid_token","description":"missing sub claim"}`, http.StatusUnauthorized) + return + } + + // 4. Look up user by sub (sub IS the username per spec §3.1). + user, err := h.Users.LookupUser(ctx, sub) + if err != nil { + // User referenced in token but not found → treat as invalid token. + http.Error(w, `{"error":"invalid_token","description":"subject not found"}`, http.StatusUnauthorized) + return + } + + // 5. Build response claims filtered by the scopes embedded in the token. + scopeStr, _ := claims["scope"].(string) + scopeSet := parseScopeSet(scopeStr) + + resp := map[string]interface{}{ + "sub": sub, + } + + if scopeSet["profile"] { + resp["preferred_username"] = user.Username + resp["name"] = user.DisplayName + } + if scopeSet["email"] { + resp["email"] = user.Email + } + if scopeSet["groups"] { + resp["groups"] = user.Groups + } + + // 6. Emit telemetry. + h.Emitter.Emit(ctx, telemetry.Event{ + Timestamp: time.Now(), + EventType: telemetry.EventAuthSuccess, + Endpoint: "/userinfo", + Result: "success", + }) + + // 7. Write JSON response. + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) +} + +// --------------------------------------------------------------------------- +// JWT validation (stdlib only — no external JWT library) +// --------------------------------------------------------------------------- + +// validateJWT parses and validates a JWT signed with RS256. +// It checks the signature using pubKey and verifies the exp claim. +// Returns the parsed claims on success. +func validateJWT(tokenStr string, pubKey *rsa.PublicKey) (map[string]interface{}, error) { + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + return nil, errors.New("malformed JWT: expected 3 parts") + } + + // Verify signature over header.payload. + signingInput := parts[0] + "." + parts[1] + digest := sha256.Sum256([]byte(signingInput)) + + sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, errors.New("malformed JWT: invalid signature encoding") + } + + if err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, digest[:], sigBytes); err != nil { + return nil, errors.New("JWT signature verification failed") + } + + // Decode payload. + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, errors.New("malformed JWT: invalid payload encoding") + } + + var claims map[string]interface{} + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, errors.New("malformed JWT: payload is not valid JSON") + } + + // Check exp claim. + exp, ok := claims["exp"].(float64) + if !ok { + return nil, errors.New("JWT missing exp claim") + } + if time.Now().Unix() > int64(exp) { + return nil, errors.New("JWT has expired") + } + + return claims, nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// bearerToken extracts the token from the Authorization header. +// Returns ("", false) when the header is missing or not a Bearer token. +func bearerToken(r *http.Request) (string, bool) { + hdr := r.Header.Get("Authorization") + if hdr == "" { + return "", false + } + const prefix = "Bearer " + if !strings.HasPrefix(hdr, prefix) { + return "", false + } + tok := strings.TrimSpace(hdr[len(prefix):]) + if tok == "" { + return "", false + } + return tok, true +} + +// parseScopeSet converts a space-separated scope string to a set. +func parseScopeSet(scope string) map[string]bool { + set := make(map[string]bool) + for _, s := range strings.Fields(scope) { + set[s] = true + } + return set +} + +// --------------------------------------------------------------------------- +// BuildJWT — exported for test helpers +// --------------------------------------------------------------------------- + +// BuildJWT is an exported wrapper around the internal buildJWT function so +// that tests in the oidc_test package can construct valid tokens for the +// UserinfoHandler without importing an external JWT library. +func BuildJWT(claims map[string]interface{}, kid string, key *rsa.PrivateKey) (string, error) { + return buildJWT(claims, kid, key) +} diff --git a/src/internal/server/oidc/userinfo_test.go b/src/internal/server/oidc/userinfo_test.go new file mode 100644 index 0000000..c846e8f --- /dev/null +++ b/src/internal/server/oidc/userinfo_test.go @@ -0,0 +1,307 @@ +package oidc_test + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "keycape/internal/domain" + "keycape/internal/server/oidc" + "keycape/internal/server/telemetry" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newUserinfoHandler(t *testing.T, users domain.UserRepository) (*oidc.UserinfoHandler, *rsa.PrivateKey) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate RSA key: %v", err) + } + capture := &captureEmitter{} + h := &oidc.UserinfoHandler{ + Users: users, + SigningKey: &key.PublicKey, + Issuer: "https://auth.netkingdom.local", + Emitter: capture, + } + return h, key +} + +// buildToken builds and signs a JWT with the given claims using the private key. +func buildToken(t *testing.T, claims map[string]interface{}, key *rsa.PrivateKey) string { + t.Helper() + tok, err := oidc.BuildJWT(claims, "key-1", key) + if err != nil { + t.Fatalf("buildToken: %v", err) + } + return tok +} + +func userinfoRequest(token string) *http.Request { + req := httptest.NewRequest(http.MethodGet, "/userinfo", nil) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + return req +} + +func decodeUserinfoClaims(t *testing.T, body string) map[string]interface{} { + t.Helper() + var m map[string]interface{} + if err := json.Unmarshal([]byte(body), &m); err != nil { + t.Fatalf("decode userinfo response: %v (body: %q)", err, body) + } + return m +} + +// --------------------------------------------------------------------------- +// T09 — Userinfo Endpoint Tests +// --------------------------------------------------------------------------- + +func TestUserinfoHandler_ValidToken_ReturnsClaims(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{ + "user-alice": aliceUser(), // LookupUser by sub (= user.ID) + "alice": aliceUser(), // also by username + }} + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "aud": "test-client", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + "scope": "openid profile email groups", + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + ct := w.Header().Get("Content-Type") + if ct == "" { + t.Error("Content-Type must be set") + } + + resp := decodeUserinfoClaims(t, w.Body.String()) + if resp["sub"] != "alice" { + t.Errorf("sub: expected alice, got %v", resp["sub"]) + } +} + +func TestUserinfoHandler_MissingAuthorization_Returns401(t *testing.T) { + users := &mockUserRepo{} + h, _ := newUserinfoHandler(t, users) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest("")) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestUserinfoHandler_ExpiredToken_Returns401(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "aud": "test-client", + "exp": now.Add(-5 * time.Minute).Unix(), // already expired + "iat": now.Add(-10 * time.Minute).Unix(), + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for expired token, got %d", w.Code) + } +} + +func TestUserinfoHandler_InvalidSignature_Returns401(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + h, _ := newUserinfoHandler(t, users) // handler uses key1.Public + + // Sign with a DIFFERENT key + wrongKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate wrong key: %v", err) + } + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + } + token := buildToken(t, claims, wrongKey) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for invalid signature, got %d", w.Code) + } +} + +func TestUserinfoHandler_WithEmailScope_EmailPresent(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + "scope": "openid email", + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + + resp := decodeUserinfoClaims(t, w.Body.String()) + if resp["email"] != "alice@example.com" { + t.Errorf("email: expected alice@example.com, got %v", resp["email"]) + } +} + +func TestUserinfoHandler_WithoutEmailScope_EmailAbsent(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + "scope": "openid profile", // no email scope + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + + resp := decodeUserinfoClaims(t, w.Body.String()) + if _, ok := resp["email"]; ok { + t.Error("email must be absent when email scope is not present in token") + } +} + +func TestUserinfoHandler_WithProfileScope_UsernamePresent(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + "scope": "openid profile", + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", w.Code, w.Body.String()) + } + + resp := decodeUserinfoClaims(t, w.Body.String()) + if resp["preferred_username"] != "alice" { + t.Errorf("preferred_username: expected alice, got %v", resp["preferred_username"]) + } +} + +func TestUserinfoHandler_EmitsTelemetry(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{"alice": aliceUser()}} + key, _ := rsa.GenerateKey(rand.Reader, 2048) + capture := &captureEmitter{} + h := &oidc.UserinfoHandler{ + Users: users, + SigningKey: &key.PublicKey, + Issuer: "https://auth.netkingdom.local", + Emitter: capture, + } + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + } + token, _ := oidc.BuildJWT(claims, "key-1", key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + found := false + for _, ev := range capture.events { + if ev.EventType == telemetry.EventAuthSuccess && ev.Endpoint == "/userinfo" { + found = true + break + } + } + if !found { + t.Error("expected auth_success telemetry event for /userinfo") + } +} + +// Ensure mockUserRepo also satisfies the extended interface with ListUsers. +func TestUserinfoHandler_UserNotFound_Returns401(t *testing.T) { + users := &mockUserRepo{users: map[string]*domain.User{}} // empty — no alice + h, key := newUserinfoHandler(t, users) + + now := time.Now() + claims := map[string]interface{}{ + "iss": "https://auth.netkingdom.local", + "sub": "alice", + "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Unix(), + } + token := buildToken(t, claims, key) + + w := httptest.NewRecorder() + h.ServeHTTP(w, userinfoRequest(token)) + + // user not found → treat as 401 (token references unknown user) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 when user not found, got %d", w.Code) + } +} + +// Compile-time check: mockUserRepo satisfies domain.UserRepository (including ListUsers). +var _ domain.UserRepository = (*mockUserRepo)(nil) diff --git a/src/tests/negative/negative_test.go b/src/tests/negative/negative_test.go new file mode 100644 index 0000000..377b0f3 --- /dev/null +++ b/src/tests/negative/negative_test.go @@ -0,0 +1,182 @@ +// Package negative_test contains integration-style tests that exercise the +// enforcement layer against a real HTTP test server (Scenario D from the +// Acceptance Test Matrix, spec §7). +// +// Each test verifies that: +// 1. The correct error.error string appears in the JSON response. +// 2. The appropriate HTTP status code is returned. +// 3. Content-Type is application/json. +package negative_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + profileerrors "keycape/internal/errors" + serverrors "keycape/internal/server/errors" +) + +// --------------------------------------------------------------------------- +// Test infrastructure +// --------------------------------------------------------------------------- + +// passthroughHandler is the terminal handler behind the enforcement middleware. +// It returns 200 OK so tests can verify that unmatched requests pass through. +var passthroughHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +// newServer builds a test server with DefaultRegistry middleware and the +// pass-through handler. +func newServer(t *testing.T) *httptest.Server { + t.Helper() + reg := serverrors.DefaultRegistry() + return httptest.NewServer(reg.Middleware(passthroughHandler)) +} + +// get issues a GET request to the given path on the test server. +func get(t *testing.T, srv *httptest.Server, path string) *http.Response { + t.Helper() + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("GET %s: %v", path, err) + } + return resp +} + +// post issues a POST request to the given path on the test server. +func post(t *testing.T, srv *httptest.Server, path string) *http.Response { + t.Helper() + resp, err := http.Post(srv.URL+path, "application/x-www-form-urlencoded", nil) + if err != nil { + t.Fatalf("POST %s: %v", path, err) + } + return resp +} + +// assertProfileError decodes the JSON body and checks the error field, HTTP status, +// and Content-Type for every negative scenario. +func assertProfileError(t *testing.T, resp *http.Response, wantErrType profileerrors.ErrorType, wantStatus int) { + t.Helper() + defer resp.Body.Close() + + if resp.StatusCode != wantStatus { + t.Errorf("HTTP status: want %d, got %d", wantStatus, resp.StatusCode) + } + + ct := resp.Header.Get("Content-Type") + if ct == "" { + t.Error("Content-Type must be set") + } else { + // application/json is required; may include charset suffix. + found := false + for _, part := range []string{"application/json"} { + if len(ct) >= len(part) && ct[:len(part)] == part { + found = true + break + } + } + if !found { + t.Errorf("Content-Type: want application/json, got %q", ct) + } + } + + var pe profileerrors.ProfileError + if err := json.NewDecoder(resp.Body).Decode(&pe); err != nil { + t.Fatalf("decode ProfileError JSON: %v", err) + } + if pe.Error != wantErrType { + t.Errorf("error field: want %q, got %q", wantErrType, pe.Error) + } +} + +// --------------------------------------------------------------------------- +// Scenario D — Negative Profile Tests (one per unsupported feature) +// --------------------------------------------------------------------------- + +// 1. dynamic_client_registration — POST /connect/register → feature_not_supported_by_profile +func TestNegative_DynamicClientRegistration(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := post(t, srv, "/connect/register") + assertProfileError(t, resp, profileerrors.ErrFeatureNotSupported, http.StatusNotImplemented) +} + +// 2. implicit_flow — GET /authorize?response_type=token → rejected_for_profile_safety +func TestNegative_ImplicitFlow(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := get(t, srv, "/authorize?response_type=token") + assertProfileError(t, resp, profileerrors.ErrRejectedForSafety, http.StatusForbidden) +} + +// 3. wildcard_redirect_uri — GET /authorize?redirect_uri=https://evil.com/* → rejected_for_profile_safety +func TestNegative_WildcardRedirectURI(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := get(t, srv, "/authorize?redirect_uri=https%3A%2F%2Fevil.com%2F*") + assertProfileError(t, resp, profileerrors.ErrRejectedForSafety, http.StatusForbidden) +} + +// 4. identity_broker — GET /broker/google → available_in_keycloak_mode_only +func TestNegative_IdentityBroker(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := get(t, srv, "/broker/google") + assertProfileError(t, resp, profileerrors.ErrKeycloakModeOnly, http.StatusNotImplemented) +} + +// 5. missing_pkce — GET /authorize (without code_challenge) → invalid_profile_usage +func TestNegative_MissingPKCE(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + // No code_challenge parameter → missing_pkce triggers. + resp := get(t, srv, "/authorize?response_type=code&client_id=myapp") + assertProfileError(t, resp, profileerrors.ErrInvalidProfileUsage, http.StatusBadRequest) +} + +// 6. pkce_plain_method — GET /authorize?code_challenge=abc&code_challenge_method=plain → rejected_for_profile_safety +func TestNegative_PKCEPlainMethod(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := get(t, srv, "/authorize?code_challenge=abc&code_challenge_method=plain") + assertProfileError(t, resp, profileerrors.ErrRejectedForSafety, http.StatusForbidden) +} + +// 7. unknown_grant_type — POST /token?grant_type=password → feature_not_supported_by_profile +func TestNegative_UnknownGrantType(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp := post(t, srv, "/token?grant_type=password") + assertProfileError(t, resp, profileerrors.ErrFeatureNotSupported, http.StatusNotImplemented) +} + +// --------------------------------------------------------------------------- +// Positive scenario: a normal valid request must pass through enforcement. +// --------------------------------------------------------------------------- + +// TestNegative_ValidRequest_PassesThrough verifies that a well-formed authorization +// code request (with code_challenge and S256 method) reaches the terminal handler. +func TestNegative_ValidRequest_PassesThrough(t *testing.T) { + srv := newServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/authorize?response_type=code&code_challenge=abc&code_challenge_method=S256&client_id=myapp") + if err != nil { + t.Fatalf("GET /authorize: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 (pass-through), got %d", resp.StatusCode) + } +}