1use std::future::{Future, IntoFuture};
2use std::pin::Pin;
3use std::task::Poll;
4
5use futures_lite::Stream;
6use tokio_util::sync::{WaitForCancellationFuture, WaitForCancellationFutureOwned};
7
8use crate::{Context, ContextTracker};
9
10pub struct ContextRef<'a> {
15 inner: ContextRefInner<'a>,
16}
17
18impl From<Context> for ContextRef<'_> {
19 fn from(ctx: Context) -> Self {
20 ContextRef {
21 inner: ContextRefInner::Owned {
22 fut: ctx.token.cancelled_owned(),
23 tracker: ctx.tracker,
24 },
25 }
26 }
27}
28
29impl<'a> From<&'a Context> for ContextRef<'a> {
30 fn from(ctx: &'a Context) -> Self {
31 ContextRef {
32 inner: ContextRefInner::Ref {
33 fut: ctx.token.cancelled(),
34 },
35 }
36 }
37}
38
39pin_project_lite::pin_project! {
40 #[project = ContextRefInnerProj]
41 enum ContextRefInner<'a> {
42 Owned {
43 #[pin] fut: WaitForCancellationFutureOwned,
44 tracker: ContextTracker,
45 },
46 Ref {
47 #[pin] fut: WaitForCancellationFuture<'a>,
48 },
49 }
50}
51
52impl std::future::Future for ContextRefInner<'_> {
53 type Output = ();
54
55 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
56 match self.project() {
57 ContextRefInnerProj::Owned { fut, .. } => fut.poll(cx),
58 ContextRefInnerProj::Ref { fut } => fut.poll(cx),
59 }
60 }
61}
62
63pin_project_lite::pin_project! {
64 pub struct FutureWithContext<'a, F> {
68 #[pin]
69 future: F,
70 #[pin]
71 ctx: ContextRefInner<'a>,
72 _marker: std::marker::PhantomData<&'a ()>,
73 }
74}
75
76impl<F: Future> Future for FutureWithContext<'_, F> {
77 type Output = Option<F::Output>;
78
79 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
80 let this = self.project();
81
82 match (this.ctx.poll(cx), this.future.poll(cx)) {
83 (_, Poll::Ready(v)) => std::task::Poll::Ready(Some(v)),
84 (Poll::Ready(_), Poll::Pending) => std::task::Poll::Ready(None),
85 (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
86 }
87 }
88}
89
90pub trait ContextFutExt<Fut> {
92 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, Fut>
112 where
113 Self: Sized;
114}
115
116impl<F: IntoFuture> ContextFutExt<F::IntoFuture> for F {
117 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> FutureWithContext<'a, F::IntoFuture>
118 where
119 F: IntoFuture,
120 {
121 FutureWithContext {
122 future: self.into_future(),
123 ctx: ctx.into().inner,
124 _marker: std::marker::PhantomData,
125 }
126 }
127}
128
129pin_project_lite::pin_project! {
130 pub struct StreamWithContext<'a, F> {
134 #[pin]
135 stream: F,
136 #[pin]
137 ctx: ContextRefInner<'a>,
138 _marker: std::marker::PhantomData<&'a ()>,
139 }
140}
141
142impl<F: Stream> Stream for StreamWithContext<'_, F> {
143 type Item = F::Item;
144
145 fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
146 let this = self.project();
147
148 match (this.ctx.poll(cx), this.stream.poll_next(cx)) {
149 (Poll::Ready(_), _) => std::task::Poll::Ready(None),
150 (Poll::Pending, Poll::Ready(v)) => std::task::Poll::Ready(v),
151 (Poll::Pending, Poll::Pending) => std::task::Poll::Pending,
152 }
153 }
154
155 fn size_hint(&self) -> (usize, Option<usize>) {
156 self.stream.size_hint()
157 }
158}
159
160pub trait ContextStreamExt<Stream> {
162 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, Stream>
186 where
187 Self: Sized;
188}
189
190impl<F: Stream> ContextStreamExt<F> for F {
191 fn with_context<'a>(self, ctx: impl Into<ContextRef<'a>>) -> StreamWithContext<'a, F> {
192 StreamWithContext {
193 stream: self,
194 ctx: ctx.into().inner,
195 _marker: std::marker::PhantomData,
196 }
197 }
198}
199
200#[cfg_attr(all(coverage_nightly, test), coverage(off))]
201#[cfg(test)]
202mod tests {
203 use std::pin::pin;
204
205 use futures_lite::{Stream, StreamExt};
206 use scuffle_future_ext::FutureExt;
207
208 use super::{Context, ContextFutExt, ContextStreamExt};
209
210 #[tokio::test]
211 async fn future() {
212 let (ctx, handler) = Context::new();
213
214 let task = tokio::spawn(
215 async {
216 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
218 }
219 .with_context(ctx),
220 );
221
222 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
224
225 handler.shutdown().await;
227
228 task.await.unwrap();
229 }
230
231 #[tokio::test]
232 async fn future_result() {
233 let (ctx, handler) = Context::new();
234
235 let task = tokio::spawn(async { 1 }.with_context(ctx));
236
237 handler.shutdown().await;
239
240 assert_eq!(task.await.unwrap(), Some(1));
241 }
242
243 #[tokio::test]
244 async fn future_ctx_by_ref() {
245 let (ctx, handler) = Context::new();
246
247 let task = tokio::spawn(async move {
248 async {
249 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
251 }
252 .with_context(&ctx)
253 .await;
254
255 drop(ctx);
256 });
257
258 handler.shutdown().await;
260
261 task.await.unwrap();
262 }
263
264 #[tokio::test]
265 async fn stream() {
266 let (ctx, handler) = Context::new();
267
268 {
269 let mut stream = pin!(futures_lite::stream::iter(0..10).with_context(ctx));
270
271 assert_eq!(stream.size_hint(), (10, Some(10)));
272
273 assert_eq!(stream.next().await, Some(0));
274 assert_eq!(stream.next().await, Some(1));
275 assert_eq!(stream.next().await, Some(2));
276 assert_eq!(stream.next().await, Some(3));
277
278 handler.cancel();
280
281 assert_eq!(stream.next().await, None);
282 }
283
284 handler.shutdown().await;
285 }
286
287 #[tokio::test]
288 async fn pending_stream() {
289 let (ctx, handler) = Context::new();
290
291 {
292 let mut stream = pin!(futures_lite::stream::pending::<()>().with_context(ctx));
293
294 assert!(
296 stream
297 .next()
298 .with_timeout(std::time::Duration::from_millis(200))
299 .await
300 .is_err()
301 );
302 }
303
304 handler.shutdown().await;
305 }
306}