scuffle_bytes_util/
bit_read.rs

1use std::io;
2
3/// A reader that reads individual bits from a stream
4#[derive(Debug)]
5#[must_use]
6pub struct BitReader<T> {
7    data: T,
8    bit_pos: u8,
9    current_byte: u8,
10}
11
12impl<T> BitReader<T> {
13    /// Create a new BitReader from a reader
14    pub const fn new(data: T) -> Self {
15        Self {
16            data,
17            bit_pos: 0,
18            current_byte: 0,
19        }
20    }
21}
22
23impl<T: io::Read> BitReader<T> {
24    /// Reads a single bit
25    pub fn read_bit(&mut self) -> io::Result<bool> {
26        if self.is_aligned() {
27            self.update_byte()?;
28        }
29
30        let bit = (self.current_byte >> (7 - self.bit_pos)) & 1;
31
32        self.bit_pos = (self.bit_pos + 1) % 8;
33
34        Ok(bit == 1)
35    }
36
37    fn update_byte(&mut self) -> io::Result<()> {
38        let mut buf = [0];
39        self.data.read_exact(&mut buf)?;
40        self.current_byte = buf[0];
41        Ok(())
42    }
43
44    /// Reads multiple bits
45    pub fn read_bits(&mut self, count: u8) -> io::Result<u64> {
46        let count = count.min(64);
47
48        let mut bits = 0;
49        for _ in 0..count {
50            let bit = self.read_bit()?;
51            bits <<= 1;
52            bits |= if bit { 1 } else { 0 };
53        }
54
55        Ok(bits)
56    }
57
58    /// Aligns the reader to the next byte boundary
59    #[inline(always)]
60    pub fn align(&mut self) -> io::Result<()> {
61        // This has the effect of making the next read_bit call read the next byte
62        // and is equivalent to calling read_bits(8 - self.bit_pos)
63        self.bit_pos = 0;
64        Ok(())
65    }
66}
67
68impl<T> BitReader<T> {
69    /// Returns the underlying reader
70    #[inline(always)]
71    #[must_use]
72    pub fn into_inner(self) -> T {
73        self.data
74    }
75
76    /// Returns a reference to the underlying reader
77    #[inline(always)]
78    #[must_use]
79    pub const fn get_ref(&self) -> &T {
80        &self.data
81    }
82
83    /// Returns the current bit position (0-7)
84    #[inline(always)]
85    #[must_use]
86    pub const fn bit_pos(&self) -> u8 {
87        self.bit_pos
88    }
89
90    /// Checks if the reader is aligned to the byte boundary
91    #[inline(always)]
92    #[must_use]
93    pub const fn is_aligned(&self) -> bool {
94        self.bit_pos == 0
95    }
96}
97
98impl<T: io::Read> io::Read for BitReader<T> {
99    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
100        // If we are aligned this will be essentially the same as just reading directly
101        // from the underlying reader.
102        if self.is_aligned() {
103            return self.data.read(buf);
104        }
105
106        // However if we are not aligned we need to shift all the bits into the correct
107        // position. Think of it like this
108        //
109        // 011|0110 0000000 11111111
110        //    ^---- This is the next bit to read (0) i show it with a | to make it clear
111        // the resulting read should be [01100000, 00001111]
112        // Byte 1: first 4 bits are from the first byte and 4 bits from the second byte
113        // Byte 2: first 4 bits are from the second byte and the first 4 bits from the
114        // third byte
115
116        for byte in buf.iter_mut() {
117            *byte = 0;
118            for _ in 0..8 {
119                let bit = self.read_bit()?;
120                *byte <<= 1;
121                *byte |= bit as u8;
122            }
123        }
124
125        Ok(buf.len())
126    }
127}
128
129impl<B: AsRef<[u8]>> BitReader<std::io::Cursor<B>> {
130    /// Creates a new BitReader from a slice
131    pub const fn new_from_slice(data: B) -> Self {
132        Self::new(std::io::Cursor::new(data))
133    }
134}
135
136impl<W: io::Seek + io::Read> BitReader<W> {
137    /// Returns the current stream position in bits
138    pub fn bit_stream_position(&mut self) -> io::Result<u64> {
139        let pos = self.data.stream_position()?;
140        Ok(pos * 8 + if self.is_aligned() { 8 } else { self.bit_pos as u64 } - 8)
141    }
142
143    /// Seeks a number of bits forward or backward
144    /// Returns the new stream position in bits
145    pub fn seek_bits(&mut self, count: i64) -> io::Result<u64> {
146        // We dont need to do any work here.
147        if count == 0 {
148            return self.bit_stream_position();
149        }
150
151        let count = self.bit_pos as i64 + count;
152
153        // Otherwise we need to do some work to move the bit position to the desired
154        // position
155
156        // the number of bits we should move by
157        let bit_move = count % 8;
158        // the number of bytes we should move by
159        let mut byte_move = count / 8;
160
161        // if we are not aligned we need to move back 1 byte (since we have partially
162        // read the current byte)
163        if !self.is_aligned() {
164            byte_move -= 1;
165        }
166
167        // if we are seeking back we need to move back 1 byte (since we are going to
168        // move forward a byte)
169        if bit_move < 0 {
170            byte_move -= 1;
171        }
172
173        let mut pos = self.data.seek(io::SeekFrom::Current(byte_move))? * 8;
174
175        // This works for both positive and negative bit_move
176        // If bit_move is -3 then we want to move 3 bits (8 + (-3)) % 8 = 5
177        // If bit_move is 3 then we want to move 3 bits (8 + 3) % 8 = 3
178        // Modulo arithmetic is cool!
179        self.bit_pos = ((8 + bit_move) % 8) as u8;
180
181        // If we are not unaligned we need to update the byte because we have a partial
182        // read, but the byte has not been read yet (its only read when the bit
183        // position is 0 on the next call to read_bit)
184        if !self.is_aligned() {
185            self.update_byte()?;
186            pos += self.bit_pos as u64;
187        }
188
189        Ok(pos)
190    }
191}
192
193impl<T: io::Seek + io::Read> io::Seek for BitReader<T> {
194    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
195        match pos {
196            // Otherwise if we are doing a relative seek we likely do care about the bit position
197            // So we call the seek_bits function to handle seeking the offset in bits
198            io::SeekFrom::Current(offset) if !self.is_aligned() => {
199                // This returns the new stream position in bytes rounded up to the nearest byte
200                Ok(self.seek_bits(offset * 8)?.div_ceil(8))
201            }
202            // Otherwise we are seeking to a position relative to the start or end of the stream, we dont care about the bit
203            // position Or the bit position is already 0 so we can just seek to the new position
204            _ => {
205                self.bit_pos = 0;
206                self.data.seek(pos)
207            }
208        }
209    }
210}
211
212#[cfg(test)]
213#[cfg_attr(all(test, coverage_nightly), coverage(off))]
214mod tests {
215    use io::{Read, Seek};
216
217    use super::*;
218
219    #[test]
220    fn test_bit_reader() {
221        let binary = 0b10101010110011001111000101010101u32;
222
223        let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
224        for i in 0..32 {
225            assert_eq!(
226                reader.read_bit().unwrap(),
227                (binary & (1 << (31 - i))) != 0,
228                "bit {i} is not correct"
229            );
230        }
231
232        assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
233    }
234
235    #[test]
236    fn test_bit_reader_read_bits() {
237        let binary = 0b10101010110011001111000101010101u32;
238        let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
239        let cases = [
240            (3, 0b101),
241            (4, 0b0101),
242            (3, 0b011),
243            (3, 0b001),
244            (3, 0b100),
245            (3, 0b111),
246            (5, 0b10001),
247            (1, 0b0),
248            (7, 0b1010101),
249        ];
250
251        for (i, (count, expected)) in cases.into_iter().enumerate() {
252            assert_eq!(
253                reader.read_bits(count).ok(),
254                Some(expected),
255                "reading {count} bits ({i}) are not correct"
256            );
257        }
258
259        assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
260    }
261
262    #[test]
263    fn test_bit_reader_align() {
264        let mut reader = BitReader::new_from_slice([0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000]);
265
266        for i in 0..6 {
267            let pos = reader.data.stream_position().unwrap();
268            assert_eq!(pos, i, "stream pos");
269            assert_eq!(reader.bit_pos(), 0, "bit pos");
270            assert!(reader.read_bit().unwrap(), "bit {i} is not correct");
271            reader.align().unwrap();
272            let pos = reader.data.stream_position().unwrap();
273            assert_eq!(pos, i + 1, "stream pos");
274            assert_eq!(reader.bit_pos(), 0, "bit pos");
275        }
276
277        assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
278    }
279
280    #[test]
281    fn test_bit_reader_io_read() {
282        let binary = 0b10101010110011001111000101010101u32;
283        let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
284
285        // Aligned read (calls the underlying read directly (very fast))
286        let mut buf = [0; 1];
287        reader.read_exact(&mut buf).unwrap();
288        assert_eq!(buf, [0b10101010]);
289
290        // Unaligned read
291        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
292        let mut buf = [0; 1];
293        reader.read_exact(&mut buf).unwrap();
294        assert_eq!(buf, [0b10011001]);
295    }
296
297    #[test]
298    fn test_bit_reader_seek() {
299        let binary = 0b10101010110011001111000101010101u32;
300        let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
301
302        assert_eq!(reader.seek_bits(5).unwrap(), 5);
303        assert_eq!(reader.data.stream_position().unwrap(), 1);
304        assert_eq!(reader.bit_pos(), 5);
305        assert_eq!(reader.read_bits(1).unwrap(), 0b0);
306        assert_eq!(reader.bit_pos(), 6);
307
308        assert_eq!(reader.seek_bits(0).unwrap(), 6);
309
310        assert_eq!(reader.seek_bits(10).unwrap(), 16);
311        assert_eq!(reader.data.stream_position().unwrap(), 2);
312        assert_eq!(reader.bit_pos(), 0);
313        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
314        assert_eq!(reader.bit_pos(), 1);
315        assert_eq!(reader.data.stream_position().unwrap(), 3);
316
317        assert_eq!(reader.seek_bits(-8).unwrap(), 9);
318        assert_eq!(reader.data.stream_position().unwrap(), 2);
319        assert_eq!(reader.bit_pos(), 1);
320        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
321        assert_eq!(reader.bit_pos(), 2);
322        assert_eq!(reader.data.stream_position().unwrap(), 2);
323
324        assert_eq!(reader.seek_bits(-2).unwrap(), 8);
325        assert_eq!(reader.data.stream_position().unwrap(), 1);
326        assert_eq!(reader.bit_pos(), 0);
327        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
328        assert_eq!(reader.bit_pos(), 1);
329        assert_eq!(reader.data.stream_position().unwrap(), 2);
330    }
331
332    #[test]
333    fn test_bit_reader_io_seek() {
334        let binary = 0b10101010110011001111000101010101u32;
335        let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
336        assert_eq!(reader.seek(io::SeekFrom::Start(1)).unwrap(), 1);
337        assert_eq!(reader.bit_pos(), 0);
338        assert_eq!(reader.data.stream_position().unwrap(), 1);
339        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
340        assert_eq!(reader.bit_pos(), 1);
341        assert_eq!(reader.data.stream_position().unwrap(), 2);
342
343        assert_eq!(reader.seek(io::SeekFrom::Current(1)).unwrap(), 3);
344        assert_eq!(reader.bit_pos(), 1);
345        assert_eq!(reader.data.stream_position().unwrap(), 3);
346        assert_eq!(reader.read_bits(1).unwrap(), 0b1);
347        assert_eq!(reader.bit_pos(), 2);
348        assert_eq!(reader.data.stream_position().unwrap(), 3);
349
350        assert_eq!(reader.seek(io::SeekFrom::Current(-1)).unwrap(), 2);
351        assert_eq!(reader.bit_pos(), 2);
352        assert_eq!(reader.data.stream_position().unwrap(), 2);
353        assert_eq!(reader.read_bits(1).unwrap(), 0b0);
354        assert_eq!(reader.bit_pos(), 3);
355        assert_eq!(reader.data.stream_position().unwrap(), 2);
356
357        assert_eq!(reader.seek(io::SeekFrom::End(-1)).unwrap(), 3);
358        assert_eq!(reader.bit_pos(), 0);
359        assert_eq!(reader.data.stream_position().unwrap(), 3);
360        assert_eq!(reader.read_bits(1).unwrap(), 0b0);
361        assert_eq!(reader.bit_pos(), 1);
362        assert_eq!(reader.data.stream_position().unwrap(), 4);
363    }
364}