diff --git a/assets/examples/https-wrench-response-certificates-filter.yaml b/assets/examples/https-wrench-response-certificates-filter.yaml new file mode 100644 index 0000000..efe48fe --- /dev/null +++ b/assets/examples/https-wrench-response-certificates-filter.yaml @@ -0,0 +1,33 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/xenOs76/https-wrench/refs/heads/main/https-wrench.schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj +--- +## HTTPS Wrench — Response Certificates Filter Example +## +## This example demonstrates how to use the 'responseCertificatesFilter' option +## to selectively print certificate chains in your HTTP responses. +## +## Note: This option depends on 'printResponseCertificates: true' being enabled. + +debug: false +verbose: true + +requests: + - name: selective-certificate-display + printResponseCertificates: true + responseCertificatesFilter: + # Filter for the leaf certificate (index 0) in the peer chain + - 0: + - Subject + - DNSNames + - Issuer + - NotAfter + - Expiration + # Filter for the intermediate/CA certificate (index 1) in the peer chain + - 1: + - Subject + - Issuer + - IsCA + + hosts: + - name: google.com + - name: github.com diff --git a/https-wrench.schema.json b/https-wrench.schema.json index 8217d07..c951a9b 100644 --- a/https-wrench.schema.json +++ b/https-wrench.schema.json @@ -64,6 +64,37 @@ "printResponseCertificates": { "type": "boolean" }, + "responseCertificatesFilter": { + "type": "array", + "description": "Filter to display only specific certificates from the peer chain and/or only subset of fields for each certificate. Each item in the array is a map of certificate index (0-indexed, where 0 is the leaf certificate) to a list of certificate fields to render (e.g. Subject, DNSNames, Issuer, NotBefore, NotAfter, Expiration). If the list of fields is empty, all fields for that certificate are printed.", + "items": { + "type": "object", + "patternProperties": { + "^[0-9]+$": { + "type": "array", + "items": { + "type": "string", + "enum": [ + "Subject", + "DNSNames", + "Issuer", + "NotBefore", + "NotAfter", + "Expiration", + "IsCA", + "AuthorityKeyId", + "SubjectKeyId", + "PublicKeyAlgorithm", + "SignatureAlgorithm", + "SerialNumber", + "Fingerprint SHA-256" + ] + } + } + }, + "additionalProperties": false + } + }, "enableProxyProtocolV2": { "type": "boolean" }, @@ -107,6 +138,9 @@ "dependencies": { "enableProxyProtocolV2": [ "transportOverrideUrl" + ], + "responseCertificatesFilter": [ + "printResponseCertificates" ] }, "title": "Request" diff --git a/internal/certinfo/certinfo_handlers.go b/internal/certinfo/certinfo_handlers.go index ec58a52..e6b1cd2 100644 --- a/internal/certinfo/certinfo_handlers.go +++ b/internal/certinfo/certinfo_handlers.go @@ -231,59 +231,164 @@ func (c *Config) GetRemoteCerts() error { } // CertsToTables formats and prints a list of x509 certificates as tables to the provided writer. -func CertsToTables(w io.Writer, certs []*x509.Certificate) { +// An optional filter slice of maps can be provided to filter printed output by certificate index and field names. +// +//nolint:gocognit,funlen,gocyclo,wsl,revive,cyclop +func CertsToTables(w io.Writer, certs []*x509.Certificate, filter ...[]map[int][]string) { sl := style.CertKeyP4.Render sv := style.CertValue.Render svn := style.CertValueNotice.Render + var f []map[int][]string + if len(filter) > 0 { + f = filter[0] + } + + // If a filter is provided, map it for fast lookup by certificate index + requestedCerts := make(map[int][]string) + + hasFilter := len(f) > 0 + if hasFilter { + for _, m := range f { + for k, fields := range m { + requestedCerts[k] = fields + } + } + } + for i := range certs { + var fields []string + + if hasFilter { + var ok bool + + fields, ok = requestedCerts[i] + if !ok { + // Certificate index not in filter list, skip displaying it + continue + } + } + header := style.LgSprintf( style.CertKeyP4.Bold(true), "Certificate %d", i) cert := certs[i] - subject := cert.Subject.String() - dnsNames := "[" + strings.Join(cert.DNSNames, ", ") + "]" - issuer := cert.Issuer.String() + // Helper to check if a specific field is requested (case-insensitive) + hasField := func(fieldName string) bool { + if !hasFilter || len(fields) == 0 { + return true // Print all fields if no filter is active or if field list is empty for this cert + } + + for _, field := range fields { + if strings.EqualFold(field, fieldName) { + return true + } + } + + return false + } + + t := table.New().Border(style.LGDefBorder).Headers(header) + hasRows := false + addRow := func(k, v string) { + t.Row(k, v) + + hasRows = true + } - notBefore := cert.NotBefore - notAfter := cert.NotAfter - expiration := humanize.Time(notAfter) - daysUntilExpiration := time.Until(notAfter).Hours() / 24 + if hasField("Subject") { + subject := cert.Subject.String() + addRow(sl("Subject"), sv(subject)) + } - expStyle := sv - if (0 < daysUntilExpiration) && (daysUntilExpiration < CertExpWarnDays) { - expStyle = style.Warn.Render + if hasField("DNSNames") { + dnsNames := strings.Join(cert.DNSNames, "\n") + addRow(sl("DNSNames"), sv(dnsNames)) } - if daysUntilExpiration <= 0 { - expStyle = style.Crit.Render + if hasField("Issuer") { + issuer := cert.Issuer.String() + addRow(sl("Issuer"), sv(issuer)) } - isCA := strconv.FormatBool(cert.IsCA) - publicKeyAlgorithm := cert.PublicKeyAlgorithm.String() - authorityKeyID := hex.EncodeToString(cert.AuthorityKeyId) - subjectKeyID := hex.EncodeToString(cert.SubjectKeyId) - signatureAlgorithm := cert.SignatureAlgorithm.String() - fingerprintSha256 := fmt.Sprintf("%x", sha256.Sum256(cert.Raw)) - serialNumber := cert.SerialNumber.String() + if hasField("NotBefore") { + notBefore := cert.NotBefore + addRow(sl("NotBefore"), sv(notBefore.String())) + } + + // Calculate expiration colors if needed + var expStyle func(...string) string + + getExpStyle := func() func(...string) string { + if expStyle != nil { + return expStyle + } + + daysUntilExpiration := time.Until(cert.NotAfter).Hours() / 24 + + expStyle = sv + if (0 < daysUntilExpiration) && (daysUntilExpiration < CertExpWarnDays) { + expStyle = style.Warn.Render + } + + if daysUntilExpiration <= 0 { + expStyle = style.Crit.Render + } + + return expStyle + } + + if hasField("NotAfter") { + notAfter := cert.NotAfter + addRow(sl("NotAfter"), getExpStyle()(notAfter.String())) + } + + if hasField("Expiration") { + expiration := humanize.Time(cert.NotAfter) + addRow(sl("Expiration"), getExpStyle()(expiration)) + } + + if hasField("IsCA") { + isCA := strconv.FormatBool(cert.IsCA) + addRow(sl("IsCA"), svn(isCA)) + } + + if hasField("AuthorityKeyId") { + authorityKeyID := hex.EncodeToString(cert.AuthorityKeyId) + addRow(sl("AuthorityKeyId"), svn(authorityKeyID)) + } + + if hasField("SubjectKeyId") { + subjectKeyID := hex.EncodeToString(cert.SubjectKeyId) + addRow(sl("SubjectKeyId"), svn(subjectKeyID)) + } + + if hasField("PublicKeyAlgorithm") { + publicKeyAlgorithm := cert.PublicKeyAlgorithm.String() + addRow(sl("PublicKeyAlgorithm"), sv(publicKeyAlgorithm)) + } + + if hasField("SignatureAlgorithm") { + signatureAlgorithm := cert.SignatureAlgorithm.String() + addRow(sl("SignatureAlgorithm"), sv(signatureAlgorithm)) + } + + if hasField("SerialNumber") { + serialNumber := cert.SerialNumber.String() + addRow(sl("SerialNumber"), sv(serialNumber)) + } + + if hasField("Fingerprint SHA-256") || hasField("Fingerprint") { + fingerprintSha256 := fmt.Sprintf("%x", sha256.Sum256(cert.Raw)) + addRow(sl("Fingerprint SHA-256"), sv(fingerprintSha256)) + } + + if hasRows { + fmt.Fprintln(w, t.Render()) + } - t := table.New().Border(style.LGDefBorder).Headers(header) - t.Row(sl("Subject"), sv(subject)) - t.Row(sl("DNSNames"), sv(dnsNames)) - t.Row(sl("Issuer"), sv(issuer)) - t.Row(sl("NotBefore"), sv(notBefore.String())) - t.Row(sl("NotAfter"), expStyle(notAfter.String())) - t.Row(sl("Expiration"), expStyle(expiration)) - t.Row(sl("IsCA"), svn(isCA)) - t.Row(sl("AuthorityKeyId"), svn(authorityKeyID)) - t.Row(sl("SubjectKeyId"), svn(subjectKeyID)) - t.Row(sl("PublicKeyAlgorithm"), sv(publicKeyAlgorithm)) - t.Row(sl("SignatureAlgorithm"), sv(signatureAlgorithm)) - t.Row(sl("SerialNumber"), sv(serialNumber)) - t.Row(sl("Fingerprint SHA-256"), sv(fingerprintSha256)) - fmt.Fprintln(w, t.Render()) t.ClearRows() } } diff --git a/internal/certinfo/certinfo_handlers_test.go b/internal/certinfo/certinfo_handlers_test.go index 51b58cd..186ec83 100644 --- a/internal/certinfo/certinfo_handlers_test.go +++ b/internal/certinfo/certinfo_handlers_test.go @@ -3,7 +3,10 @@ package certinfo import ( "bytes" "crypto/x509" + "crypto/x509/pkix" + "math/big" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -204,7 +207,7 @@ func TestCertinfo_CertsToTables(t *testing.T) { subject: "Subject CN=RSA Testing CA", isCA: "IsCA true", expiration: "Expiration 23 hours from now", - dnsNames: "DNSNames []", + dnsNames: "DNSNames", publicKeyAlgorithm: "PublicKeyAlgorithm RSA", signatureAlgorithm: "SignatureAlgorithm SHA256-RSA", }, @@ -214,7 +217,7 @@ func TestCertinfo_CertsToTables(t *testing.T) { subject: "Subject CN=RSA Testing Sample Certificate", isCA: "IsCA false", expiration: "Expiration 23 hours from now", - dnsNames: "DNSNames [example.com, example.net, example.de]", + dnsNames: "example.com", publicKeyAlgorithm: "PublicKeyAlgorithm RSA", signatureAlgorithm: "SignatureAlgorithm SHA256-RSA", }, @@ -224,7 +227,7 @@ func TestCertinfo_CertsToTables(t *testing.T) { subject: "Subject CN=example.com,O=example Ltd,L=Berlin,ST=Some-State,C=DE", isCA: "IsCA false", expiration: "ago", - dnsNames: "DNSNames []", + dnsNames: "DNSNames", publicKeyAlgorithm: "PublicKeyAlgorithm RSA", signatureAlgorithm: "SignatureAlgorithm SHA256-RSA", }, @@ -234,7 +237,7 @@ func TestCertinfo_CertsToTables(t *testing.T) { cert: ecdsaCert[0], subject: "Subject CN=example.com,O=Example Org", isCA: "IsCA true", - dnsNames: "DNSNames []", + dnsNames: "DNSNames", publicKeyAlgorithm: "PublicKeyAlgorithm ECDSA", signatureAlgorithm: "SignatureAlgorithm ECDSA-SHA256", }, @@ -243,7 +246,7 @@ func TestCertinfo_CertsToTables(t *testing.T) { cert: ed25519Cert[0], subject: "Subject CN=example.com,O=Example Org", isCA: "IsCA true", - dnsNames: "DNSNames []", + dnsNames: "DNSNames", publicKeyAlgorithm: "PublicKeyAlgorithm Ed25519", signatureAlgorithm: "SignatureAlgorithm Ed25519", }, @@ -289,7 +292,75 @@ func TestCertinfo_CertsToTables(t *testing.T) { } } -//nolint:revive +func TestCertinfo_CertsToTables_FilteringAndWarning(t *testing.T) { + warnCert := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "warning.example.com", + Organization: []string{"Warning Corp"}, + }, + Issuer: pkix.Name{ + CommonName: "warning.example.com", + Organization: []string{"Warning Corp"}, + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(10 * 24 * time.Hour), // 10 days -> warning (< 40 days)! + Raw: []byte("dummy raw cert bytes"), + SerialNumber: big.NewInt(999), + DNSNames: []string{"warning.example.com", "alt.warning.example.com"}, + } + + t.Run("Warning Style", func(t *testing.T) { + var buf bytes.Buffer + CertsToTables(&buf, []*x509.Certificate{warnCert}) + got := buf.String() + require.Contains(t, got, "warning.example.com") + require.Contains(t, got, "1 week from now") + }) + + t.Run("Filtered Output", func(t *testing.T) { + var buf bytes.Buffer + + filter := []map[int][]string{ + { + 0: []string{"Subject", "DNSNames"}, + }, + } + CertsToTables(&buf, []*x509.Certificate{warnCert}, filter) + got := buf.String() + require.Contains(t, got, "Subject") + require.Contains(t, got, "warning.example.com") + require.Contains(t, got, "alt.warning.example.com") + require.NotContains(t, got, "Issuer") + require.NotContains(t, got, "NotAfter") + }) + + t.Run("Filtered Empty List", func(t *testing.T) { + var buf bytes.Buffer + + filter := []map[int][]string{ + { + 0: []string{}, // empty list means print all fields for this cert + }, + } + CertsToTables(&buf, []*x509.Certificate{warnCert}, filter) + got := buf.String() + require.Contains(t, got, "Subject") + require.Contains(t, got, "Issuer") + }) + + t.Run("Filtered Mismatched Index", func(t *testing.T) { + var buf bytes.Buffer + + filter := []map[int][]string{ + { + 1: []string{"Subject"}, // index 1 doesn't exist for single cert, so cert 0 skipped + }, + } + CertsToTables(&buf, []*x509.Certificate{warnCert}, filter) + got := buf.String() + require.NotContains(t, got, "warning.example.com") + }) +} //nolint:revive func TestCertinfo_PrintData(t *testing.T) { diff --git a/internal/cmd/embedded/config-example.yaml b/internal/cmd/embedded/config-example.yaml index 032a45c..f3a82ee 100644 --- a/internal/cmd/embedded/config-example.yaml +++ b/internal/cmd/embedded/config-example.yaml @@ -62,6 +62,15 @@ requests: ## printResponseCertificates - If true, prints the TLS certificates returned in the response. printResponseCertificates: true + ## responseCertificatesFilter - Filter the printed TLS certificates to show only specific certificates in the chain (0-indexed) and/or a subset of fields. + # responseCertificatesFilter: + # - 0: + # - Subject + # - DNSNames + # - Issuer + # - NotAfter + # - Expiration + ## requestMethod - The HTTP method to use for the request (e.g., GET, POST, PUT, DELETE). Defaults to GET. requestMethod: POST diff --git a/internal/requests/requests.go b/internal/requests/requests.go index 8a89862..e03a9c8 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -122,6 +122,12 @@ type RequestConfig struct { PrintResponseHeaders bool `mapstructure:"printResponseHeaders"` // PrintResponseCertificates indicates if the response TLS certificates should be printed. PrintResponseCertificates bool `mapstructure:"printResponseCertificates"` + // ResponseCertificatesFilter is a list of filters mapping certificate chain indices + // (0 for leaf, 1, 2, etc. for intermediates/roots) to specific fields that should be printed. + // Valid fields include: "Subject", "DNSNames", "Issuer", "NotBefore", "NotAfter", + // "Expiration", "IsCA", "AuthorityKeyId", "SubjectKeyId", "PublicKeyAlgorithm", + // "SignatureAlgorithm", "SerialNumber", and "Fingerprint SHA-256". + ResponseCertificatesFilter []map[int][]string `mapstructure:"responseCertificatesFilter"` // Hosts is a list of target hosts and their URIs. Hosts []Host `mapstructure:"hosts"` } diff --git a/internal/requests/requests_filter_test.go b/internal/requests/requests_filter_test.go new file mode 100644 index 0000000..2884d36 --- /dev/null +++ b/internal/requests/requests_filter_test.go @@ -0,0 +1,120 @@ +package requests + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createTestCert(t *testing.T) *x509.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Corp"}, + CommonName: "example.com", + }, + Issuer: pkix.Name{ + Organization: []string{"Acme Authority"}, + CommonName: "CA", + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"example.com", "www.example.com"}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(derBytes) + require.NoError(t, err) + + return cert +} + +func TestRenderTLSData_Filtering(t *testing.T) { + cert := createTestCert(t) + + resp := &http.Response{ + TLS: &tls.ConnectionState{ + Version: tls.VersionTLS13, + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + PeerCertificates: []*x509.Certificate{cert}, + }, + } + + t.Run("No Filter", func(t *testing.T) { + var buf bytes.Buffer + RenderTLSData(&buf, resp) + + output := buf.String() + assert.Contains(t, output, "TLS:") + assert.Contains(t, output, "Version") + assert.Contains(t, output, "TLS 1.3") + assert.Contains(t, output, "Subject") + assert.Contains(t, output, "example.com") + assert.Contains(t, output, "Issuer") + assert.Contains(t, output, "Acme Corp") + assert.Contains(t, output, "DNSNames") + assert.Contains(t, output, "www.example.com") + }) + + t.Run("Filter Cert 0 with specific fields", func(t *testing.T) { + var buf bytes.Buffer + + filter := []map[int][]string{ + { + 0: []string{"Subject", "DNSNames"}, + }, + } + RenderTLSData(&buf, resp, filter) + + output := buf.String() + assert.Contains(t, output, "TLS:") + assert.Contains(t, output, "Subject") + assert.Contains(t, output, "example.com") + assert.Contains(t, output, "DNSNames") + assert.Contains(t, output, "www.example.com") + + // Filtered-out fields should NOT be in the output + assert.NotContains(t, output, "Issuer") + assert.NotContains(t, output, "NotBefore") + assert.NotContains(t, output, "Expiration") + }) + + t.Run("Filter Non-existent Cert Index", func(t *testing.T) { + var buf bytes.Buffer + + filter := []map[int][]string{ + { + 1: []string{"Subject"}, // Certificate 1 does not exist in chain + }, + } + RenderTLSData(&buf, resp, filter) + + output := buf.String() + assert.Contains(t, output, "TLS:") + // Output should NOT contain Certificate 0 information + assert.NotContains(t, output, "Certificate 0") + assert.NotContains(t, output, "Subject") + }) +} diff --git a/internal/requests/requests_handlers.go b/internal/requests/requests_handlers.go index a89c401..6e76a77 100644 --- a/internal/requests/requests_handlers.go +++ b/internal/requests/requests_handlers.go @@ -315,7 +315,7 @@ func (rd ResponseData) PrintResponseData(isVerbose bool) { style.StatusCodeParse(rd.Response.StatusCode))) if rd.Request.PrintResponseCertificates { - RenderTLSData(os.Stdout, rd.Response) + RenderTLSData(os.Stdout, rd.Response, rd.Request.ResponseCertificatesFilter) } if rd.Request.PrintResponseHeaders { @@ -342,7 +342,8 @@ func (rd ResponseData) PrintResponseData(isVerbose bool) { } // RenderTLSData prints TLS version, cipher suite, and peer certificates for an HTTP response. -func RenderTLSData(w io.Writer, r *http.Response) { +// An optional filter can be provided to only print specific certificate indices and fields. +func RenderTLSData(w io.Writer, r *http.Response, filter ...[]map[int][]string) { respTLS := r.TLS sl := style.CertKeyP4.Render sv := style.CertValue.Render @@ -373,5 +374,10 @@ func RenderTLSData(w io.Writer, r *http.Response) { fmt.Fprintln(w, t.Render()) t.ClearRows() - certinfo.CertsToTables(w, respTLS.PeerCertificates) + var f []map[int][]string + if len(filter) > 0 { + f = filter[0] + } + + certinfo.CertsToTables(w, respTLS.PeerCertificates, f) }