Skip to content

Commit

Permalink
Implement OsStr::slice_encoded_bytes() proof of concept
Browse files Browse the repository at this point in the history
  • Loading branch information
blyxxyz committed Nov 26, 2023
1 parent 2ed9095 commit 4c9ebb6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 38 deletions.
59 changes: 21 additions & 38 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
//! - If we don't know what to do with an argument we use [`return Err(arg.unexpected())`][Arg::unexpected] to turn it into an error message.
//! - Strings can be promoted to errors for custom error messages.
#![feature(slice_range)]
#![deny(unsafe_code)]
#![warn(missing_docs, missing_debug_implementations, elided_lifetimes_in_paths)]
#![allow(clippy::should_implement_trait)]

Expand All @@ -84,6 +86,10 @@ use std::{
str::{FromStr, Utf8Error},
};

mod os_str_slice;

use os_str_slice::OsStrSlice;

type InnerIter = std::vec::IntoIter<OsString>;

fn make_iter(iter: impl Iterator<Item = OsString>) -> InnerIter {
Expand All @@ -109,10 +115,9 @@ enum State {
PendingValue(OsString),
/// We're in the middle of -abc.
///
/// In order to satisfy OsString::from_encoded_bytes_unchecked() we make
/// sure that the usize always point to the end of a valid UTF-8 substring.
/// This is a safety invariant!
Shorts(Vec<u8>, usize),
/// In order to satisfy OsStr::slice_encoded_bytes() we make sure that the
/// usize always point to the end of a valid UTF-8 substring.
Shorts(OsString, usize),
/// We saw -- and know no more options are coming.
FinishedOpts,
}
Expand Down Expand Up @@ -170,7 +175,7 @@ impl Parser {
State::Shorts(ref arg, ref mut pos) => {
// We're somewhere inside a -abc chain. Because we're in .next(),
// not .value(), we can assume that the next character is another option.
match first_codepoint(&arg[*pos..]) {
match first_codepoint(&arg.as_encoded_bytes()[*pos..]) {
Ok(None) => {
self.state = State::None;
}
Expand All @@ -186,14 +191,13 @@ impl Parser {
});
}
Ok(Some(ch)) => {
// SAFETY: pos still points to the end of a valid UTF-8 codepoint.
*pos += ch.len_utf8();
self.last_option = LastOption::Short(ch);
return Ok(Some(Arg::Short(ch)));
}
Err(_) => {
// Skip the rest of the argument. This makes it easy to maintain the
// OsString invariants, and the caller is almost certainly going to
// OsString invariant, and the caller is almost certainly going to
// abort anyway.
self.state = State::None;
self.last_option = LastOption::Short('�');
Expand All @@ -212,7 +216,7 @@ impl Parser {
ref state => panic!("unexpected state {:?}", state),
}

let arg = match self.source.next() {
let mut arg = match self.source.next() {
Some(arg) => arg,
None => return Ok(None),
};
Expand All @@ -222,33 +226,19 @@ impl Parser {
return self.next();
}

if arg.as_encoded_bytes().starts_with(b"--") {
let mut arg = arg.into_encoded_bytes();

let arg_bytes = arg.as_encoded_bytes();
if arg_bytes.starts_with(b"--") {
// Long options have two forms: --option and --option=value.
if let Some(ind) = arg.iter().position(|&b| b == b'=') {
if let Some(ind) = arg_bytes.iter().position(|&b| b == b'=') {
// The value can be an OsString...
let value = arg[ind + 1..].to_vec();

// SAFETY: this substring comes immediately after a valid UTF-8 sequence
// (i.e. the equals sign), and it originates from bytes we obtained from
// an OsString just now.
let value = unsafe { OsString::from_encoded_bytes_unchecked(value) };
let value = arg.slice_encoded_bytes(ind + 1..).to_owned();

self.state = State::PendingValue(value);
arg.truncate(ind);
arg = arg.slice_encoded_bytes(..ind).to_owned();
}

// ...but the option has to be a string.

// Transform arg back into an OsString so we can use the platform-specific
// to_string_lossy() implementation.
// (In particular: String::from_utf8_lossy() turns a WTF-8 lone surrogate
// into three replacement characters instead of one.)
// SAFETY: arg is either an unmodified OsString or one we truncated
// right before a valid UTF-8 sequence ("=").
let arg = unsafe { OsString::from_encoded_bytes_unchecked(arg) };

// Calling arg.to_string_lossy().into_owned() would work, but because
// the return type is Cow this would perform an unnecessary copy in
// the common case where arg is already UTF-8.
Expand All @@ -259,9 +249,7 @@ impl Parser {
Err(arg) => arg.to_string_lossy().into_owned(),
};
Ok(Some(self.set_long(option)))
} else if arg.as_encoded_bytes().len() > 1 && arg.as_encoded_bytes()[0] == b'-' {
let arg = arg.into_encoded_bytes();
// SAFETY: 1 points at the end of the dash.
} else if arg_bytes.len() > 1 && arg_bytes[0] == b'-' {
self.state = State::Shorts(arg, 1);
self.next()
} else {
Expand Down Expand Up @@ -528,24 +516,19 @@ impl Parser {
fn raw_optional_value(&mut self) -> Option<(OsString, bool)> {
match replace(&mut self.state, State::None) {
State::PendingValue(value) => Some((value, true)),
State::Shorts(mut arg, mut pos) => {
State::Shorts(arg, mut pos) => {
if pos >= arg.len() {
return None;
}
let mut had_eq_sign = false;
if arg[pos] == b'=' {
if arg.as_encoded_bytes()[pos] == b'=' {
// -o=value.
// clap actually strips out all leading '='s, but that seems silly.
// We allow `-xo=value`. Python's argparse doesn't strip the = in that case.
// SAFETY: pos now points to the end of the '='.
pos += 1;
had_eq_sign = true;
}
arg.drain(..pos); // Reuse allocation

// SAFETY: arg originates from an OsString. We ensure that pos always
// points to a valid UTF-8 boundary.
let value = unsafe { OsString::from_encoded_bytes_unchecked(arg) };
let value = arg.slice_encoded_bytes(pos..).to_owned();
Some((value, had_eq_sign))
}
State::FinishedOpts => {
Expand Down
77 changes: 77 additions & 0 deletions src/os_str_slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#![allow(unsafe_code)]
use std::ffi::OsStr;
use std::ops::RangeBounds;

pub(crate) trait OsStrSlice {
/// Takes a substring based on a range that corresponds to the return value of
/// [`OsStr::as_encoded_bytes`].
///
/// The range's start and end must lie on valid `OsStr` boundaries, meaning one of:
/// - The start of the string
/// - The end of the string
/// - Immediately before a valid non-empty UTF-8 substring
/// - Immediately after a valid non-empty UTF-8 substring
///
/// This requirement holds even on platforms where the underlying encoding is more
/// permissive.
///
/// # Panics
///
/// Panics if the range does not lie on valid `OsStr` boundaries.
///
/// # Example
///
/// ```ignore
/// use std::ffi::OsStr;
///
/// let os_str = OsStr::new("foo=bar");
/// let bytes = os_str.as_encoded_bytes();
/// if let Some(index) = bytes.iter().position(|b| *b == b'=') {
/// let key = os_str.slice_encoded_bytes(..index);
/// let value = os_str.slice_encoded_bytes(index + 1..);
/// assert_eq!(key, "foo");
/// assert_eq!(value, "bar");
/// }
/// ```
fn slice_encoded_bytes<R: RangeBounds<usize>>(&self, range: R) -> &Self;
}

impl OsStrSlice for OsStr {
fn slice_encoded_bytes<R: RangeBounds<usize>>(&self, range: R) -> &Self {
fn is_valid_boundary(bytes: &[u8], index: usize) -> bool {
if index == 0 || index == bytes.len() {
return true;
}

let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match std::str::from_utf8(after) {
Ok(_) => return true,
Err(err) if err.valid_up_to() != 0 => return true,
Err(_) => (),
}

for len in 1..=4.min(index) {
let before = &before[index - len..];
if std::str::from_utf8(before).is_ok() {
return true;
}
}

false
}

let bytes = self.as_encoded_bytes();
let range = std::slice::range(range, ..bytes.len());
assert!(is_valid_boundary(bytes, range.start));
assert!(is_valid_boundary(bytes, range.end));

// SAFETY: bytes was obtained from an OsStr just now, and we validated
// that we only slice immediately before or after a valid non-empty
// UTF-8 substring.
unsafe { Self::from_encoded_bytes_unchecked(&bytes[range]) }
}
}

0 comments on commit 4c9ebb6

Please sign in to comment.