Skip to content

Commit

Permalink
Modifies cells to make function handles pickleable (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji authored Oct 4, 2024
1 parent dfef174 commit 072f954
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 55 deletions.
58 changes: 47 additions & 11 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,52 @@ def check_indices(problem: ProblemDefinition) -> None:
)


class FileHandles:
"""Class to defer creation of custom file handles.
Parameters
----------
files : ClassList[CustomFile]
A list of custom file models.
"""

def __init__(self, files):
self.index = 0
self.files = [*files]

def __iter__(self):
self.index = 0
return self

def get_handle(self, index):
"""Returns file handle for a given custom file.
Parameters
----------
index : int
The index of the custom file.
"""
custom_file = self.files[index]
full_path = os.path.join(custom_file.path, custom_file.filename)
if custom_file.language == Languages.Python:
file_handle = get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path)
elif custom_file.language == Languages.Matlab:
file_handle = RATapi.wrappers.MatlabWrapper(full_path).getHandle()
elif custom_file.language == Languages.Cpp:
file_handle = RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle()

return file_handle

def __next__(self):
if self.index < len(self.files):
custom_file = self.get_handle(self.index)
self.index += 1
return custom_file
else:
raise StopIteration


def make_cells(project: RATapi.Project) -> Cells:
"""Constructs the cells input required for the compiled RAT code.
Expand Down Expand Up @@ -344,16 +390,6 @@ def make_cells(project: RATapi.Project) -> Cells:
else:
simulation_limits.append([0.0, 0.0])

file_handles = []
for custom_file in project.custom_files:
full_path = os.path.join(custom_file.path, custom_file.filename)
if custom_file.language == Languages.Python:
file_handles.append(get_python_handle(custom_file.filename, custom_file.function_name, custom_file.path))
elif custom_file.language == Languages.Matlab:
file_handles.append(RATapi.wrappers.MatlabWrapper(full_path).getHandle())
elif custom_file.language == Languages.Cpp:
file_handles.append(RATapi.wrappers.DylibWrapper(full_path, custom_file.function_name).getHandle())

# Populate the set of cells
cells = Cells()
cells.f1 = [[0, 1]] * len(project.contrasts) # This is marked as "to do" in RAT
Expand All @@ -369,7 +405,7 @@ def make_cells(project: RATapi.Project) -> Cells:
cells.f11 = [param.name for param in project.bulk_in]
cells.f12 = [param.name for param in project.bulk_out]
cells.f13 = [param.name for param in project.resolution_parameters]
cells.f14 = file_handles
cells.f14 = FileHandles(project.custom_files)
cells.f15 = [param.type for param in project.backgrounds]
cells.f16 = [param.type for param in project.resolutions]

Expand Down
11 changes: 6 additions & 5 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct Cells {
py::list f11;
py::list f12;
py::list f13;
py::list f14;
py::object f14;
py::list f15;
py::list f16;
py::list f17;
Expand Down Expand Up @@ -844,12 +844,13 @@ coder::array<RAT::cell_wrap_6, 2U> pyListToRatCellWrap6(py::list values)
return result;
}

coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::list values)
coder::array<RAT::cell_wrap_6, 2U> py_function_array_to_rat_cell_wrap_6(py::object values)
{
auto handles = py::cast<py::list>(values);
coder::array<RAT::cell_wrap_6, 2U> result;
result.set_size(1, values.size());
result.set_size(1, handles.size());
int32_T idx {0};
for (py::handle array: values)
for (py::handle array: handles)
{
auto func = py::cast<py::function>(array);
std::string func_ptr = convertPtr2String<CallbackInterface>(new Library(func));
Expand Down Expand Up @@ -1585,7 +1586,7 @@ PYBIND11_MODULE(rat_core, m) {
cell.f11 = t[10].cast<py::list>();
cell.f12 = t[11].cast<py::list>();
cell.f13 = t[12].cast<py::list>();
cell.f14 = t[13].cast<py::list>();
cell.f14 = t[13].cast<py::object>();
cell.f15 = t[14].cast<py::list>();
cell.f16 = t[15].cast<py::list>();
cell.f17 = t[16].cast<py::list>();
Expand Down
60 changes: 21 additions & 39 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,25 +624,7 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr
"domainRatio",
]

mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine

with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())

problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls())
problem = pickle.loads(pickle.dumps(problem))
check_problem_equal(problem, test_problem)
cells = pickle.loads(pickle.dumps(cells))
Expand Down Expand Up @@ -768,25 +750,7 @@ def test_make_cells(test_project, test_cells, request) -> None:
"""The cells object should be populated according to the input project object."""
test_project = request.getfixturevalue(test_project)
test_cells = request.getfixturevalue(test_cells)

mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine
with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
cells = make_cells(test_project)

cells = make_cells(test_project)
check_cells_equal(cells, test_cells)


Expand Down Expand Up @@ -865,7 +829,25 @@ def check_cells_equal(actual_cells, expected_cells) -> None:
"NaN" if np.isnan(el) else el for entry in actual_cells.f6 for el in entry
] == ["NaN" if np.isnan(el) else el for entry in expected_cells.f6 for el in entry]

for index in chain(range(3, 6), range(7, 21)):
mocked_matlab_future = mock.MagicMock()
mocked_engine = mock.MagicMock()
mocked_matlab_future.result.return_value = mocked_engine
with mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"loader",
mocked_matlab_future,
), mock.patch.object(RATapi.rat_core, "DylibEngine", mock.MagicMock()), mock.patch.object(
RATapi.inputs,
"get_python_handle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(
RATapi.wrappers.MatlabWrapper,
"getHandle",
mock.MagicMock(return_value=dummy_function),
), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)):
assert list(actual_cells.f14) == expected_cells.f14

for index in chain(range(3, 6), range(7, 14), range(15, 21)):
field = f"f{index}"
assert getattr(actual_cells, field) == getattr(expected_cells, field)

Expand Down

0 comments on commit 072f954

Please sign in to comment.