Skip to content

Commit

Permalink
Merge pull request #17 from moriyoshi/prepared-certificates
Browse files Browse the repository at this point in the history
Support for prepared certificates
  • Loading branch information
moriyoshi authored Aug 9, 2020
2 parents 51a9eee + fee58d1 commit f1de758
Show file tree
Hide file tree
Showing 9 changed files with 498 additions and 66 deletions.
190 changes: 150 additions & 40 deletions certcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,80 @@ type CertCache struct {

const certificateFileName = "cert.pem"
const certificateBlockName = "CERTIFICATE"
const privateKeyFileName = "key.pem"
const privateKeyBlockName = "PRIVATE KEY"

func buildKeyString(hosts []string) string {
key := strings.Join(hosts, ";")
return key
}

func (c *CertCache) writeCertificate(key string, cert *tls.Certificate) (err error) {
leadingDirs, ok := c.buildPathToCachedCert(key)
leadingDir, ok := c.buildPathToCachedCert(key)
if !ok {
return
}
err = os.MkdirAll(leadingDirs, os.FileMode(0777))

leadingDirTmp := leadingDir + "$tmp$"
err = os.MkdirAll(leadingDirTmp, os.FileMode(0700))
if err != nil {
return
}
path := filepath.Join(leadingDirs, certificateFileName)
w, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0666))
defer func() {
if err != nil {
os.RemoveAll(leadingDirTmp)
}
}()

certFilePath := filepath.Join(leadingDirTmp, certificateFileName)
privKeyFilePath := filepath.Join(leadingDirTmp, privateKeyFileName)

err = func(certFilePath string) (err error) {
w, err := os.OpenFile(certFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0600))
if err != nil {
return
}
defer w.Close()

for _, x509Cert := range cert.Certificate {
err = pem.Encode(w, &pem.Block{Type: certificateBlockName, Bytes: x509Cert})
if err != nil {
return
}
_, err = w.Write([]byte{'\n'})
if err != nil {
return
}
}
return
}(certFilePath)
if err != nil {
return
}
defer w.Close()
err = pem.Encode(w, &pem.Block{Type: certificateBlockName, Bytes: cert.Certificate[0]})

err = func(privKeyFilePath string) (err error) {
privKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
if err != nil {
return
}

w, err := os.OpenFile(privKeyFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0666))
if err != nil {
return
}
defer w.Close()
err = pem.Encode(w, &pem.Block{Type: privateKeyBlockName, Bytes: privKeyBytes})
if err != nil {
return
}
return
}(privKeyFilePath)
if err != nil {
return
}
return nil

err = os.Rename(leadingDirTmp, leadingDir)
return
}

