Skip to content

Commit

Permalink
fix maturin develop and test
Browse files Browse the repository at this point in the history
  • Loading branch information
bertiqwerty committed Feb 15, 2024
1 parent de3f6e1 commit 0e1fc13
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
1 change: 0 additions & 1 deletion rormula/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ homepage = "https://github.com/basf/rormula"

[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "rormula._rormula"
[tool.ruff]
ignore = ["E731"]
10 changes: 5 additions & 5 deletions rormula/rormula/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, NamedTuple, Tuple, Union
import numpy as np
import pandas as pd
from rormula import _rormula as ror
from .rormula import parse_wilkinson, eval_wilkinson, parse_arithmetic, eval_arithmetic


class SeparatedData(NamedTuple):
Expand Down Expand Up @@ -32,7 +32,7 @@ def separate_num_cat(

class Wilkinson:
def __init__(self, formula: str):
self.ror = ror.parse_wilkinson(formula)
self.ror = parse_wilkinson(formula)

def eval(
self, data: Union[pd.DataFrame, SeparatedData], skip_names: bool = False
Expand All @@ -47,7 +47,7 @@ def eval(
categorical_data,
) = separate_num_cat(data)

names, resulting_data = ror.eval_wilkinson(
names, resulting_data = eval_wilkinson(
self.ror,
numerical_data,
numerical_cols,
Expand All @@ -68,14 +68,14 @@ def eval_asdf(

class Arithmetic:
def __init__(self, formula: str, name: str):
self.ror = ror.parse_arithmetic(formula)
self.ror = parse_arithmetic(formula)
self.name = name

def eval(self, data: pd.DataFrame) -> np.ndarray:
numerical_cols = data.columns.to_list()
numerical_data = data.to_numpy()

resulting_data = ror.eval_arithmetic(
resulting_data = eval_arithmetic(
self.ror,
numerical_data,
numerical_cols,
Expand Down
2 changes: 1 addition & 1 deletion rormula/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ fn parse_wilkinson(s: &str) -> PyResult<Wilkinson> {
}

#[pymodule]
fn _rormula(_py: Python, m: &PyModule) -> PyResult<()> {
fn rormula(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_wilkinson, m)?)?;
m.add_function(wrap_pyfunction!(eval_wilkinson, m)?)?;
m.add_function(wrap_pyfunction!(parse_arithmetic, m)?)?;
Expand Down
8 changes: 5 additions & 3 deletions rormula/test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,18 @@ def eval_asdf():
res = rormula.eval_asdf(df)
assert np.allclose(res.to_numpy().item(), (5.0 - 2.5) / 4.0)

def test_scalar_scalar():
name = "test_scalar"
data = np.random.random((100, 6)) * 1000
df = pd.DataFrame(
data=data, columns=["alpha", "beta", "gamma", "delta", "epsilon", "phi"]
)
s = "5/3 * alpha / beta * (0.2 / 200.0 / (29.22+gamma+epsilon+phi) / 1000)"
rormula = Arithmetic(s, "testslash")
rormula = Arithmetic(s, name)
res = rormula.eval_asdf(df)
ref = df.eval(s)
np.allclose(res.values, ref.values)
np.allclose(res[name].to_numpy(), ref.values)


if __name__ == "__main__":
test_arithmetic()
test_scalar_scalar()

0 comments on commit 0e1fc13

Please sign in to comment.