Skip to content

Commit

Permalink
Fix string aliases in deriveCompare (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
gavv authored Jun 20, 2024
1 parent e3f2fdf commit f4ac349
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 14 deletions.
12 changes: 10 additions & 2 deletions plugin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ func (g *gen) genStatement(typ types.Type, this, that string) error {
case *types.Basic:
switch ttyp.Kind() {
case types.String:
p.P("return %s.Compare(%s, %s)", g.stringsPkg(), this, that)
if types.Identical(typ, ttyp) {
p.P("return %s.Compare(%s, %s)", g.stringsPkg(), this, that)
} else {
p.P("return %s.Compare(string(%s), string(%s))", g.stringsPkg(), this, that)
}
case types.Complex128, types.Complex64:
p.P("if thisr, thatr := real(%s), real(%s); thisr == thatr {", this, that)
p.In()
Expand Down Expand Up @@ -473,7 +477,11 @@ func (g *gen) field(thisField, thatField string, fieldType types.Type) (string,
switch typ := fieldType.Underlying().(type) {
case *types.Basic:
if typ.Kind() == types.String {
return fmt.Sprintf("%s.Compare(%s, %s)", g.stringsPkg(), thisField, thatField), nil
if types.Identical(fieldType, typ) {
return fmt.Sprintf("%s.Compare(%s, %s)", g.stringsPkg(), thisField, thatField), nil
} else {
return fmt.Sprintf("%s.Compare(string(%s), string(%s))", g.stringsPkg(), thisField, thatField), nil
}
}
return fmt.Sprintf("%s(%s, %s)", g.GetFuncName(fieldType, fieldType), thisField, thatField), nil
case *types.Pointer:
Expand Down
28 changes: 28 additions & 0 deletions test/normal/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,31 @@ func TestCompareCurry(t *testing.T) {
t.Fatalf("compare: got %d want %d", c, 1)
}
}

func TestCompareStringAlias(t *testing.T) {
this := stringAlias("aaa")
that := stringAlias("bbb")
if c := deriveCompareStringAlias(this, this); c != 0 {
t.Fatalf("compare: got %d want %d", c, 0)
}
if c := deriveCompareStringAlias(this, that); c != -1 {
t.Fatalf("compare: got %d want %d", c, 0)
}
if c := deriveCompareStringAlias(that, this); c != 1 {
t.Fatalf("compare: got %d want %d", c, 0)
}
}

func TestCompareStringAliasField(t *testing.T) {
this := StructWithStringAlias{stringAlias("aaa")}
that := StructWithStringAlias{stringAlias("bbb")}
if c := deriveCompareStructWithStringAlias(this, this); c != 0 {
t.Fatalf("compare: got %d want %d", c, 0)
}
if c := deriveCompareStructWithStringAlias(this, that); c != -1 {
t.Fatalf("compare: got %d want %d", c, 0)
}
if c := deriveCompareStructWithStringAlias(that, this); c != 1 {
t.Fatalf("compare: got %d want %d", c, 0)
}
}
60 changes: 48 additions & 12 deletions test/normal/derived.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions test/normal/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func true5(lt *LocalType) (*LocalType, bool) {
return lt, true
}

type StructWithStringAlias struct {
Field stringAlias
}

type DeriveTheDerived struct {
Field int
}
Expand Down

0 comments on commit f4ac349

Please sign in to comment.