diff --git a/ast/ast.go b/ast/ast.go index 2d57f87..48d4229 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -40,7 +40,7 @@ type AST interface { UpdateLabel(n *Node, newLabel NodeLabelType) error // MakeHashMemo creates a new hash memo for the entire AST. - MakeHashMemo() *NodeHashMemo + MakeHashMemo() NodeHashMemo } type astConcrete struct { @@ -74,8 +74,8 @@ func (a *astConcrete) Add(parent *Node, i int, label NodeLabelType, value NodeVa a.root = newNode } - a.nodes[newNode.id] = newNode - return a.nodes[newNode.id], nil + a.nodes[newNode.Id] = newNode + return a.nodes[newNode.Id], nil } func (a *astConcrete) Move(n, newParent *Node, i int) error { @@ -90,7 +90,7 @@ func (a *astConcrete) Delete(n *Node) error { } n.DestroySubtree() - delete(a.nodes, n.id) + delete(a.nodes, n.Id) return nil } @@ -120,14 +120,14 @@ func (a *astConcrete) UpdateLabel(n *Node, newLabel NodeLabelType) error { return nil } -func (a *astConcrete) MakeHashMemo() *NodeHashMemo { +func (a *astConcrete) MakeHashMemo() NodeHashMemo { if a.root == nil { return nil } memo := make(NodeHashMemo) _ = a.root.HashValue(&memo) - return &memo + return memo } // NewAST creates a new AST. diff --git a/ast/hash_memo.go b/ast/hash_memo.go deleted file mode 100644 index 4ecd95c..0000000 --- a/ast/hash_memo.go +++ /dev/null @@ -1,7 +0,0 @@ -package ast - -type NodeHashMemo map[NodeIdType]uint64 - -func (m *NodeHashMemo) IsIsomorphicBetween(n1, n2 NodeIdType) bool { - return (*m)[n1] == (*m)[n2] -} diff --git a/ast/node.go b/ast/node.go index 0638bf1..a54bc0a 100644 --- a/ast/node.go +++ b/ast/node.go @@ -17,10 +17,12 @@ type Node struct { Value NodeValueType Parent *Node Children map[int]*Node - id NodeIdType + Id NodeIdType idxToParent int } +type NodeHashMemo map[NodeIdType]uint64 + type NodeParentInfo struct { // The Parent of a Node. Parent *Node @@ -46,7 +48,7 @@ func NewNode(parentInfo NodeParentInfo, label NodeLabelType, value NodeValueType Value: value, Parent: nil, Children: make(map[int]*Node), - id: newIdStr, + Id: newIdStr, idxToParent: -1, } @@ -121,6 +123,14 @@ func (n *Node) ValueOfOrder() int { return n.Height() } +func (n *Node) OrderedChildren() []*Node { + children := maps.Values(n.Children) + sort.Slice(children, func(i, j int) bool { + return children[i].idxToParent < children[j].idxToParent + }) + return children +} + // Isomorphic returns true if the Node is isomorphic to the other Node. func (n *Node) Isomorphic(other *Node) bool { if n == nil || other == nil { @@ -138,15 +148,12 @@ func (n *Node) HashValue(memo *NodeHashMemo) uint64 { if memo != nil { lock.Lock() defer lock.Unlock() - (*memo)[n.id] = propertyHash + (*memo)[n.Id] = propertyHash } return propertyHash } - children := maps.Values(n.Children) - sort.Slice(children, func(i, j int) bool { - return children[i].idxToParent < children[j].idxToParent - }) + children := n.OrderedChildren() var combinedChildrenHash xxhash.Digest _, _ = combinedChildrenHash.WriteString(strconv.FormatUint(propertyHash, 10)) @@ -160,7 +167,7 @@ func (n *Node) HashValue(memo *NodeHashMemo) uint64 { if memo != nil { lock.Lock() defer lock.Unlock() - (*memo)[n.id] = result + (*memo)[n.Id] = result } return result }