@@ -11,7 +11,7 @@ use std::{
11
11
convert:: TryFrom , env, future:: Future , net:: SocketAddr , path:: PathBuf , sync:: Arc ,
12
12
time:: Duration ,
13
13
} ;
14
- use temporal_client:: WorkflowOptions ;
14
+ use temporal_client:: { RetryGateway , ServerGateway , ServerGatewayApis , WorkflowOptions } ;
15
15
use temporal_sdk:: TestRustWorker ;
16
16
use temporal_sdk_core:: {
17
17
init_replay_worker, init_worker, replay:: mock_gateway_from_history, telemetry_init,
@@ -29,6 +29,7 @@ use temporal_sdk_core_protos::{
29
29
} ,
30
30
temporal:: api:: history:: v1:: History ,
31
31
} ;
32
+ use tokio:: sync:: OnceCell ;
32
33
use url:: Url ;
33
34
34
35
pub const NAMESPACE : & str = "default" ;
@@ -77,7 +78,11 @@ pub struct CoreWfStarter {
77
78
telemetry_options : TelemetryOptions ,
78
79
worker_config : WorkerConfig ,
79
80
wft_timeout : Option < Duration > ,
80
- initted_worker : Option < Arc < dyn Worker > > ,
81
+ initted_worker : OnceCell < InitializedWorker > ,
82
+ }
83
+ struct InitializedWorker {
84
+ worker : Arc < dyn Worker > ,
85
+ client : Arc < RetryGateway < ServerGateway > > ,
81
86
}
82
87
83
88
impl CoreWfStarter {
@@ -99,7 +104,7 @@ impl CoreWfStarter {
99
104
. build ( )
100
105
. unwrap ( ) ,
101
106
wft_timeout : None ,
102
- initted_worker : None ,
107
+ initted_worker : OnceCell :: new ( ) ,
103
108
}
104
109
}
105
110
@@ -116,16 +121,11 @@ impl CoreWfStarter {
116
121
}
117
122
118
123
pub async fn get_worker ( & mut self ) -> Arc < dyn Worker > {
119
- if self . initted_worker . is_none ( ) {
120
- telemetry_init ( & self . telemetry_options ) . expect ( "Telemetry inits cleanly" ) ;
121
- let gateway = get_integ_server_options ( )
122
- . connect ( self . worker_config . namespace . clone ( ) , None )
123
- . await
124
- . expect ( "Must connect" ) ;
125
- let worker = init_worker ( self . worker_config . clone ( ) , gateway) ;
126
- self . initted_worker = Some ( Arc :: new ( worker) ) ;
127
- }
128
- self . initted_worker . as_ref ( ) . unwrap ( ) . clone ( )
124
+ self . get_or_init ( ) . await . worker . clone ( )
125
+ }
126
+
127
+ pub async fn get_client ( & mut self ) -> Arc < RetryGateway < ServerGateway > > {
128
+ self . get_or_init ( ) . await . client . clone ( )
129
129
}
130
130
131
131
/// Start the workflow defined by the builder and return run id
@@ -137,13 +137,12 @@ impl CoreWfStarter {
137
137
pub async fn start_wf_with_id ( & self , workflow_id : String , mut opts : WorkflowOptions ) -> String {
138
138
opts. task_timeout = opts. task_timeout . or ( self . wft_timeout ) ;
139
139
self . initted_worker
140
- . as_ref ( )
140
+ . get ( )
141
141
. expect (
142
- "Core must be initted before starting a workflow.\
143
- Tests must call `get_core ` first.",
142
+ "Worker must be initted before starting a workflow.\
143
+ Tests must call `get_worker ` first.",
144
144
)
145
- . as_ref ( )
146
- . server_gateway ( )
145
+ . client
147
146
. start_workflow (
148
147
vec ! [ ] ,
149
148
self . worker_config . task_queue . clone ( ) ,
@@ -218,6 +217,25 @@ impl CoreWfStarter {
218
217
self . wft_timeout = Some ( timeout) ;
219
218
self
220
219
}
220
+
221
+ async fn get_or_init ( & mut self ) -> & InitializedWorker {
222
+ self . initted_worker
223
+ . get_or_init ( || async {
224
+ telemetry_init ( & self . telemetry_options ) . expect ( "Telemetry inits cleanly" ) ;
225
+ let gateway = Arc :: new (
226
+ get_integ_server_options ( )
227
+ . connect ( self . worker_config . namespace . clone ( ) , None )
228
+ . await
229
+ . expect ( "Must connect" ) ,
230
+ ) ;
231
+ let worker = init_worker ( self . worker_config . clone ( ) , gateway. clone ( ) ) ;
232
+ InitializedWorker {
233
+ worker : Arc :: new ( worker) ,
234
+ client : gateway,
235
+ }
236
+ } )
237
+ . await
238
+ }
221
239
}
222
240
223
241
pub fn get_integ_server_options ( ) -> ServerGatewayOptions {
0 commit comments