scuffle_amf0/
value.rs

1//! AMF0 value types.
2
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::io;
6
7use scuffle_bytes_util::StringCow;
8
9use crate::Amf0Error;
10use crate::encoder::Amf0Encoder;
11
12/// Represents any AMF0 object.
13pub type Amf0Object<'a> = HashMap<StringCow<'a>, Amf0Value<'a>>;
14/// Represents any AMF0 array.
15pub type Amf0Array<'a> = Cow<'a, [Amf0Value<'a>]>;
16
17/// Represents any AMF0 value.
18#[derive(Debug, PartialEq, Clone)]
19pub enum Amf0Value<'a> {
20    /// AMF0 Number.
21    Number(f64),
22    /// AMF0 Boolean.
23    Boolean(bool),
24    /// AMF0 String.
25    String(StringCow<'a>),
26    /// AMF0 Object.
27    Object(Amf0Object<'a>),
28    /// AMF0 Null.
29    Null,
30    /// AMF0 Array.
31    Array(Amf0Array<'a>),
32}
33
34impl Amf0Value<'_> {
35    /// Converts this AMF0 value into an owned version (static lifetime).
36    pub fn into_owned(self) -> Amf0Value<'static> {
37        match self {
38            Amf0Value::Number(v) => Amf0Value::Number(v),
39            Amf0Value::Boolean(v) => Amf0Value::Boolean(v),
40            Amf0Value::String(v) => Amf0Value::String(v.into_owned()),
41            Amf0Value::Object(v) => {
42                Amf0Value::Object(v.into_iter().map(|(k, v)| (k.into_owned(), v.into_owned())).collect())
43            }
44            Amf0Value::Null => Amf0Value::Null,
45            Amf0Value::Array(v) => Amf0Value::Array(v.into_owned().into_iter().map(|v| v.into_owned()).collect()),
46        }
47    }
48
49    /// Encode this AMF0 value with the given encoder.
50    pub fn encode<W: io::Write>(&self, encoder: &mut Amf0Encoder<W>) -> Result<(), Amf0Error> {
51        match self {
52            Amf0Value::Number(v) => encoder.encode_number(*v),
53            Amf0Value::Boolean(v) => encoder.encode_boolean(*v),
54            Amf0Value::String(v) => encoder.encode_string(v.as_str()),
55            Amf0Value::Object(v) => encoder.encode_object(v),
56            Amf0Value::Null => encoder.encode_null(),
57            Amf0Value::Array(v) => encoder.encode_array(v),
58        }
59    }
60}
61
62impl From<f64> for Amf0Value<'_> {
63    fn from(value: f64) -> Self {
64        Amf0Value::Number(value)
65    }
66}
67
68impl From<bool> for Amf0Value<'_> {
69    fn from(value: bool) -> Self {
70        Amf0Value::Boolean(value)
71    }
72}
73
74impl<'a> From<StringCow<'a>> for Amf0Value<'a> {
75    fn from(value: StringCow<'a>) -> Self {
76        Amf0Value::String(value)
77    }
78}
79
80// object
81impl<'a> From<Amf0Object<'a>> for Amf0Value<'a> {
82    fn from(value: Amf0Object<'a>) -> Self {
83        Amf0Value::Object(value)
84    }
85}
86
87// owned array
88impl<'a> From<Vec<Amf0Value<'a>>> for Amf0Value<'a> {
89    fn from(value: Vec<Amf0Value<'a>>) -> Self {
90        Amf0Value::Array(Cow::Owned(value))
91    }
92}
93
94// borrowed array
95impl<'a> From<&'a [Amf0Value<'a>]> for Amf0Value<'a> {
96    fn from(value: &'a [Amf0Value<'a>]) -> Self {
97        Amf0Value::Array(Cow::Borrowed(value))
98    }
99}
100
101// cow array
102impl<'a> From<Amf0Array<'a>> for Amf0Value<'a> {
103    fn from(value: Amf0Array<'a>) -> Self {
104        Amf0Value::Array(value)
105    }
106}
107
108impl<'a> FromIterator<Amf0Value<'a>> for Amf0Value<'a> {
109    fn from_iter<T: IntoIterator<Item = Amf0Value<'a>>>(iter: T) -> Self {
110        Amf0Value::Array(Cow::Owned(iter.into_iter().collect()))
111    }
112}
113
114#[cfg(feature = "serde")]
115impl<'de> serde::de::Deserialize<'de> for Amf0Value<'de> {
116    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
117    where
118        D: serde::Deserializer<'de>,
119    {
120        struct Amf0ValueVisitor;
121
122        impl<'de> serde::de::Visitor<'de> for Amf0ValueVisitor {
123            type Value = Amf0Value<'de>;
124
125            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
126                formatter.write_str("an AMF0 value")
127            }
128
129            #[inline]
130            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
131            where
132                E: serde::de::Error,
133            {
134                Ok(Amf0Value::Boolean(v))
135            }
136
137            #[inline]
138            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
139            where
140                E: serde::de::Error,
141            {
142                self.visit_f64(v as f64)
143            }
144
145            #[inline]
146            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
147            where
148                E: serde::de::Error,
149            {
150                self.visit_f64(v as f64)
151            }
152
153            #[inline]
154            fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
155            where
156                E: serde::de::Error,
157            {
158                Ok(Amf0Value::Number(v))
159            }
160
161            #[inline]
162            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
163            where
164                E: serde::de::Error,
165            {
166                self.visit_string(v.to_owned())
167            }
168
169            #[inline]
170            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
171            where
172                E: serde::de::Error,
173            {
174                Ok(StringCow::from(v).into())
175            }
176
177            #[inline]
178            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
179            where
180                E: serde::de::Error,
181            {
182                Ok(StringCow::from(v).into())
183            }
184
185            #[inline]
186            fn visit_unit<E>(self) -> Result<Self::Value, E>
187            where
188                E: serde::de::Error,
189            {
190                Ok(Amf0Value::Null)
191            }
192
193            #[inline]
194            fn visit_none<E>(self) -> Result<Self::Value, E>
195            where
196                E: serde::de::Error,
197            {
198                Ok(Amf0Value::Null)
199            }
200
201            #[inline]
202            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
203            where
204                D: serde::Deserializer<'de>,
205            {
206                serde::Deserialize::deserialize(deserializer)
207            }
208
209            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
210            where
211                A: serde::de::SeqAccess<'de>,
212            {
213                let mut vec = Vec::new();
214
215                while let Some(value) = seq.next_element()? {
216                    vec.push(value);
217                }
218
219                Ok(vec.into())
220            }
221
222            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
223            where
224                A: serde::de::MapAccess<'de>,
225            {
226                let mut object = HashMap::new();
227
228                while let Some((key, value)) = map.next_entry()? {
229                    object.insert(key, value);
230                }
231
232                Ok(object.into())
233            }
234
235            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
236            where
237                E: serde::de::Error,
238            {
239                let array = v.iter().map(|b| Amf0Value::Number(*b as f64)).collect();
240                Ok(Amf0Value::Array(array))
241            }
242        }
243
244        deserializer.deserialize_any(Amf0ValueVisitor)
245    }
246}
247
248#[cfg(feature = "serde")]
249impl serde::ser::Serialize for Amf0Value<'_> {
250    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
251    where
252        S: serde::Serializer,
253    {
254        match self {
255            Amf0Value::Number(v) => serializer.serialize_f64(*v),
256            Amf0Value::Boolean(v) => serializer.serialize_bool(*v),
257            Amf0Value::String(v) => v.serialize(serializer),
258            Amf0Value::Object(v) => {
259                let mut map = serializer.serialize_map(Some(v.len()))?;
260
261                for (key, value) in v.iter() {
262                    serde::ser::SerializeMap::serialize_entry(&mut map, key, value)?;
263                }
264
265                serde::ser::SerializeMap::end(map)
266            }
267            Amf0Value::Null => serializer.serialize_none(),
268            Amf0Value::Array(v) => {
269                let mut seq = serializer.serialize_seq(Some(v.len()))?;
270
271                for value in v.iter() {
272                    serde::ser::SerializeSeq::serialize_element(&mut seq, value)?;
273                }
274
275                serde::ser::SerializeSeq::end(seq)
276            }
277        }
278    }
279}
280
281#[cfg(test)]
282#[cfg_attr(all(test, coverage_nightly), coverage(off))]
283mod tests {
284    use std::borrow::Cow;
285
286    use scuffle_bytes_util::StringCow;
287
288    use super::Amf0Value;
289    use crate::{Amf0Array, Amf0Decoder, Amf0Encoder, Amf0Error, Amf0Marker, Amf0Object};
290
291    #[test]
292    fn from() {
293        let value: Amf0Value = 1.0.into();
294        assert_eq!(value, Amf0Value::Number(1.0));
295
296        let value: Amf0Value = true.into();
297        assert_eq!(value, Amf0Value::Boolean(true));
298
299        let value: Amf0Value = StringCow::from("abc").into();
300        assert_eq!(value, Amf0Value::String("abc".into()));
301
302        let object: Amf0Object = [("a".into(), Amf0Value::Boolean(true))].into_iter().collect();
303        let value: Amf0Value = object.clone().into();
304        assert_eq!(value, Amf0Value::Object(object));
305
306        let array: Vec<Amf0Value> = vec![Amf0Value::Boolean(true)];
307        let value: Amf0Value = array.clone().into();
308        assert_eq!(value, Amf0Value::Array(Cow::Owned(array)));
309
310        let array: &[Amf0Value] = &[Amf0Value::Boolean(true)];
311        let value: Amf0Value = array.into();
312        assert_eq!(value, Amf0Value::Array(Cow::Borrowed(array)));
313
314        let array: Amf0Array = Cow::Borrowed(&[Amf0Value::Boolean(true)]);
315        let value: Amf0Value = array.clone().into();
316        assert_eq!(value, Amf0Value::Array(array));
317
318        let iter = std::iter::once(Amf0Value::Boolean(true));
319        let value: Amf0Value = iter.collect();
320        assert_eq!(value, Amf0Value::Array(Cow::Owned(vec![Amf0Value::Boolean(true)])));
321    }
322
323    #[test]
324    fn unsupported_marker() {
325        let bytes = [Amf0Marker::MovieClipMarker as u8];
326
327        let err = Amf0Decoder::from_slice(&bytes).decode_value().unwrap_err();
328        assert!(matches!(err, Amf0Error::UnsupportedMarker(Amf0Marker::MovieClipMarker)));
329    }
330
331    #[test]
332    fn string() {
333        use crate::Amf0Decoder;
334
335        #[rustfmt::skip]
336        let bytes = [
337            Amf0Marker::String as u8,
338            0, 3, // length
339            b'a', b'b', b'c',
340        ];
341
342        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
343        assert_eq!(value, Amf0Value::String("abc".into()));
344    }
345
346    #[test]
347    fn bool() {
348        let bytes = [Amf0Marker::Boolean as u8, 0];
349
350        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
351        assert_eq!(value, Amf0Value::Boolean(false));
352    }
353
354    #[test]
355    fn object() {
356        #[rustfmt::skip]
357        let bytes = [
358            Amf0Marker::Object as u8,
359            0, 1,
360            b'a',
361            Amf0Marker::Boolean as u8,
362            1,
363            0, 0, Amf0Marker::ObjectEnd as u8,
364        ];
365
366        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
367        assert_eq!(
368            value,
369            Amf0Value::Object([("a".into(), Amf0Value::Boolean(true))].into_iter().collect())
370        );
371    }
372
373    #[test]
374    fn array() {
375        #[rustfmt::skip]
376        let bytes = [
377            Amf0Marker::StrictArray as u8,
378            0, 0, 0, 1,
379            Amf0Marker::Boolean as u8,
380            1,
381        ];
382
383        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
384        assert_eq!(value, Amf0Value::Array(Cow::Borrowed(&[Amf0Value::Boolean(true)])));
385
386        let mut serialized = vec![];
387        value.encode(&mut Amf0Encoder::new(&mut serialized)).unwrap();
388        assert_eq!(serialized, bytes);
389    }
390
391    #[test]
392    fn null() {
393        let bytes = [Amf0Marker::Null as u8];
394
395        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
396        assert_eq!(value, Amf0Value::Null);
397
398        let mut serialized = vec![];
399        value.encode(&mut Amf0Encoder::new(&mut serialized)).unwrap();
400        assert_eq!(serialized, bytes);
401    }
402
403    #[test]
404    fn into_owned() {
405        let value = Amf0Value::Number(1.0);
406        let owned_value = value.clone().into_owned();
407        assert_eq!(owned_value, value);
408
409        let value = Amf0Value::Boolean(true);
410        let owned_value = value.clone().into_owned();
411        assert_eq!(owned_value, value);
412
413        let value = Amf0Value::String("abc".into());
414        let owned_value = value.clone().into_owned();
415        assert_eq!(owned_value, value);
416
417        let value = Amf0Value::Object([("a".into(), Amf0Value::Boolean(true))].into_iter().collect());
418        let owned_value = value.clone().into_owned();
419        assert_eq!(owned_value, value,);
420
421        let value = Amf0Value::Null;
422        let owned_value = value.clone().into_owned();
423        assert_eq!(owned_value, value);
424
425        let value = Amf0Value::Array(Cow::Borrowed(&[Amf0Value::Boolean(true)]));
426        let owned_value = value.clone().into_owned();
427        assert_eq!(owned_value, value);
428    }
429
430    #[cfg(feature = "serde")]
431    #[test]
432    fn deserialize() {
433        use std::fmt::Display;
434
435        use serde::Deserialize;
436        use serde::de::{IntoDeserializer, MapAccess, SeqAccess};
437
438        #[derive(Debug)]
439        struct TestError;
440
441        impl Display for TestError {
442            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443                write!(f, "Test error")
444            }
445        }
446
447        impl std::error::Error for TestError {}
448
449        impl serde::de::Error for TestError {
450            fn custom<T: std::fmt::Display>(msg: T) -> Self {
451                assert_eq!(msg.to_string(), "invalid type: Option value, expected a byte slice");
452                Self
453            }
454        }
455
456        enum Mode {
457            Bool,
458            I64,
459            U64,
460            F64,
461            Str,
462            String,
463            BorrowedStr,
464            Unit,
465            None,
466            Some,
467            Seq,
468            Map,
469            Bytes,
470            End,
471        }
472
473        struct TestDeserializer {
474            mode: Mode,
475        }
476
477        impl<'de> SeqAccess<'de> for TestDeserializer {
478            type Error = TestError;
479
480            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
481            where
482                T: serde::de::DeserializeSeed<'de>,
483            {
484                match self.mode {
485                    Mode::Seq => {
486                        self.mode = Mode::End;
487                        Ok(Some(seed.deserialize(TestDeserializer { mode: Mode::I64 })?))
488                    }
489                    Mode::End => Ok(None),
490                    _ => Err(TestError),
491                }
492            }
493        }
494
495        impl<'de> MapAccess<'de> for TestDeserializer {
496            type Error = TestError;
497
498            fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
499            where
500                K: serde::de::DeserializeSeed<'de>,
501            {
502                match self.mode {
503                    Mode::Map => Ok(Some(seed.deserialize(TestDeserializer { mode: Mode::Str })?)),
504                    Mode::End => Ok(None),
505                    _ => Err(TestError),
506                }
507            }
508
509            fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
510            where
511                V: serde::de::DeserializeSeed<'de>,
512            {
513                match self.mode {
514                    Mode::Map => {
515                        self.mode = Mode::End;
516                        Ok(seed.deserialize(TestDeserializer { mode: Mode::I64 })?)
517                    }
518                    _ => Err(TestError),
519                }
520            }
521        }
522
523        impl<'de> serde::Deserializer<'de> for TestDeserializer {
524            type Error = TestError;
525
526            serde::forward_to_deserialize_any! {
527                bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf
528                option unit unit_struct newtype_struct seq tuple tuple_struct
529                map struct enum identifier ignored_any
530            }
531
532            fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
533            where
534                V: serde::de::Visitor<'de>,
535            {
536                match self.mode {
537                    Mode::Bool => visitor.visit_bool(true),
538                    Mode::I64 => visitor.visit_i64(1),
539                    Mode::U64 => visitor.visit_u64(1),
540                    Mode::F64 => visitor.visit_f64(1.0),
541                    Mode::Str => visitor.visit_str("hello"),
542                    Mode::String => visitor.visit_string("hello".to_owned()),
543                    Mode::BorrowedStr => visitor.visit_borrowed_str("hello"),
544                    Mode::Unit => visitor.visit_unit(),
545                    Mode::None => visitor.visit_none(),
546                    Mode::Some => visitor.visit_some(1.into_deserializer()),
547                    Mode::Seq => visitor.visit_seq(self),
548                    Mode::Map => visitor.visit_map(self),
549                    Mode::Bytes => visitor.visit_bytes(b"hello"),
550                    Mode::End => unreachable!(),
551                }
552            }
553        }
554
555        fn test_de(mode: Mode, expected: Amf0Value) {
556            let deserializer = TestDeserializer { mode };
557            let deserialized_value: Amf0Value = Amf0Value::deserialize(deserializer).unwrap();
558            assert_eq!(deserialized_value, expected);
559        }
560
561        test_de(Mode::Bool, Amf0Value::Boolean(true));
562        test_de(Mode::I64, Amf0Value::Number(1.0));
563        test_de(Mode::U64, Amf0Value::Number(1.0));
564        test_de(Mode::F64, Amf0Value::Number(1.0));
565        test_de(Mode::Str, Amf0Value::String("hello".into()));
566        test_de(Mode::String, Amf0Value::String("hello".into()));
567        test_de(Mode::BorrowedStr, Amf0Value::String("hello".into()));
568        test_de(Mode::Unit, Amf0Value::Null);
569        test_de(Mode::None, Amf0Value::Null);
570        test_de(Mode::Some, Amf0Value::Number(1.0));
571        test_de(Mode::Seq, Amf0Value::Array(Cow::Owned(vec![Amf0Value::Number(1.0)])));
572        test_de(
573            Mode::Map,
574            Amf0Value::Object([("hello".into(), Amf0Value::Number(1.0))].into_iter().collect()),
575        );
576        test_de(
577            Mode::Bytes,
578            Amf0Value::Array(Cow::Owned(vec![
579                Amf0Value::Number(104.0),
580                Amf0Value::Number(101.0),
581                Amf0Value::Number(108.0),
582                Amf0Value::Number(108.0),
583                Amf0Value::Number(111.0),
584            ])),
585        );
586    }
587}