Skip to content

Commit 6089b5e

Browse files
committed
Add publisher method to Client
1 parent d04f4e7 commit 6089b5e

File tree

2 files changed

+162
-102
lines changed

2 files changed

+162
-102
lines changed

src/client/client.rs

+158-101
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,41 @@ pub struct Client {
102102
///
103103
/// This field uses a Mutex for interior mutability so that
104104
/// `Client` is `Send`. It's not expected to be `Sync`.
105-
free_write_pids: Mutex<FreePidList>,
105+
free_write_pids: Arc<Mutex<FreePidList>>,
106+
}
107+
108+
/// A clonable structure which can be used to publish messages from multiple
109+
/// threads concurrently.
110+
///
111+
/// Publisher instances are invalidated when the client disconnects (after retrying).
112+
/// After calling `connect` again, publishers will have to be recreated by calling `Client::publisher`.
113+
///
114+
/// ```
115+
/// # use mqtt_async_client::client::Client;
116+
///
117+
/// let client = ..;
118+
///
119+
/// let publisher1 = client.publisher();
120+
/// let publisher2 = client.publisher(); // Or use publisher1.clone()
121+
///
122+
/// tokio::spawn(async {
123+
/// publisher1.publish(..)
124+
/// });
125+
///
126+
/// tokio::spawn(async {
127+
/// publisher2.publish(..)
128+
/// });
129+
/// ```
130+
#[derive(Clone)]
131+
pub struct ClientPublisher {
132+
/// Sender to send IO requests to the IO task.
133+
tx_io_requests: mpsc::Sender<IoRequest>,
134+
135+
/// Time to wait for a response before timing out
136+
operation_timeout: Duration,
137+
138+
/// Tracks which Pids (MQTT packet IDs) are in use.
139+
free_write_pids: Arc<Mutex<FreePidList>>,
106140
}
107141

