Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements #123

Merged
merged 5 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions cached_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,48 @@ import (
)

type cachedReader struct {
buffer *bufio.Reader
cache []byte
cacheCap int
cacheLen int
buffer *bufio.Reader
cache []byte
caching bool
}

func newCachedReader(r *bufio.Reader) *cachedReader {
return &cachedReader{
buffer: r,
cache: make([]byte, 4096),
cacheCap: 4096,
cacheLen: 0,
caching: false,
buffer: r,
cache: make([]byte, 0, 4096),
caching: false,
}
}

func (c *cachedReader) StartCaching() {
c.cacheLen = 0
c.cache = c.cache[:0]
c.caching = true
}

func (c *cachedReader) ReadByte() (byte, error) {
if !c.caching {
return c.buffer.ReadByte()
}
b, err := c.buffer.ReadByte()
func (c *cachedReader) ReadByte() (b byte, err error) {
b, err = c.buffer.ReadByte()
if err != nil {
return b, err
return
}
if c.cacheLen < c.cacheCap {
c.cache[c.cacheLen] = b
c.cacheLen++
if c.caching {
c.cacheByte(b)
}
return b, err
return
}

func (c *cachedReader) Cache() []byte {
return c.cache[:c.cacheLen]
return c.cache
}

func (c *cachedReader) CacheWithLimit(n int) []byte {
if n < 1 {
return nil
}
l := len(c.cache)
if n > l {
n = l
}
return c.cache[:n]
}

func (c *cachedReader) StopCaching() {
Expand All @@ -55,15 +58,22 @@ func (c *cachedReader) Read(p []byte) (int, error) {
if err != nil {
return n, err
}
if c.caching && c.cacheLen < c.cacheCap {
if c.caching {
for i := 0; i < n; i++ {
c.cache[c.cacheLen] = p[i]
c.cacheLen++
if c.cacheLen >= c.cacheCap {
if !c.cacheByte(p[i]) {
break
}
}
}
return n, err
}

func (c *cachedReader) cacheByte(b byte) bool {
n := len(c.cache)
if n == cap(c.cache) {
return false
}
c.cache = c.cache[:n+1]
c.cache[n] = b
return true
}
15 changes: 15 additions & 0 deletions cached_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,19 @@ func TestCaching(t *testing.T) {
if !bytes.Equal(cached, []byte("BCDEF")) {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(-1)
if cached != nil {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(3)
if !bytes.Equal(cached, []byte("BCD")) {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(1000)
if !bytes.Equal(cached, []byte("BCDEF")) {
t.Fatalf("Incorrect cached buffer value")
}
}
42 changes: 36 additions & 6 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ func WithPreserveSpace() OutputOption {
}
}

// WithoutPreserveSpace will not preserve spaces in output
func WithoutPreserveSpace() OutputOption {
return func(oc *outputConfiguration) {
oc.preserveSpaces = false
}
}

// WithIndentation sets the indentation string used for formatting the output.
func WithIndentation(indentation string) OutputOption {
return func(oc *outputConfiguration) {
Expand Down Expand Up @@ -328,7 +335,9 @@ func (n *Node) Write(writer io.Writer, self bool) error {

// WriteWithOptions writes xml with given options to given writer.
func (n *Node) WriteWithOptions(writer io.Writer, opts ...OutputOption) (err error) {
config := &outputConfiguration{}
config := &outputConfiguration{
preserveSpaces: true,
}
// Set the options
for _, opt := range opts {
opt(config)
Expand Down Expand Up @@ -400,11 +409,7 @@ func AddChild(parent, n *Node) {
parent.LastChild = n
}

// AddSibling adds a new node 'n' as a sibling of a given node 'sibling'.
// Note it is not necessarily true that the new node 'n' would be added
// immediately after 'sibling'. If 'sibling' isn't the last child of its
// parent, then the new node 'n' will be added at the end of the sibling
// chain of their parent.
// AddSibling adds a new node 'n' as a last node of sibling chain for a given node 'sibling'.
func AddSibling(sibling, n *Node) {
for t := sibling.NextSibling; t != nil; t = t.NextSibling {
sibling = t
Expand All @@ -418,6 +423,19 @@ func AddSibling(sibling, n *Node) {
}
}

// AddImmediateSibling adds a new node 'n' as immediate sibling a given node 'sibling'.
func AddImmediateSibling(sibling, n *Node) {
n.Parent = sibling.Parent
n.NextSibling = sibling.NextSibling
sibling.NextSibling = n
n.PrevSibling = sibling
if n.NextSibling != nil {
n.NextSibling.PrevSibling = n
} else if n.Parent != nil {
sibling.Parent.LastChild = n
}
}

// RemoveFromTree removes a node and its subtree from the document
// tree it is in. If the node is the root of the tree, then it's no-op.
func RemoveFromTree(n *Node) {
Expand Down Expand Up @@ -445,3 +463,15 @@ func RemoveFromTree(n *Node) {
n.PrevSibling = nil
n.NextSibling = nil
}

// GetRoot returns a root of the tree where 'n' is a node.
func GetRoot(n *Node) *Node {
if n == nil {
return nil
}
root := n
for root.Parent != nil {
root = root.Parent
}
return root
}
48 changes: 36 additions & 12 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -260,7 +260,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -270,7 +270,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ggg></ggg></aaa>`)
})

Expand All @@ -280,7 +280,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd></aaa>`)
})

Expand All @@ -290,7 +290,7 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, procInst.Type, DeclarationNode)
RemoveFromTree(procInst)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -300,19 +300,44 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, commentNode.Type, CommentNode)
RemoveFromTree(commentNode)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

t.Run("remove call on root does nothing", func(t *testing.T) {
doc := parseXML()
RemoveFromTree(doc)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})
}

func TestAddImmediateSibling(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?>
<AAA>
<BBB id="1"/>
<CCC id="2">
<DDD/>
</CCC>
<CCC id="3">
<DDD/>
</CCC>
</AAA>`
root, err := Parse(strings.NewReader(s))
if err != nil {
t.Error(err)
}

aaa := findNode(root, "AAA")
n := aaa.SelectElement("BBB")
if n == nil {
t.Fatalf("n is nil")
}
AddImmediateSibling(n, &Node{Type: ElementNode, Data: "r"})
testValue(t, root.OutputXMLWithOptions(WithoutPreserveSpace()), `<?xml version="1.0" encoding="UTF-8"?><AAA><BBB id="1"></BBB><r></r><CCC id="2"><DDD></DDD></CCC><CCC id="3"><DDD></DDD></CCC></AAA>`)
}

func TestSelectElement(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?>
<AAA>
Expand Down Expand Up @@ -497,7 +522,6 @@ func TestWriteWithNamespacePrefix(t *testing.T) {
}
}


func TestQueryWithPrefix(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?><S:Envelope xmlns:S="http://schemas.xmlsoap.org/soap/envelope/"><S:Body test="1"><ns2:Fault xmlns:ns2="http://schemas.xmlsoap.org/soap/envelope/" xmlns:ns3="http://www.w3.org/2003/05/soap-envelope"><faultcode>ns2:Client</faultcode><faultstring>This is a client fault</faultstring></ns2:Fault></S:Body></S:Envelope>`
doc, _ := Parse(strings.NewReader(s))
Expand Down Expand Up @@ -582,7 +606,7 @@ func TestOutputXMLWithSpaceDirect(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}

output := html.UnescapeString(doc.OutputXML(true))
output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
Expand All @@ -606,7 +630,7 @@ func TestOutputXMLWithSpaceOverwrittenToPreserve(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}

output := html.UnescapeString(doc.OutputXML(true))
output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
Expand Down Expand Up @@ -680,8 +704,8 @@ func TestOutputXMLWithPreserveSpaceOption(t *testing.T) {
</student>
</class_list>`
doc, _ := Parse(strings.NewReader(s))
resultWithSpace := doc.OutputXMLWithOptions(WithPreserveSpace())
resultWithoutSpace := doc.OutputXMLWithOptions()
resultWithSpace := doc.OutputXMLWithOptions()
resultWithoutSpace := doc.OutputXMLWithOptions(WithoutPreserveSpace())
if !strings.Contains(resultWithSpace, "> Robert <") {
t.Errorf("output was not expected. expected %v but got %v", " Robert ", resultWithSpace)
}
Expand Down
36 changes: 26 additions & 10 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package xmlquery

import (
"bufio"
"bytes"
"encoding/xml"
"fmt"
"io"
Expand Down Expand Up @@ -39,15 +40,31 @@ func Parse(r io.Reader) (*Node, error) {
func ParseWithOptions(r io.Reader, options ParserOptions) (*Node, error) {
p := createParser(r)
options.apply(p)
for {
_, err := p.parse()
if err == io.EOF {
return p.doc, nil
var err error
for err == nil {
_, err = p.parse()
}

if err == io.EOF {
// additional check for validity
// according to: https://www.w3.org/TR/xml
// the document MUST contain at least ONE element
valid := false
for doc := p.doc; doc != nil; doc = doc.NextSibling {
for node := doc.FirstChild; node != nil; node = node.NextSibling {
if node.Type == ElementNode {
valid = true
break
}
}
}
if err != nil {
return nil, err
if !valid {
return nil, fmt.Errorf("xmlquery: invalid XML document")
}
return p.doc, nil
}

return nil, err
}

type parser struct {
Expand Down Expand Up @@ -168,7 +185,7 @@ func (p *parser) parse() (*Node, error) {

if node.NamespaceURI != "" {
if v, ok := p.space2prefix[node.NamespaceURI]; ok {
cached := string(p.reader.Cache())
cached := string(p.reader.CacheWithLimit(len(v.name) + len(node.Data) + 2))
if strings.HasPrefix(cached, fmt.Sprintf("%s:%s", v.name, node.Data)) || strings.HasPrefix(cached, fmt.Sprintf("<%s:%s", v.name, node.Data)) {
node.Prefix = v.name
}
Expand Down Expand Up @@ -228,12 +245,11 @@ func (p *parser) parse() (*Node, error) {
}
case xml.CharData:
// First, normalize the cache...
cached := strings.ToUpper(string(p.reader.Cache()))
cached := bytes.ToUpper(p.reader.CacheWithLimit(9))
nodeType := TextNode
if strings.HasPrefix(cached, "<![CDATA[") || strings.HasPrefix(cached, "![CDATA[") {
if bytes.HasPrefix(cached, []byte("<![CDATA[")) || bytes.HasPrefix(cached, []byte("![CDATA[")) {
nodeType = CharDataNode
}

node := &Node{Type: nodeType, Data: string(tok), level: p.level}
if p.level == p.prev.level {
AddSibling(p.prev, node)
Expand Down
Loading
Loading