Skip to main content

sinktools/
lazy_sink_source.rs

1//! [`LazySinkSource`], and related items.
2
3use core::future::Future;
4use core::marker::PhantomData;
5use core::pin::Pin;
6use core::task::{Context, Poll, Waker};
7use std::sync::Arc;
8use std::task::Wake;
9
10use futures_util::task::AtomicWaker;
11use futures_util::{Sink, Stream, ready};
12use pin_project_lite::pin_project;
13
14#[derive(Default)]
15struct DualWaker {
16    sink: AtomicWaker,
17    stream: AtomicWaker,
18}
19
20impl DualWaker {
21    fn new() -> (Arc<Self>, Waker) {
22        let dual_waker = Arc::new(Self::default());
23        let waker = Waker::from(dual_waker.clone());
24        (dual_waker, waker)
25    }
26}
27
28impl Wake for DualWaker {
29    fn wake(self: Arc<Self>) {
30        self.wake_by_ref();
31    }
32
33    fn wake_by_ref(self: &Arc<Self>) {
34        self.sink.wake();
35        self.stream.wake();
36    }
37}
38
39pin_project! {
40    #[project = SharedStateProj]
41    enum SharedState<Fut, St, Si, Item> {
42        Uninit {
43            // The future, always `Some` in this state.
44            future: Option<Fut>,
45        },
46        Thunkulating {
47            #[pin]
48            future: Fut,
49            item: Option<Item>,
50            dual_waker_state: Arc<DualWaker>,
51            dual_waker_waker: Waker,
52        },
53        Done {
54            #[pin]
55            stream: St,
56            #[pin]
57            sink: Si,
58            buf: Option<Item>,
59        },
60    }
61}
62
63pin_project! {
64    /// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the
65    /// source, or when the first item is sent to the sink. To split into separate source and sink halves, use
66    /// [`futures_util::StreamExt::split`].
67    pub struct LazySinkSource<Fut, St, Si, Item, Error> {
68        #[pin]
69        state: SharedState<Fut, St, Si, Item>,
70        _phantom: PhantomData<Error>,
71    }
72}
73
74impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
75    /// Creates a new `LazySinkSource` with the given initialization future.
76    pub fn new(future: Fut) -> Self {
77        Self {
78            state: SharedState::Uninit {
79                future: Some(future),
80            },
81            _phantom: PhantomData,
82        }
83    }
84}
85
86impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error>
87where
88    Fut: Future<Output = Result<(St, Si), Error>>,
89    St: Stream,
90    Si: Sink<Item>,
91    Error: From<Si::Error>,
92{
93    fn poll_sink_op(
94        self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        sink_op: impl FnOnce(Pin<&mut Si>, &mut Context<'_>) -> Poll<Result<(), Si::Error>>,
97    ) -> Poll<Result<(), Error>> {
98        let mut this = self.project();
99
100        if let SharedStateProj::Uninit { .. } = this.state.as_mut().project() {
101            return Poll::Ready(Ok(()));
102        }
103
104        if let SharedStateProj::Thunkulating {
105            future,
106            item,
107            dual_waker_state,
108            dual_waker_waker,
109        } = this.state.as_mut().project()
110        {
111            dual_waker_state.sink.register(cx.waker());
112
113            let mut dual_context = Context::from_waker(dual_waker_waker);
114
115            match future.poll(&mut dual_context) {
116                Poll::Ready(Ok((stream, sink))) => {
117                    let buf = item.take();
118                    this.state
119                        .as_mut()
120                        .set(SharedState::Done { stream, sink, buf });
121                }
122                Poll::Ready(Err(e)) => {
123                    return Poll::Ready(Err(e));
124                }
125                Poll::Pending => {
126                    return Poll::Pending;
127                }
128            }
129        }
130
131        if let SharedStateProj::Done { mut sink, buf, .. } = this.state.as_mut().project() {
132            if buf.is_some() {
133                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
134                sink.as_mut().start_send(buf.take().unwrap())?;
135            }
136            return (sink_op)(sink, cx).map_err(From::from);
137        }
138
139        panic!("LazySinkSource in invalid state.");
140    }
141}
142
143impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
144where
145    Fut: Future<Output = Result<(St, Si), Error>>,
146    St: Stream,
147    Si: Sink<Item>,
148    Error: From<Si::Error>,
149{
150    type Error = Error;
151
152    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
153        self.poll_sink_op(cx, Sink::poll_ready)
154    }
155
156    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
157        let mut this = self.project();
158
159        if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
160            let future = future.take().unwrap();
161            let (dual_waker_state, dual_waker_waker) = DualWaker::new();
162            this.state.as_mut().set(SharedState::Thunkulating {
163                future,
164                item: Some(item),
165                dual_waker_state,
166                dual_waker_waker,
167            });
168            return Ok(());
169        }
170
171        if let SharedStateProj::Thunkulating { .. } = this.state.as_mut().project() {
172            panic!("LazySinkSource not ready.");
173        }
174
175        if let SharedStateProj::Done { sink, buf, .. } = this.state.as_mut().project() {
176            debug_assert!(buf.is_none());
177            return sink.start_send(item).map_err(From::from);
178        }
179
180        panic!("LazySinkSource not ready.");
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        self.poll_sink_op(cx, Sink::poll_flush)
185    }
186
187    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188        self.poll_sink_op(cx, Sink::poll_close)
189    }
190}
191
192impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
193where
194    Fut: Future<Output = Result<(St, Si), Error>>,
195    St: Stream,
196    Si: Sink<Item>,
197{
198    type Item = St::Item;
199
200    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201        let mut this = self.project();
202
203        if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
204            let future = future.take().unwrap();
205            let (dual_waker_state, dual_waker_waker) = DualWaker::new();
206            this.state.as_mut().set(SharedState::Thunkulating {
207                future,
208                item: None,
209                dual_waker_state,
210                dual_waker_waker,
211            });
212        }
213
214        if let SharedStateProj::Thunkulating {
215            future,
216            item,
217            dual_waker_state,
218            dual_waker_waker,
219        } = this.state.as_mut().project()
220        {
221            dual_waker_state.stream.register(cx.waker());
222
223            let mut new_context = Context::from_waker(dual_waker_waker);
224
225            match future.poll(&mut new_context) {
226                Poll::Ready(Ok((stream, sink))) => {
227                    let buf = item.take();
228                    this.state
229                        .as_mut()
230                        .set(SharedState::Done { stream, sink, buf });
231                }
232
233                Poll::Ready(Err(_)) => {
234                    return Poll::Ready(None);
235                }
236
237                Poll::Pending => {
238                    return Poll::Pending;
239                }
240            }
241        }
242
243        if let SharedStateProj::Done { stream, .. } = this.state.as_mut().project() {
244            return stream.poll_next(cx);
245        }
246
247        panic!("LazySinkSource in invalid state.");
248    }
249}
250
251#[cfg(test)]
252mod test {
253    use futures_util::{SinkExt, StreamExt};
254    use tokio_util::sync::PollSendError;
255
256    use super::*;
257
258    #[tokio::test(flavor = "current_thread")]
259    async fn stream_drives_initialization() {
260        let local = tokio::task::LocalSet::new();
261        local
262            .run_until(async {
263                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
264
265                let sink_source = LazySinkSource::new(async move {
266                    let () = init_lazy_recv.await.unwrap();
267                    let (send, recv) = tokio::sync::mpsc::channel(1);
268                    let sink = tokio_util::sync::PollSender::new(send);
269                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
270                    Ok::<_, PollSendError<_>>((stream, sink))
271                });
272
273                let (mut sink, mut stream) = sink_source.split();
274
275                // Ensures stream starts the lazy.
276                let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
277                let stream_task = tokio::task::spawn_local(async move {
278                    stream_init_send.send(()).unwrap();
279                    (stream.next().await.unwrap(), stream.next().await.unwrap())
280                });
281                let sink_task = tokio::task::spawn_local(async move {
282                    stream_init_recv.await.unwrap();
283                    SinkExt::send(&mut sink, "test1").await.unwrap();
284                    SinkExt::send(&mut sink, "test2").await.unwrap();
285                });
286
287                // finish the future.
288                init_lazy_send.send(()).unwrap();
289
290                tokio::task::yield_now().await;
291
292                assert!(sink_task.is_finished());
293                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
294                sink_task.await.unwrap();
295            })
296            .await;
297    }
298
299    #[tokio::test(flavor = "current_thread")]
300    async fn sink_drives_initialization() {
301        let local = tokio::task::LocalSet::new();
302        local
303            .run_until(async {
304                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
305
306                let sink_source = LazySinkSource::new(async move {
307                    let () = init_lazy_recv.await.unwrap();
308                    let (send, recv) = tokio::sync::mpsc::channel(1);
309                    let sink = tokio_util::sync::PollSender::new(send);
310                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
311                    Ok::<_, PollSendError<_>>((stream, sink))
312                });
313
314                let (mut sink, mut stream) = sink_source.split();
315
316                // Ensures sink starts the lazy.
317                let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
318                let stream_task = tokio::task::spawn_local(async move {
319                    sink_init_recv.await.unwrap();
320                    (stream.next().await.unwrap(), stream.next().await.unwrap())
321                });
322                let sink_task = tokio::task::spawn_local(async move {
323                    sink_init_send.send(()).unwrap();
324                    SinkExt::send(&mut sink, "test1").await.unwrap();
325                    SinkExt::send(&mut sink, "test2").await.unwrap();
326                });
327
328                // finish the future.
329                init_lazy_send.send(()).unwrap();
330
331                tokio::task::yield_now().await;
332
333                assert!(sink_task.is_finished());
334                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
335                sink_task.await.unwrap();
336            })
337            .await;
338    }
339
340    #[tokio::test(flavor = "current_thread")]
341    async fn tcp_stream_drives_initialization() {
342        use tokio::net::{TcpListener, TcpStream};
343        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
344
345        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
346
347        let local = tokio::task::LocalSet::new();
348        local
349            .run_until(async {
350                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
351                let addr = listener.local_addr().unwrap();
352
353                let sink_source = LazySinkSource::new(async move {
354                    // initialization is at least partially started now.
355                    initialization_tx.send(()).unwrap();
356
357                    let (stream, _) = listener.accept().await.unwrap();
358                    let (rx, tx) = stream.into_split();
359                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
360                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
361                    Ok::<_, std::io::Error>((fr, fw))
362                });
363
364                let (mut sink, mut stream) = sink_source.split();
365
366                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
367
368                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
369
370                let sink_task = tokio::task::spawn_local(async move {
371                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
372                        .await
373                        .unwrap();
374                });
375
376                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
377                for _ in 0..20 {
378                    tokio::task::yield_now().await
379                }
380
381                // trigger further initialization of the future.
382                let mut socket = TcpStream::connect(addr).await.unwrap();
383                let (client_rx, client_tx) = socket.split();
384                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
385                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
386
387                // try to be really sure that the effects of the above initialization completing are propagated.
388                for _ in 0..20 {
389                    tokio::task::yield_now().await
390                }
391
392                assert!(!stream_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
393
394                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
395                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
396                    .await
397                    .unwrap();
398
399                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
400                sink_task.await.unwrap();
401
402                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
403            })
404            .await;
405    }
406
407    #[tokio::test(flavor = "current_thread")]
408    async fn tcp_sink_drives_initialization() {
409        use tokio::net::{TcpListener, TcpStream};
410        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
411
412        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
413
414        let local = tokio::task::LocalSet::new();
415        local
416            .run_until(async {
417                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
418                let addr = listener.local_addr().unwrap();
419
420                let sink_source = LazySinkSource::new(async move {
421                    // initialization is at least partially started now.
422                    initialization_tx.send(()).unwrap();
423
424                    let (stream, _) = listener.accept().await.unwrap();
425                    let (rx, tx) = stream.into_split();
426                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
427                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
428                    Ok::<_, std::io::Error>((fr, fw))
429                });
430
431                let (mut sink, mut stream) = sink_source.split();
432
433                let sink_task = tokio::task::spawn_local(async move {
434                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
435                        .await
436                        .unwrap();
437                });
438
439                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
440
441                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
442
443                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
444                for _ in 0..20 {
445                    tokio::task::yield_now().await
446                }
447
448                assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
449
450                // trigger further initialization of the future.
451                let mut socket = TcpStream::connect(addr).await.unwrap();
452                let (client_rx, client_tx) = socket.split();
453                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
454                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
455
456                // try to be really sure that the effects of the above initialization completing are propagated.
457                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
458
459                assert!(sink_task.is_finished()); // Sink should have sent its item.
460
461                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
462
463                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
464                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
465                    .await
466                    .unwrap();
467
468                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
469                sink_task.await.unwrap();
470            })
471            .await;
472    }
473}