1use std::sync::{Arc, Mutex};
2
3use bytes::{Buf, BytesMut};
4
5#[derive(Debug, Clone)]
9pub struct ChannelCompat<T: Send> {
10 inner: Arc<Mutex<T>>,
14 buffer: BytesMut,
15}
16
17impl<T: Send> ChannelCompat<T> {
18 pub fn new(inner: T) -> Self {
20 Self {
21 inner: Arc::new(Mutex::new(inner)),
22 buffer: BytesMut::new(),
23 }
24 }
25}
26
27pub trait ChannelCompatRecv: Send {
29 type Data: AsRef<[u8]>;
31
32 fn channel_recv(&mut self) -> Option<Self::Data>;
34
35 fn try_channel_recv(&mut self) -> Option<Self::Data>;
37
38 fn into_compat(self) -> ChannelCompat<Self>
40 where
41 Self: Sized,
42 {
43 ChannelCompat::new(self)
44 }
45}
46
47pub trait ChannelCompatSend: Send {
49 type Data: From<Vec<u8>>;
51
52 fn channel_send(&mut self, data: Self::Data) -> bool;
54
55 fn into_compat(self) -> ChannelCompat<Self>
57 where
58 Self: Sized,
59 {
60 ChannelCompat::new(self)
61 }
62}
63
64#[cfg(feature = "tokio-channel")]
65impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for tokio::sync::mpsc::Receiver<D> {
66 type Data = D;
67
68 fn channel_recv(&mut self) -> Option<Self::Data> {
69 self.blocking_recv()
70 }
71
72 fn try_channel_recv(&mut self) -> Option<Self::Data> {
73 self.try_recv().ok()
74 }
75}
76
77#[cfg(feature = "tokio-channel")]
78impl<D: From<Vec<u8>> + Send> ChannelCompatSend for tokio::sync::mpsc::Sender<D> {
79 type Data = D;
80
81 fn channel_send(&mut self, data: Self::Data) -> bool {
82 self.blocking_send(data).is_ok()
83 }
84}
85
86#[cfg(feature = "tokio-channel")]
87impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for tokio::sync::mpsc::UnboundedReceiver<D> {
88 type Data = D;
89
90 fn channel_recv(&mut self) -> Option<Self::Data> {
91 self.blocking_recv()
92 }
93
94 fn try_channel_recv(&mut self) -> Option<Self::Data> {
95 self.try_recv().ok()
96 }
97}
98
99#[cfg(feature = "tokio-channel")]
100impl<D: From<Vec<u8>> + Send> ChannelCompatSend for tokio::sync::mpsc::UnboundedSender<D> {
101 type Data = D;
102
103 fn channel_send(&mut self, data: Self::Data) -> bool {
104 self.send(data).is_ok()
105 }
106}
107
108#[cfg(feature = "tokio-channel")]
109impl<D: AsRef<[u8]> + Clone + Send> ChannelCompatRecv for tokio::sync::broadcast::Receiver<D> {
110 type Data = D;
111
112 fn channel_recv(&mut self) -> Option<Self::Data> {
113 self.blocking_recv().ok()
114 }
115
116 fn try_channel_recv(&mut self) -> Option<Self::Data> {
117 self.try_recv().ok()
118 }
119}
120
121#[cfg(feature = "tokio-channel")]
122impl<D: From<Vec<u8>> + Clone + Send> ChannelCompatSend for tokio::sync::broadcast::Sender<D> {
123 type Data = D;
124
125 fn channel_send(&mut self, data: Self::Data) -> bool {
126 self.send(data).is_ok()
127 }
128}
129
130#[cfg(feature = "crossbeam-channel")]
131impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for crossbeam_channel::Receiver<D> {
132 type Data = D;
133
134 fn channel_recv(&mut self) -> Option<Self::Data> {
135 self.recv().ok()
136 }
137
138 fn try_channel_recv(&mut self) -> Option<Self::Data> {
139 self.try_recv().ok()
140 }
141}
142
143#[cfg(feature = "crossbeam-channel")]
144impl<D: From<Vec<u8>> + Send> ChannelCompatSend for crossbeam_channel::Sender<D> {
145 type Data = D;
146
147 fn channel_send(&mut self, data: Self::Data) -> bool {
148 self.send(data).is_ok()
149 }
150}
151
152impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for std::sync::mpsc::Receiver<D> {
153 type Data = D;
154
155 fn channel_recv(&mut self) -> Option<Self::Data> {
156 self.recv().ok()
157 }
158
159 fn try_channel_recv(&mut self) -> Option<Self::Data> {
160 self.try_recv().ok()
161 }
162}
163
164impl<D: From<Vec<u8>> + Send> ChannelCompatSend for std::sync::mpsc::Sender<D> {
165 type Data = D;
166
167 fn channel_send(&mut self, data: Self::Data) -> bool {
168 self.send(data).is_ok()
169 }
170}
171
172impl<D: From<Vec<u8>> + Send> ChannelCompatSend for std::sync::mpsc::SyncSender<D> {
173 type Data = D;
174
175 fn channel_send(&mut self, data: Self::Data) -> bool {
176 self.send(data).is_ok()
177 }
178}
179
180impl<T: ChannelCompatRecv> std::io::Read for ChannelCompat<T> {
181 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
182 if self.buffer.len() >= buf.len() {
183 buf.copy_from_slice(&self.buffer[..buf.len()]);
184 self.buffer.advance(buf.len());
185 return Ok(buf.len());
186 }
187
188 let mut inner = self.inner.lock().unwrap();
189
190 let mut total_read = 0;
191 if self.buffer.is_empty() {
192 let Some(data) = inner.channel_recv() else {
193 return Ok(0);
194 };
195
196 let data = data.as_ref();
197 let min = data.len().min(buf.len());
198
199 buf.copy_from_slice(&data[..min]);
200 self.buffer.extend_from_slice(&data[min..]);
201 total_read += min;
202 } else {
203 buf[..self.buffer.len()].copy_from_slice(&self.buffer);
204 total_read += self.buffer.len();
205 self.buffer.clear();
206 }
207
208 while let Some(Some(data)) = (total_read < buf.len()).then(|| inner.try_channel_recv()) {
209 let data = data.as_ref();
210 let min = data.len().min(buf.len() - total_read);
211 buf[total_read..total_read + min].copy_from_slice(&data[..min]);
212 self.buffer.extend_from_slice(&data[min..]);
213 total_read += min;
214 }
215
216 Ok(total_read)
217 }
218}
219
220impl<T: ChannelCompatSend> std::io::Write for ChannelCompat<T> {
221 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
222 if !self.inner.lock().unwrap().channel_send(buf.to_vec().into()) {
223 return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Unexpected EOF"));
224 }
225
226 Ok(buf.len())
227 }
228
229 fn flush(&mut self) -> std::io::Result<()> {
230 Ok(())
231 }
232}
233
234#[cfg(test)]
235#[cfg_attr(all(test, coverage_nightly), coverage(off))]
236mod tests {
237 use std::io::{Read, Write};
238
239 use rand::Rng;
240 use rand::distr::StandardUniform;
241
242 use crate::io::channel::{ChannelCompat, ChannelCompatRecv, ChannelCompatSend};
243
244 macro_rules! make_test {
245 (
246 $(
247 $(
248 #[variant($name:ident, $channel:expr$(, cfg($($cfg_meta:meta)*))?)]
249 )*
250 |$tx:ident, $rx:ident| $body:block
251 )*
252 ) => {
253 $(
254 $(
255 #[test]
256 $(#[cfg($($cfg_meta)*)])?
257 fn $name() {
258 let ($tx, $rx) = $channel;
259 $body
260 }
261 )*
262 )*
263 };
264 }
265
266 make_test! {
268 #[variant(
269 test_read_std_mpsc,
270 std::sync::mpsc::channel::<Vec<u8>>()
271 )]
272 #[variant(
273 test_read_std_sync_mpsc,
274 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
275 )]
276 #[variant(
277 test_read_tokio_mpsc,
278 tokio::sync::mpsc::channel::<Vec<u8>>(1),
279 cfg(feature = "tokio-channel")
280 )]
281 #[variant(
282 test_read_tokio_unbounded,
283 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
284 cfg(feature = "tokio-channel")
285 )]
286 #[variant(
287 test_read_tokio_broadcast,
288 tokio::sync::broadcast::channel::<Vec<u8>>(1),
289 cfg(feature = "tokio-channel")
290 )]
291 #[variant(
292 test_read_crossbeam_unbounded,
293 crossbeam_channel::unbounded::<Vec<u8>>(),
294 cfg(feature = "crossbeam-channel")
295 )]
296 |tx, rx| {
297 let mut reader = rx.into_compat();
298
299 let mut rng = rand::rng();
301 let data: Vec<u8> = (0..1000).map(|_| rng.sample::<u8, _>(StandardUniform)).collect();
302
303 let mut tx = tx;
304 let write_result = tx.channel_send(data.clone());
305 assert!(write_result);
306
307 let mut buffer = vec![0u8; 1000];
309 let read_result = reader.read(&mut buffer);
310 assert!(read_result.is_ok());
311 assert_eq!(read_result.unwrap(), data.len());
312
313 assert_eq!(buffer, data);
315 }
316 }
317
318 make_test! {
320 #[variant(
321 test_write_std_mpsc,
322 std::sync::mpsc::channel::<Vec<u8>>()
323 )]
324 #[variant(
325 test_write_std_sync_mpsc,
326 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
327 )]
328 #[variant(
329 test_write_tokio_mpsc,
330 tokio::sync::mpsc::channel::<Vec<u8>>(1),
331 cfg(feature = "tokio-channel")
332 )]
333 #[variant(
334 test_write_tokio_unbounded,
335 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
336 cfg(feature = "tokio-channel")
337 )]
338 #[variant(
339 test_write_tokio_broadcast,
340 tokio::sync::broadcast::channel::<Vec<u8>>(1),
341 cfg(feature = "tokio-channel")
342 )]
343 #[variant(
344 test_write_crossbeam_unbounded,
345 crossbeam_channel::unbounded::<Vec<u8>>(),
346 cfg(feature = "crossbeam-channel")
347 )]
348 |tx, rx| {
349 let mut writer = tx.into_compat();
350
351 let mut rng = rand::rng();
353 let data: Vec<u8> = (0..1000).map(|_| rng.sample::<u8, _>(StandardUniform)).collect();
354
355 let write_result = writer.write(&data);
356 assert!(write_result.is_ok(), "Failed to write data to the channel");
357 assert_eq!(write_result.unwrap(), data.len(), "Written byte count mismatch");
358
359 let mut rx = rx;
361 let read_result = rx.channel_recv();
362 assert!(read_result.is_some(), "No data received from the channel");
363
364 let received_data = read_result.unwrap();
365 assert_eq!(received_data.len(), data.len(), "Received byte count mismatch");
366
367 assert_eq!(
369 received_data, data,
370 "Mismatch between written data and received data"
371 );
372 }
373 }
374
375 make_test! {
377 #[variant(
378 test_read_smaller_buffer_than_data_std_mpsc,
379 std::sync::mpsc::channel::<Vec<u8>>()
380 )]
381 #[variant(
382 test_read_smaller_buffer_than_data_std_sync_mpsc,
383 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
384 )]
385 #[variant(
386 test_read_smaller_buffer_than_data_tokio_mpsc,
387 tokio::sync::mpsc::channel::<Vec<u8>>(1),
388 cfg(feature = "tokio-channel")
389 )]
390 #[variant(
391 test_read_smaller_buffer_than_data_tokio_unbounded,
392 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
393 cfg(feature = "tokio-channel")
394 )]
395 #[variant(
396 test_read_smaller_buffer_than_data_tokio_broadcast,
397 tokio::sync::broadcast::channel::<Vec<u8>>(1),
398 cfg(feature = "tokio-channel")
399 )]
400 #[variant(
401 test_read_smaller_buffer_than_data_crossbeam_unbounded,
402 crossbeam_channel::unbounded::<Vec<u8>>(),
403 cfg(feature = "crossbeam-channel")
404 )]
405 |tx, rx| {
406 let mut reader = ChannelCompat::new(rx);
407 let data = b"PartialReadTest".to_vec();
408 let mut tx = tx;
409 let send_result = tx.channel_send(data);
410 assert!(send_result);
411
412 let mut buffer = vec![0u8; 7]; let read_result = reader.read(&mut buffer);
414 assert!(read_result.is_ok());
415 assert_eq!(read_result.unwrap(), buffer.len());
416 assert_eq!(&buffer, b"Partial");
417
418 let mut buffer = vec![0u8; 8];
420 let read_result = reader.read(&mut buffer);
421 assert!(read_result.is_ok());
422 assert_eq!(read_result.unwrap(), buffer.len());
423 assert_eq!(&buffer, b"ReadTest");
424 }
425 }
426
427 make_test! {
429 #[variant(
430 test_read_no_data_std_mpsc,
431 std::sync::mpsc::channel::<Vec<u8>>()
432 )]
433 #[variant(
434 test_read_no_data_std_sync_mpsc,
435 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
436 )]
437 #[variant(
438 test_read_no_data_tokio_mpsc,
439 tokio::sync::mpsc::channel::<Vec<u8>>(1),
440 cfg(feature = "tokio-channel")
441 )]
442 #[variant(
443 test_read_no_data_tokio_unbounded,
444 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
445 cfg(feature = "tokio-channel")
446 )]
447 #[variant(
448 test_read_no_data_tokio_broadcast,
449 tokio::sync::broadcast::channel::<Vec<u8>>(1),
450 cfg(feature = "tokio-channel")
451 )]
452 #[variant(
453 test_read_no_data_crossbeam_unbounded,
454 crossbeam_channel::unbounded::<Vec<u8>>(),
455 cfg(feature = "crossbeam-channel")
456 )]
457 |tx, rx| {
458 let mut reader = ChannelCompat::new(rx);
459
460 drop(tx);
462 let mut buffer = vec![0u8; 10];
463 let read_result = reader.read(&mut buffer);
464
465 assert!(read_result.is_ok());
466 assert_eq!(read_result.unwrap(), 0);
467 }
468 }
469
470 make_test! {
472 #[variant(
473 test_read_else_case_std_mpsc,
474 std::sync::mpsc::channel::<Vec<u8>>()
475 )]
476 #[variant(
477 test_read_else_case_std_sync_mpsc,
478 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
479 )]
480 #[variant(
481 test_read_else_case_tokio_mpsc,
482 tokio::sync::mpsc::channel::<Vec<u8>>(1),
483 cfg(feature = "tokio-channel")
484 )]
485 #[variant(
486 test_read_else_case_tokio_unbounded,
487 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
488 cfg(feature = "tokio-channel")
489 )]
490 #[variant(
491 test_read_else_case_tokio_broadcast,
492 tokio::sync::broadcast::channel::<Vec<u8>>(1),
493 cfg(feature = "tokio-channel")
494 )]
495 #[variant(
496 test_read_else_case_crossbeam_unbounded,
497 crossbeam_channel::unbounded::<Vec<u8>>(),
498 cfg(feature = "crossbeam-channel")
499 )]
500 |tx, rx| {
501 let mut reader = ChannelCompat::new(rx);
502 let mut tx = tx;
503
504 let data1 = b"FirstChunk".to_vec();
505 let write_result1 = tx.channel_send(data1);
506 assert!(write_result1, "Failed to send data1");
507
508 let mut buffer = vec![0u8; 5];
510 let read_result = reader.read(&mut buffer);
511 assert!(read_result.is_ok(), "Failed to read the first chunk");
512 let bytes_read = read_result.unwrap();
513 assert_eq!(bytes_read, buffer.len(), "Mismatch in first chunk read size");
514 assert_eq!(&buffer, b"First", "Buffer content mismatch for first part of FirstChunk");
515
516 let mut buffer = vec![0u8; 10];
518 let read_result = reader.read(&mut buffer);
519 assert!(read_result.is_ok(), "Failed to read the next 10 bytes");
520 let bytes_read = read_result.unwrap();
521
522 assert_eq!(bytes_read, 5, "Unexpected read size for the next part");
524 assert_eq!(&buffer[..bytes_read], b"Chunk", "Buffer content mismatch for combined reads");
525
526 let data2 = b"SecondChunk".to_vec();
528 let write_result2 = tx.channel_send(data2);
529 assert!(write_result2, "Failed to send data2");
530
531 let mut buffer = vec![0u8; 5];
533 let read_result = reader.read(&mut buffer);
534 assert!(read_result.is_ok(), "Failed to read leftover data from data2");
535 let bytes_read = read_result.unwrap();
536 assert!(bytes_read > 0, "No leftover data from data2 was available");
537 }
538 }
539
540 make_test! {
542 #[variant(
543 test_read_while_case_std_mpsc,
544 std::sync::mpsc::channel::<Vec<u8>>()
545 )]
546 #[variant(
547 test_read_while_case_std_sync_mpsc,
548 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
549 )]
550 #[variant(
551 test_read_while_case_tokio_mpsc,
552 tokio::sync::mpsc::channel::<Vec<u8>>(1),
553 cfg(feature = "tokio-channel")
554 )]
555 #[variant(
556 test_read_while_case_tokio_unbounded,
557 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
558 cfg(feature = "tokio-channel")
559 )]
560 #[variant(
561 test_read_while_case_tokio_broadcast,
562 tokio::sync::broadcast::channel::<Vec<u8>>(1),
563 cfg(feature = "tokio-channel")
564 )]
565 #[variant(
566 test_read_while_case_crossbeam_unbounded,
567 crossbeam_channel::unbounded::<Vec<u8>>(),
568 cfg(feature = "crossbeam-channel")
569 )]
570 |tx, rx| {
571 let mut reader = ChannelCompat::new(rx);
572 let mut tx = tx;
573
574 let data1 = b"FirstChunk".to_vec();
575 let write_result1 = tx.channel_send(data1);
576 assert!(write_result1, "Failed to send data1");
577
578 let mut buffer = vec![0u8; 5];
580 let read_result = reader.read(&mut buffer);
581 assert!(read_result.is_ok(), "Failed to read the first chunk");
582 let bytes_read = read_result.unwrap();
583 assert_eq!(bytes_read, buffer.len(), "Mismatch in first chunk read size");
584 assert_eq!(&buffer, b"First", "Buffer content mismatch for first part of FirstChunk");
585
586 let data2 = b"SecondChunk".to_vec();
588 let write_result2 = tx.channel_send(data2);
589 assert!(write_result2, "Failed to send data2");
590
591 let mut buffer = vec![0u8; 10];
593 let read_result = reader.read(&mut buffer);
594 assert!(read_result.is_ok(), "Failed to read the next chunk of data");
595 let bytes_read = read_result.unwrap();
596 assert!(bytes_read > 0, "No data was read");
597 assert_eq!(&buffer[..bytes_read], b"ChunkSecon", "Buffer content mismatch");
598
599 let mut buffer = vec![0u8; 6];
601 let read_result = reader.read(&mut buffer);
602 assert!(read_result.is_ok(), "Failed to read remaining data");
603 let bytes_read = read_result.unwrap();
604 assert!(bytes_read > 0, "No additional data was read");
605 assert_eq!(&buffer[..bytes_read], b"dChunk", "Buffer content mismatch for remaining data");
606 }
607 }
608
609 make_test! {
611 #[variant(
612 test_write_eof_error_std_mpsc,
613 std::sync::mpsc::channel::<Vec<u8>>()
614 )]
615 #[variant(
616 test_write_eof_error_std_sync_mpsc,
617 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
618 )]
619 #[variant(
620 test_write_eof_error_tokio_mpsc,
621 tokio::sync::mpsc::channel::<Vec<u8>>(1),
622 cfg(feature = "tokio-channel")
623 )]
624 #[variant(
625 test_write_eof_error_tokio_unbounded,
626 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
627 cfg(feature = "tokio-channel")
628 )]
629 #[variant(
630 test_write_eof_error_tokio_broadcast,
631 tokio::sync::broadcast::channel::<Vec<u8>>(1),
632 cfg(feature = "tokio-channel")
633 )]
634 #[variant(
635 test_write_eof_error_crossbeam_unbounded,
636 crossbeam_channel::unbounded::<Vec<u8>>(),
637 cfg(feature = "crossbeam-channel")
638 )]
639 |tx, rx| {
640 let mut writer = ChannelCompat::new(tx);
641
642 drop(rx);
644
645 let data = vec![42u8; 100];
646 let write_result = writer.write(&data);
647 assert!(write_result.is_err());
648 assert_eq!(write_result.unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof);
649 }
650 }
651
652 make_test! {
654 #[variant(
655 test_flush_std_mpsc,
656 std::sync::mpsc::channel::<Vec<u8>>()
657 )]
658 #[variant(
659 test_flush_std_sync_mpsc,
660 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
661 )]
662 #[variant(
663 test_flush_tokio_mpsc,
664 tokio::sync::mpsc::channel::<Vec<u8>>(1),
665 cfg(feature = "tokio-channel")
666 )]
667 #[variant(
668 test_flush_tokio_unbounded,
669 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
670 cfg(feature = "tokio-channel")
671 )]
672 #[variant(
673 test_flush_tokio_broadcast,
674 tokio::sync::broadcast::channel::<Vec<u8>>(1),
675 cfg(feature = "tokio-channel")
676 )]
677 #[variant(
678 test_flush_crossbeam_unbounded,
679 crossbeam_channel::unbounded::<Vec<u8>>(),
680 cfg(feature = "crossbeam-channel")
681 )]
682 |tx, _rx| {
683 let mut writer = ChannelCompat::new(tx);
684
685 let flush_result = writer.flush();
686 assert!(flush_result.is_ok());
687 }
688 }
689}