Skip to content

Commit

Permalink
Fix Bytes serde support + add tests (#238)
Browse files Browse the repository at this point in the history
* Fix Bytes serde support + add test

* Don't blow up when a deserializer returns a Seq

* Cover byte/bytebuf cases

---------

Co-authored-by: s4h <s4h@ditto.live>
  • Loading branch information
S4H and s4h authored Sep 4, 2024
1 parent 2ad21c3 commit aaa4014
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(docs)'] }
safer-ffi.path = "."
safer-ffi.features = ["internal-tests"]
rand = "0.8.5"
serde_test = { version = "1.0.177" }

[dependencies]
async-compat.optional = true
Expand Down Expand Up @@ -111,7 +112,7 @@ paste.version = "1.0.12"
scopeguard.version = "1.1.0"
scopeguard.default-features = false

serde.version = "1.0.204"
serde.version = "1.0.171"
serde.optional = true
serde.default-features = false

Expand Down
87 changes: 85 additions & 2 deletions src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,91 @@ impl<'a> serde::Serialize for Bytes<'a> {
}

#[cfg(feature = "serde")]
impl<'a, 'de: 'a> serde::Deserialize<'de> for Bytes<'a> {
struct BytesVisitor;

#[cfg(feature = "serde")]
impl<'de> serde::de::Visitor<'de> for BytesVisitor {
type Value = Bytes<'de>;

fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
formatter.write_str("a byte array")
}

fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E> {
Ok(Bytes::from_slice(v))
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> {
Ok(Bytes::from_slice(v).upgrade())
}

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E> {
Ok(Bytes::from(v))
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut buf = Vec::with_capacity(seq.size_hint().unwrap_or(64));

while let Some(c) = seq.next_element::<u8>()? {
buf.push(c);
}

Ok(Bytes::from(buf))
}
}

#[cfg(feature = "serde")]
impl<'de: 'a, 'a> serde::Deserialize<'de> for Bytes<'a> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
serde::Deserialize::deserialize(deserializer).map(|x: &[u8]| Bytes::from(x))
deserializer.deserialize_byte_buf(BytesVisitor)
}
}

#[cfg(all(feature = "serde", test))]
mod tests {
use serde_test::{assert_de_tokens, assert_tokens, Token};

use super::*;

#[test]
fn serde() {
let bytes: Bytes<'static> = Bytes::from_static(b"Hello there");

assert_tokens(&bytes, &[Token::BorrowedBytes(b"Hello there")]);

let data = b"Hello there";
let bytes: Bytes<'_> = Bytes::from(data);

assert_tokens(&bytes, &[Token::BorrowedBytes(b"Hello there")]);

// deserialize from a sequence (like we get with serde_cbor)
assert_de_tokens(
&Bytes::from(&[0, 1, 2]),
&[
Token::Seq { len: Some(3) },
Token::U8(0),
Token::U8(1),
Token::U8(2),
Token::SeqEnd,
],
);

assert_de_tokens(
&Bytes::from(&[0, 1, 2]),
&[
Token::Seq { len: None },
Token::U8(0),
Token::U8(1),
Token::U8(2),
Token::SeqEnd,
],
);

assert_de_tokens(&Bytes::from(&[0, 1, 2]), &[Token::Bytes(&[0, 1, 2])]);

assert_de_tokens(&Bytes::from(&[0, 1, 2]), &[Token::ByteBuf(&[0, 1, 2])]);
}
}

0 comments on commit aaa4014

Please sign in to comment.