Skip to content

Commit 31b92e9

Browse files
committed
Fix on_cancel behavior and add test
Signed-off-by: Michael X. Grey <[email protected]>
1 parent ac7467b commit 31b92e9

File tree

7 files changed

+362
-123
lines changed

7 files changed

+362
-123
lines changed

src/builder.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,171 @@ impl<'w, 's, 'a> Builder<'w, 's, 'a> {
296296
}
297297
}
298298
}
299+
300+
#[cfg(test)]
301+
mod tests {
302+
use crate::{*, testing::*};
303+
304+
#[test]
305+
fn test_fork_clone() {
306+
let mut context = TestingContext::minimal_plugins();
307+
308+
let workflow = context.spawn_io_workflow(|scope, builder| {
309+
let fork = scope.input.fork_clone(builder);
310+
let branch_a = fork.clone_output(builder);
311+
let branch_b = fork.clone_output(builder);
312+
builder.connect(branch_a, scope.terminate);
313+
builder.connect(branch_b, scope.terminate);
314+
});
315+
316+
let mut promise = context.command(|commands| {
317+
commands
318+
.request(5.0, workflow)
319+
.take_response()
320+
});
321+
322+
context.run_with_conditions(&mut promise, Duration::from_secs(1));
323+
assert!(promise.take().available().is_some_and(|v| v == 5.0));
324+
assert!(context.no_unhandled_errors());
325+
326+
let workflow = context.spawn_io_workflow(|scope, builder| {
327+
scope.input.chain(builder)
328+
.fork_clone((
329+
|chain: Chain<f64>| chain.connect(scope.terminate),
330+
|chain: Chain<f64>| chain.connect(scope.terminate),
331+
));
332+
});
333+
334+
let mut promise = context.command(|commands| {
335+
commands
336+
.request(3.0, workflow)
337+
.take_response()
338+
});
339+
340+
context.run_with_conditions(&mut promise, Duration::from_secs(1));
341+
assert!(promise.take().available().is_some_and(|v| v == 3.0));
342+
assert!(context.no_unhandled_errors());
343+
344+
let workflow = context.spawn_io_workflow(|scope, builder| {
345+
scope.input.chain(builder)
346+
.fork_clone((
347+
|chain: Chain<f64>| chain
348+
.map_block(|t| WaitRequest { duration: Duration::from_secs_f64(10.0*t), value: 10.0*t })
349+
.map(|r: AsyncMap<WaitRequest<f64>>| {
350+
wait(r.request)
351+
})
352+
.connect(scope.terminate),
353+
|chain: Chain<f64>| chain
354+
.map_block(|t| WaitRequest { duration: Duration::from_secs_f64(t/100.0), value: t/100.0 })
355+
.map(|r: AsyncMap<WaitRequest<f64>>| {
356+
wait(r.request)
357+
})
358+
.connect(scope.terminate),
359+
));
360+
});
361+
362+
let mut promise = context.command(|commands| {
363+
commands
364+
.request(1.0, workflow)
365+
.take_response()
366+
});
367+
368+
context.run_with_conditions(&mut promise, Duration::from_secs_f64(0.5));
369+
assert!(promise.take().available().is_some_and(|v| v == 0.01));
370+
assert!(context.no_unhandled_errors());
371+
}
372+
373+
#[test]
374+
fn test_stream_reachability() {
375+
let mut context = TestingContext::minimal_plugins();
376+
377+
// Test for streams from a blocking node
378+
let workflow = context.spawn_io_workflow(|scope, builder| {
379+
let stream_node = builder.create_map(|_: BlockingMap<(), StreamOf<u32>>| {
380+
// Do nothing. The purpose of this node is to just return without
381+
// sending off any streams.
382+
});
383+
384+
builder.connect(scope.input, stream_node.input);
385+
stream_node.streams.chain(builder)
386+
.inner()
387+
.map_block(|value| 2 * value)
388+
.connect(scope.terminate);
389+
});
390+
391+
let mut promise = context.command(|commands| {
392+
commands.request((), workflow).take_response()
393+
});
394+
395+
context.run_with_conditions(&mut promise, Duration::from_secs(2));
396+
assert!(promise.peek().is_cancelled());
397+
assert!(context.no_unhandled_errors());
398+
399+
// Test for streams from an async node
400+
let workflow = context.spawn_io_workflow(|scope, builder| {
401+
let stream_node = builder.create_map(|_: AsyncMap<(), StreamOf<u32>>| {
402+
async { /* Do nothing */ }
403+
});
404+
405+
builder.connect(scope.input, stream_node.input);
406+
stream_node.streams.chain(builder)
407+
.inner()
408+
.map_block(|value| 2 * value)
409+
.connect(scope.terminate);
410+
});
411+
412+
let mut promise = context.command(|commands| {
413+
commands.request((), workflow).take_response()
414+
});
415+
416+
context.run_with_conditions(&mut promise, Duration::from_secs(2));
417+
assert!(promise.peek().is_cancelled());
418+
assert!(context.no_unhandled_errors());
419+
}
420+
421+
use crossbeam::channel::unbounded;
422+
423+
#[test]
424+
fn test_on_cancel() {
425+
let (sender, receiver) = unbounded();
426+
427+
let mut context = TestingContext::minimal_plugins();
428+
let workflow = context.spawn_io_workflow(|scope, builder| {
429+
430+
let input = scope.input.fork_clone(builder);
431+
432+
let buffer = builder.create_buffer(BufferSettings::default());
433+
let input_to_buffer = input.clone_output(builder);
434+
builder.connect(input_to_buffer, buffer.input_slot());
435+
436+
let none_node = builder.create_map_block(produce_none);
437+
let input_to_node = input.clone_output(builder);
438+
builder.connect(input_to_node, none_node.input);
439+
none_node.output.chain(builder)
440+
.cancel_on_none()
441+
.connect(scope.terminate);
442+
443+
// The chain coming out of the none_node will result in the scope
444+
// being cancelled. After that, this scope should run, and the value
445+
// that went into the buffer should get sent over the channel.
446+
builder.on_cancel(buffer, |scope, builder| {
447+
scope.input.chain(builder)
448+
.map_block(move |value| {
449+
sender.send(value).ok();
450+
})
451+
.connect(scope.terminate);
452+
});
453+
});
454+
455+
let mut promise = context.command(|commands| {
456+
commands.request(5, workflow).take_response()
457+
});
458+
459+
context.run_with_conditions(&mut promise, Duration::from_secs(2));
460+
assert!(promise.peek().is_cancelled());
461+
let channel_output = receiver.try_recv().unwrap();
462+
assert_eq!(channel_output, 5);
463+
assert!(context.no_unhandled_errors());
464+
assert!(context.confirm_buffers_empty().is_ok());
465+
}
466+
}

