1use core::future::Future;
4use core::marker::PhantomData;
5use core::pin::Pin;
6use core::task::{Context, Poll, Waker};
7use std::sync::Arc;
8use std::task::Wake;
9
10use futures_util::task::AtomicWaker;
11use futures_util::{Sink, Stream, ready};
12use pin_project_lite::pin_project;
13
14#[derive(Default)]
15struct DualWaker {
16 sink: AtomicWaker,
17 stream: AtomicWaker,
18}
19
20impl DualWaker {
21 fn new() -> (Arc<Self>, Waker) {
22 let dual_waker = Arc::new(Self::default());
23 let waker = Waker::from(dual_waker.clone());
24 (dual_waker, waker)
25 }
26}
27
28impl Wake for DualWaker {
29 fn wake(self: Arc<Self>) {
30 self.wake_by_ref();
31 }
32
33 fn wake_by_ref(self: &Arc<Self>) {
34 self.sink.wake();
35 self.stream.wake();
36 }
37}
38
39pin_project! {
40 #[project = SharedStateProj]
41 enum SharedState<Fut, St, Si, Item> {
42 Uninit {
43 future: Option<Fut>,
45 },
46 Thunkulating {
47 #[pin]
48 future: Fut,
49 item: Option<Item>,
50 dual_waker_state: Arc<DualWaker>,
51 dual_waker_waker: Waker,
52 },
53 Done {
54 #[pin]
55 stream: St,
56 #[pin]
57 sink: Si,
58 buf: Option<Item>,
59 },
60 }
61}
62
63pin_project! {
64 pub struct LazySinkSource<Fut, St, Si, Item, Error> {
68 #[pin]
69 state: SharedState<Fut, St, Si, Item>,
70 _phantom: PhantomData<Error>,
71 }
72}
73
74impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
75 pub fn new(future: Fut) -> Self {
77 Self {
78 state: SharedState::Uninit {
79 future: Some(future),
80 },
81 _phantom: PhantomData,
82 }
83 }
84}
85
86impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error>
87where
88 Fut: Future<Output = Result<(St, Si), Error>>,
89 St: Stream,
90 Si: Sink<Item>,
91 Error: From<Si::Error>,
92{
93 fn poll_sink_op(
94 self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 sink_op: impl FnOnce(Pin<&mut Si>, &mut Context<'_>) -> Poll<Result<(), Si::Error>>,
97 ) -> Poll<Result<(), Error>> {
98 let mut this = self.project();
99
100 if let SharedStateProj::Uninit { .. } = this.state.as_mut().project() {
101 return Poll::Ready(Ok(()));
102 }
103
104 if let SharedStateProj::Thunkulating {
105 future,
106 item,
107 dual_waker_state,
108 dual_waker_waker,
109 } = this.state.as_mut().project()
110 {
111 dual_waker_state.sink.register(cx.waker());
112
113 let mut dual_context = Context::from_waker(dual_waker_waker);
114
115 match future.poll(&mut dual_context) {
116 Poll::Ready(Ok((stream, sink))) => {
117 let buf = item.take();
118 this.state
119 .as_mut()
120 .set(SharedState::Done { stream, sink, buf });
121 }
122 Poll::Ready(Err(e)) => {
123 return Poll::Ready(Err(e));
124 }
125 Poll::Pending => {
126 return Poll::Pending;
127 }
128 }
129 }
130
131 if let SharedStateProj::Done { mut sink, buf, .. } = this.state.as_mut().project() {
132 if buf.is_some() {
133 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
134 sink.as_mut().start_send(buf.take().unwrap())?;
135 }
136 return (sink_op)(sink, cx).map_err(From::from);
137 }
138
139 panic!("LazySinkSource in invalid state.");
140 }
141}
142
143impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
144where
145 Fut: Future<Output = Result<(St, Si), Error>>,
146 St: Stream,
147 Si: Sink<Item>,
148 Error: From<Si::Error>,
149{
150 type Error = Error;
151
152 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
153 self.poll_sink_op(cx, Sink::poll_ready)
154 }
155
156 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
157 let mut this = self.project();
158
159 if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
160 let future = future.take().unwrap();
161 let (dual_waker_state, dual_waker_waker) = DualWaker::new();
162 this.state.as_mut().set(SharedState::Thunkulating {
163 future,
164 item: Some(item),
165 dual_waker_state,
166 dual_waker_waker,
167 });
168 return Ok(());
169 }
170
171 if let SharedStateProj::Thunkulating { .. } = this.state.as_mut().project() {
172 panic!("LazySinkSource not ready.");
173 }
174
175 if let SharedStateProj::Done { sink, buf, .. } = this.state.as_mut().project() {
176 debug_assert!(buf.is_none());
177 return sink.start_send(item).map_err(From::from);
178 }
179
180 panic!("LazySinkSource not ready.");
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 self.poll_sink_op(cx, Sink::poll_flush)
185 }
186
187 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188 self.poll_sink_op(cx, Sink::poll_close)
189 }
190}
191
192impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
193where
194 Fut: Future<Output = Result<(St, Si), Error>>,
195 St: Stream,
196 Si: Sink<Item>,
197{
198 type Item = St::Item;
199
200 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201 let mut this = self.project();
202
203 if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
204 let future = future.take().unwrap();
205 let (dual_waker_state, dual_waker_waker) = DualWaker::new();
206 this.state.as_mut().set(SharedState::Thunkulating {
207 future,
208 item: None,
209 dual_waker_state,
210 dual_waker_waker,
211 });
212 }
213
214 if let SharedStateProj::Thunkulating {
215 future,
216 item,
217 dual_waker_state,
218 dual_waker_waker,
219 } = this.state.as_mut().project()
220 {
221 dual_waker_state.stream.register(cx.waker());
222
223 let mut new_context = Context::from_waker(dual_waker_waker);
224
225 match future.poll(&mut new_context) {
226 Poll::Ready(Ok((stream, sink))) => {
227 let buf = item.take();
228 this.state
229 .as_mut()
230 .set(SharedState::Done { stream, sink, buf });
231 }
232
233 Poll::Ready(Err(_)) => {
234 return Poll::Ready(None);
235 }
236
237 Poll::Pending => {
238 return Poll::Pending;
239 }
240 }
241 }
242
243 if let SharedStateProj::Done { stream, .. } = this.state.as_mut().project() {
244 return stream.poll_next(cx);
245 }
246
247 panic!("LazySinkSource in invalid state.");
248 }
249}
250
251#[cfg(test)]
252mod test {
253 use futures_util::{SinkExt, StreamExt};
254 use tokio_util::sync::PollSendError;
255
256 use super::*;
257
258 #[tokio::test(flavor = "current_thread")]
259 async fn stream_drives_initialization() {
260 let local = tokio::task::LocalSet::new();
261 local
262 .run_until(async {
263 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
264
265 let sink_source = LazySinkSource::new(async move {
266 let () = init_lazy_recv.await.unwrap();
267 let (send, recv) = tokio::sync::mpsc::channel(1);
268 let sink = tokio_util::sync::PollSender::new(send);
269 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
270 Ok::<_, PollSendError<_>>((stream, sink))
271 });
272
273 let (mut sink, mut stream) = sink_source.split();
274
275 let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
277 let stream_task = tokio::task::spawn_local(async move {
278 stream_init_send.send(()).unwrap();
279 (stream.next().await.unwrap(), stream.next().await.unwrap())
280 });
281 let sink_task = tokio::task::spawn_local(async move {
282 stream_init_recv.await.unwrap();
283 SinkExt::send(&mut sink, "test1").await.unwrap();
284 SinkExt::send(&mut sink, "test2").await.unwrap();
285 });
286
287 init_lazy_send.send(()).unwrap();
289
290 tokio::task::yield_now().await;
291
292 assert!(sink_task.is_finished());
293 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
294 sink_task.await.unwrap();
295 })
296 .await;
297 }
298
299 #[tokio::test(flavor = "current_thread")]
300 async fn sink_drives_initialization() {
301 let local = tokio::task::LocalSet::new();
302 local
303 .run_until(async {
304 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
305
306 let sink_source = LazySinkSource::new(async move {
307 let () = init_lazy_recv.await.unwrap();
308 let (send, recv) = tokio::sync::mpsc::channel(1);
309 let sink = tokio_util::sync::PollSender::new(send);
310 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
311 Ok::<_, PollSendError<_>>((stream, sink))
312 });
313
314 let (mut sink, mut stream) = sink_source.split();
315
316 let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
318 let stream_task = tokio::task::spawn_local(async move {
319 sink_init_recv.await.unwrap();
320 (stream.next().await.unwrap(), stream.next().await.unwrap())
321 });
322 let sink_task = tokio::task::spawn_local(async move {
323 sink_init_send.send(()).unwrap();
324 SinkExt::send(&mut sink, "test1").await.unwrap();
325 SinkExt::send(&mut sink, "test2").await.unwrap();
326 });
327
328 init_lazy_send.send(()).unwrap();
330
331 tokio::task::yield_now().await;
332
333 assert!(sink_task.is_finished());
334 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
335 sink_task.await.unwrap();
336 })
337 .await;
338 }
339
340 #[tokio::test(flavor = "current_thread")]
341 async fn tcp_stream_drives_initialization() {
342 use tokio::net::{TcpListener, TcpStream};
343 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
344
345 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
346
347 let local = tokio::task::LocalSet::new();
348 local
349 .run_until(async {
350 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
351 let addr = listener.local_addr().unwrap();
352
353 let sink_source = LazySinkSource::new(async move {
354 initialization_tx.send(()).unwrap();
356
357 let (stream, _) = listener.accept().await.unwrap();
358 let (rx, tx) = stream.into_split();
359 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
360 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
361 Ok::<_, std::io::Error>((fr, fw))
362 });
363
364 let (mut sink, mut stream) = sink_source.split();
365
366 let stream_task = tokio::task::spawn_local(async move { stream.next().await });
367
368 initialization_rx.await.unwrap(); let sink_task = tokio::task::spawn_local(async move {
371 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
372 .await
373 .unwrap();
374 });
375
376 for _ in 0..20 {
378 tokio::task::yield_now().await
379 }
380
381 let mut socket = TcpStream::connect(addr).await.unwrap();
383 let (client_rx, client_tx) = socket.split();
384 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
385 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
386
387 for _ in 0..20 {
389 tokio::task::yield_now().await
390 }
391
392 assert!(!stream_task.is_finished()); SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
396 .await
397 .unwrap();
398
399 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
400 sink_task.await.unwrap();
401
402 assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
403 })
404 .await;
405 }
406
407 #[tokio::test(flavor = "current_thread")]
408 async fn tcp_sink_drives_initialization() {
409 use tokio::net::{TcpListener, TcpStream};
410 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
411
412 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
413
414 let local = tokio::task::LocalSet::new();
415 local
416 .run_until(async {
417 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
418 let addr = listener.local_addr().unwrap();
419
420 let sink_source = LazySinkSource::new(async move {
421 initialization_tx.send(()).unwrap();
423
424 let (stream, _) = listener.accept().await.unwrap();
425 let (rx, tx) = stream.into_split();
426 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
427 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
428 Ok::<_, std::io::Error>((fr, fw))
429 });
430
431 let (mut sink, mut stream) = sink_source.split();
432
433 let sink_task = tokio::task::spawn_local(async move {
434 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
435 .await
436 .unwrap();
437 });
438
439 initialization_rx.await.unwrap(); let stream_task = tokio::task::spawn_local(async move { stream.next().await });
442
443 for _ in 0..20 {
445 tokio::task::yield_now().await
446 }
447
448 assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
449
450 let mut socket = TcpStream::connect(addr).await.unwrap();
452 let (client_rx, client_tx) = socket.split();
453 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
454 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
455
456 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
458
459 assert!(sink_task.is_finished()); assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
462
463 SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
465 .await
466 .unwrap();
467
468 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
469 sink_task.await.unwrap();
470 })
471 .await;
472 }
473}