Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add token support for client manager #213

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"crypto/tls"
"sync"
"time"

"github.com/sideshow/apns2/token"
)

type managerItem struct {
Expand All @@ -31,6 +33,10 @@ type ClientManager struct {
// manager.
Factory func(certificate tls.Certificate) *Client

// FactoryToken is the function which constructs clients if not found in the
// manager when token auth is used
FactoryToken func(token *token.Token) *Client

cache map[[sha1.Size]byte]*list.Element
ll *list.List
mu sync.Mutex
Expand All @@ -48,9 +54,10 @@ type ClientManager struct {
// a Client with default options.
func NewClientManager() *ClientManager {
manager := &ClientManager{
MaxSize: 64,
MaxAge: 10 * time.Minute,
Factory: NewClient,
MaxSize: 64,
MaxAge: 10 * time.Minute,
Factory: NewClient,
FactoryToken: NewTokenClient,
}

manager.initInternals()
Expand All @@ -65,7 +72,14 @@ func (m *ClientManager) Add(client *Client) {
m.mu.Lock()
defer m.mu.Unlock()

key := cacheKey(client.Certificate)
var key [sha1.Size]byte

if client.Token != nil {
key = cacheTokenKey(client.Token)
} else {
key = cacheKey(client.Certificate)
}

now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
Expand All @@ -88,16 +102,35 @@ func (m *ClientManager) Add(client *Client) {
// the ClientManager's Factory function, store the result in the manager if
// non-nil, and return it.
func (m *ClientManager) Get(certificate tls.Certificate) *Client {
key := cacheKey(certificate)

return m.get(key, func() *Client {
return m.Factory(certificate)
})
}

// Get gets a Client from the manager. If a Client is not found in the manager
// or if a Client has remained in the manager longer than MaxAge, Get will call
// the ClientManager's Factory function, store the result in the manager if
// non-nil, and return it.
func (m *ClientManager) GetByToken(token *token.Token) *Client {
key := cacheTokenKey(token)

return m.get(key, func() *Client {
return m.FactoryToken(token)
})
}

func (m *ClientManager) get(key [sha1.Size]byte, factory func() *Client) *Client {
m.initInternals()
m.mu.Lock()
defer m.mu.Unlock()

key := cacheKey(certificate)
now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
if m.MaxAge != 0 && item.lastUsed.Before(now.Add(-m.MaxAge)) {
c := m.Factory(certificate)
c := factory()
if c == nil {
return nil
}
Expand All @@ -108,7 +141,7 @@ func (m *ClientManager) Get(certificate tls.Certificate) *Client {
return item.client
}

c := m.Factory(certificate)
c := factory()
if c == nil {
return nil
}
Expand Down Expand Up @@ -160,3 +193,7 @@ func cacheKey(certificate tls.Certificate) [sha1.Size]byte {

return sha1.Sum(data)
}

func cacheTokenKey(token *token.Token) [sha1.Size]byte {
return sha1.Sum([]byte(token.Bearer))
}
65 changes: 65 additions & 0 deletions client_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/sideshow/apns2"
"github.com/sideshow/apns2/certificate"
"github.com/sideshow/apns2/token"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -147,3 +148,67 @@ func TestClientManagerAddTwice(t *testing.T) {
manager.Add(apns2.NewClient(mockCert()))
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerAddTokenClientWithoutNew(t *testing.T) {
fn := func(token *token.Token) *apns2.Client {
t.Fatal("factory should not have been called")
return nil
}

manager := apns2.NewClientManager()
manager.FactoryToken = fn
token := mockToken()
manager.Add(apns2.NewTokenClient(token))
manager.GetByToken(token)
}

func TestClientManagerAddTokenClientWithNew(t *testing.T) {
manager := apns2.NewClientManager()

t1 := mockToken()
_, err := t1.Generate()
assert.NoError(t, err)

t2 := mockToken()
_, err = t2.Generate()
assert.NoError(t, err)

manager.Add(apns2.NewTokenClient(t1))
manager.Add(apns2.NewTokenClient(t2))
assert.Equal(t, 2, manager.Len())
}

func TestClientManagerGetByTokenWithoutNew(t *testing.T) {
manager := apns2.NewClientManager()

token := mockToken()
c1 := manager.GetByToken(token)
c2 := manager.GetByToken(token)
v1 := reflect.ValueOf(c1)
v2 := reflect.ValueOf(c2)
assert.NotNil(t, c1)
assert.Equal(t, v1.Pointer(), v2.Pointer())
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerGetByTokenWithNew(t *testing.T) {
manager := apns2.NewClientManager()

t1 := mockToken()
_, err := t1.Generate()
assert.NoError(t, err)

t2 := mockToken()
_, err = t2.Generate()
assert.NoError(t, err)

c1 := manager.GetByToken(t1)
c2 := manager.GetByToken(t2)

v1 := reflect.ValueOf(c1)
v2 := reflect.ValueOf(c2)
assert.NotNil(t, c1)
assert.NotNil(t, c2)
assert.NotEqual(t, v1.Pointer(), v2.Pointer())
assert.Equal(t, 2, manager.Len())
}