src/input.rs

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use std::collections::HashMap;
3030
use backtrace::Backtrace;
3131

3232
use crate::{
33-
OperationRoster, OperationError, OrBroken,
33+
OperationRoster, OperationError, OrBroken, OperationResult,
3434
DeferredRoster, Cancel, Cancellation, CancellationCause, Broken,
3535
BufferSettings, RetentionPolicy, ForkTargetStorage, UnusedTarget,
3636
};
@@ -147,6 +147,11 @@ pub trait ManageInput {
147147
session: Entity,
148148
) -> Result<SmallVec<[T; 16]>, OperationError>;
149149

150+
fn clear_buffer<T: 'static + Send + Sync>(
151+
&mut self,
152+
session: Entity,
153+
) -> OperationResult;
154+
150155
fn cleanup_inputs<T: 'static + Send + Sync>(
151156
&mut self,
152157
session: Entity,
@@ -180,6 +185,10 @@ pub trait InspectInput {
180185
) -> Result<Option<T>, OperationError>
181186
where
182187
T: Clone;
188+
189+
fn buffered_sessions<T: 'static + Send + Sync>(
190+
&self,
191+
) -> Result<SmallVec<[Entity; 16]>, OperationError>;
183192
}
184193

185194
impl<'w> ManageInput for EntityMut<'w> {
@@ -276,10 +285,52 @@ impl<'w> ManageInput for EntityMut<'w> {
276285
Ok(buffer.consume(session))
277286
}
278287

288+
fn clear_buffer<T: 'static + Send + Sync>(
289+
&mut self,
290+
session: Entity,
291+
) -> OperationResult {
292+
let mut buffer = self.get_mut::<BufferStorage<T>>().or_broken()?;
293+
buffer.reverse_queues.remove(&session);
294+
Ok(())
295+
}
296+
279297
fn cleanup_inputs<T: 'static + Send + Sync>(
280298
&mut self,
281299
session: Entity,
282300
) {
301+
if self.contains::<BufferStorage<T>>() {
302+
// Buffers are handled in a special way because the data of some
303+
// buffers will be used during cancellation. Therefore we do not
304+
// want to just delete their contents, but instead store them in the
305+
// buffer storage until the scope gives the signal to clear all
306+
// buffer data after all the cancellation workflows are finished.
307+
if let Some(mut inputs) = self.get_mut::<InputStorage<T>>() {
308+
// Pull out only the data that
309+
let remaining_indices: SmallVec<[usize; 16]> = inputs.reverse_queue
310+
.iter()
311+
.enumerate()
312+
.filter_map(|(i, input)|
313+
if input.session == session { Some(i) } else { None }
314+
)
315+
.collect();
316+
317+
let mut reverse_remaining: SmallVec<[T; 16]> = SmallVec::new();
318+
for i in remaining_indices.into_iter().rev() {
319+
reverse_remaining.push(inputs.reverse_queue.remove(i).data);
320+
}
321+
322+
// INVARIANT: Earlier in this function we checked that the
323+
// entity contains this component, and we have not removed it
324+
// since then.
325+
let mut buffer = self.get_mut::<BufferStorage<T>>().unwrap();
326+
for data in reverse_remaining.into_iter().rev() {
327+
buffer.push(session, data);
328+
}
329+
}
330+
331+
return;
332+
}
333+
283334
if let Some(mut inputs) = self.get_mut::<InputStorage<T>>() {
284335
inputs.reverse_queue.retain(
285336
|Input { session: r, .. }| *r != session
@@ -319,6 +370,18 @@ impl<'a> InspectInput for EntityMut<'a> {
319370
let buffer = self.get::<BufferStorage<T>>().or_broken()?;
320371
Ok(buffer.reverse_queues.get(&session).map(|q| q.last().cloned()).flatten())
321372
}
373+
374+
fn buffered_sessions<T: 'static + Send + Sync>(
375+
&self,
376+
) -> Result<SmallVec<[Entity; 16]>, OperationError> {
377+
let sessions = self.get::<BufferStorage<T>>().or_broken()?
378+
.reverse_queues
379+
.iter()
380+
.map(|(e, _)| *e)
381+
.collect();
382+
383+
Ok(sessions)
384+
}
322385
}
323386

324387
impl<'a> InspectInput for EntityRef<'a> {
@@ -348,6 +411,18 @@ impl<'a> InspectInput for EntityRef<'a> {
348411
let buffer = self.get::<BufferStorage<T>>().or_broken()?;
349412
Ok(buffer.reverse_queues.get(&session).map(|q| q.last().cloned()).flatten())
350413
}
414+
415+
fn buffered_sessions<T: 'static + Send + Sync>(
416+
&self,
417+
) -> Result<SmallVec<[Entity; 16]>, OperationError> {
418+
let sessions = self.get::<BufferStorage<T>>().or_broken()?
419+
.reverse_queues
420+
.iter()
421+
.map(|(e, _)| *e)
422+
.collect();
423+
424+
Ok(sessions)
425+
}
351426
}
352427

353428
pub(crate) struct InputCommand<T> {

src/operation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ mod noop;
5656
pub(crate) use noop::*;
5757

5858
mod operate_buffer;
59-
pub(crate) use operate_buffer::*;
59+
pub use operate_buffer::*;
6060

6161
mod operate_callback;
6262
pub(crate) use operate_callback::*;

0 commit comments

Comments
 (0)