Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering immutable struct types that will not be deep-copied #7

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions deepcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ type copier func(interface{}, map[uintptr]interface{}) (interface{}, error)

var copiers map[Kind]copier

var immutableTypes map[Type]struct{}

func init() {
copiers = map[Kind]copier{
Bool: _primitive,
Expand All @@ -34,6 +36,8 @@ func init() {
String: _primitive,
Struct: _struct,
}

immutableTypes = map[Type]struct{}{}
}

// MustAnything does a deep copy and panics on any errors.
Expand All @@ -45,6 +49,16 @@ func MustAnything(x interface{}) interface{} {
return dc
}

// RegisterImmutableType registers a type as immutable. This means that when a deep copy is made,
// if the type of the value being copied is the same as the type passed in, the value will not be
// copied. Instead, the original value will be used. This is useful for types that are immutable.
//
// It is intended to be called at init time.

func RegisterImmutableType(t Type) {
immutableTypes[t] = struct{}{}
}

// Primitive makes a copy of a primitive type...which just means it returns the input value.
// This is wholly uninteresting, but I included it for consistency's sake.
func _primitive(x interface{}, ptrs map[uintptr]interface{}) (interface{}, error) {
Expand Down Expand Up @@ -132,7 +146,7 @@ func _pointer(x interface{}, ptrs map[uintptr]interface{}) (interface{}, error)

if v.IsNil() {
t := TypeOf(x)
return Zero(t).Interface(),nil
return Zero(t).Interface(), nil
}

addr := v.Pointer()
Expand All @@ -142,7 +156,7 @@ func _pointer(x interface{}, ptrs map[uintptr]interface{}) (interface{}, error)
t := TypeOf(x)
dc := New(t.Elem())
ptrs[addr] = dc.Interface()

item, err := _anything(v.Elem().Interface(), ptrs)
if err != nil {
return nil, fmt.Errorf("failed to copy the value under the pointer %v: %v", v, err)
Expand All @@ -151,7 +165,7 @@ func _pointer(x interface{}, ptrs map[uintptr]interface{}) (interface{}, error)
if iv.IsValid() {
dc.Elem().Set(ValueOf(item))
}

return dc.Interface(), nil
}

Expand All @@ -161,6 +175,10 @@ func _struct(x interface{}, ptrs map[uintptr]interface{}) (interface{}, error) {
return nil, fmt.Errorf("must pass a value with kind of Struct; got %v", v.Kind())
}
t := TypeOf(x)
if _, ok := immutableTypes[t]; ok {
// This is an immutable type, so we can just return it.
return x, nil
}
dc := New(t)
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
Expand Down
37 changes: 32 additions & 5 deletions deepcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
. "reflect"
"testing"
"time"
)

func ExampleAnything() {
Expand Down Expand Up @@ -46,8 +47,8 @@ type Foo struct {

func ExampleMap() {
x := map[string]*Foo{
"foo": &Foo{Bar: 1},
"bar": &Foo{Bar: 2},
"foo": {Foo: &Foo{Bar: 1}, Bar: 1},
"bar": {Foo: &Foo{Bar: 2}, Bar: 2},
}
y := MustAnything(x).(map[string]*Foo)
for _, k := range []string{"foo", "bar"} { // to ensure consistent order
Expand Down Expand Up @@ -161,8 +162,8 @@ func TestTwoNils(t *testing.T) {
B int
}
type FooBar struct {
Foo *Foo
Bar *Bar
Foo *Foo
Bar *Bar
Foo2 *Foo
Bar2 *Bar
}
Expand All @@ -178,4 +179,30 @@ func TestTwoNils(t *testing.T) {
t.Errorf("expect %v == %v; ", src, dst)
}

}
}

func TestImmutableTypes(t *testing.T) {
type Foo struct {
Time time.Time
TimePtr *time.Time
}

now := time.Now()

src := &Foo{
Time: time.Now(),
TimePtr: &now,
}

RegisterImmutableType(TypeOf(time.Time{}))

dst := MustAnything(src)

if src.TimePtr == dst.(*Foo).TimePtr {
t.Error("expect pointers to different time structs")
}

if !DeepEqual(src, dst) {
t.Errorf("expect %v == %v; ", src, dst)
}
}