From b0adbc5daaa5d0a6143fd060aad6b48e3c336e5c Mon Sep 17 00:00:00 2001 From: tegwick Date: Fri, 13 Mar 2026 01:45:21 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20implement=20T14,=20T10=20=E2=80=94=20en?= =?UTF-8?q?forcement=20middleware,=20LLDAP=20adapter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - T14: Unsupported feature registry with 7 pre-registered profile boundaries - T10: LLDAP adapter implementing UserRepository; validator-gated reads 24 tests pass, go vet clean. Co-Authored-By: Claude Sonnet 4.6 --- src/go.mod | 9 +- src/go.sum | 37 +- src/internal/adapters/lldap/adapter.go | 294 +++++++++++++++ src/internal/adapters/lldap/adapter_test.go | 341 ++++++++++++++++++ src/internal/adapters/lldap/config.go | 55 +++ src/internal/domain/repository.go | 27 ++ src/internal/server/errors/enforcement.go | 203 +++++++++++ .../server/errors/enforcement_test.go | 299 +++++++++++++++ 8 files changed, 1262 insertions(+), 3 deletions(-) create mode 100644 src/internal/adapters/lldap/adapter.go create mode 100644 src/internal/adapters/lldap/adapter_test.go create mode 100644 src/internal/adapters/lldap/config.go create mode 100644 src/internal/domain/repository.go create mode 100644 src/internal/server/errors/enforcement.go create mode 100644 src/internal/server/errors/enforcement_test.go diff --git a/src/go.mod b/src/go.mod index d84f4aa..992c33f 100644 --- a/src/go.mod +++ b/src/go.mod @@ -1,14 +1,19 @@ module keycape -go 1.22 +go 1.23.0 require ( + github.com/go-ldap/ldap/v3 v3.4.12 github.com/rs/zerolog v1.34.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/sys v0.31.0 // indirect ) diff --git a/src/go.sum b/src/go.sum index ae93fc4..d1f81c7 100644 --- a/src/go.sum +++ b/src/go.sum @@ -1,18 +1,53 @@ +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/src/internal/adapters/lldap/adapter.go b/src/internal/adapters/lldap/adapter.go new file mode 100644 index 0000000..4e8419d --- /dev/null +++ b/src/internal/adapters/lldap/adapter.go @@ -0,0 +1,294 @@ +package lldap + +import ( + "context" + "crypto/tls" + "fmt" + "net/url" + "strings" + + "github.com/go-ldap/ldap/v3" + + "keycape/internal/domain" + "keycape/internal/validator" +) + +// LDAPConn is a minimal interface over an LDAP connection, enabling test injection. +// Only the operations used by the adapter are included; no concrete LDAP types are +// exposed through return values or parameters visible outside this package. +type LDAPConn interface { + Bind(username, password string) error + Search(request *ldap.SearchRequest) (*ldap.SearchResult, error) + Close() error +} + +// LDAPAdapter implements domain.UserRepository using an LLDAP backend. +// All LDAP types are confined to this package — the domain and server layers +// are not aware of any LDAP-specific constructs. +type LDAPAdapter struct { + cfg Config + dialFn func(addr string) (LDAPConn, error) +} + +// New returns a production-ready LDAPAdapter that dials real LDAP connections. +func New(cfg Config) *LDAPAdapter { + return &LDAPAdapter{ + cfg: cfg, + dialFn: defaultDialFn(cfg), + } +} + +// NewForTest returns an LDAPAdapter with a custom dial function for test injection. +// Production code should use New instead. +func NewForTest(cfg Config, dialFn func(addr string) (LDAPConn, error)) *LDAPAdapter { + return &LDAPAdapter{cfg: cfg, dialFn: dialFn} +} + +// defaultDialFn returns a dial function that establishes a real LDAP connection. +func defaultDialFn(cfg Config) func(addr string) (LDAPConn, error) { + return func(addr string) (LDAPConn, error) { + u, err := url.Parse(cfg.URL) + if err != nil { + return nil, fmt.Errorf("lldap: invalid URL %q: %w", cfg.URL, err) + } + if u.Scheme == "ldaps" { + conn, err := ldap.DialTLS("tcp", addr, &tls.Config{ + InsecureSkipVerify: cfg.TLSSkipVerify, //nolint:gosec // dev flag, documented + }) + if err != nil { + return nil, fmt.Errorf("lldap: TLS dial %q: %w", addr, err) + } + return conn, nil + } + conn, err := ldap.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("lldap: dial %q: %w", addr, err) + } + return conn, nil + } +} + +// dial opens a new LDAP connection and performs the service-account bind. +func (a *LDAPAdapter) dial() (LDAPConn, error) { + u, err := url.Parse(a.cfg.URL) + if err != nil { + return nil, fmt.Errorf("lldap: invalid URL %q: %w", a.cfg.URL, err) + } + host := u.Host + if host == "" { + host = a.cfg.URL // fallback for bare addr passed in tests + } + conn, err := a.dialFn(host) + if err != nil { + return nil, err + } + if err := conn.Bind(a.cfg.BindDN, a.cfg.BindPW); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("lldap: service bind failed: %w", err) + } + return conn, nil +} + +// --------------------------------------------------------------------------- +// domain.UserRepository implementation +// --------------------------------------------------------------------------- + +// LookupUser retrieves the canonical User for the given username. +// Returns domain.ErrUserNotFound when no matching entry exists. +// After mapping LDAP attributes the result is run through the canonical +// LDAP schema validator; a validation failure is returned as an error. +func (a *LDAPAdapter) LookupUser(ctx context.Context, username string) (*domain.User, error) { + conn, err := a.dial() + if err != nil { + return nil, err + } + defer conn.Close() + + filter := fmt.Sprintf("(uid=%s)", ldap.EscapeFilter(username)) + req := ldap.NewSearchRequest( + a.cfg.userBaseDN(), + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, 0, false, + filter, + []string{"dn", "uid", "cn", "sn", "mail", "memberOf"}, + nil, + ) + result, err := conn.Search(req) + if err != nil { + return nil, fmt.Errorf("lldap: search for user %q: %w", username, err) + } + if len(result.Entries) == 0 { + return nil, domain.ErrUserNotFound + } + + entry := result.Entries[0] + user := mapEntryToUser(entry) + + // Run the canonical LDAP schema validator. + snap := validator.Snapshot{Users: []domain.User{user}} + report := validator.Validate(snap, validator.ModeProvisioning) + if !report.Passed { + return nil, fmt.Errorf("lldap: validation failed for user %q: %s", username, validationSummary(report)) + } + + return &user, nil +} + +// LookupGroups retrieves all groups the user (identified by their LDAP DN) belongs to. +func (a *LDAPAdapter) LookupGroups(ctx context.Context, userDN string) ([]domain.Group, error) { + conn, err := a.dial() + if err != nil { + return nil, err + } + defer conn.Close() + + // Search for groups that list the user as a member. + filter := fmt.Sprintf("(member=%s)", ldap.EscapeFilter(userDN)) + req := ldap.NewSearchRequest( + a.cfg.groupBaseDN(), + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, 0, false, + filter, + []string{"dn", "cn", "description"}, + nil, + ) + result, err := conn.Search(req) + if err != nil { + return nil, fmt.Errorf("lldap: group search for DN %q: %w", userDN, err) + } + + groups := make([]domain.Group, 0, len(result.Entries)) + for _, entry := range result.Entries { + groups = append(groups, domain.Group{ + ID: entry.DN, + Name: entry.GetAttributeValue("cn"), + Description: entry.GetAttributeValue("description"), + }) + } + return groups, nil +} + +// ValidatePassword returns true when the username and password are valid. +// It opens a second connection and attempts a user bind. Bind failure (wrong +// credentials) returns false, nil. Infrastructure errors return false, err. +func (a *LDAPAdapter) ValidatePassword(ctx context.Context, username, password string) (bool, error) { + // First resolve the user DN. + conn, err := a.dial() + if err != nil { + return false, err + } + + filter := fmt.Sprintf("(uid=%s)", ldap.EscapeFilter(username)) + req := ldap.NewSearchRequest( + a.cfg.userBaseDN(), + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, 0, false, + filter, + []string{"dn"}, + nil, + ) + result, err := conn.Search(req) + conn.Close() + if err != nil { + return false, fmt.Errorf("lldap: DN lookup for user %q: %w", username, err) + } + if len(result.Entries) == 0 { + return false, nil + } + + userDN := result.Entries[0].DN + + // Attempt a user bind with the provided password using a fresh connection. + host := ldapHost(a.cfg.URL) + userConn, err := a.dialFn(host) + if err != nil { + return false, err + } + defer userConn.Close() + + if err := userConn.Bind(userDN, password); err != nil { + // Distinguish authentication failure from infrastructure error. + if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { + return false, nil + } + return false, fmt.Errorf("lldap: user bind for %q: %w", username, err) + } + return true, nil +} + +// --------------------------------------------------------------------------- +// Attribute mapping helpers (LDAP → canonical domain model). +// --------------------------------------------------------------------------- + +// mapEntryToUser converts an LDAP entry to a canonical domain.User. +// Attribute mapping per spec: +// - uid → Username +// - cn → DisplayName (sn as fallback) +// - sn → DisplayName fallback if cn is empty +// - mail → Email +// - memberOf → Groups (DNs parsed to group names) +// - dn → ID (stable identifier) +func mapEntryToUser(entry *ldap.Entry) domain.User { + displayName := entry.GetAttributeValue("cn") + if displayName == "" { + displayName = entry.GetAttributeValue("sn") + } + + memberOfs := entry.GetAttributeValues("memberOf") + groups := make([]string, 0, len(memberOfs)) + for _, dn := range memberOfs { + groups = append(groups, groupNameFromDN(dn)) + } + + return domain.User{ + ID: entry.DN, + Username: entry.GetAttributeValue("uid"), + DisplayName: displayName, + Email: entry.GetAttributeValue("mail"), + Groups: groups, + Enabled: true, // LLDAP does not expose a disabled flag in base schema + } +} + +// groupNameFromDN extracts the cn value from an LDAP DN such as +// "cn=admins,ou=groups,dc=netkingdom,dc=local" → "admins". +// If parsing fails the full DN is returned unchanged. +func groupNameFromDN(dn string) string { + parts := strings.SplitN(dn, ",", 2) + if len(parts) == 0 { + return dn + } + kv := strings.SplitN(parts[0], "=", 2) + if len(kv) == 2 { + return kv[1] + } + return dn +} + +// ldapHost extracts host:port from a URL string; falls back to the raw value. +func ldapHost(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return rawURL + } + return u.Host +} + +// validationSummary produces a short string summarising all failed rules. +func validationSummary(r validator.Report) string { + var msgs []string + for _, rule := range r.Structural { + if !rule.Passed { + msgs = append(msgs, rule.Message) + } + } + for _, rule := range r.Semantic { + if !rule.Passed { + msgs = append(msgs, rule.Message) + } + } + return strings.Join(msgs, "; ") +} diff --git a/src/internal/adapters/lldap/adapter_test.go b/src/internal/adapters/lldap/adapter_test.go new file mode 100644 index 0000000..b556af4 --- /dev/null +++ b/src/internal/adapters/lldap/adapter_test.go @@ -0,0 +1,341 @@ +package lldap_test + +import ( + "context" + "errors" + "testing" + + "github.com/go-ldap/ldap/v3" + + "keycape/internal/adapters/lldap" + "keycape/internal/domain" +) + +// --------------------------------------------------------------------------- +// Mock LDAP connection +// --------------------------------------------------------------------------- + +// mockConn implements lldap.LDAPConn for test injection. +type mockConn struct { + bindFn func(username, password string) error + searchFn func(req *ldap.SearchRequest) (*ldap.SearchResult, error) + closed bool +} + +func (m *mockConn) Bind(username, password string) error { + if m.bindFn != nil { + return m.bindFn(username, password) + } + return nil +} + +func (m *mockConn) Search(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + if m.searchFn != nil { + return m.searchFn(req) + } + return &ldap.SearchResult{}, nil +} + +func (m *mockConn) Close() error { + m.closed = true + return nil +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +// testConfig returns a minimal Config suitable for tests. +func testConfig() lldap.Config { + return lldap.Config{ + URL: "ldap://lldap:389", + BindDN: "cn=admin,dc=netkingdom,dc=local", + BindPW: "secret", + BaseDN: "dc=netkingdom,dc=local", + } +} + +// singleEntryResult builds a SearchResult with one entry for LookupUser tests. +func singleEntryResult(dn, uid, cn, sn, mail string, memberOfs []string) *ldap.SearchResult { + attrs := []*ldap.EntryAttribute{ + {Name: "uid", Values: []string{uid}}, + {Name: "cn", Values: []string{cn}}, + {Name: "sn", Values: []string{sn}}, + {Name: "mail", Values: []string{mail}}, + } + if len(memberOfs) > 0 { + attrs = append(attrs, &ldap.EntryAttribute{Name: "memberOf", Values: memberOfs}) + } + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + {DN: dn, Attributes: attrs}, + }, + } +} + +// makeAdapter returns an LDAPAdapter using the exported NewForTest constructor. +// We use the package-level helper exported for testing. +func makeAdapter(cfg lldap.Config, conn lldap.LDAPConn) *lldap.LDAPAdapter { + return lldap.NewForTest(cfg, func(_ string) (lldap.LDAPConn, error) { + return conn, nil + }) +} + +// --------------------------------------------------------------------------- +// LookupUser +// --------------------------------------------------------------------------- + +func TestLookupUser_Success(t *testing.T) { + dn := "uid=alice,ou=users,dc=netkingdom,dc=local" + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return singleEntryResult( + dn, "alice", "Alice Liddell", "Liddell", "alice@example.com", + []string{"cn=admins,ou=groups,dc=netkingdom,dc=local"}, + ), nil + }, + } + + adapter := makeAdapter(testConfig(), conn) + user, err := adapter.LookupUser(context.Background(), "alice") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if user.Username != "alice" { + t.Errorf("Username: want %q, got %q", "alice", user.Username) + } + if user.DisplayName != "Alice Liddell" { + t.Errorf("DisplayName: want %q, got %q", "Alice Liddell", user.DisplayName) + } + if user.Email != "alice@example.com" { + t.Errorf("Email: want %q, got %q", "alice@example.com", user.Email) + } + if user.ID != dn { + t.Errorf("ID: want %q, got %q", dn, user.ID) + } + if len(user.Groups) != 1 || user.Groups[0] != "admins" { + t.Errorf("Groups: want [admins], got %v", user.Groups) + } +} + +func TestLookupUser_DisplayName_FallsBackToSN(t *testing.T) { + dn := "uid=bob,ou=users,dc=netkingdom,dc=local" + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return singleEntryResult(dn, "bob", "", "Builder", "bob@example.com", nil), nil + }, + } + + adapter := makeAdapter(testConfig(), conn) + user, err := adapter.LookupUser(context.Background(), "bob") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.DisplayName != "Builder" { + t.Errorf("DisplayName fallback: want %q, got %q", "Builder", user.DisplayName) + } +} + +func TestLookupUser_NotFound(t *testing.T) { + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return &ldap.SearchResult{}, nil // zero entries + }, + } + + adapter := makeAdapter(testConfig(), conn) + _, err := adapter.LookupUser(context.Background(), "ghost") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, domain.ErrUserNotFound) { + t.Errorf("expected domain.ErrUserNotFound, got %v", err) + } +} + +func TestLookupUser_ValidationFailure(t *testing.T) { + // Return an entry with an empty DisplayName and empty sn — will fail validator. + dn := "uid=broken,ou=users,dc=netkingdom,dc=local" + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + attrs := []*ldap.EntryAttribute{ + {Name: "uid", Values: []string{"broken"}}, + {Name: "cn", Values: []string{""}}, + {Name: "sn", Values: []string{""}}, + {Name: "mail", Values: []string{"broken@example.com"}}, + } + return &ldap.SearchResult{ + Entries: []*ldap.Entry{{DN: dn, Attributes: attrs}}, + }, nil + }, + } + + adapter := makeAdapter(testConfig(), conn) + _, err := adapter.LookupUser(context.Background(), "broken") + if err == nil { + t.Fatal("expected validation error, got nil") + } +} + +// --------------------------------------------------------------------------- +// LookupGroups +// --------------------------------------------------------------------------- + +func TestLookupGroups_Success(t *testing.T) { + userDN := "uid=alice,ou=users,dc=netkingdom,dc=local" + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: "cn=admins,ou=groups,dc=netkingdom,dc=local", + Attributes: []*ldap.EntryAttribute{ + {Name: "cn", Values: []string{"admins"}}, + {Name: "description", Values: []string{"Admins group"}}, + }, + }, + }, + }, nil + }, + } + + adapter := makeAdapter(testConfig(), conn) + groups, err := adapter.LookupGroups(context.Background(), userDN) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(groups) != 1 { + t.Fatalf("want 1 group, got %d", len(groups)) + } + if groups[0].Name != "admins" { + t.Errorf("Group name: want %q, got %q", "admins", groups[0].Name) + } +} + +func TestLookupGroups_Empty(t *testing.T) { + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return &ldap.SearchResult{}, nil + }, + } + adapter := makeAdapter(testConfig(), conn) + groups, err := adapter.LookupGroups(context.Background(), "uid=nobody,ou=users,dc=test,dc=local") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(groups) != 0 { + t.Errorf("expected 0 groups, got %d", len(groups)) + } +} + +// --------------------------------------------------------------------------- +// ValidatePassword +// --------------------------------------------------------------------------- + +func TestValidatePassword_Success(t *testing.T) { + userDN := "uid=alice,ou=users,dc=netkingdom,dc=local" + + callCount := 0 + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + attrs := []*ldap.EntryAttribute{{Name: "dn", Values: []string{userDN}}} + return &ldap.SearchResult{ + Entries: []*ldap.Entry{{DN: userDN, Attributes: attrs}}, + }, nil + }, + bindFn: func(username, password string) error { + callCount++ + // First call: service bind (BindDN); second call: user bind. + return nil + }, + } + + // Provide two connections: one for the DN lookup and one for the user bind. + connIdx := 0 + conns := []*mockConn{conn, {bindFn: func(u, p string) error { return nil }}} + adapter := lldap.NewForTest(testConfig(), func(_ string) (lldap.LDAPConn, error) { + c := conns[connIdx] + connIdx++ + return c, nil + }) + + ok, err := adapter.ValidatePassword(context.Background(), "alice", "correct") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Error("expected ValidatePassword to return true") + } +} + +func TestValidatePassword_WrongPassword(t *testing.T) { + userDN := "uid=alice,ou=users,dc=netkingdom,dc=local" + + searchConn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + attrs := []*ldap.EntryAttribute{{Name: "dn", Values: []string{userDN}}} + return &ldap.SearchResult{ + Entries: []*ldap.Entry{{DN: userDN, Attributes: attrs}}, + }, nil + }, + } + userConn := &mockConn{ + bindFn: func(username, password string) error { + return ldap.NewError(ldap.LDAPResultInvalidCredentials, errors.New("invalid credentials")) + }, + } + + connIdx := 0 + conns := []lldap.LDAPConn{searchConn, userConn} + adapter := lldap.NewForTest(testConfig(), func(_ string) (lldap.LDAPConn, error) { + c := conns[connIdx] + connIdx++ + return c, nil + }) + + ok, err := adapter.ValidatePassword(context.Background(), "alice", "wrong") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Error("expected ValidatePassword to return false for wrong password") + } +} + +func TestValidatePassword_BindFailure(t *testing.T) { + // Service bind fails — infrastructure error. + conn := &mockConn{ + bindFn: func(username, password string) error { + return errors.New("connection refused") + }, + } + adapter := lldap.NewForTest(testConfig(), func(_ string) (lldap.LDAPConn, error) { + return conn, nil + }) + + ok, err := adapter.ValidatePassword(context.Background(), "alice", "pass") + if err == nil { + t.Fatal("expected infrastructure error, got nil") + } + if ok { + t.Error("expected false on bind failure") + } +} + +func TestValidatePassword_UserNotFound(t *testing.T) { + conn := &mockConn{ + searchFn: func(req *ldap.SearchRequest) (*ldap.SearchResult, error) { + return &ldap.SearchResult{}, nil // no entries + }, + } + adapter := makeAdapter(testConfig(), conn) + + ok, err := adapter.ValidatePassword(context.Background(), "ghost", "pass") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Error("expected false for non-existent user") + } +} diff --git a/src/internal/adapters/lldap/config.go b/src/internal/adapters/lldap/config.go new file mode 100644 index 0000000..bb2b60b --- /dev/null +++ b/src/internal/adapters/lldap/config.go @@ -0,0 +1,55 @@ +// Package lldap implements the UserRepository adapter for LLDAP (Lightweight LDAP). +// No LDAP types are exposed beyond this package — the domain and server layers +// interact exclusively through the domain.UserRepository interface. +package lldap + +// Config holds all connection parameters for the LLDAP adapter. +type Config struct { + // URL is the LDAP server address, e.g. "ldap://lldap:389" or "ldaps://lldap:636". + URL string + + // BindDN is the distinguished name used for the service account bind, + // e.g. "cn=admin,dc=netkingdom,dc=local". + BindDN string + + // BindPW is the service account password. + BindPW string + + // BaseDN is the search base, e.g. "dc=netkingdom,dc=local". + BaseDN string + + // UserOU is the organisational unit for users. Defaults to "ou=users" when empty. + UserOU string + + // GroupOU is the organisational unit for groups. Defaults to "ou=groups" when empty. + GroupOU string + + // TLSSkipVerify disables TLS certificate verification. For development only. + TLSSkipVerify bool +} + +// userOU returns the effective UserOU, falling back to the default. +func (c Config) userOU() string { + if c.UserOU != "" { + return c.UserOU + } + return "ou=users" +} + +// groupOU returns the effective GroupOU, falling back to the default. +func (c Config) groupOU() string { + if c.GroupOU != "" { + return c.GroupOU + } + return "ou=groups" +} + +// userBaseDN returns the full DN for the user search base. +func (c Config) userBaseDN() string { + return c.userOU() + "," + c.BaseDN +} + +// groupBaseDN returns the full DN for the group search base. +func (c Config) groupBaseDN() string { + return c.groupOU() + "," + c.BaseDN +} diff --git a/src/internal/domain/repository.go b/src/internal/domain/repository.go new file mode 100644 index 0000000..849be5d --- /dev/null +++ b/src/internal/domain/repository.go @@ -0,0 +1,27 @@ +package domain + +import "context" + +// UserRepository is the adapter interface between the OIDC layer and the identity directory. +// The server/ layer sees ONLY this interface — no LDAP types leak through. +type UserRepository interface { + // LookupUser retrieves the canonical User record for the given username. + // Returns an error wrapping ErrUserNotFound when the user does not exist. + LookupUser(ctx context.Context, username string) (*User, error) + + // LookupGroups retrieves all groups the user (identified by their LDAP DN) belongs to. + LookupGroups(ctx context.Context, userDN string) ([]Group, error) + + // ValidatePassword returns true when the username and password are correct. + // Returns false (not an error) for wrong credentials; errors indicate + // infrastructure failures (network, config, etc.). + ValidatePassword(ctx context.Context, username, password string) (bool, error) +} + +// ErrUserNotFound is returned by UserRepository.LookupUser when the +// requested user does not exist in the directory. +const ErrUserNotFound = userNotFound("user not found") + +type userNotFound string + +func (e userNotFound) Error() string { return string(e) } diff --git a/src/internal/server/errors/enforcement.go b/src/internal/server/errors/enforcement.go new file mode 100644 index 0000000..03f875b --- /dev/null +++ b/src/internal/server/errors/enforcement.go @@ -0,0 +1,203 @@ +// Package errors implements the unsupported feature enforcement layer for KeyCape. +// Every request passes through the Registry middleware before reaching any handler. +// If a registered feature is detected the middleware writes a ProfileError JSON +// response, emits an EventUnsupportedFeature telemetry event, and short-circuits +// the handler chain. Adding a new unsupported feature requires only a call to +// Register — no handler changes are needed. +package errors + +import ( + "net/http" + "strings" + "time" + + profileerrors "keycape/internal/errors" + "keycape/internal/server/telemetry" +) + +// UnsupportedFeature describes a profile boundary that KeyCape enforces. +type UnsupportedFeature struct { + // Name is a stable string identifier used in telemetry and error payloads. + Name string + // ErrorType is the profile error category emitted when this feature is triggered. + ErrorType profileerrors.ErrorType + // Description is a human-readable explanation of why the feature is blocked. + Description string + // Detector reports whether the given request triggers this feature. + Detector func(r *http.Request) bool +} + +// Registry holds all known unsupported features and exposes middleware that +// enforces them on every incoming request. +type Registry struct { + features []UnsupportedFeature +} + +// NewRegistry returns an empty Registry. Use Register to add features and +// DefaultRegistry to obtain one pre-populated with the spec-mandated set. +func NewRegistry() *Registry { + return &Registry{} +} + +// Register appends a feature to the registry. Registered features are checked +// in insertion order; the first match wins. +func (reg *Registry) Register(f UnsupportedFeature) { + reg.features = append(reg.features, f) +} + +// Middleware returns an http.Handler that evaluates all registered features +// for every request before delegating to next. +// +// If a feature is triggered: +// - A ProfileError JSON response is written with an appropriate HTTP status. +// - An EventUnsupportedFeature telemetry event is emitted via the Emitter +// stored in the request context (a NoopEmitter is used when none is set). +// - next is NOT called. +// +// If no feature matches, next is called normally. +func (reg *Registry) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, f := range reg.features { + if f.Detector(r) { + pe := &profileerrors.ProfileError{ + Error: f.ErrorType, + Description: f.Description, + Feature: f.Name, + } + pe.Write(w, httpStatusFor(f.ErrorType)) + + em := telemetry.EmitterFromContext(r.Context()) + em.Emit(r.Context(), telemetry.Event{ + Timestamp: time.Now().UTC(), + EventType: telemetry.EventUnsupportedFeature, + Feature: f.Name, + ErrorType: string(f.ErrorType), + Endpoint: r.URL.Path, + Result: "failure", + Environment: "", + TraceID: "", + ClientID: r.URL.Query().Get("client_id"), + }) + return + } + } + next.ServeHTTP(w, r) + }) +} + +// httpStatusFor maps an ErrorType to its canonical HTTP status code. +func httpStatusFor(et profileerrors.ErrorType) int { + switch et { + case profileerrors.ErrInvalidProfileUsage: + return http.StatusBadRequest + case profileerrors.ErrRejectedForSafety: + return http.StatusForbidden + case profileerrors.ErrKeycloakModeOnly: + return http.StatusNotImplemented + default: // ErrFeatureNotSupported + return http.StatusNotImplemented + } +} + +// --------------------------------------------------------------------------- +// Default feature set (spec §4 — normative). +// --------------------------------------------------------------------------- + +// DefaultRegistry returns a Registry pre-populated with all spec-mandated +// unsupported features. No handler changes are required to enforce new entries. +func DefaultRegistry() *Registry { + reg := NewRegistry() + + // 1. Dynamic client registration (RFC 7591) — not in the profile. + reg.Register(UnsupportedFeature{ + Name: "dynamic_client_registration", + ErrorType: profileerrors.ErrFeatureNotSupported, + Description: "Dynamic client registration is not part of the NetKingdom IAM Profile. Register clients statically in KeyCape configuration.", + Detector: func(r *http.Request) bool { + return (r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/connect/register")) || + strings.Contains(r.URL.Path, "registration") + }, + }) + + // 2. Implicit flow — blocked for security. + reg.Register(UnsupportedFeature{ + Name: "implicit_flow", + ErrorType: profileerrors.ErrRejectedForSafety, + Description: "The implicit flow (response_type=token or id_token) is rejected. Use the authorization code flow with PKCE.", + Detector: func(r *http.Request) bool { + rt := r.URL.Query().Get("response_type") + if rt == "" { + return false + } + // Blocked when response_type contains "token" or "id_token" but NOT when it is exactly "code". + // "code token" (hybrid) is also blocked. + return rt == "token" || rt == "id_token" || strings.Contains(rt, "token") && rt != "code" + }, + }) + + // 3. Wildcard redirect_uri — blocked for security. + reg.Register(UnsupportedFeature{ + Name: "wildcard_redirect_uri", + ErrorType: profileerrors.ErrRejectedForSafety, + Description: "Wildcard redirect URIs are not permitted. Register exact redirect URIs in the client configuration.", + Detector: func(r *http.Request) bool { + return strings.Contains(r.URL.Query().Get("redirect_uri"), "*") + }, + }) + + // 4. Identity brokering — available only in Keycloak mode. + reg.Register(UnsupportedFeature{ + Name: "identity_broker", + ErrorType: profileerrors.ErrKeycloakModeOnly, + Description: "Identity brokering is available only in expanded (Keycloak) mode.", + Detector: func(r *http.Request) bool { + return strings.Contains(r.URL.Path, "/broker/") + }, + }) + + // 5. PKCE plain method — blocked for security (must use S256). + // Registered BEFORE missing_pkce so a plain-method request is reported + // as pkce_plain_method, not missing_pkce. + reg.Register(UnsupportedFeature{ + Name: "pkce_plain_method", + ErrorType: profileerrors.ErrRejectedForSafety, + Description: "PKCE plain code challenge method is not allowed. Use S256.", + Detector: func(r *http.Request) bool { + return r.URL.Query().Get("code_challenge_method") == "plain" + }, + }) + + // 6. Missing PKCE on /authorize — invalid profile usage. + reg.Register(UnsupportedFeature{ + Name: "missing_pkce", + ErrorType: profileerrors.ErrInvalidProfileUsage, + Description: "Requests to /authorize must include a code_challenge (PKCE S256 required).", + Detector: func(r *http.Request) bool { + return strings.HasSuffix(r.URL.Path, "/authorize") && + r.URL.Query().Get("code_challenge") == "" + }, + }) + + // 7. Unknown grant type on /token. + reg.Register(UnsupportedFeature{ + Name: "unknown_grant_type", + ErrorType: profileerrors.ErrFeatureNotSupported, + Description: "Only authorization_code and refresh_token grant types are supported.", + Detector: func(r *http.Request) bool { + if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/token") { + return false + } + gt := r.URL.Query().Get("grant_type") + if gt == "" { + // Also check form body if already parsed — callers may pre-parse. + gt = r.FormValue("grant_type") + } + if gt == "" { + return false // no grant_type present; let the handler decide + } + return gt != "authorization_code" && gt != "refresh_token" + }, + }) + + return reg +} diff --git a/src/internal/server/errors/enforcement_test.go b/src/internal/server/errors/enforcement_test.go new file mode 100644 index 0000000..6fe84e8 --- /dev/null +++ b/src/internal/server/errors/enforcement_test.go @@ -0,0 +1,299 @@ +package errors_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + profileerrors "keycape/internal/errors" + serverrors "keycape/internal/server/errors" + "keycape/internal/server/telemetry" +) + +// --------------------------------------------------------------------------- +// recEmitter records emitted events for assertions. +// --------------------------------------------------------------------------- + +type recEmitter struct { + events []telemetry.Event +} + +func (r *recEmitter) Emit(_ context.Context, ev telemetry.Event) { + r.events = append(r.events, ev) +} + +func newRecEmitter() *recEmitter { return &recEmitter{} } + +// --------------------------------------------------------------------------- +// Helper: build request with emitter in context. +// --------------------------------------------------------------------------- + +func reqWithEmitter(method, target string, em telemetry.Emitter) *http.Request { + req := httptest.NewRequest(method, target, nil) + ctx := telemetry.WithEmitter(req.Context(), em) + return req.WithContext(ctx) +} + +// --------------------------------------------------------------------------- +// Tests — default registry features triggered. +// --------------------------------------------------------------------------- + +func TestDefaultRegistry_DynamicClientRegistration_PostConnect(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodPost, "/connect/register", em)) + + assertProfileError(t, w, profileerrors.ErrFeatureNotSupported, "dynamic_client_registration") + assertTelemetryEmitted(t, em, "dynamic_client_registration") +} + +func TestDefaultRegistry_DynamicClientRegistration_PathContainsRegistration(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/oauth/registration/info", em)) + + assertProfileError(t, w, profileerrors.ErrFeatureNotSupported, "dynamic_client_registration") + assertTelemetryEmitted(t, em, "dynamic_client_registration") +} + +func TestDefaultRegistry_ImplicitFlow_Token(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?response_type=token", em)) + + assertProfileError(t, w, profileerrors.ErrRejectedForSafety, "implicit_flow") + assertTelemetryEmitted(t, em, "implicit_flow") +} + +func TestDefaultRegistry_ImplicitFlow_IDToken(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?response_type=id_token", em)) + + assertProfileError(t, w, profileerrors.ErrRejectedForSafety, "implicit_flow") + assertTelemetryEmitted(t, em, "implicit_flow") +} + +func TestDefaultRegistry_WildcardRedirectURI(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?redirect_uri=https%3A%2F%2Fexample.com%2F*%2Fcb", em)) + + assertProfileError(t, w, profileerrors.ErrRejectedForSafety, "wildcard_redirect_uri") + assertTelemetryEmitted(t, em, "wildcard_redirect_uri") +} + +func TestDefaultRegistry_IdentityBroker(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/auth/realms/master/broker/github/endpoint", em)) + + assertProfileError(t, w, profileerrors.ErrKeycloakModeOnly, "identity_broker") + assertTelemetryEmitted(t, em, "identity_broker") +} + +func TestDefaultRegistry_MissingPKCE(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + // /authorize without code_challenge + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?response_type=code&client_id=myapp", em)) + + assertProfileError(t, w, profileerrors.ErrInvalidProfileUsage, "missing_pkce") + assertTelemetryEmitted(t, em, "missing_pkce") +} + +func TestDefaultRegistry_MissingPKCE_WithCodeChallenge_PassesThrough(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + handler := reg.Middleware(next) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?response_type=code&code_challenge=abc&code_challenge_method=S256", em)) + + if !called { + t.Fatal("expected next handler to be called when code_challenge is present") + } + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestDefaultRegistry_PKCEPlainMethod(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/authorize?code_challenge=abc&code_challenge_method=plain", em)) + + assertProfileError(t, w, profileerrors.ErrRejectedForSafety, "pkce_plain_method") + assertTelemetryEmitted(t, em, "pkce_plain_method") +} + +func TestDefaultRegistry_UnknownGrantType(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodPost, "/token?grant_type=client_credentials", em)) + + assertProfileError(t, w, profileerrors.ErrFeatureNotSupported, "unknown_grant_type") + assertTelemetryEmitted(t, em, "unknown_grant_type") +} + +func TestDefaultRegistry_UnknownGrantType_AllowedTypes(t *testing.T) { + reg := serverrors.DefaultRegistry() + handler := reg.Middleware(alwaysOK()) + + for _, gt := range []string{"authorization_code", "refresh_token"} { + req := reqWithEmitter(http.MethodPost, "/token?grant_type="+gt, newRecEmitter()) + w := serve(handler, req) + if w.Code != http.StatusOK { + t.Fatalf("grant_type=%q: expected 200 (pass-through), got %d: %s", gt, w.Code, w.Body.String()) + } + } +} + +// --------------------------------------------------------------------------- +// Tests — no feature triggered: passes through. +// --------------------------------------------------------------------------- + +func TestDefaultRegistry_NoMatchPassesThrough(t *testing.T) { + reg := serverrors.DefaultRegistry() + em := newRecEmitter() + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + handler := reg.Middleware(next) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/userinfo", em)) + + if !called { + t.Fatal("expected next handler to be called for unmatched request") + } + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if len(em.events) != 0 { + t.Fatalf("expected no telemetry events, got %d", len(em.events)) + } +} + +// --------------------------------------------------------------------------- +// Tests — custom feature registration. +// --------------------------------------------------------------------------- + +func TestRegistry_CustomFeature(t *testing.T) { + reg := serverrors.NewRegistry() + reg.Register(serverrors.UnsupportedFeature{ + Name: "test_feature", + ErrorType: profileerrors.ErrFeatureNotSupported, + Description: "test feature blocked", + Detector: func(r *http.Request) bool { return strings.Contains(r.URL.Path, "/test-blocked") }, + }) + em := newRecEmitter() + handler := reg.Middleware(alwaysOK()) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/test-blocked/foo", em)) + + assertProfileError(t, w, profileerrors.ErrFeatureNotSupported, "test_feature") + assertTelemetryEmitted(t, em, "test_feature") +} + +func TestRegistry_CustomFeature_NoMatch_PassesThrough(t *testing.T) { + reg := serverrors.NewRegistry() + reg.Register(serverrors.UnsupportedFeature{ + Name: "test_feature", + ErrorType: profileerrors.ErrFeatureNotSupported, + Description: "test feature blocked", + Detector: func(r *http.Request) bool { return strings.Contains(r.URL.Path, "/test-blocked") }, + }) + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + handler := reg.Middleware(next) + + w := serve(handler, reqWithEmitter(http.MethodGet, "/safe-path", newRecEmitter())) + + if !called { + t.Fatal("expected next to be called when no feature matches") + } + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func alwaysOK() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} + +func serve(h http.Handler, r *http.Request) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + return w +} + +func assertProfileError(t *testing.T, w *httptest.ResponseRecorder, errType profileerrors.ErrorType, feature string) { + t.Helper() + if w.Code == http.StatusOK { + t.Fatalf("expected non-200 status, got 200") + } + ct := w.Header().Get("Content-Type") + if !strings.Contains(ct, "application/json") { + t.Fatalf("expected application/json content type, got %q", ct) + } + var pe profileerrors.ProfileError + if err := json.NewDecoder(w.Body).Decode(&pe); err != nil { + t.Fatalf("failed to decode ProfileError: %v", err) + } + if pe.Error != errType { + t.Errorf("expected error type %q, got %q", errType, pe.Error) + } + if pe.Feature != feature { + t.Errorf("expected feature %q, got %q", feature, pe.Feature) + } +} + +func assertTelemetryEmitted(t *testing.T, em *recEmitter, feature string) { + t.Helper() + if len(em.events) == 0 { + t.Fatalf("expected telemetry event for feature %q, got none", feature) + } + last := em.events[len(em.events)-1] + if last.EventType != telemetry.EventUnsupportedFeature { + t.Errorf("expected event type %q, got %q", telemetry.EventUnsupportedFeature, last.EventType) + } + if last.Feature != feature { + t.Errorf("expected feature %q in event, got %q", feature, last.Feature) + } +}