diff --git a/Cargo.toml b/Cargo.toml index 17a086e..6a3c931 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -num-traits = "0.2" [dev-dependencies.pyo3] version = "0.21.2" diff --git a/src/bit_utils.rs b/src/bit_utils.rs deleted file mode 100644 index a60f806..0000000 --- a/src/bit_utils.rs +++ /dev/null @@ -1,116 +0,0 @@ -use num_traits::{int::PrimInt, One, WrappingAdd}; -use std::ops::{Not, ShrAssign}; - -pub fn int_to_binary_str(mut data: T) -> String { - const BLOCK_WIDTH: usize = 8; - const DELIM: &str = " "; - - let bits: usize = get_bit_width::(); - - let mut result = String::new(); - let _0001: T = T::one(); // bit mask - - for i in 0..bits { - if (i != 0) && (i % BLOCK_WIDTH == 0) { - result += DELIM; - } - let bit_is_1 = (data & _0001) == _0001; - result += if bit_is_1 { "1" } else { "0" }; - data >>= T::one(); - } - - result.chars().rev().collect() -} - -/// -/// Returns the bit count of type `T`. -/// NOTE: The implementation is only realiable for simple primitive types! -/// (see: "std::mem::size_of" for more info) -/// -fn get_bit_width() -> usize { - std::mem::size_of::() * 8 -} - -/// Calculates twos complement. -/// Possible overflows are wrapped, since they are intentional. -pub fn twos_complement(x: T) -> T -where - T: Not + WrappingAdd + One, -{ - (!x).wrapping_add(&T::one()) -} - -#[cfg(test)] -mod tests { - use super::*; - - mod test_twos_complement { - use super::*; - - #[test] - fn positive_value() { - // 0001 0100 → 1110 1100 and vica versa - assert_eq!(twos_complement(20_u8), 236); - assert_eq!(twos_complement(236_u8), 20); - } - - #[test] - fn negative_value() { - assert_eq!(twos_complement(-21_i8), 21); - assert_eq!(twos_complement(21_i8), -21); - assert_eq!(twos_complement(-127_i8), 127); - assert_eq!(twos_complement(127_i8), -127); - } - - #[test] - fn boundries() { - // 1111...1111 → 1000...0000 - assert_eq!(twos_complement(u16::MAX), 1); - assert_eq!(twos_complement(1), u16::MAX); - assert_eq!(twos_complement(-1_i16), 1); - assert_eq!(twos_complement(1), -1_i16); - } - - #[test] - fn corner_case() { - // 1000...0000 → 1000...0000 - assert_eq!(twos_complement(32_768_u16), 32_768); - assert_eq!(twos_complement(128_u8), 128); - - // Special case. Signed int overflows (into MSB) after inverting - assert_eq!(twos_complement(-128_i8), -128); - } - - #[test] - fn zero() { - assert_eq!(twos_complement(0_u8), 0); - assert_eq!(twos_complement(0_i16), 0); - } - } - - mod test_int_to_binary_str { - use super::*; - - #[test] - #[rustfmt::skip] - fn normal_values() { - assert_eq!(int_to_binary_str(0 as i8), "00000000"); - assert_eq!(int_to_binary_str(1 as i16), "00000000 00000001"); - assert_eq!(int_to_binary_str(2 as i32), "00000000 00000000 00000000 00000010"); - assert_eq!(int_to_binary_str(3 as u32), "00000000 00000000 00000000 00000011"); - assert_eq!(int_to_binary_str(4 as u32), "00000000 00000000 00000000 00000100"); - assert_eq!(int_to_binary_str(127 as i32), "00000000 00000000 00000000 01111111"); - assert_eq!(int_to_binary_str(-127 as i32), "11111111 11111111 11111111 10000001"); - assert_eq!(int_to_binary_str(2_147_483_648 as u32), "10000000 00000000 00000000 00000000"); - } - - #[test] - #[rustfmt::skip] - fn boundries() { - assert_eq!(int_to_binary_str(i32::MIN), "10000000 00000000 00000000 00000000"); - assert_eq!(int_to_binary_str(-2 as i32), "11111111 11111111 11111111 11111110"); - assert_eq!(int_to_binary_str(-1 as i32), "11111111 11111111 11111111 11111111"); - assert_eq!(int_to_binary_str(u32::MAX), "11111111 11111111 11111111 11111111"); - } - } -} diff --git a/src/lib.rs b/src/lib.rs index fdf7faf..62a6598 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,8 @@ -pub mod bit_utils; pub mod utils; pub mod mp_int { use crate::utils::{ - add_with_carry, dec_to_bit_width, div_with_rem, parse_to_digits, ParseError, TrimInPlace, + add_with_carry, dec_to_bit_width, div_with_rem, parse_to_digits, ParseError, }; use std::{ cmp::Ordering, @@ -308,10 +307,11 @@ pub mod mp_int { result.push(self.sign.into()); } - for d in self.iter().rev() { - result += &crate::bit_utils::int_to_binary_str(*d); - result += " "; - } + const BIN_WIDTH: usize = DIGIT_BITS as usize; + result = self + .iter() + .rev() + .fold(result, |acc, d| acc + &format!("{:0width$b}", d, width = BIN_WIDTH)); result } @@ -329,16 +329,15 @@ pub mod mp_int { } pub fn to_hex_string(&self) -> String { - const X_WIDTH: usize = (DIGIT_BITS / 4) as usize; let mut hex: String = String::new(); if self.is_negative() { hex.push(self.sign.into()); } + const HEX_WIDTH: usize = (DIGIT_BITS / 4) as usize; hex = self .iter() .rev() - .fold(hex, |acc, d| acc + &format!("{:0width$X} ", d, width = X_WIDTH)); - hex.trim_end_in_place(); + .fold(hex, |acc, d| acc + &format!("{:0width$X}", d, width = HEX_WIDTH)); hex } @@ -429,6 +428,14 @@ pub mod mp_int { } } + impl Sub for MPint { + type Output = Self; + fn sub(mut self, rhs: Self) -> Self::Output { + self -= rhs; + self + } + } + impl Sub for &MPint { type Output = MPint; fn sub(self, rhs: Self) -> Self::Output { @@ -764,7 +771,6 @@ pub mod mp_int { use pyo3::types::PyList; use super::*; - use crate::utils::Op; const D_MAX: DigitT = DigitT::MAX; @@ -773,7 +779,7 @@ pub mod mp_int { /// # Rules /// - `create_op_correctness_tester($fn_name, $op)` /// - `$fn_name` - The created function's name. - /// - `$op` - An operator token (e.g. `+`). Must have a corresponding `utils::Op`. + /// - `$op` - An operator token. Currently supports: `+`, `-`, `*`. /// # Examples /// ```rust /// create_op_correctness_tester!(test_addition_correctness, +); @@ -783,7 +789,7 @@ pub mod mp_int { fn $fn_name(a: MPint, b: MPint) { let result = &a $op &b; let test_result = verify_arithmetic_result( - &a, stringify!($op).try_into().unwrap(), &b, &result); + &a, stringify!($op), &b, &result); println!("{:?}", test_result); assert!(test_result.0, "{}", test_result.1); } @@ -888,9 +894,9 @@ pub mod mp_int { { let a = mpint![0, D_MAX, 2, 3]; let expected = concat!( - "0000000000000003 ", - "0000000000000002 ", - "FFFFFFFFFFFFFFFF ", + "0000000000000003", + "0000000000000002", + "FFFFFFFFFFFFFFFF", "0000000000000000" ); assert_eq!(a.to_hex_string(), expected); @@ -898,13 +904,13 @@ pub mod mp_int { { let a = mpint![42, 1 << 13, (1 as DigitT).rotate_right(1)]; let expected = - concat!("8000000000000000 ", "0000000000002000 ", "000000000000002A",); + concat!("8000000000000000", "0000000000002000", "000000000000002A",); assert_eq!(a.to_hex_string(), expected); } { let a = mpint![D_MAX, D_MAX, D_MAX]; let expected = - concat!("FFFFFFFFFFFFFFFF ", "FFFFFFFFFFFFFFFF ", "FFFFFFFFFFFFFFFF",); + concat!("FFFFFFFFFFFFFFFF", "FFFFFFFFFFFFFFFF", "FFFFFFFFFFFFFFFF",); assert_eq!(a.to_hex_string(), expected); } } @@ -915,9 +921,9 @@ pub mod mp_int { let a = -mpint![0, D_MAX, 2, 3]; let expected = concat!( "-", - "0000000000000003 ", - "0000000000000002 ", - "FFFFFFFFFFFFFFFF ", + "0000000000000003", + "0000000000000002", + "FFFFFFFFFFFFFFFF", "0000000000000000" ); assert_eq!(a.to_hex_string(), expected); @@ -925,13 +931,13 @@ pub mod mp_int { { let a = -mpint![42, 1 << 13, (1 as DigitT).rotate_right(1)]; let expected = - concat!("-", "8000000000000000 ", "0000000000002000 ", "000000000000002A",); + concat!("-", "8000000000000000", "0000000000002000", "000000000000002A",); assert_eq!(a.to_hex_string(), expected); } { let a = -mpint![D_MAX, D_MAX, D_MAX]; let expected = - concat!("-", "FFFFFFFFFFFFFFFF ", "FFFFFFFFFFFFFFFF ", "FFFFFFFFFFFFFFFF",); + concat!("-", "FFFFFFFFFFFFFFFF", "FFFFFFFFFFFFFFFF", "FFFFFFFFFFFFFFFF",); assert_eq!(a.to_hex_string(), expected); } } @@ -1316,7 +1322,7 @@ pub mod mp_int { /// - `res_to_verify` - The result to verify against python's calculations. fn verify_arithmetic_result( lhs: &MPint, - op: Op, + op: &str, rhs: &MPint, res_to_verify: &MPint, ) -> (bool, String) { @@ -1340,7 +1346,7 @@ pub mod mp_int { let fn_name = "test_operation_result"; let args: (String, &str, String, String, i32) = ( lhs.to_hex_string(), - op.into(), + op, rhs.to_hex_string(), res_to_verify.to_hex_string(), 16, //base of the number strings diff --git a/src/utils.rs b/src/utils.rs index a1aa7f6..d1267a0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,37 +3,6 @@ use std::f64::consts::{LOG10_2, LOG2_10}; use std::fmt::Display; use std::ops::{Div, Rem}; -pub enum Op { - PLUS, - MINUS, - MULT, - DIV, -} - -impl From for &'static str { - fn from(value: Op) -> Self { - match value { - Op::PLUS => "+", - Op::MINUS => "-", - Op::MULT => "*", - Op::DIV => "/", - } - } -} - -impl TryFrom<&'static str> for Op { - type Error = ParseError; - fn try_from(value: &'static str) -> Result { - match value { - "+" => Ok(Op::PLUS), - "-" => Ok(Op::MINUS), - "*" => Ok(Op::MULT), - "/" => Ok(Op::DIV), - _ => Err("unknown operator".into()), - } - } -} - /// Basically a full adder for `u64` /// /// # Explanation @@ -126,16 +95,6 @@ fn digit_char_to_value(ch: u8) -> Option { } } -pub trait TrimInPlace { - fn trim_end_in_place(&mut self); -} -impl TrimInPlace for String { - fn trim_end_in_place(&mut self) { - let new_end = self.trim_end().len(); - self.truncate(new_end); - } -} - #[derive(Debug, Clone, PartialEq)] pub struct ParseError { msg: &'static str,