Skip to content

Commit

Permalink
Move visitors to a shared module
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Aug 6, 2023
1 parent d7cccf4 commit 7d02cda
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 202 deletions.
15 changes: 6 additions & 9 deletions serdect/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Serialization primitives for arrays.

// Unfortunately, we currently cannot assert generically that we are serializing
// Unfortunately, we currently cannot tell `serde` in a uniform fashion that we are serializing
// a fixed-size byte array.
// See https://github.com/serde-rs/serde/issues/2120 for the discussion.
// Therefore we have to fall back to the slice methods,
Expand All @@ -13,7 +13,7 @@ use core::marker::PhantomData;

use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::slice;
use crate::common::{self, ExactLength, SliceVisitor, StrIntoBufVisitor};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand All @@ -25,7 +25,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
slice::serialize_hex_lower_or_bin(value, serializer)
common::serialize_hex_lower_or_bin(value, serializer)
}

/// Serialize the given type as upper case hex when using human-readable
Expand All @@ -35,7 +35,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
slice::serialize_hex_upper_or_bin(value, serializer)
common::serialize_hex_upper_or_bin(value, serializer)
}

/// Deserialize from hex when using human-readable formats or binary if the
Expand All @@ -46,12 +46,9 @@ where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(slice::StrVisitor::<slice::ExactLength>(buffer, PhantomData))
deserializer.deserialize_str(StrIntoBufVisitor::<ExactLength>(buffer, PhantomData))
} else {
deserializer.deserialize_byte_buf(slice::SliceVisitor::<slice::ExactLength>(
buffer,
PhantomData,
))
deserializer.deserialize_byte_buf(SliceVisitor::<ExactLength>(buffer, PhantomData))
}
}

Expand Down
30 changes: 1 addition & 29 deletions serdect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,35 +131,7 @@
extern crate alloc;

pub mod array;
mod common;
pub mod slice;

pub use serde;

use serde::Serializer;

#[cfg(not(feature = "alloc"))]
use serde::ser::Error;

#[cfg(feature = "alloc")]
use serde::Serialize;

fn serialize_hex<S, T, const UPPERCASE: bool>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: AsRef<[u8]>,
{
#[cfg(feature = "alloc")]
if UPPERCASE {
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
} else {
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
}
#[cfg(not(feature = "alloc"))]
{
let _ = value;
let _ = serializer;
return Err(S::Error::custom(
"serializer is human readable, which requires the `alloc` crate feature",
));
}
}
175 changes: 11 additions & 164 deletions serdect/src/slice.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
//! Serialization primitives for slices.

use core::fmt;
use core::marker::PhantomData;

use serde::de::{Error, Visitor};
use serde::{Deserializer, Serializer};

use crate::common::{self, SliceVisitor, StrIntoBufVisitor, UpperBound};

#[cfg(feature = "alloc")]
use serde::Serialize;
use ::{
alloc::vec::Vec,
serde::{Deserialize, Serialize},
};

#[cfg(feature = "alloc")]
use ::{alloc::vec::Vec, serde::Deserialize};
use crate::common::{StrIntoVecVisitor, VecVisitor};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand All @@ -22,11 +25,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
if serializer.is_human_readable() {
crate::serialize_hex::<_, _, false>(value, serializer)
} else {
serializer.serialize_bytes(value.as_ref())
}
common::serialize_hex_lower_or_bin(value, serializer)
}

/// Serialize the given type as upper case hex when using human-readable
Expand All @@ -36,118 +35,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
if serializer.is_human_readable() {
crate::serialize_hex::<_, _, true>(value, serializer)
} else {
serializer.serialize_bytes(value.as_ref())
}
}

pub(crate) trait LengthCheck {
fn length_check(buffer_length: usize, data_length: usize) -> bool;
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result;
}

pub(crate) struct ExactLength;

impl LengthCheck for ExactLength {
fn length_check(buffer_length: usize, data_length: usize) -> bool {
buffer_length == data_length
}
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result {
write!(formatter, "{} of length {}", data_type, data_length)
}
}

struct UpperBound;

impl LengthCheck for UpperBound {
fn length_check(buffer_length: usize, data_length: usize) -> bool {
buffer_length >= data_length
}
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result {
write!(
formatter,
"{} with a maximum length of {}",
data_type, data_length
)
}
}

pub(crate) struct StrVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);

impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrVisitor<'b, T> {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
T::expecting(formatter, "a string", self.0.len() * 2)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if !T::length_check(self.0.len() * 2, v.len()) {
return Err(Error::invalid_length(v.len(), &self));
}
// TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`.
base16ct::mixed::decode(v, self.0)
.map(|_| ())
.map_err(E::custom)
}
}

pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);

impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
T::expecting(formatter, "an array", self.0.len())
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
// Workaround for
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
if T::length_check(self.0.len(), v.len()) {
let buffer = &mut self.0[..v.len()];
buffer.copy_from_slice(v);
return Ok(());
}

Err(E::invalid_length(v.len(), &self))
}

#[cfg(feature = "alloc")]
fn visit_byte_buf<E>(self, mut v: Vec<u8>) -> Result<Self::Value, E>
where
E: Error,
{
// Workaround for
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
if T::length_check(self.0.len(), v.len()) {
let buffer = &mut self.0[..v.len()];
buffer.swap_with_slice(&mut v);
return Ok(());
}

Err(E::invalid_length(v.len(), &self))
}
common::serialize_hex_upper_or_bin(value, serializer)
}

/// Deserialize from hex when using human-readable formats or binary if the
Expand All @@ -158,7 +46,7 @@ where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(StrVisitor::<UpperBound>(buffer, PhantomData))
deserializer.deserialize_str(StrIntoBufVisitor::<UpperBound>(buffer, PhantomData))
} else {
deserializer.deserialize_byte_buf(SliceVisitor::<UpperBound>(buffer, PhantomData))
}
Expand All @@ -172,49 +60,8 @@ where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
struct StrVisitor;

impl<'de> Visitor<'de> for StrVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a string")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
base16ct::mixed::decode_vec(v).map_err(E::custom)
}
}

deserializer.deserialize_str(StrVisitor)
deserializer.deserialize_str(StrIntoVecVisitor)
} else {
struct VecVisitor;

impl<'de> Visitor<'de> for VecVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a bytestring")
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
Ok(v.into())
}

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: Error,
{
Ok(v)
}
}

deserializer.deserialize_byte_buf(VecVisitor)
}
}
Expand Down

0 comments on commit 7d02cda

Please sign in to comment.