Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hash: implement Clone() and speed up NewShaX #31

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions xcrypto/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ func SHA512(p []byte) (sum [64]byte) {
return
}

// cloneHash is an interface that defines a Clone method.
//
// hash.CloneHash will probably be added in Go 1.25, see https://golang.org/issue/69521,
// but we need it now.
type cloneHash interface {
hash.Hash
// Clone returns a separate Hash instance with the same state as h.
Clone() hash.Hash
}

var _ hash.Hash = (*evpHash)(nil)
var _ cloneHash = (*evpHash)(nil)

// evpHash implements generic hash methods.
type evpHash struct {
ctx unsafe.Pointer
// ctx2 is used in evpHash.sum to avoid changing
Expand All @@ -107,12 +121,7 @@ type evpHash struct {
}

func newEvpHash(init func(ctx unsafe.Pointer) C.int, update func(ctx unsafe.Pointer, data []byte) C.int, final func(ctx unsafe.Pointer, digest []byte) C.int, ctxSize, blockSize, size int) *evpHash {
ctx := C.malloc(C.size_t(ctxSize))
ctx2 := C.malloc(C.size_t(ctxSize))
init(ctx)
h := &evpHash{
ctx: ctx,
ctx2: ctx2,
init: init,
update: update,
final: final,
Expand All @@ -125,8 +134,24 @@ func newEvpHash(init func(ctx unsafe.Pointer) C.int, update func(ctx unsafe.Poin
}

func (h *evpHash) finalize() {
C.free(h.ctx)
C.free(h.ctx2)
if h.ctx != nil {
C.free(h.ctx)
}
if h.ctx2 != nil {
C.free(h.ctx2)
}
}

func (h *evpHash) initialize() {
if h.ctx == nil {
h.ctx = C.malloc(C.size_t(h.ctxSize))
h.ctx2 = C.malloc(C.size_t(h.ctxSize))
if h.init(h.ctx) != 1 {
C.free(h.ctx)
C.free(h.ctx2)
panic("commoncrypto: initialization failed")
}
}
}

func (h *evpHash) Reset() {
Expand All @@ -137,6 +162,7 @@ func (h *evpHash) Reset() {
}

func (h *evpHash) Write(p []byte) (int, error) {
h.initialize()
if len(p) > 0 {
// Use a local variable to prevent the compiler from misinterpreting the pointer
data := p
Expand All @@ -149,6 +175,7 @@ func (h *evpHash) Write(p []byte) (int, error) {
}

func (h *evpHash) WriteString(s string) (int, error) {
h.initialize()
if len(s) > 0 {
data := []byte(s)
if h.update(h.ctx, data) != 1 {
Expand All @@ -160,6 +187,7 @@ func (h *evpHash) WriteString(s string) (int, error) {
}

func (h *evpHash) WriteByte(c byte) error {
h.initialize()
if h.update(h.ctx, []byte{c}) != 1 {
return errors.New("commoncrypto: Update function failed")
}
Expand All @@ -175,12 +203,35 @@ func (h *evpHash) BlockSize() int {
}

func (h *evpHash) Sum(b []byte) []byte {
h.initialize()
digest := make([]byte, h.size)
C.memcpy(h.ctx2, h.ctx, C.size_t(h.ctxSize))
h.final(h.ctx2, digest)
return append(b, digest...)
}

// Clone returns a new evpHash object that is a deep clone of itself.
// The duplicate object contains all state and data contained in the
// original object at the point of duplication.
func (h *evpHash) Clone() hash.Hash {
h.initialize()
cloned := &evpHash{
init: h.init,
update: h.update,
final: h.final,
blockSize: h.blockSize,
size: h.size,
ctxSize: h.ctxSize,
}
cloned.ctx = C.malloc(C.size_t(h.ctxSize))
cloned.ctx2 = C.malloc(C.size_t(h.ctxSize))
C.memcpy(cloned.ctx, h.ctx, C.size_t(h.ctxSize))
C.memcpy(cloned.ctx2, h.ctx2, C.size_t(h.ctxSize))
runtime.SetFinalizer(cloned, (*evpHash).finalize)
runtime.KeepAlive(h)
return cloned
}

type md4Hash struct {
*evpHash
}
Expand Down
15 changes: 8 additions & 7 deletions xcrypto/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,13 @@ func TestHash_Clone(t *testing.T) {
t.Skip("not supported")
}
h := cryptoToHash(ch)()
if _, ok := h.(encoding.BinaryMarshaler); !ok {
t.Skip("not supported")
}
_, err := h.Write(msg)
if err != nil {
t.Fatal(err)
}
// We don't define an interface for the Clone method to avoid other
// packages from depending on it. Use type assertion to call it.
h2, err := h.(interface{ Clone() (hash.Hash, error) }).Clone()
if err != nil {
t.Fatal(err)
gdams marked this conversation as resolved.
Show resolved Hide resolved
}
h2 := h.(interface{ Clone() hash.Hash }).Clone()
h.Write(msg)
h2.Write(msg)
if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) {
Expand Down Expand Up @@ -494,6 +488,13 @@ func BenchmarkSHA256(b *testing.B) {
}
}

func BenchmarkNewSHA256(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
xcrypto.NewSHA256()
}
}

// stubHash is a hash.Hash implementation that does nothing.
type stubHash struct{}

Expand Down
Loading