108142
impl fmt::Debug for Client {
@@ -151,7 +185,7 @@ impl fmt::Debug for ClientOptions {
151185
/// The client side of the communication channels to an IO task.
152186
struct IoTaskHandle {
153187
/// Sender to send IO requests to the IO task.
154-
tx_io_requests: mpsc::Sender<IoRequest>,
188+
publisher: ClientPublisher,
155189

156190
/// Receiver to receive Publish packets from the IO task.
157191
rx_recv_published: mpsc::Receiver<Result<Packet>>,
@@ -257,7 +291,7 @@ impl Client {
257291
Ok(Client {
258292
options: opts,
259293
io_task_handle: None,
260-
free_write_pids: Mutex::new(FreePidList::new()),
294+
free_write_pids: Arc::new(Mutex::new(FreePidList::new())),
261295
})
262296
}
263297

@@ -276,7 +310,11 @@ impl Client {
276310
mpsc::channel::<Result<Packet>>(self.options.packet_buffer_len);
277311
let halt = Arc::new(AtomicBool::new(false));
278312
self.io_task_handle = Some(IoTaskHandle {
279-
tx_io_requests,
313+
publisher: ClientPublisher {
314+
tx_io_requests,
315+
operation_timeout: self.options.operation_timeout,
316+
free_write_pids: self.free_write_pids.clone()
317+
},
280318
rx_recv_published,
281319
halt: halt.clone(),
282320
});
@@ -298,56 +336,22 @@ impl Client {
298336
/// create several publish futures to publish several payloads of
299337
/// data simultaneously without waiting for responses.
300338
pub async fn publish(&self, p: &Publish) -> Result<()> {
301-
let qos = p.qos();
302-
if qos == QoS::ExactlyOnce {
303-
return Err("QoS::ExactlyOnce is not supported".into());
304-
}
305-
let p2 = Packet::Publish(mqttrs::Publish {
306-
dup: false, // TODO.
307-
qospid: match qos {
308-
QoS::AtMostOnce => QosPid::AtMostOnce,
309-
QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?),
310-
QoS::ExactlyOnce => panic!("Not reached"),
311-
},
312-
retain: p.retain(),
313-
topic_name: p.topic().to_owned(),
314-
payload: p.payload().to_owned(),
315-
});
316-
match qos {
317-
QoS::AtMostOnce => {
318-
let res = timeout(self.options.operation_timeout,
319-
self.write_only_packet(&p2)).await;
320-
if let Err(Elapsed { .. }) = res {
321-
return Err(format!("Timeout writing publish after {}ms",
322-
self.options.operation_timeout.as_millis()).into());
323-
}
324-
res.expect("No timeout")?;
325-
}
326-
QoS::AtLeastOnce => {
327-
let res = timeout(self.options.operation_timeout,
328-
self.write_response_packet(&p2)).await;
329-
if let Err(Elapsed { .. }) = res {
330-
// We report this but can't really deal with it properly.
331-
// The protocol says we can't re-use the packet ID so we have to leak it
332-
// and potentially run out of packet IDs.
333-
return Err(format!("Timeout waiting for Puback after {}ms",
334-
self.options.operation_timeout.as_millis()).into());
335-
}
336-
let res = res.expect("No timeout")?;
337-
match res {
338-
Packet::Puback(pid) => self.free_write_pid(pid)?,
339-
_ => error!("Bad packet response for publish: {:#?}", res),
340-
}
341-
},
342-
QoS::ExactlyOnce => panic!("Not reached"),
343-
};
344-
Ok(())
339+
let c = self.check_io_task()?;
340+
c.publisher.publish(p).await
341+
}
342+
343+
/// Creates a `ClientPublisher` which can be cloned and used to publish messages
344+
/// from multiple threads concurrently.
345+
pub fn publisher(&self) -> Result<ClientPublisher> {
346+
let c = self.check_io_task()?;
347+
Ok(c.publisher.clone())
345348
}
346349

347350
/// Subscribe to some topics.`read_subscriptions` will return
348351
/// data for them.
349352
pub async fn subscribe(&mut self, s: Subscribe) -> Result<SubscribeResult> {
350-
let pid = self.alloc_write_pid()?;
353+
let c = self.check_io_task()?;
354+
let pid = c.publisher.alloc_write_pid()?;
351355
// TODO: Support subscribe to qos == ExactlyOnce.
352356
if s.topics().iter().any(|t| t.qos == QoS::ExactlyOnce) {
353357
return Err("Qos::ExactlyOnce is not supported right now".into())
@@ -356,7 +360,7 @@ impl Client {
356360
pid,
357361
topics: s.topics().to_owned(),
358362
});
359-
let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
363+
let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await;
360364
if let Err(Elapsed { .. }) = res {
361365
// We report this but can't really deal with it properly.
362366
// The protocol says we can't re-use the packet ID so we have to leak it
@@ -370,7 +374,7 @@ impl Client {
370374
pid: suback_pid,
371375
return_codes: rcs,
372376
}) if suback_pid == pid => {
373-
self.free_write_pid(pid)?;
377+
c.publisher.free_write_pid(pid)?;
374378
Ok(SubscribeResult {
375379
return_codes: rcs
376380
})
@@ -386,13 +390,14 @@ impl Client {
386390
/// Unsubscribe from some topics. `read_subscriptions` will no
387391
/// longer return data for them.
388392
pub async fn unsubscribe(&mut self, u: Unsubscribe) -> Result<()> {
389-
let pid = self.alloc_write_pid()?;
393+
let c = self.check_io_task()?;
394+
let pid = c.publisher.alloc_write_pid()?;
390395
let p = Packet::Unsubscribe(mqttrs::Unsubscribe {
391396
pid,
392397
topics: u.topics().iter().map(|ut| ut.topic_name().to_owned())
393398
.collect::<Vec<String>>(),
394399
});
395-
let res = timeout(self.options.operation_timeout, self.write_response_packet(&p)).await;
400+
let res = timeout(self.options.operation_timeout, c.publisher.write_response_packet(&p)).await;
396401
if let Err(Elapsed { .. }) = res {
397402
// We report this but can't really deal with it properly.
398403
// The protocol says we can't re-use the packet ID so we have to leak it
@@ -404,7 +409,7 @@ impl Client {
404409
match res {
405410
Packet::Unsuback(ack_pid)
406411
if ack_pid == pid => {
407-
self.free_write_pid(pid)?;
412+
c.publisher.free_write_pid(pid)?;
408413
Ok(())
409414
},
410415
_ => {
@@ -431,7 +436,7 @@ impl Client {
431436
match p.qospid {
432437
QosPid::AtMostOnce => (),
433438
QosPid::AtLeastOnce(pid) => {
434-
self.write_only_packet(&Packet::Puback(pid)).await?;
439+
h.publisher.write_only_packet(&Packet::Puback(pid)).await?;
435440
},
436441
QosPid::ExactlyOnce(_) => {
437442
error!("Received publish with unimplemented QoS: ExactlyOnce");
@@ -451,11 +456,11 @@ impl Client {
451456

452457
/// Gracefully close the connection to the server.
453458
pub async fn disconnect(&mut self) -> Result<()> {
454-
self.check_io_task()?;
459+
let c = self.check_io_task()?;
455460
debug!("Disconnecting");
456461
let p = Packet::Disconnect;
457462
let res = timeout(self.options.operation_timeout,
458-
self.write_only_packet(&p)).await;
463+
c.publisher.write_only_packet(&p)).await;
459464
if let Err(Elapsed { .. }) = res {
460465
return Err(format!("Timeout waiting for Disconnect to send after {}ms",
461466
self.options.operation_timeout.as_millis()).into());
@@ -465,57 +470,14 @@ impl Client {
465470
Ok(())
466471
}
467472

468-
fn alloc_write_pid(&self) -> Result<Pid> {
469-
match self.free_write_pids.lock().expect("not poisoned").alloc() {
470-
Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")),
471-
None => Err(Error::from("No free Pids")),
472-
}
473-
}
474-
475-
fn free_write_pid(&self, p: Pid) -> Result<()> {
476-
match self.free_write_pids.lock().expect("not poisoned").free(p.get()) {
477-
true => Err(Error::from("Pid was already free")),
478-
false => Ok(())
479-
}
480-
}
481-
482473
async fn shutdown(&mut self) -> Result <()> {
483474
let c = self.check_io_task()?;
484475
c.halt.store(true, Ordering::SeqCst);
485-
self.write_request(IoType::ShutdownConnection, None).await?;
476+
c.publisher.write_request(IoType::ShutdownConnection, None).await?;
486477
self.io_task_handle = None;
487478
Ok(())
488479
}
489480

490-
async fn write_only_packet(&self, p: &Packet) -> Result<()> {
491-
self.write_request(IoType::WriteOnly { packet: p.clone(), }, None)
492-
.await.map(|_v| ())
493-
494-
}
495-
496-
async fn write_response_packet(&self, p: &Packet) -> Result<Packet> {
497-
let io_type = IoType::WriteAndResponse {
498-
packet: p.clone(),
499-
response_pid: packet_pid(p).expect("packet_pid"),
500-
};
501-
let (tx, rx) = oneshot::channel::<IoResult>();
502-
self.write_request(io_type, Some(tx))
503-
.await?;
504-
// TODO: Add a timeout?
505-
let res = rx.await.map_err(Error::from_std_err)?;
506-
res.result.map(|v| v.expect("return packet"))
507-
}
508-
509-
async fn write_request(&self, io_type: IoType, tx_result: Option<oneshot::Sender<IoResult>>) -> Result<()> {
510-
// NB: Some duplication in IoTask::replay_subscriptions.
511-
512-
let c = self.check_io_task()?;
513-
let req = IoRequest { tx_result, io_type };
514-
c.tx_io_requests.clone().send(req).await
515-
.map_err(|e| Error::from_std_err(e))?;
516-
Ok(())
517-
}
518-
519481
fn check_io_task_mut(&mut self) -> Result<&mut IoTaskHandle> {
520482
match self.io_task_handle {
521483
Some(ref mut h) => Ok(h),
@@ -630,6 +592,101 @@ async fn connect_stream(opts: &ClientOptions) -> Result<AsyncStream> {
630592
}
631593
}
632594

595+
impl ClientPublisher {
596+
/// Publish some data on a topic.
597+
///
598+
/// Note that this method takes `&self`. This means a caller can
599+
/// create several publish futures to publish several payloads of
600+
/// data simultaneously without waiting for responses.
601+
pub async fn publish(&self, p: &Publish) -> Result<()> {
602+
let qos = p.qos();
603+
if qos == QoS::ExactlyOnce {
604+
return Err("QoS::ExactlyOnce is not supported".into());
605+
}
606+
let p2 = Packet::Publish(mqttrs::Publish {
607+
dup: false, // TODO.
608+
qospid: match qos {
609+
QoS::AtMostOnce => QosPid::AtMostOnce,
610+
QoS::AtLeastOnce => QosPid::AtLeastOnce(self.alloc_write_pid()?),
611+
QoS::ExactlyOnce => panic!("Not reached"),
612+
},
613+
retain: p.retain(),
614+
topic_name: p.topic().to_owned(),
615+
payload: p.payload().to_owned(),
616+
});
617+
match qos {
618+
QoS::AtMostOnce => {
619+
let res = timeout(self.operation_timeout,
620+
self.write_only_packet(&p2)).await;
621+
if let Err(Elapsed { .. }) = res {
622+
return Err(format!("Timeout writing publish after {}ms",
623+
self.operation_timeout.as_millis()).into());
624+
}
625+
res.expect("No timeout")?;
626+
}
627+
QoS::AtLeastOnce => {
628+
let res = timeout(self.operation_timeout,
629+
self.write_response_packet(&p2)).await;
630+
if let Err(Elapsed { .. }) = res {
631+
// We report this but can't really deal with it properly.
632+
// The protocol says we can't re-use the packet ID so we have to leak it
633+
// and potentially run out of packet IDs.
634+
return Err(format!("Timeout waiting for Puback after {}ms",
635+
self.operation_timeout.as_millis()).into());
636+
}
637+
let res = res.expect("No timeout")?;
638+
match res {
639+
Packet::Puback(pid) => self.free_write_pid(pid)?,
640+
_ => error!("Bad packet response for publish: {:#?}", res),
641+
}
642+
},
643+
QoS::ExactlyOnce => panic!("Not reached"),
644+
};
645+
Ok(())
646+
}
647+
648+
async fn write_only_packet(&self, p: &Packet) -> Result<()> {
649+
self.write_request(IoType::WriteOnly { packet: p.clone(), }, None)
650+
.await.map(|_v| ())
651+
652+
}
653+
654+
async fn write_response_packet(&self, p: &Packet) -> Result<Packet> {
655+
let io_type = IoType::WriteAndResponse {
656+
packet: p.clone(),
657+
response_pid: packet_pid(p).expect("packet_pid"),
658+
};
659+
let (tx, rx) = oneshot::channel::<IoResult>();
660+
self.write_request(io_type, Some(tx))
661+
.await?;
662+
// TODO: Add a timeout?
663+
let res = rx.await.map_err(Error::from_std_err)?;
664+
res.result.map(|v| v.expect("return packet"))
665+
}
666+
667+
async fn write_request(&self, io_type: IoType, tx_result: Option<oneshot::Sender<IoResult>>) -> Result<()> {
668+
// NB: Some duplication in IoTask::replay_subscriptions.
669+
let req = IoRequest { tx_result, io_type };
670+
self.tx_io_requests.clone().send(req).await
671+
.map_err(|e| Error::from_std_err(e))?;
672+
Ok(())
673+
}
674+
675+
fn alloc_write_pid(&self) -> Result<Pid> {
676+
match self.free_write_pids.lock().expect("not poisoned").alloc() {
677+
Some(pid) => Ok(Pid::try_from(pid).expect("Non-zero Pid")),
678+
None => Err(Error::from("No free Pids")),
679+
}
680+
}
681+
682+
fn free_write_pid(&self, p: Pid) -> Result<()> {
683+
match self.free_write_pids.lock().expect("not poisoned").free(p.get()) {
684+
true => Err(Error::from("Pid was already free")),
685+
false => Ok(())
686+
}
687+
}
688+
}
689+
633690
/// Build a connect packet from ClientOptions.
634691
fn connect_packet(opts: &ClientOptions) -> Result<Packet> {
635692
Ok(Packet::Connect(mqttrs::Connect {

src/client/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ mod builder;
44
pub use builder::ClientBuilder;
55

66
mod client;
7-
pub use client::Client;
7+
pub use client::{
8+
Client,
9+
ClientPublisher
10+
};
811
pub(crate) use client::ClientOptions;
912

1013
mod value_types;

0 commit comments

Comments
 (0)