Skip to content

Commit

Permalink
Refactor some code using the new stdlib and generics, and add several…
Browse files Browse the repository at this point in the history
… tests (#404)

* refactor: use stdlib atomic type and correct chan usages
* refactor: call event listeners in order they added
* refactor(test): add more tests and use generics in helpers
* test: should handle nested frames
  • Loading branch information
canstand authored Jan 9, 2024
1 parent 968ab19 commit 8b2bd7b
Show file tree
Hide file tree
Showing 18 changed files with 233 additions and 144 deletions.
1 change: 0 additions & 1 deletion channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,5 @@ func newChannel(owner *channelOwner, object interface{}) *channel {
owner: owner,
object: object,
}
channel.initEventEmitter()
return channel
}
1 change: 0 additions & 1 deletion channel_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner
}
c.channel = newChannel(c, self)
c.eventToSubscriptionMapping = map[string]string{}
c.initEventEmitter()
}

type rootChannelOwner struct {
Expand Down
29 changes: 17 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ type connection struct {
transport transport
apiZone sync.Map
objects map[string]*channelOwner
lastID int
lastIDLock sync.Mutex
lastID atomic.Uint32
rootObject *rootChannelOwner
callbacks sync.Map
afterClose func()
Expand Down Expand Up @@ -97,7 +96,7 @@ func (c *connection) Dispatch(msg *message) {
}
method := msg.Method
if msg.ID != 0 {
cb, _ := c.callbacks.LoadAndDelete(msg.ID)
cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID))
if cb.(*protocolCallback).noReply {
return
}
Expand Down Expand Up @@ -226,10 +225,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa
return nil, errors.New("The object has been collected to prevent unbounded heap growth.")
}

c.lastIDLock.Lock()
c.lastID++
id := c.lastID
c.lastIDLock.Unlock()
id := c.lastID.Add(1)
cb, _ := c.callbacks.LoadOrStore(id, newProtocolCallback(noReply, c.abort))
var (
metadata = make(map[string]interface{}, 0)
Expand Down Expand Up @@ -356,7 +352,7 @@ func fromNullableChannel(v interface{}) interface{} {
}

type protocolCallback struct {
Callback chan result
callback chan result
noReply bool
abort <-chan struct{}
}
Expand All @@ -367,8 +363,12 @@ func (pc *protocolCallback) SetResult(r result) {
}
select {
case <-pc.abort:
select {
case pc.callback <- r:
default:
}
return
case pc.Callback <- r:
case pc.callback <- r:
}
}

Expand All @@ -377,10 +377,15 @@ func (pc *protocolCallback) GetResult() (interface{}, error) {
return nil, nil
}
select {
case result := <-pc.Callback:
case result := <-pc.callback:
return result.Data, result.Error
case <-pc.abort:
return nil, errors.New("Connection closed")
select {
case result := <-pc.callback:
return result.Data, result.Error
default:
return nil, errors.New("Connection closed")
}
}
}

Expand All @@ -392,7 +397,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback
}
}
return &protocolCallback{
Callback: make(chan result),
callback: make(chan result, 1),
abort: abort,
}
}
137 changes: 78 additions & 59 deletions event_emitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"math"
"reflect"
"sync"

"golang.org/x/exp/slices"
)

