scuffle_bytes_util/
bytes_cursor.rs

1use std::io;
2
3use bytes::Bytes;
4
5/// A cursor for reading bytes.
6///
7/// This cursor is a [`io::Cursor`] where the underlying type is a [`Bytes`] object
8/// which enables zero copy decoding.
9pub type BytesCursor = io::Cursor<Bytes>;
10
11/// A helper trait to implement zero copy reads on a [`BytesCursor`] type.
12///
13/// Allowing for zero copy reads from a [`BytesCursor`] type.
14pub trait BytesCursorExt {
15    /// Extracts the remaining bytes from the cursor.
16    ///
17    /// This does not do a copy of the bytes, and is O(1) time.
18    ///
19    /// This is the same as `BytesCursor::extract_bytes(self.remaining())`.
20    ///
21    /// This is equivalent if you were to read the remaining data into a new
22    /// buffer, however this is more efficient as it does not copy the
23    /// bytes.
24    fn extract_remaining(&mut self) -> Bytes;
25
26    /// Extracts bytes from the cursor.
27    ///
28    /// This does not do a copy of the bytes, and is O(1) time.
29    /// Returns an error if the size is greater than the remaining bytes.
30    ///
31    /// This is equivalent if you were to read the remaining data into a new
32    /// buffer, however this is more efficient as it does not copy the
33    /// bytes.
34    fn extract_bytes(&mut self, size: usize) -> io::Result<Bytes>;
35}
36
37fn remaining(cursor: &BytesCursor) -> usize {
38    cursor.get_ref().len().saturating_sub(cursor.position() as usize)
39}
40
41impl BytesCursorExt for BytesCursor {
42    fn extract_remaining(&mut self) -> Bytes {
43        // We don't really care if we fail here since the desired behavior is
44        // to return all bytes remaining in the cursor. If we fail its because
45        // there are not enough bytes left in the cursor to read.
46        self.extract_bytes(remaining(self)).unwrap_or_default()
47    }
48
49    fn extract_bytes(&mut self, size: usize) -> io::Result<Bytes> {
50        // If the size is zero we can just return an empty bytes slice.
51        if size == 0 {
52            return Ok(Bytes::new());
53        }
54
55        // If the size is greater than the remaining bytes we can just return an
56        // error.
57        if size > remaining(self) {
58            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
59        }
60
61        let position = self.position() as usize;
62
63        // We slice bytes here which is a O(1) operation as it only modifies a few
64        // reference counters and does not copy the memory.
65        let slice = self.get_ref().slice(position..position + size);
66
67        // We advance the cursor because we have now "read" the bytes.
68        self.set_position((position + size) as u64);
69
70        Ok(slice)
71    }
72}
73
74#[cfg(test)]
75#[cfg_attr(all(test, coverage_nightly), coverage(off))]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_bytes_cursor_extract_remaining() {
81        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
82        let remaining = cursor.extract_remaining();
83        assert_eq!(remaining, Bytes::from_static(&[1, 2, 3, 4, 5]));
84    }
85
86    #[test]
87    fn test_bytes_cursor_extract_bytes() {
88        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
89        let bytes = cursor.extract_bytes(3).unwrap();
90        assert_eq!(bytes, Bytes::from_static(&[1, 2, 3]));
91        assert_eq!(remaining(&cursor), 2);
92
93        let bytes = cursor.extract_bytes(2).unwrap();
94        assert_eq!(bytes, Bytes::from_static(&[4, 5]));
95        assert_eq!(remaining(&cursor), 0);
96
97        let bytes = cursor.extract_bytes(1).unwrap_err();
98        assert_eq!(bytes.kind(), io::ErrorKind::UnexpectedEof);
99
100        let bytes = cursor.extract_bytes(0).unwrap();
101        assert_eq!(bytes, Bytes::from_static(&[]));
102        assert_eq!(remaining(&cursor), 0);
103
104        let bytes = cursor.extract_remaining();
105        assert_eq!(bytes, Bytes::from_static(&[]));
106        assert_eq!(remaining(&cursor), 0);
107    }
108
109    #[test]
110    fn seek_out_of_bounds() {
111        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
112        cursor.set_position(10);
113        assert_eq!(remaining(&cursor), 0);
114
115        let bytes = cursor.extract_remaining();
116        assert_eq!(bytes, Bytes::from_static(&[]));
117
118        let bytes = cursor.extract_bytes(1);
119        assert_eq!(bytes.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
120
121        let bytes = cursor.extract_bytes(0);
122        assert_eq!(bytes.unwrap(), Bytes::from_static(&[]));
123    }
124}