From f4ac34957c075ec2017568e0072a1cd00bd7e9b0 Mon Sep 17 00:00:00 2001 From: Victor Gaydov Date: Thu, 20 Jun 2024 15:14:14 +0400 Subject: [PATCH] Fix string aliases in deriveCompare (#86) --- plugin/compare/compare.go | 12 ++++++-- test/normal/compare_test.go | 28 +++++++++++++++++ test/normal/derived.gen.go | 60 +++++++++++++++++++++++++++++-------- test/normal/types.go | 4 +++ 4 files changed, 90 insertions(+), 14 deletions(-) diff --git a/plugin/compare/compare.go b/plugin/compare/compare.go index cb721bf..3e05ce6 100644 --- a/plugin/compare/compare.go +++ b/plugin/compare/compare.go @@ -251,7 +251,11 @@ func (g *gen) genStatement(typ types.Type, this, that string) error { case *types.Basic: switch ttyp.Kind() { case types.String: - p.P("return %s.Compare(%s, %s)", g.stringsPkg(), this, that) + if types.Identical(typ, ttyp) { + p.P("return %s.Compare(%s, %s)", g.stringsPkg(), this, that) + } else { + p.P("return %s.Compare(string(%s), string(%s))", g.stringsPkg(), this, that) + } case types.Complex128, types.Complex64: p.P("if thisr, thatr := real(%s), real(%s); thisr == thatr {", this, that) p.In() @@ -473,7 +477,11 @@ func (g *gen) field(thisField, thatField string, fieldType types.Type) (string, switch typ := fieldType.Underlying().(type) { case *types.Basic: if typ.Kind() == types.String { - return fmt.Sprintf("%s.Compare(%s, %s)", g.stringsPkg(), thisField, thatField), nil + if types.Identical(fieldType, typ) { + return fmt.Sprintf("%s.Compare(%s, %s)", g.stringsPkg(), thisField, thatField), nil + } else { + return fmt.Sprintf("%s.Compare(string(%s), string(%s))", g.stringsPkg(), thisField, thatField), nil + } } return fmt.Sprintf("%s(%s, %s)", g.GetFuncName(fieldType, fieldType), thisField, thatField), nil case *types.Pointer: diff --git a/test/normal/compare_test.go b/test/normal/compare_test.go index 28ba4d0..7c73889 100644 --- a/test/normal/compare_test.go +++ b/test/normal/compare_test.go @@ -128,3 +128,31 @@ func TestCompareCurry(t *testing.T) { t.Fatalf("compare: got %d want %d", c, 1) } } + +func TestCompareStringAlias(t *testing.T) { + this := stringAlias("aaa") + that := stringAlias("bbb") + if c := deriveCompareStringAlias(this, this); c != 0 { + t.Fatalf("compare: got %d want %d", c, 0) + } + if c := deriveCompareStringAlias(this, that); c != -1 { + t.Fatalf("compare: got %d want %d", c, 0) + } + if c := deriveCompareStringAlias(that, this); c != 1 { + t.Fatalf("compare: got %d want %d", c, 0) + } +} + +func TestCompareStringAliasField(t *testing.T) { + this := StructWithStringAlias{stringAlias("aaa")} + that := StructWithStringAlias{stringAlias("bbb")} + if c := deriveCompareStructWithStringAlias(this, this); c != 0 { + t.Fatalf("compare: got %d want %d", c, 0) + } + if c := deriveCompareStructWithStringAlias(this, that); c != -1 { + t.Fatalf("compare: got %d want %d", c, 0) + } + if c := deriveCompareStructWithStringAlias(that, this); c != 1 { + t.Fatalf("compare: got %d want %d", c, 0) + } +} diff --git a/test/normal/derived.gen.go b/test/normal/derived.gen.go index a10e2a4..24d945d 100644 --- a/test/normal/derived.gen.go +++ b/test/normal/derived.gen.go @@ -3884,6 +3884,22 @@ func deriveCompareCurryComplex64(this complex128) func(complex128) int { } } +// deriveCompareStringAlias returns: +// * 0 if this and that are equal, +// * -1 is this is smaller and +// * +1 is this is bigger. +func deriveCompareStringAlias(this, that stringAlias) int { + return strings.Compare(string(this), string(that)) +} + +// deriveCompareStructWithStringAlias returns: +// * 0 if this and that are equal, +// * -1 is this is smaller and +// * +1 is this is bigger. +func deriveCompareStructWithStringAlias(this, that StructWithStringAlias) int { + return deriveCompare_138(&this, &that) +} + // deriveCompareDeriveTheDerived returns: // * 0 if this and that are equal, // * -1 is this is smaller and @@ -10516,7 +10532,7 @@ func deriveCompare_103(this, that *[4]int) int { if that == nil { return 1 } - return deriveCompare_138(*this, *that) + return deriveCompare_139(*this, *that) } // deriveCompare_104 returns: @@ -10533,7 +10549,7 @@ func deriveCompare_104(this, that *map[int]int) int { if that == nil { return 1 } - return deriveCompare_139(*this, *that) + return deriveCompare_140(*this, *that) } // deriveCompare_105 returns: @@ -11555,7 +11571,7 @@ func deriveCompare_136(this, that map[string][]*pickle.Rick) int { if thiskey == thatkey { thisvalue := this[thiskey] thatvalue := that[thatkey] - if c := deriveCompare_140(thisvalue, thatvalue); c != 0 { + if c := deriveCompare_141(thisvalue, thatvalue); c != 0 { return c } } else { @@ -11587,6 +11603,26 @@ func deriveCompare_137(this, that *privateStruct) int { return 0 } +// deriveCompare_138 returns: +// * 0 if this and that are equal, +// * -1 is this is smaller and +// * +1 is this is bigger. +func deriveCompare_138(this, that *StructWithStringAlias) int { + if this == nil { + if that == nil { + return 0 + } + return -1 + } + if that == nil { + return 1 + } + if c := strings.Compare(string(this.Field), string(that.Field)); c != 0 { + return c + } + return 0 +} + // deriveTuple returns a function, which returns the input values. // Since tuples are not first class citizens in Go, this is a way to fake it, because functions that return tuples are first class citizens. func deriveTuple(v0 int, v1 error) func() (int, error) { @@ -15141,11 +15177,11 @@ func deriveCompare_s(this, that string) int { return strings.Compare(this, that) } -// deriveCompare_138 returns: +// deriveCompare_139 returns: // * 0 if this and that are equal, // * -1 is this is smaller and // * +1 is this is bigger. -func deriveCompare_138(this, that [4]int) int { +func deriveCompare_139(this, that [4]int) int { if len(this) != len(that) { if len(this) < len(that) { return -1 @@ -15160,11 +15196,11 @@ func deriveCompare_138(this, that [4]int) int { return 0 } -// deriveCompare_139 returns: +// deriveCompare_140 returns: // * 0 if this and that are equal, // * -1 is this is smaller and // * +1 is this is bigger. -func deriveCompare_139(this, that map[int]int) int { +func deriveCompare_140(this, that map[int]int) int { if this == nil { if that == nil { return 0 @@ -15199,11 +15235,11 @@ func deriveCompare_139(this, that map[int]int) int { return 0 } -// deriveCompare_140 returns: +// deriveCompare_141 returns: // * 0 if this and that are equal, // * -1 is this is smaller and // * +1 is this is bigger. -func deriveCompare_140(this, that []*pickle.Rick) int { +func deriveCompare_141(this, that []*pickle.Rick) int { if this == nil { if that == nil { return 0 @@ -15220,7 +15256,7 @@ func deriveCompare_140(this, that []*pickle.Rick) int { return 1 } for i := 0; i < len(this); i++ { - if c := deriveCompare_141(this[i], that[i]); c != 0 { + if c := deriveCompare_142(this[i], that[i]); c != 0 { return c } } @@ -15317,11 +15353,11 @@ func deriveGoString_89(this *pickle.Rick) string { return buf.String() } -// deriveCompare_141 returns: +// deriveCompare_142 returns: // * 0 if this and that are equal, // * -1 is this is smaller and // * +1 is this is bigger. -func deriveCompare_141(this, that *pickle.Rick) int { +func deriveCompare_142(this, that *pickle.Rick) int { if this == nil { if that == nil { return 0 diff --git a/test/normal/types.go b/test/normal/types.go index d44dbea..c045f51 100644 --- a/test/normal/types.go +++ b/test/normal/types.go @@ -20,6 +20,10 @@ func true5(lt *LocalType) (*LocalType, bool) { return lt, true } +type StructWithStringAlias struct { + Field stringAlias +} + type DeriveTheDerived struct { Field int }