scuffle_bytes_util/
nal_emulation_prevention.rs

1/// A wrapper around a [`std::io::Read`] or [`std::io::Write`] that automatically inserts or removes
2/// [NAL](https://en.wikipedia.org/wiki/Network_Abstraction_Layer) emulation prevention bytes, when reading or writing respectively.
3///
4/// Defined by:
5/// - ISO/IEC 14496-10 - 7.4.1.1
6/// - ISO/IEC 23008-2 - 7.4.2.3
7pub struct EmulationPreventionIo<I> {
8    inner: I,
9    zero_count: u8,
10}
11
12impl<I> EmulationPreventionIo<I> {
13    /// Creates a new `EmulationPrevention` wrapper around the given [`std::io::Read`] or [`std::io::Write`].
14    /// This should be a buffered reader or writer because we will only read or write one byte at a time.
15    /// If the underlying io is not buffered this will result in poor performance.
16    pub fn new(inner: I) -> Self {
17        Self { inner, zero_count: 0 }
18    }
19}
20
21impl<I: std::io::Write> std::io::Write for EmulationPreventionIo<I> {
22    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
23        for &byte in buf {
24            if self.zero_count >= 2 && (byte <= 0x03) {
25                self.inner.write_all(&[0x3])?;
26                self.zero_count = 0;
27            }
28
29            self.inner.write_all(&[byte])?;
30            if byte == 0x00 {
31                self.zero_count += 1;
32            } else {
33                self.zero_count = 0;
34            }
35        }
36
37        Ok(buf.len())
38    }
39
40    fn flush(&mut self) -> std::io::Result<()> {
41        self.inner.flush()
42    }
43}
44
45impl<I: std::io::Read> std::io::Read for EmulationPreventionIo<I> {
46    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
47        let mut read_size = 0;
48        let mut one_byte = [0; 1];
49        while buf.len() > read_size {
50            let size = self.inner.read(&mut one_byte)?;
51            if size == 0 {
52                break;
53            }
54
55            let byte = one_byte[0];
56            match byte {
57                0x03 if self.zero_count >= 2 => {
58                    self.zero_count = 0;
59                    continue;
60                }
61                0x00 => {
62                    self.zero_count += 1;
63                }
64                _ => {
65                    self.zero_count = 0;
66                }
67            }
68
69            buf[read_size] = byte;
70            read_size += 1;
71        }
72
73        Ok(read_size)
74    }
75}
76
77#[cfg(test)]
78#[cfg_attr(all(test, coverage_nightly), coverage(off))]
79mod tests {
80    use std::io::{Read, Write};
81
82    use crate::EmulationPreventionIo;
83
84    #[test]
85    fn test_write_emulation_prevention_single() {
86        let mut buf = Vec::new();
87        let mut writer = EmulationPreventionIo::new(&mut buf);
88
89        writer.write_all(&[0x00, 0x00, 0x01]).unwrap();
90        writer.flush().unwrap();
91
92        assert_eq!(buf, vec![0x00, 0x00, 0x03, 0x01]);
93    }
94
95    #[test]
96    fn test_write_emulation_prevention_multiple() {
97        let mut buf = Vec::new();
98        let mut writer = EmulationPreventionIo::new(&mut buf);
99        writer.write_all(&[0x00, 0x00, 0x01, 0x00, 0x00, 0x02]).unwrap();
100        writer.flush().unwrap();
101
102        assert_eq!(buf, vec![0x00, 0x00, 0x03, 0x01, 0x00, 0x00, 0x03, 0x02]);
103    }
104
105    #[test]
106    fn test_read_emulation_prevention() {
107        let input = [0x00, 0x00, 0x03, 0x01];
108
109        let mut reader = EmulationPreventionIo::new(&input[..]);
110        let mut output = Vec::new();
111        reader.read_to_end(&mut output).unwrap();
112
113        assert_eq!(output, vec![0x00, 0x00, 0x01]);
114    }
115
116    #[test]
117    fn test_read_emulation_prevention_multiple() {
118        let input = [0x00, 0x00, 0x03, 0x01, 0x00, 0x00, 0x03, 0x02];
119
120        let mut reader = EmulationPreventionIo::new(&input[..]);
121        let mut output = Vec::new();
122        reader.read_to_end(&mut output).unwrap();
123
124        assert_eq!(output, vec![0x00, 0x00, 0x01, 0x00, 0x00, 0x02]);
125    }
126
127    #[test]
128    fn test_roundtrip() {
129        let original = vec![0x00, 0x00, 0x01, 0x00, 0x00, 0x02];
130
131        // Write with emulation prevention
132        let mut encoded = Vec::new();
133        let mut writer = EmulationPreventionIo::new(&mut encoded);
134        writer.write_all(&original).unwrap();
135        writer.flush().unwrap();
136
137        // Read back with emulation prevention removal
138        let mut reader = EmulationPreventionIo::new(&encoded[..]);
139        let mut decoded = Vec::new();
140        reader.read_to_end(&mut decoded).unwrap();
141
142        // Should match original after roundtrip
143        assert_eq!(original, decoded);
144    }
145}