@@ -4,12 +4,11 @@ package srtgo
4
4
#cgo LDFLAGS: -lsrt
5
5
#include <srt/srt.h>
6
6
#include <srt/access_control.h>
7
+ #include "callback.h"
7
8
static const SRTSOCKET get_srt_invalid_sock() { return SRT_INVALID_SOCK; };
8
9
static const int get_srt_error() { return SRT_ERROR; };
9
10
static const int get_srt_error_reject_predefined() { return SRT_REJC_PREDEFINED; };
10
11
static const int get_srt_error_reject_userdefined() { return SRT_REJC_USERDEFINED; };
11
-
12
- extern int srtListenCB(void* opaque, SRTSOCKET ns, int hs_version, const struct sockaddr* peeraddr, const char* streamid);
13
12
*/
14
13
import "C"
15
14
@@ -55,8 +54,9 @@ type SrtSocket struct {
55
54
}
56
55
57
56
var (
58
- listenCallbackMutex sync.Mutex
59
- listenCallbackMap map [C.int ]unsafe.Pointer = make (map [C.int ]unsafe.Pointer )
57
+ callbackMutex sync.Mutex
58
+ listenCallbackMap map [C.int ]unsafe.Pointer = make (map [C.int ]unsafe.Pointer )
59
+ connectCallbackMap map [C.int ]unsafe.Pointer = make (map [C.int ]unsafe.Pointer )
60
60
)
61
61
62
62
// Static consts from library
@@ -243,6 +243,25 @@ func (s SrtSocket) Accept() (*SrtSocket, *net.UDPAddr, error) {
243
243
return newSocket , udpAddr , nil
244
244
}
245
245
246
+ func errcodeToError (errorcode C.int ) error {
247
+ switch errorcode {
248
+ case C .SRT_EINVSOCK :
249
+ return & SrtInvalidSock {}
250
+ case C .SRT_ERDVUNBOUND :
251
+ return & SrtRendezvousUnbound {}
252
+ case C .SRT_ECONNSOCK :
253
+ return & SrtSockConnected {}
254
+ case C .SRT_ECONNREJ :
255
+ return & SrtConnectionRejected {}
256
+ case C .SRT_ENOSERVER :
257
+ return & SrtConnectTimeout {}
258
+ case C .SRT_ESCLOSED :
259
+ return & SrtSocketClosed {}
260
+ default :
261
+ return fmt .Errorf ("unknown error" )
262
+ }
263
+ }
264
+
246
265
// Connect to a remote endpoint
247
266
func (s SrtSocket ) Connect () error {
248
267
sa , salen , err := CreateAddrInet (s .host , s .port )
@@ -256,21 +275,7 @@ func (s SrtSocket) Connect() error {
256
275
C .srt_close (s .socket )
257
276
srt_errno := C .srt_getlasterror (nil )
258
277
runtime .UnlockOSThread ()
259
- switch srt_errno {
260
- case C .SRT_EINVSOCK :
261
- return & SrtInvalidSock {}
262
- case C .SRT_ERDVUNBOUND :
263
- return & SrtRendezvousUnbound {}
264
- case C .SRT_ECONNSOCK :
265
- return & SrtSockConnected {}
266
- case C .SRT_ECONNREJ :
267
- return & SrtConnectionRejected {}
268
- case C .SRT_ENOSERVER :
269
- return & SrtConnectTimeout {}
270
- case C .SRT_ESCLOSED :
271
- return & SrtSocketClosed {}
272
- }
273
- return fmt .Errorf ("Error in srt_connect" )
278
+ return errcodeToError (srt_errno )
274
279
}
275
280
runtime .UnlockOSThread ()
276
281
@@ -401,18 +406,21 @@ func (s *SrtSocket) Close() {
401
406
}
402
407
}
403
408
C .srt_close (s .socket )
404
- listenCallbackMutex .Lock ()
409
+ callbackMutex .Lock ()
405
410
if ptr , exists := listenCallbackMap [s .socket ]; exists {
406
411
gopointer .Unref (ptr )
407
412
}
408
- listenCallbackMutex .Unlock ()
413
+ if ptr , exists := connectCallbackMap [s .socket ]; exists {
414
+ gopointer .Unref (ptr )
415
+ }
416
+ callbackMutex .Unlock ()
409
417
}
410
418
411
419
// ListenCallbackFunc specifies a function to be called before a connecting socket is passed to accept
412
420
type ListenCallbackFunc func (socket * SrtSocket , version int , addr * net.UDPAddr , streamid string ) bool
413
421
414
422
//export srtListenCBWrapper
415
- func srtListenCBWrapper (arg unsafe.Pointer , socket C.int , hsVersion C.int , peeraddr * C.struct_sockaddr , streamid * C.char ) C.int {
423
+ func srtListenCBWrapper (arg unsafe.Pointer , socket C.SRTSOCKET , hsVersion C.int , peeraddr * C.struct_sockaddr , streamid * C.char ) C.int {
416
424
userCB := gopointer .Restore (arg ).(ListenCallbackFunc )
417
425
418
426
s := new (SrtSocket )
@@ -433,14 +441,42 @@ func (s SrtSocket) SetListenCallback(cb ListenCallbackFunc) {
433
441
ptr := gopointer .Save (cb )
434
442
C .srt_listen_callback (s .socket , (* C .srt_listen_callback_fn )(C .srtListenCB ), ptr )
435
443
436
- listenCallbackMutex .Lock ()
437
- defer listenCallbackMutex .Unlock ()
444
+ callbackMutex .Lock ()
445
+ defer callbackMutex .Unlock ()
438
446
if listenCallbackMap [s .socket ] != nil {
439
447
gopointer .Unref (listenCallbackMap [s .socket ])
440
448
}
441
449
listenCallbackMap [s .socket ] = ptr
442
450
}
443
451
452
+ // ConnectCallbackFunc specifies a function to be called after a socket or connection in a group has failed.
453
+ type ConnectCallbackFunc func (socket * SrtSocket , err error , addr * net.UDPAddr , token int )
454
+
455
+ //export srtConnectCBWrapper
456
+ func srtConnectCBWrapper (arg unsafe.Pointer , socket C.SRTSOCKET , errcode C.int , peeraddr * C.struct_sockaddr , token C.int ) {
457
+ userCB := gopointer .Restore (arg ).(ConnectCallbackFunc )
458
+
459
+ s := new (SrtSocket )
460
+ s .socket = socket
461
+ udpAddr , _ := udpAddrFromSockaddr ((* syscall .RawSockaddrAny )(unsafe .Pointer (peeraddr )))
462
+
463
+ userCB (s , errcodeToError (errcode ), udpAddr , int (token ))
464
+ }
465
+
466
+ // SetConnectCallback - set a function to be called after a socket or connection in a group has failed
467
+ // Note that the function is not guaranteed to be called if the socket is set to blocking mode.
468
+ func (s SrtSocket ) SetConnectCallback (cb ConnectCallbackFunc ) {
469
+ ptr := gopointer .Save (cb )
470
+ C .srt_connect_callback (s .socket , (* C .srt_connect_callback_fn )(C .srtConnectCB ), ptr )
471
+
472
+ callbackMutex .Lock ()
473
+ defer callbackMutex .Unlock ()
474
+ if connectCallbackMap [s .socket ] != nil {
475
+ gopointer .Unref (connectCallbackMap [s .socket ])
476
+ }
477
+ connectCallbackMap [s .socket ] = ptr
478
+ }
479
+
444
480
// Rejection reasons
445
481
var (
446
482
// Start of range for predefined rejection reasons
0 commit comments