scuffle_bytes_util/
bit_write.rs

1use std::io;
2
3/// A writer that allows you to write bits to a stream
4#[derive(Debug)]
5#[must_use]
6pub struct BitWriter<W> {
7    bit_pos: u8,
8    current_byte: u8,
9    writer: W,
10}
11
12impl<W: Default> Default for BitWriter<W> {
13    fn default() -> Self {
14        Self::new(W::default())
15    }
16}
17
18impl<W: io::Write> BitWriter<W> {
19    /// Writes a single bit to the stream
20    pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
21        if bit {
22            self.current_byte |= 1 << (7 - self.bit_pos);
23        } else {
24            self.current_byte &= !(1 << (7 - self.bit_pos));
25        }
26
27        self.bit_pos += 1;
28
29        if self.bit_pos == 8 {
30            self.writer.write_all(&[self.current_byte])?;
31            self.current_byte = 0;
32            self.bit_pos = 0;
33        }
34
35        Ok(())
36    }
37
38    /// Writes a number of bits to the stream (the most significant bit is
39    /// written first)
40    pub fn write_bits(&mut self, bits: u64, count: u8) -> io::Result<()> {
41        let count = count.min(64);
42
43        if count != 64 && bits > (1 << count as u64) - 1 {
44            return Err(io::Error::new(io::ErrorKind::InvalidData, "bits too large to write"));
45        }
46
47        for i in 0..count {
48            let bit = (bits >> (count - i - 1)) & 1 == 1;
49            self.write_bit(bit)?;
50        }
51
52        Ok(())
53    }
54
55    /// Flushes the buffer and returns the underlying writer
56    /// This will also align the writer to the byte boundary
57    pub fn finish(mut self) -> io::Result<W> {
58        self.align()?;
59        Ok(self.writer)
60    }
61
62    /// Aligns the writer to the byte boundary
63    pub fn align(&mut self) -> io::Result<()> {
64        if !self.is_aligned() {
65            self.write_bits(0, 8 - self.bit_pos())?;
66        }
67
68        Ok(())
69    }
70}
71
72impl<W> BitWriter<W> {
73    /// Creates a new BitWriter from a writer
74    pub const fn new(writer: W) -> Self {
75        Self {
76            bit_pos: 0,
77            current_byte: 0,
78            writer,
79        }
80    }
81
82    /// Returns the current bit position (0-7)
83    #[inline(always)]
84    #[must_use]
85    pub const fn bit_pos(&self) -> u8 {
86        self.bit_pos % 8
87    }
88
89    /// Checks if the writer is aligned to the byte boundary
90    #[inline(always)]
91    #[must_use]
92    pub const fn is_aligned(&self) -> bool {
93        self.bit_pos % 8 == 0
94    }
95
96    /// Returns a reference to the underlying writer
97    #[inline(always)]
98    #[must_use]
99    pub const fn get_ref(&self) -> &W {
100        &self.writer
101    }
102}
103
104impl<W: io::Write> io::Write for BitWriter<W> {
105    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
106        if self.is_aligned() {
107            return self.writer.write(buf);
108        }
109
110        for byte in buf {
111            self.write_bits(*byte as u64, 8)?;
112        }
113
114        Ok(buf.len())
115    }
116
117    fn flush(&mut self) -> io::Result<()> {
118        self.writer.flush()
119    }
120}
121
122#[cfg(test)]
123#[cfg_attr(all(test, coverage_nightly), coverage(off))]
124mod tests {
125    use io::Write;
126
127    use super::*;
128
129    #[test]
130    fn test_bit_writer() {
131        let mut bit_writer = BitWriter::<Vec<u8>>::default();
132
133        bit_writer.write_bits(0b11111111, 8).unwrap();
134        assert_eq!(bit_writer.bit_pos(), 0);
135        assert!(bit_writer.is_aligned());
136
137        bit_writer.write_bits(0b0000, 4).unwrap();
138        assert_eq!(bit_writer.bit_pos(), 4);
139        assert!(!bit_writer.is_aligned());
140        bit_writer.align().unwrap();
141        assert_eq!(bit_writer.bit_pos(), 0);
142        assert!(bit_writer.is_aligned());
143
144        bit_writer.write_bits(0b1010, 4).unwrap();
145        assert_eq!(bit_writer.bit_pos(), 4);
146        assert!(!bit_writer.is_aligned());
147
148        bit_writer.write_bits(0b101010101010, 12).unwrap();
149        assert_eq!(bit_writer.bit_pos(), 0);
150        assert!(bit_writer.is_aligned());
151
152        bit_writer.write_bit(true).unwrap();
153        assert_eq!(bit_writer.bit_pos(), 1);
154        assert!(!bit_writer.is_aligned());
155
156        let err = bit_writer.write_bits(0b10000, 4).unwrap_err();
157        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
158        assert_eq!(err.to_string(), "bits too large to write");
159
160        assert_eq!(
161            bit_writer.finish().unwrap(),
162            vec![0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b10000000]
163        );
164    }
165
166    #[test]
167    fn test_flush_buffer() {
168        let mut bit_writer = BitWriter::<Vec<u8>>::default();
169
170        bit_writer.write_bits(0b11111111, 8).unwrap();
171        assert_eq!(bit_writer.bit_pos(), 0);
172        assert!(bit_writer.is_aligned());
173        assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one byte");
174
175        bit_writer.write_bits(0b0000, 4).unwrap();
176        assert_eq!(bit_writer.bit_pos(), 4);
177        assert!(!bit_writer.is_aligned());
178        assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one bytes");
179
180        bit_writer.write_bits(0b1010, 4).unwrap();
181        assert_eq!(bit_writer.bit_pos(), 0);
182        assert!(bit_writer.is_aligned());
183        assert_eq!(
184            bit_writer.get_ref(),
185            &[0b11111111, 0b00001010],
186            "underlying writer should have two bytes"
187        );
188    }
189
190    #[test]
191    fn test_io_write() {
192        let mut inner = Vec::new();
193        let mut bit_writer = BitWriter::new(&mut inner);
194
195        bit_writer.write_bits(0b11111111, 8).unwrap();
196        assert_eq!(bit_writer.bit_pos(), 0);
197        assert!(bit_writer.is_aligned());
198        // We should have buffered the write
199        assert_eq!(bit_writer.get_ref().as_slice(), &[255]);
200
201        bit_writer.write_all(&[1, 2, 3]).unwrap();
202        assert_eq!(bit_writer.bit_pos(), 0);
203        assert!(bit_writer.is_aligned());
204        // since we did an io::Write on an aligned bit_writer
205        // we should have written directly to the underlying
206        // writer
207        assert_eq!(bit_writer.get_ref().as_slice(), &[255, 1, 2, 3]);
208
209        bit_writer.write_bit(true).unwrap();
210
211        bit_writer.write_bits(0b1010, 4).unwrap();
212
213        bit_writer
214            .write_all(&[0b11111111, 0b00000000, 0b11111111, 0b00000000])
215            .unwrap();
216
217        // Since the writer was not aligned we should have buffered the writes
218        assert_eq!(
219            bit_writer.get_ref().as_slice(),
220            &[255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000]
221        );
222
223        bit_writer.finish().unwrap();
224
225        assert_eq!(
226            inner,
227            vec![255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000, 0b00000000]
228        );
229    }
230
231    #[test]
232    fn test_flush() {
233        let mut inner = Vec::new();
234        let mut bit_writer = BitWriter::new(&mut inner);
235
236        bit_writer.write_bits(0b10100000, 8).unwrap();
237
238        bit_writer.flush().unwrap();
239
240        assert_eq!(bit_writer.get_ref().as_slice(), &[0b10100000]);
241        assert_eq!(bit_writer.bit_pos(), 0);
242        assert!(bit_writer.is_aligned());
243    }
244}