Skip to content

Commit

Permalink
PR review: add unit tests for certificates that fail name constraints…
Browse files Browse the repository at this point in the history
… verification.
  • Loading branch information
victorr committed Dec 20, 2024
1 parent f9cf7a9 commit 4b496b6
Showing 1 changed file with 164 additions and 95 deletions.
259 changes: 164 additions & 95 deletions builtin/logical/pki/cert_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ func TestPki_PermitFQDNs(t *testing.T) {
}

type parseCertificateTestCase struct {
name string
data map[string]interface{}
roleData map[string]interface{} // if a role is to be created
ttl time.Duration
wantParams certutil.CreationParameters
wantFields map[string]interface{}
wantErr bool
name string
data map[string]interface{}
roleData map[string]interface{} // if a role is to be created
ttl time.Duration
wantParams certutil.CreationParameters
wantFields map[string]interface{}
wantIssuanceErr string // If not empty, require.ErrorContains will be used on this string
}

// TestDisableVerifyCertificateEnvVar verifies that env var VAULT_DISABLE_PKI_CONSTRAINTS_VERIFICATION
Expand Down Expand Up @@ -468,7 +468,6 @@ func TestParseCertificate(t *testing.T) {
"key_bits": 384,
"skid": "We'll assert that it is not nil as an special case",
},
wantErr: false,
},
{
// Note that this test's data is used to create the internal CA used by test "full non CA cert"
Expand All @@ -483,13 +482,13 @@ func TestParseCertificate(t *testing.T) {
"ttl": "2h",
"max_path_length": 2,
"permitted_dns_domains": "example.com,.example.com,.www.example.com",
"excluded_dns_domains": "bad.org,reallybad.com",
"excluded_dns_domains": "bad.example.com,reallybad.com",
"permitted_ip_ranges": "192.0.2.1/24,76.76.21.21/24,2001:4860:4860::8889/32", // Note that while an IP address if specified here, it is the network address that will be stored
"excluded_ip_ranges": "127.0.0.1/16,2001:4860:4860::8888/32",
"permitted_email_addresses": "info@example.com,user@example.com,admin@example.com",
"excluded_email_addresses": "root@example.com,robots@example.com",
"permitted_uri_domains": "example.com,www.example.com",
"excluded_uri_domains": "ftp://example.com,gopher://www.example.com",
"excluded_uri_domains": "ftp.example.com,gopher.www.example.com",
"ou": "unit1, unit2",
"organization": "org1, org2",
"country": "US, CA",
Expand Down Expand Up @@ -535,13 +534,13 @@ func TestParseCertificate(t *testing.T) {
ForceAppendCaChain: false,
UseCSRValues: false,
PermittedDNSDomains: []string{"example.com", ".example.com", ".www.example.com"},
ExcludedDNSDomains: []string{"bad.org", "reallybad.com"},
ExcludedDNSDomains: []string{"bad.example.com", "reallybad.com"},
PermittedIPRanges: convertIps("192.0.2.0/24", "76.76.21.0/24", "2001:4860::/32"), // Note that we stored the network address rather than the specific IP address
ExcludedIPRanges: convertIps("127.0.0.0/16", "2001:4860::/32"),
PermittedEmailAddresses: []string{"info@example.com", "user@example.com", "admin@example.com"},
ExcludedEmailAddresses: []string{"root@example.com", "robots@example.com"},
PermittedURIDomains: []string{"example.com", "www.example.com"},
ExcludedURIDomains: []string{"ftp://example.com", "gopher://www.example.com"},
ExcludedURIDomains: []string{"ftp.example.com", "gopher.www.example.com"},
URLs: nil,
MaxPathLength: 2,
NotBeforeDuration: 45 * time.Second,
Expand All @@ -566,19 +565,18 @@ func TestParseCertificate(t *testing.T) {
"ttl": "2h0m45s",
"max_path_length": 2,
"permitted_dns_domains": "example.com,.example.com,.www.example.com",
"excluded_dns_domains": "bad.org,reallybad.com",
"excluded_dns_domains": "bad.example.com,reallybad.com",
"permitted_ip_ranges": "192.0.2.0/24,76.76.21.0/24,2001:4860::/32",
"excluded_ip_ranges": "127.0.0.0/16,2001:4860::/32",
"permitted_email_addresses": "info@example.com,user@example.com,admin@example.com",
"excluded_email_addresses": "root@example.com,robots@example.com",
"permitted_uri_domains": "example.com,www.example.com",
"excluded_uri_domains": "ftp://example.com,gopher://www.example.com",
"excluded_uri_domains": "ftp.example.com,gopher.www.example.com",
"use_pss": true,
"key_type": "rsa",
"key_bits": 2048,
"skid": "We'll assert that it is not nil as an special case",
},
wantErr: false,
},
{
// Note that we use the data of test "full CA" to create the internal CA needed for this test
Expand Down Expand Up @@ -674,7 +672,90 @@ func TestParseCertificate(t *testing.T) {
"key_bits": 2048,
"skid": "We'll assert that it is not nil as an special case",
},
wantErr: false,
},
{
name: "DNS domain not permitted",
data: map[string]interface{}{
"common_name": "the common name non ca",
"alt_names": "badexample.com",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `DNS name "badexample.com" is not permitted by any constraint`,
},
{
name: "DNS domain explicitly excluded",
data: map[string]interface{}{
"common_name": "the common name non ca",
"alt_names": "bad.example.com",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `DNS name "bad.example.com" is excluded by constraint "bad.example.com"`,
},
{
name: "IP address not permitted",
data: map[string]interface{}{
"common_name": "the common name non ca",
"ip_sans": "192.0.3.1",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `IP address "192.0.3.1" is not permitted by any constraint`,
},
{
name: "IP address explicitly excluded",
data: map[string]interface{}{
"common_name": "the common name non ca",
"ip_sans": "127.0.0.123",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `IP address "127.0.0.123" is excluded by constraint "127.0.0.0/16"`,
},
{
name: "email address not permitted",
data: map[string]interface{}{
"common_name": "the common name non ca",
"alt_names": "random@example.com",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `email address "random@example.com" is not permitted by any constraint`,
},
{
name: "email address explicitly excluded",
data: map[string]interface{}{
"common_name": "the common name non ca",
"alt_names": "root@example.com",
"ttl": "2h",
},
ttl: 2 * time.Hour,
roleData: map[string]interface{}{
"allow_any_name": true,
"cn_validations": "disabled",
},
wantIssuanceErr: `email address "root@example.com" is excluded by constraint "root@example.com"`,
},
}
for _, tt := range tests {
Expand Down Expand Up @@ -707,15 +788,22 @@ func TestParseCertificate(t *testing.T) {

// create the cert
resp, err = CBWrite(b, s, "issue/test", tt.data)
require.NoError(t, err)
require.NotNil(t, resp)

certData := resp.Data["certificate"].(string)
cert, err = parsing.ParseCertificateFromString(certData)
require.NoError(t, err)
require.NotNil(t, cert)
if tt.wantIssuanceErr != "" {
require.ErrorContains(t, err, tt.wantIssuanceErr)
} else {
require.NoError(t, err)
require.NotNil(t, resp)

certData := resp.Data["certificate"].(string)
cert, err = parsing.ParseCertificateFromString(certData)
require.NoError(t, err)
require.NotNil(t, cert)
}
}

if tt.wantIssuanceErr != "" {
return
}
t.Run(tt.name+" parameters", func(t *testing.T) {
testParseCertificateToCreationParameters(t, issueTime, tt, cert)
})
Expand All @@ -729,72 +817,64 @@ func TestParseCertificate(t *testing.T) {
func testParseCertificateToCreationParameters(t *testing.T, issueTime time.Time, tt *parseCertificateTestCase, cert *x509.Certificate) {
params, err := certutil.ParseCertificateToCreationParameters(*cert)

if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NoError(t, err)

ignoreBasicConstraintsValidForNonCA := tt.wantParams.IsCA

var diff []string
for _, d := range deep.Equal(tt.wantParams, params) {
switch {
case strings.HasPrefix(d, "SKID"):
continue
case strings.HasPrefix(d, "BasicConstraintsValidForNonCA") && ignoreBasicConstraintsValidForNonCA:
continue
case strings.HasPrefix(d, "NotBeforeDuration"):
continue
case strings.HasPrefix(d, "NotAfter"):
continue
}
diff = append(diff, d)
}
if diff != nil {
t.Errorf("testParseCertificateToCreationParameters() diff: %s", strings.Join(diff, "\n"))
ignoreBasicConstraintsValidForNonCA := tt.wantParams.IsCA

var diff []string
for _, d := range deep.Equal(tt.wantParams, params) {
switch {
case strings.HasPrefix(d, "SKID"):
continue
case strings.HasPrefix(d, "BasicConstraintsValidForNonCA") && ignoreBasicConstraintsValidForNonCA:
continue
case strings.HasPrefix(d, "NotBeforeDuration"):
continue
case strings.HasPrefix(d, "NotAfter"):
continue
}
diff = append(diff, d)
}
if diff != nil {
t.Errorf("testParseCertificateToCreationParameters() diff: %s", strings.Join(diff, "\n"))
}

require.NotNil(t, params.SKID)
require.GreaterOrEqual(t, params.NotBeforeDuration, tt.wantParams.NotBeforeDuration,
"NotBeforeDuration want: %s got: %s", tt.wantParams.NotBeforeDuration, params.NotBeforeDuration)
require.NotNil(t, params.SKID)
require.GreaterOrEqual(t, params.NotBeforeDuration, tt.wantParams.NotBeforeDuration,
"NotBeforeDuration want: %s got: %s", tt.wantParams.NotBeforeDuration, params.NotBeforeDuration)

require.GreaterOrEqual(t, params.NotAfter, issueTime.Add(tt.ttl).Add(-1*time.Minute),
"NotAfter want: %s got: %s", tt.wantParams.NotAfter, params.NotAfter)
require.LessOrEqual(t, params.NotAfter, issueTime.Add(tt.ttl).Add(1*time.Minute),
"NotAfter want: %s got: %s", tt.wantParams.NotAfter, params.NotAfter)
}
require.GreaterOrEqual(t, params.NotAfter, issueTime.Add(tt.ttl).Add(-1*time.Minute),
"NotAfter want: %s got: %s", tt.wantParams.NotAfter, params.NotAfter)
require.LessOrEqual(t, params.NotAfter, issueTime.Add(tt.ttl).Add(1*time.Minute),
"NotAfter want: %s got: %s", tt.wantParams.NotAfter, params.NotAfter)
}

func testParseCertificateToFields(t *testing.T, issueTime time.Time, tt *parseCertificateTestCase, cert *x509.Certificate) {
fields, err := certutil.ParseCertificateToFields(*cert)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NoError(t, err)

require.NotNil(t, fields["skid"])
delete(fields, "skid")
delete(tt.wantFields, "skid")
require.NotNil(t, fields["skid"])
delete(fields, "skid")
delete(tt.wantFields, "skid")

{
// Sometimes TTL comes back as 1s off, so we'll allow that
expectedTTL, err := parseutil.ParseDurationSecond(tt.wantFields["ttl"].(string))
require.NoError(t, err)
actualTTL, err := parseutil.ParseDurationSecond(fields["ttl"].(string))
require.NoError(t, err)

diff := expectedTTL - actualTTL
require.LessOrEqual(t, actualTTL, expectedTTL, // NotAfter is generated before NotBefore so the time.Now of notBefore may be later, shrinking our calculated TTL during very slow tests
"ttl should be, if off, smaller than expected want: %s got: %s", tt.wantFields["ttl"], fields["ttl"])
require.LessOrEqual(t, diff, 30*time.Second, // Test can be slow, allow more off in the other direction
"ttl must be at most 30s off, want: %s got: %s", tt.wantFields["ttl"], fields["ttl"])
delete(fields, "ttl")
delete(tt.wantFields, "ttl")
}
{
// Sometimes TTL comes back as 1s off, so we'll allow that
expectedTTL, err := parseutil.ParseDurationSecond(tt.wantFields["ttl"].(string))
require.NoError(t, err)
actualTTL, err := parseutil.ParseDurationSecond(fields["ttl"].(string))
require.NoError(t, err)

if diff := deep.Equal(tt.wantFields, fields); diff != nil {
t.Errorf("testParseCertificateToFields() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
diff := expectedTTL - actualTTL
require.LessOrEqual(t, actualTTL, expectedTTL, // NotAfter is generated before NotBefore so the time.Now of notBefore may be later, shrinking our calculated TTL during very slow tests
"ttl should be, if off, smaller than expected want: %s got: %s", tt.wantFields["ttl"], fields["ttl"])
require.LessOrEqual(t, diff, 30*time.Second, // Test can be slow, allow more off in the other direction
"ttl must be at most 30s off, want: %s got: %s", tt.wantFields["ttl"], fields["ttl"])
delete(fields, "ttl")
delete(tt.wantFields, "ttl")
}

if diff := deep.Equal(tt.wantFields, fields); diff != nil {
t.Errorf("testParseCertificateToFields() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
}

Expand Down Expand Up @@ -870,7 +950,6 @@ func TestParseCsr(t *testing.T) {
"serial_number": "",
"add_basic_constraints": false,
},
wantErr: false,
},
{
name: "full CSR with basic constraints",
Expand Down Expand Up @@ -957,7 +1036,6 @@ func TestParseCsr(t *testing.T) {
"serial_number": "37:60:16:e4:85:d5:96:38:3a:ed:31:06:8d:ed:7a:46:d4:22:63:d8",
"add_basic_constraints": true,
},
wantErr: false,
},
{
name: "full CSR without basic constraints",
Expand Down Expand Up @@ -1044,7 +1122,6 @@ func TestParseCsr(t *testing.T) {
"serial_number": "37:60:16:e4:85:d5:96:38:3a:ed:31:06:8d:ed:7a:46:d4:22:63:d8",
"add_basic_constraints": false,
},
wantErr: false,
},
}
for _, tt := range tests {
Expand Down Expand Up @@ -1073,26 +1150,18 @@ func TestParseCsr(t *testing.T) {
func testParseCsrToCreationParameters(t *testing.T, issueTime time.Time, tt *parseCertificateTestCase, csr *x509.CertificateRequest) {
params, err := certutil.ParseCsrToCreationParameters(*csr)

if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NoError(t, err)

if diff := deep.Equal(tt.wantParams, params); diff != nil {
t.Errorf("testParseCertificateToCreationParameters() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
if diff := deep.Equal(tt.wantParams, params); diff != nil {
t.Errorf("testParseCertificateToCreationParameters() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
}

func testParseCsrToFields(t *testing.T, issueTime time.Time, tt *parseCertificateTestCase, csr *x509.CertificateRequest) {
fields, err := certutil.ParseCsrToFields(*csr)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NoError(t, err)

if diff := deep.Equal(tt.wantFields, fields); diff != nil {
t.Errorf("testParseCertificateToFields() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
if diff := deep.Equal(tt.wantFields, fields); diff != nil {
t.Errorf("testParseCertificateToFields() diff: %s", strings.ReplaceAll(strings.Join(diff, "\n"), "map", "\nmap"))
}
}

0 comments on commit 4b496b6

Please sign in to comment.