type EventEmitter interface {
Expand All @@ -15,44 +17,33 @@ type EventEmitter interface {
}

type (
eventRegister struct {
once []interface{}
on []interface{}
}
eventEmitter struct {
eventsMutex sync.Mutex
events map[string]*eventRegister
hasInit bool
}
eventRegister struct {
listeners []listener
}
listener struct {
handler interface{}
once bool
}
)

func (e *eventEmitter) Emit(name string, payload ...interface{}) (handled bool) {
func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
if _, ok := e.events[name]; !ok {
return
}

if len(e.events[name].once) > 0 || len(e.events[name].on) > 0 {
handled = true
}
e.init()

payloadV := make([]reflect.Value, 0)

for _, p := range payload {
payloadV = append(payloadV, reflect.ValueOf(p))
}

callHandlers := func(handlers []interface{}) {
for _, handler := range handlers {
handlerV := reflect.ValueOf(handler)
handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
}
evt, ok := e.events[name]
if !ok {
return
}

callHandlers(e.events[name].on)
callHandlers(e.events[name].once)
hasListener = evt.count() > 0

e.events[name].once = make([]interface{}, 0)
evt.callHandlers(payload...)
return
}

Expand All @@ -67,60 +58,88 @@ func (e *eventEmitter) On(name string, handler interface{}) {
func (e *eventEmitter) RemoveListener(name string, handler interface{}) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

if _, ok := e.events[name]; !ok {
return
}
handlerPtr := reflect.ValueOf(handler).Pointer()
e.events[name].removeHandler(handler)
}

onHandlers := []interface{}{}
for idx := range e.events[name].on {
eventPtr := reflect.ValueOf(e.events[name].on[idx]).Pointer()
if eventPtr != handlerPtr {
onHandlers = append(onHandlers, e.events[name].on[idx])
}
}
e.events[name].on = onHandlers
// ListenerCount count the listeners by name, count all if name is empty
func (e *eventEmitter) ListenerCount(name string) int {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

onceHandlers := []interface{}{}
for idx := range e.events[name].once {
eventPtr := reflect.ValueOf(e.events[name].once[idx]).Pointer()
if eventPtr != handlerPtr {
onceHandlers = append(onceHandlers, e.events[name].once[idx])
if name != "" {
evt, ok := e.events[name]
if !ok {
return 0
}
return evt.count()
}

e.events[name].once = onceHandlers
}

// ListenerCount count the listeners by name, count all if name is empty
func (e *eventEmitter) ListenerCount(name string) int {
count := 0
e.eventsMutex.Lock()
for key := range e.events {
if name == "" || name == key {
count += len(e.events[key].on) + len(e.events[key].once)
}
count += e.events[key].count()
}
e.eventsMutex.Unlock()

return count
}

func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
e.eventsMutex.Lock()
e.init()

if _, ok := e.events[name]; !ok {
e.events[name] = &eventRegister{
on: make([]interface{}, 0),
once: make([]interface{}, 0),
listeners: make([]listener, 0),
}
}
if once {
e.events[name].once = append(e.events[name].once, handler)
} else {
e.events[name].on = append(e.events[name].on, handler)
}
e.events[name].addHandler(handler, once)
e.eventsMutex.Unlock()
}

func (e *eventEmitter) initEventEmitter() {
e.events = make(map[string]*eventRegister)
func (e *eventEmitter) init() {
if !e.hasInit {
e.events = make(map[string]*eventRegister, 0)
e.hasInit = true
}
}

func (e *eventRegister) addHandler(handler interface{}, once bool) {
e.listeners = append(e.listeners, listener{handler: handler, once: once})
}

func (e *eventRegister) count() int {
return len(e.listeners)
}

func (e *eventRegister) removeHandler(handler interface{}) {
handlerPtr := reflect.ValueOf(handler).Pointer()

e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool {
return reflect.ValueOf(l.handler).Pointer() == handlerPtr
})
}

func (e *eventRegister) callHandlers(payloads ...interface{}) {
payloadV := make([]reflect.Value, 0)

for _, p := range payloads {
payloadV = append(payloadV, reflect.ValueOf(p))
}

handle := func(l listener) {
handlerV := reflect.ValueOf(l.handler)
handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
}

for _, l := range e.listeners {
if l.once {
defer e.removeHandler(l.handler)
}
handle(l)
}
}
7 changes: 0 additions & 7 deletions event_emitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ const (

func TestEventEmitterListenerCount(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
myHandler := func(payload ...interface{}) {
wasCalled <- payload[0]
Expand All @@ -32,7 +31,6 @@ func TestEventEmitterListenerCount(t *testing.T) {

func TestEventEmitterOn(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
handler.On(testEventName, func(payload ...interface{}) {
Expand All @@ -48,7 +46,6 @@ func TestEventEmitterOn(t *testing.T) {

func TestEventEmitterOnce(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
handler.Once(testEventName, func(payload ...interface{}) {
Expand All @@ -64,7 +61,6 @@ func TestEventEmitterOnce(t *testing.T) {

func TestEventEmitterRemove(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
myHandler := func(payload ...interface{}) {
Expand All @@ -84,14 +80,12 @@ func TestEventEmitterRemove(t *testing.T) {

func TestEventEmitterRemoveEmpty(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
handler.RemoveListener(testEventName, func(...interface{}) {})
require.Equal(t, 0, handler.ListenerCount(testEventName))
}

func TestEventEmitterRemoveKeepExisting(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
handler.On(testEventName, func(...interface{}) {})
handler.Once(testEventName, func(...interface{}) {})
handler.RemoveListener("abc123", func(...interface{}) {})
Expand All @@ -101,7 +95,6 @@ func TestEventEmitterRemoveKeepExisting(t *testing.T) {

func TestEventEmitterOnLessArgsAcceptingReceiver(t *testing.T) {
handler := &eventEmitter{}
handler.initEventEmitter()
wasCalled := make(chan bool, 1)
require.Nil(t, handler.events[testEventName])
handler.Once(testEventName, func(ev ...interface{}) {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.0
go.uber.org/multierr v1.11.0
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc h1:ao2WRsKSzW6KuUY9IWPwWahcHCgR0s52IfwutMfEbdM=
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
Expand Down
2 changes: 1 addition & 1 deletion local_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (l *localUtilsImpl) TraceDiscarded(stacksId string) error {
return err
}

func (l *localUtilsImpl) AddStackToTracingNoReply(id int, stack []map[string]interface{}) {
func (l *localUtilsImpl) AddStackToTracingNoReply(id uint32, stack []map[string]interface{}) {
l.channel.SendNoReply("addStackToTracingNoReply", map[string]interface{}{
"callData": map[string]interface{}{
"id": id,
Expand Down
15 changes: 15 additions & 0 deletions tests/browser_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,18 @@ func TestPageErrorEventShouldWork(t *testing.T) {
require.Equal(t, page, weberror.Page())
require.ErrorContains(t, weberror.Error(), "boom")
}

func TestBrowserContextOnResponse(t *testing.T) {
BeforeEach(t)
defer AfterEach(t)
responseChan := make(chan playwright.Response, 1)
context.OnResponse(func(response playwright.Response) {
responseChan <- response
})
_, err := page.Goto(fmt.Sprintf("%s/title.html", server.PREFIX))
require.NoError(t, err)
response := <-responseChan
body, err := response.Body()
require.NoError(t, err)
require.Equal(t, "<title>Woof-Woof</title>\n", string(body))
}
Loading

0 comments on commit 8b2bd7b

Please sign in to comment.