diff --git a/hash.go b/hash.go index d01dd68..e545ae5 100644 --- a/hash.go +++ b/hash.go @@ -227,27 +227,42 @@ func newEvpHash(ch crypto.Hash) *evpHash { if alg == nil { panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch))) } - ctx := C.go_openssl_EVP_MD_CTX_new() - if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 { - C.go_openssl_EVP_MD_CTX_free(ctx) - panic(newOpenSSLError("EVP_DigestInit_ex")) - } - ctx2 := C.go_openssl_EVP_MD_CTX_new() - h := &evpHash{ - alg: alg, - ctx: ctx, - ctx2: ctx2, - } - runtime.SetFinalizer(h, (*evpHash).finalize) + h := &evpHash{alg: alg} + // Don't call init() yet, it would be wasteful + // if the caller only wants to know the hash type. This + // is a common pattern in this package, as some functions + // accept a `func() hash.Hash` parameter and call it just + // to know the hash type. return h } func (h *evpHash) finalize() { - C.go_openssl_EVP_MD_CTX_free(h.ctx) - C.go_openssl_EVP_MD_CTX_free(h.ctx2) + if h.ctx != nil { + C.go_openssl_EVP_MD_CTX_free(h.ctx) + } + if h.ctx2 != nil { + C.go_openssl_EVP_MD_CTX_free(h.ctx2) + } +} + +func (h *evpHash) init() { + if h.ctx != nil { + return + } + h.ctx = C.go_openssl_EVP_MD_CTX_new() + if C.go_openssl_EVP_DigestInit_ex(h.ctx, h.alg.md, nil) != 1 { + C.go_openssl_EVP_MD_CTX_free(h.ctx) + panic(newOpenSSLError("EVP_DigestInit_ex")) + } + h.ctx2 = C.go_openssl_EVP_MD_CTX_new() + runtime.SetFinalizer(h, (*evpHash).finalize) } func (h *evpHash) Reset() { + if h.ctx == nil { + // The hash is not initialized yet, no need to reset. + return + } // There is no need to reset h.ctx2 because it is always reset after // use in evpHash.sum. if C.go_openssl_EVP_DigestInit_ex(h.ctx, nil, nil) != 1 { @@ -257,7 +272,11 @@ func (h *evpHash) Reset() { } func (h *evpHash) Write(p []byte) (int, error) { - if len(p) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 { + if len(p) == 0 { + return 0, nil + } + h.init() + if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 { panic(newOpenSSLError("EVP_DigestUpdate")) } runtime.KeepAlive(h) @@ -265,7 +284,11 @@ func (h *evpHash) Write(p []byte) (int, error) { } func (h *evpHash) WriteString(s string) (int, error) { - if len(s) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 { + if len(s) == 0 { + return 0, nil + } + h.init() + if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 { panic("openssl: EVP_DigestUpdate failed") } runtime.KeepAlive(h) @@ -273,6 +296,7 @@ func (h *evpHash) WriteString(s string) (int, error) { } func (h *evpHash) WriteByte(c byte) error { + h.init() if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&c), 1) == 0 { panic("openssl: EVP_DigestUpdate failed") } @@ -289,11 +313,12 @@ func (h *evpHash) BlockSize() int { } func (h *evpHash) Sum(in []byte) []byte { - defer runtime.KeepAlive(h) + h.init() out := make([]byte, h.Size(), maxHashSize) // explicit cap to allow stack allocation if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 { panic(newOpenSSLError("go_hash_sum")) } + runtime.KeepAlive(h) return append(in, out...) } @@ -301,26 +326,25 @@ func (h *evpHash) Sum(in []byte) []byte { // The duplicate object contains all state and data contained in the // original object at the point of duplication. func (h *evpHash) Clone() (hash.Hash, error) { - ctx := C.go_openssl_EVP_MD_CTX_new() - if ctx == nil { - return nil, newOpenSSLError("EVP_MD_CTX_new") - } - if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 { - C.go_openssl_EVP_MD_CTX_free(ctx) - return nil, newOpenSSLError("EVP_MD_CTX_copy") - } - ctx2 := C.go_openssl_EVP_MD_CTX_new() - if ctx2 == nil { - C.go_openssl_EVP_MD_CTX_free(ctx) - return nil, newOpenSSLError("EVP_MD_CTX_new") - } - cloned := &evpHash{ - alg: h.alg, - ctx: ctx, - ctx2: ctx2, + h2 := &evpHash{alg: h.alg} + if h.ctx != nil { + h2.ctx = C.go_openssl_EVP_MD_CTX_new() + if h2.ctx == nil { + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + if C.go_openssl_EVP_MD_CTX_copy_ex(h2.ctx, h.ctx) != 1 { + C.go_openssl_EVP_MD_CTX_free(h2.ctx) + return nil, newOpenSSLError("EVP_MD_CTX_copy") + } + h2.ctx2 = C.go_openssl_EVP_MD_CTX_new() + if h2.ctx2 == nil { + C.go_openssl_EVP_MD_CTX_free(h2.ctx) + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + runtime.SetFinalizer(h2, (*evpHash).finalize) } - runtime.SetFinalizer(cloned, (*evpHash).finalize) - return cloned, nil + runtime.KeepAlive(h) + return h2, nil } // hashState returns a pointer to the internal hash structure. @@ -360,6 +384,8 @@ func (d *evpHash) MarshalBinary() ([]byte, error) { } func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { + defer runtime.KeepAlive(d) + d.init() if !d.alg.marshallable { return nil, errors.New("openssl: hash state is not marshallable") } @@ -391,6 +417,8 @@ func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { } func (d *evpHash) UnmarshalBinary(b []byte) error { + defer runtime.KeepAlive(d) + d.init() if !d.alg.marshallable { return errors.New("openssl: hash state is not marshallable") } diff --git a/hash_test.go b/hash_test.go index 948be94..a1875ce 100644 --- a/hash_test.go +++ b/hash_test.go @@ -377,6 +377,13 @@ func BenchmarkSHA256(b *testing.B) { } } +func BenchmarkNewSHA256(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + openssl.NewSHA256() + } +} + // stubHash is a hash.Hash implementation that does nothing. type stubHash struct{}