scuffle_rtmp/handshake/complex/
digest.rs

1//! Digest processing for complex handshakes.
2
3use std::io;
4
5use bytes::Bytes;
6use hmac::{Hmac, Mac};
7use sha2::Sha256;
8
9use super::error::ComplexHandshakeError;
10use super::{RTMP_DIGEST_LENGTH, SchemaVersion};
11use crate::handshake::{CHUNK_LENGTH, TIME_VERSION_LENGTH};
12
13/// A digest processor.
14///
15/// This is used to process the digest of a message.
16pub struct DigestProcessor<'a> {
17    data: Bytes,
18    key: &'a [u8],
19}
20
21/// The result of a digest.
22///
23/// Use [`DigestProcessor::generate_and_fill_digest`] to create a `DigestResult`
24/// and [`DigestResult::write_to`] to write the result to a buffer.
25pub struct DigestResult {
26    /// The left part.
27    pub left: Bytes,
28    /// The digest.
29    pub digest: [u8; 32],
30    /// The right part.
31    pub right: Bytes,
32}
33
34impl DigestResult {
35    /// Write the digest result to a given buffer.
36    pub fn write_to(&self, writer: &mut impl io::Write) -> io::Result<()> {
37        writer.write_all(&self.left)?;
38        writer.write_all(&self.digest)?;
39        writer.write_all(&self.right)?;
40
41        Ok(())
42    }
43}
44
45impl<'a> DigestProcessor<'a> {
46    /// Create a new digest processor.
47    pub const fn new(data: Bytes, key: &'a [u8]) -> Self {
48        Self { data, key }
49    }
50
51    /// Read digest from message
52    ///
53    /// According the the spec the schema can either be in the order of
54    /// - time, version, key, digest (schema 0) or
55    /// - time, version, digest, key (schema 1)
56    pub fn read_digest(&self) -> Result<(Bytes, SchemaVersion), ComplexHandshakeError> {
57        if let Ok(digest) = self.generate_and_validate(SchemaVersion::Schema0) {
58            Ok((digest, SchemaVersion::Schema0))
59        } else {
60            let digest = self.generate_and_validate(SchemaVersion::Schema1)?;
61            Ok((digest, SchemaVersion::Schema1))
62        }
63    }
64
65    /// Generate and fill digest based on the schema version.
66    pub fn generate_and_fill_digest(&self, version: SchemaVersion) -> Result<DigestResult, ComplexHandshakeError> {
67        let (left_part, _, right_part) = self.split_message(version)?;
68        let computed_digest = self.make_digest(&left_part, &right_part)?;
69
70        // The reason we return 3 parts vs 1 is because if we return 1 part we need to
71        // copy the memory But this is unnecessary because we are just going to write it
72        // into a buffer.
73        Ok(DigestResult {
74            left: left_part,
75            digest: computed_digest,
76            right: right_part,
77        })
78    }
79
80    fn find_digest_offset(&self, version: SchemaVersion) -> Result<usize, ComplexHandshakeError> {
81        const OFFSET_LENGTH: usize = 4;
82
83        // in schema 0 the digest is after the key (which is after the time and version)
84        // in schema 1 the digest is after the time and version
85        let schema_offset = match version {
86            SchemaVersion::Schema0 => CHUNK_LENGTH + TIME_VERSION_LENGTH,
87            SchemaVersion::Schema1 => TIME_VERSION_LENGTH,
88        };
89
90        // No idea why this isn't a be u32.
91        // It seems to be 4 x 8bit values we add together.
92        // We then mod it by the chunk length - digest length - offset length
93        // Then add the schema offset and offset length to get the digest offset
94        Ok((*self.data.get(schema_offset).unwrap() as usize
95            + *self.data.get(schema_offset + 1).unwrap() as usize
96            + *self.data.get(schema_offset + 2).unwrap() as usize
97            + *self.data.get(schema_offset + 3).unwrap() as usize)
98            % (CHUNK_LENGTH - RTMP_DIGEST_LENGTH - OFFSET_LENGTH)
99            + schema_offset
100            + OFFSET_LENGTH)
101    }
102
103    fn split_message(&self, version: SchemaVersion) -> Result<(Bytes, Bytes, Bytes), ComplexHandshakeError> {
104        let digest_offset = self.find_digest_offset(version)?;
105
106        // We split the message into 3 parts:
107        // 1. The part before the digest
108        // 2. The digest
109        // 3. The part after the digest
110        // This is so we can calculate the digest.
111        // We then compare it to the digest we read from the message.
112        // If they are the same we have a valid message.
113
114        // Slice is a O(1) operation and does not copy the memory.
115        let left_part = self.data.slice(0..digest_offset);
116        let digest_data = self.data.slice(digest_offset..digest_offset + RTMP_DIGEST_LENGTH);
117        let right_part = self.data.slice(digest_offset + RTMP_DIGEST_LENGTH..);
118
119        Ok((left_part, digest_data, right_part))
120    }
121
122    /// Make a digest from the left and right parts using the key.
123    pub fn make_digest(&self, left: &[u8], right: &[u8]) -> Result<[u8; 32], ComplexHandshakeError> {
124        // New hmac from the key
125        let mut mac = Hmac::<Sha256>::new_from_slice(self.key).unwrap();
126        // Update the hmac with the left and right parts
127        mac.update(left);
128        mac.update(right);
129
130        // Finalize the hmac and get the digest
131        let result = mac.finalize().into_bytes();
132        if result.len() != RTMP_DIGEST_LENGTH {
133            return Err(ComplexHandshakeError::DigestLengthNotCorrect);
134        }
135
136        // This does a copy of the memory but its only 32 bytes so its not a big deal.
137        Ok(result.into())
138    }
139
140    fn generate_and_validate(&self, version: SchemaVersion) -> Result<Bytes, ComplexHandshakeError> {
141        // We need the 3 parts so we can calculate the digest and compare it to the
142        // digest we read from the message.
143        let (left_part, digest_data, right_part) = self.split_message(version)?;
144
145        // If the digest we calculated is the same as the digest we read from the
146        // message we have a valid message.
147        if digest_data == self.make_digest(&left_part, &right_part)?.as_ref() {
148            Ok(digest_data)
149        } else {
150            // This does not mean the message is invalid, it just means we need to try the
151            // other schema. If both schemas fail then the message is invalid and its likely
152            // a simple handshake.
153            Err(ComplexHandshakeError::CannotGenerate)
154        }
155    }
156}