func (c *CertCache) buildPathToCachedCert(key string) (string, bool) {
Expand All @@ -92,66 +140,121 @@ func (c *CertCache) buildPathToCachedCert(key string) (string, bool) {
}

func (c *CertCache) readAndValidateCertificate(key string, hosts []string, now time.Time) (*tls.Certificate, error) {
leadingDirs, ok := c.buildPathToCachedCert(key)
leadingDir, ok := c.buildPathToCachedCert(key)
if !ok {
return nil, nil
}
path := filepath.Join(leadingDirs, certificateFileName)
pemBytes, err := ioutil.ReadFile(path)
if err != nil {
return nil, err

if _, err := os.Stat(leadingDir); os.IsNotExist(err) {
return nil, nil
}
certDerBytes := []byte(nil)
for {
var pemBlock *pem.Block
pemBlock, pemBytes = pem.Decode(pemBytes)
if pemBlock == nil {
break

var x509Cert *x509.Certificate
var certDerBytes [][]byte
{
certFilePath := filepath.Join(leadingDir, certificateFileName)
pemBytes, err := ioutil.ReadFile(certFilePath)
if err != nil {
return nil, err
}
if pemBlock.Type == certificateBlockName {
certDerBytes = pemBlock.Bytes
break

for {
var pemBlock *pem.Block
pemBlock, pemBytes = pem.Decode(pemBytes)
if pemBlock == nil {
break
}
if pemBlock.Type == certificateBlockName {
certDerBytes = append(certDerBytes, pemBlock.Bytes)
}
}
if len(certDerBytes) == 0 {
return nil, errors.Errorf("no valid certificate contained in %s", certFilePath)
}

x509Cert, err = x509.ParseCertificate(certDerBytes[0])
if err != nil {
return nil, errors.Wrapf(err, "invalid certificate found in %s", certFilePath)
}
if len(certDerBytes) == 1 && c.issuerCert != nil {
err = x509Cert.CheckSignatureFrom(c.issuerCert)
if err != nil {
return nil, errors.Wrapf(err, "invalid certificate found in %s", certFilePath)
}
}

if !now.Before(x509Cert.NotAfter) {
return nil, errors.Errorf("ceritificate no longer valid (not after: %s, now: %s)", x509Cert.NotAfter.Local().Format(time.RFC1123), now.Local().Format(time.RFC1123))
}
}
if certDerBytes == nil {
return nil, errors.Errorf("no valid certificate contained in %s", path)
}
x509Cert, err := x509.ParseCertificate(certDerBytes)
if err != nil {
return nil, errors.Wrapf(err, "invalid certificate found in %s", path)
}
x509Cert.RawIssuer = c.issuerCert.Raw
err = x509Cert.CheckSignatureFrom(c.issuerCert)
if err != nil {
return nil, errors.Wrapf(err, "invalid certificate found in %s", path)

var privKey crypto.PrivateKey
{
privKeyFilePath := filepath.Join(leadingDir, privateKeyFileName)
pemBytes, err := ioutil.ReadFile(privKeyFilePath)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
}
}
if err == nil {
b, _ := pem.Decode(pemBytes)
privKey, err = x509.ParsePKCS8PrivateKey(b.Bytes)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse private key %s", privKeyFilePath)
}
} else {
if c != nil {
privKey = c.privateKey
}
err = nil
}
}
if !now.Before(x509Cert.NotAfter) {
return nil, errors.Errorf("ceritificate no longer valid (not after: %s, now: %s)", x509Cert.NotAfter.Local().Format(time.RFC1123), now.Local().Format(time.RFC1123))

if privKey == nil {
return nil, errors.Errorf("no private key is available (cache is broken)")
}

outer:
for _, a := range hosts {
dnsNameMatched := false
for _, b := range x509Cert.DNSNames {
if a == b {
if wildMatch(b, a) {
dnsNameMatched = true
break outer
}
}
return nil, errors.Errorf("certificate does not cover the host name %s", a)
if !dnsNameMatched {
dnsNameMatched = wildMatch(x509Cert.Subject.CommonName, a)
}
if !dnsNameMatched {
return nil, errors.Errorf("certificate does not cover the host name %s", a)
}
}

return &tls.Certificate{
Certificate: [][]byte{certDerBytes, c.issuerCert.Raw},
PrivateKey: c.privateKey,
Certificate: certDerBytes,
PrivateKey: privKey,
}, nil
}

func (c *CertCache) evict(key string) error {
c.Logger.Debugf("evicting cache foe %s", key)
leadingDir, ok := c.buildPathToCachedCert(key)
if !ok {
return nil
}
return os.RemoveAll(leadingDir)
}

func (c *CertCache) readCertificate(key string, hosts []string, now time.Time) (cert *tls.Certificate, err error) {
cert, err = c.readAndValidateCertificate(
key,
hosts,
now,
)
if err != nil {
c.evict(key)
c.Logger.Warn(err.Error())
err = nil
}
Expand All @@ -161,20 +264,27 @@ func (c *CertCache) readCertificate(key string, hosts []string, now time.Time) (
func (c *CertCache) Put(hosts []string, cert *tls.Certificate) error {
key := buildKeyString(hosts)
c.certs[key] = cert
return c.writeCertificate(key, cert)
err := c.writeCertificate(key, cert)
if err != nil {
c.Logger.Warn(err.Error())
err = nil
}
return err
}

func (c *CertCache) Get(hosts []string, now time.Time) (cert *tls.Certificate, err error) {
key := buildKeyString(hosts)
cert, ok := c.certs[key]
if !ok {
c.Logger.Debug("Certificate not found in in-process cache")
c.Logger.Debug("certificate not found in in-process cache")
cert, err = c.readCertificate(key, hosts, now)
if err != nil {
return
}
if cert != nil {
c.certs[key] = cert
} else {
c.Logger.Debug("certificate not found in cache directory")
}
}
return
Expand Down
34 changes: 33 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,20 @@ type PerHostConfig struct {
Patterns []Pattern
}

type PreparedCertificate struct {
Pattern *regexp.Regexp
TLSCertificate *tls.Certificate
Certificate *x509.Certificate
}

type MITMConfig struct {
ServerTLSConfigTemplate *tls.Config
ClientTLSConfigTemplate *tls.Config
SigningCertificateKeyPair struct {
Certificate *x509.Certificate
PrivateKey crypto.PrivateKey
}
Prepared []PreparedCertificate
CacheDirectory string
DisableCache bool
}
Expand Down Expand Up @@ -140,7 +147,7 @@ func (ctx *ConfigReaderContext) extractPerHostConfigs(deref dereference) (perHos
"hosts", func(urlStr string, hostMap dereference) error {
url, err := url.Parse(urlStr)
if err != nil {
return errors.Wrapf(err, "invalid value for URL (%s) under %s", urlStr)
return errors.Wrapf(err, "invalid value for URL (%s)", urlStr)
}
if url.Path != "" {
return errors.Errorf("path may not be present: %s", urlStr)
Expand Down Expand Up @@ -735,6 +742,7 @@ func (ctx *ConfigReaderContext) extractTLSConfig(deref dereference, client bool)
}

func (ctx *ConfigReaderContext) extractMITMConfig(deref dereference) (retval MITMConfig, err error) {
retval.ServerTLSConfigTemplate = new(tls.Config)
retval.ClientTLSConfigTemplate = new(tls.Config)
err = deref.multi(
"tls", func(deref dereference) error {
Expand All @@ -758,6 +766,29 @@ func (ctx *ConfigReaderContext) extractMITMConfig(deref dereference) (retval MIT
retval.SigningCertificateKeyPair.PrivateKey = tlsCert.PrivateKey
return nil
},
"prepared", func(_ int, deref dereference) error {
visited := false
return deref.iterateHomogeneousValuedMap(yamlMapType, func(hostPattern string, deref dereference) error {
if visited {
return errors.Errorf("extra item exists")
}
visited = true
hostPatternRegexp, err := regexp.Compile(hostPattern)
if err != nil {
return errors.Errorf("invalid regexp %s", hostPattern)
}
tlsCert, cert, err := ctx.extractCertPrivateKeyPairs(deref)
if err != nil {
return err
}
retval.Prepared = append(retval.Prepared, PreparedCertificate{
Pattern: hostPatternRegexp,
TLSCertificate: &tlsCert,
Certificate: cert,
})
return nil
})
},
"cache_directory", func(cacheDirectory string) error {
retval.CacheDirectory = cacheDirectory
return nil
Expand Down Expand Up @@ -862,6 +893,7 @@ func loadConfig(yamlFile string, progname string) (*Config, error) {
return nil, errors.Wrapf(err, "failed to load %s", yamlFile)
}
ctx := &ConfigReaderContext{
filename: yamlFile,
warn: func(msg string) {
fmt.Fprintf(os.Stderr, "%s: %s\n", progname, msg)
},
Expand Down
2 changes: 1 addition & 1 deletion deref.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ func (deref dereference) homogeneousMapValue(typ reflect.Type) (interface{}, err
}
mapVal, ok := deref.value.(map[interface{}]interface{})
if !ok {
return nil, deref.errorf("mapping of %s expected, got %T", deref.value)
return nil, deref.errorf("mapping of %s expected, got %T", typ.String(), deref.value)
}
hMapVal := reflect.MakeMap(reflect.MapOf(emptyInterfaceType, typ))
for k, v := range mapVal {
Expand Down
8 changes: 8 additions & 0 deletions example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ tls:
ca:
cert: testca.rsa.crt.pem
key: testca.rsa.key.pem
# MITM with prepared certificates
prepared:
- ^local\\.my-domain\\.example\\.com$:
cert: certs/my-domain-cert.pem
key: certs/my-domain-key.pem
- .*:
cert: real-certs/fallback-cert.pem
key: real-certs/fallback-key.pem

# response filters
response_filters:
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0
github.com/sirupsen/logrus v1.3.0
github.com/stretchr/testify v1.2.2
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3
golang.org/x/text v0.3.0 // indirect
gopkg.in/yaml.v2 v2.3.0
Expand Down
2 changes: 1 addition & 1 deletion httpx/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1720,7 +1720,7 @@ type connectMethodKey struct {

func (k connectMethodKey) String() string {
// Only used by tests.
return fmt.Sprintf("%s|%s|%s|%p", k.proxy, k.scheme, k.addr, k.tlsConfigAddr)
return fmt.Sprintf("%s|%s|%s|%d", k.proxy, k.scheme, k.addr, k.tlsConfigAddr)
}

// persistConn wraps a connection, usually a persistent one
Expand Down
Loading

0 comments on commit f1de758

Please sign in to comment.