Skip to content

Commit

Permalink
Python bindings for AMReX-MPMD (#271)
Browse files Browse the repository at this point in the history
* First changes related to MPMD feature in pyAMReX

* Two MultiFab tests with MPMD::Copier::send from both sides.

* Importing mpi4py not needed when amrex::MPMD is leveraged.

* First working conversion of mpi4py comm to C MPI_Comm

* MPMD now depends on mpi4py for python scripts.

* Input file for lammps_tests/bench/ included

* Create README.md in tests/test_MPMD

* A two-way data transfer example is included.

* First torch.distributed CPU example on a single node.

* Update README.md of test_MPMD

* Second update of README.md in test_MPMD

* Including an environment.yml

* Removal of test_MPMD subfolder in tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update .gitignore

Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja>

* MPI4Py Comm Wrapper

Simplify Cross-Compiles. Using same solution as in openPMD-api to
test and convert MPI Communicators.

* Import `mpi4py`

* test_1 added in tests/test_MPMD/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Removed unused variable bx in tests/test_MPMD/test_1/main.py

* mpi4py.rc settings in tests were changing thread support level from 3 to 0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Bhargav Siddani <bsiddani@gigan.lbl.gov>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Axel Huebl <axel.huebl@plasma.ninja>
  • Loading branch information
4 people authored Apr 30, 2024
1 parent c2c1224 commit d75140f
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ foreach(D IN LISTS AMReX_SPACEDIM)
Utility.cpp
Vector.cpp
Version.cpp
MPMD.cpp
)
endforeach()
180 changes: 180 additions & 0 deletions src/Base/MPMD.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/* Copyright 2021-2022 The AMReX Community
*
* Authors: Axel Huebl
* License: BSD-3-Clause-LBNL
*/
#include "pyAMReX.H"
#include <AMReX_BoxArray.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_FArrayBox.H>
#include <AMReX_FabArray.H>
#include <AMReX_MPMD.H>
#include <AMReX_ParallelDescriptor.H>

#ifdef AMREX_USE_MPI
#include <mpi.h>

/** mpi4py communicator wrapper
*
* refs:
* - https://github.com/mpi4py/mpi4py/blob/3.0.0/src/mpi4py/libmpi.pxd#L35-L36
* - https://github.com/mpi4py/mpi4py/blob/3.0.0/src/mpi4py/MPI.pxd#L100-L105
* - installed: include/mpi4py/mpi4py.MPI.h
*/
struct pyAMReX_PyMPICommObject
{
PyObject_HEAD MPI_Comm ob_mpi;
unsigned int flags;
};
using pyAMReX_PyMPIIntracommObject = pyAMReX_PyMPICommObject;


void init_MPMD(py::module &m) {
using namespace amrex;

// Several functions here are copied from AMReX.cpp
m.def("MPMD_Initialize_without_split",
[](const py::list args) {
Vector<std::string> cargs{"amrex"};
Vector<char*> argv;

// Populate the "command line"
for (const auto& v: args)
cargs.push_back(v.cast<std::string>());
for (auto& v: cargs)
argv.push_back(&v[0]);
int argc = argv.size();

// note: +1 since there is an extra char-string array element,
// that ANSII C requires to be a simple NULL entry
// https://stackoverflow.com/a/39096006/2719194
argv.push_back(NULL);
char** tmp = argv.data();
MPMD::Initialize_without_split(argc, tmp);
});

// This is AMReX::Initialize when MPMD exists
m.def("initialize_when_MPMD",
[](const py::list args, py::object &app_comm_py) {
Vector<std::string> cargs{"amrex"};
Vector<char*> argv;

// Populate the "command line"
for (const auto& v: args)
cargs.push_back(v.cast<std::string>());
for (auto& v: cargs)
argv.push_back(&v[0]);
int argc = argv.size();

// note: +1 since there is an extra char-string array element,
// that ANSII C requires to be a simple NULL entry
// https://stackoverflow.com/a/39096006/2719194
argv.push_back(NULL);
char** tmp = argv.data();

const bool build_parm_parse = (cargs.size() > 1);

//! TODO perform mpi4py import test and check min-version
//! careful: double MPI_Init risk? only import mpi4py.MPI?
//! required C-API init? probably just checks:
//! refs:
//! -
//! https://bitbucket.org/mpi4py/mpi4py/src/3.0.0/demo/wrap-c/helloworld.c
//! - installed: include/mpi4py/mpi4py.MPI_api.h
//auto m_mpi4py = py::module::import("mpi4py");
//amrex::ignore_unused(m_mpi4py);

if (app_comm_py.ptr() == Py_None)
throw std::runtime_error(
"MPMD: MPI communicator cannot be None.");
if (app_comm_py.ptr() == nullptr)
throw std::runtime_error(
"MPMD: MPI communicator is a nullptr.");

// check type string to see if this is mpi4py
// __str__ (pretty)
// __repr__ (unambiguous)
// mpi4py: <mpi4py.MPI.Intracomm object at 0x7f998e6e28d0>
// pyMPI: ... (TODO)
py::str const comm_pystr = py::repr(app_comm_py);
std::string const comm_str = comm_pystr.cast<std::string>();
if (comm_str.substr(0, 12) != std::string("<mpi4py.MPI."))
throw std::runtime_error(
"MPMD: comm is not an mpi4py communicator: " +
comm_str);
// only checks same layout, e.g. an `int` in `PyObject` could
// pass this
if (!py::isinstance<py::class_<pyAMReX_PyMPIIntracommObject> >(
app_comm_py.get_type()))
// TODO add mpi4py version from above import check to error
// message
throw std::runtime_error(
"MPMD: comm has unexpected type layout in " +
comm_str +
" (Mismatched MPI at compile vs. runtime? "
"Breaking mpi4py release?)");

// todo other possible implementations:
// - pyMPI (inactive since 2008?): import mpi; mpi.WORLD

// reimplementation of mpi4py's:
// MPI_Comm* mpiCommPtr = PyMPIComm_Get(app_comm_py.ptr());
MPI_Comm *mpiCommPtr =
&((pyAMReX_PyMPIIntracommObject *)(app_comm_py.ptr()))->ob_mpi;

if (PyErr_Occurred())
throw std::runtime_error(
"MPMD: MPI communicator access error.");
if (mpiCommPtr == nullptr)
{
throw std::runtime_error(
"MPMD: MPI communicator cast failed. "
"(Mismatched MPI at compile vs. runtime?)");
}

return Initialize(argc, tmp, build_parm_parse, *mpiCommPtr);
}, py::return_value_policy::reference);

constexpr auto run_gc = []() {
// explicitly run the garbage collector, so deleted objects
// get freed.
// This is a convenience helper/bandage for making work with Python
// garbage collectors in various implementations more easy.
// https://github.com/AMReX-Codes/pyamrex/issues/81
auto m_gc = py::module::import("gc");
auto collect = m_gc.attr("collect");
collect();
};
m.def("MPMD_Finalize",
[run_gc]() {
run_gc();
MPMD::Finalize();
});
m.def("MPMD_Initialized",&MPMD::Initialized);
m.def("MPMD_MyProc",&MPMD::MyProc);
m.def("MPMD_NProcs",&MPMD::NProcs);
m.def("MPMD_AppNum",&MPMD::AppNum);
m.def("MPMD_MyProgId",&MPMD::MyProgId);

// Binding MPMD::Copier class
py::class_< MPMD::Copier >(m, "MPMD_Copier")
//! Construct an MPMD::Copier without BoxArray and DistributionMApping
.def(py::init <bool>())
//! Construct an MPMD::Copier with BoxArray and DistributionMApping
.def(py::init< BoxArray const&, DistributionMapping const&,bool>(),
py::arg("ba"),py::arg("dm"),py::arg("send_ba")=false)
// Copier function to send data
.def("send",&MPMD::Copier::send<FArrayBox>)
// Copier function to receive data
.def("recv",&MPMD::Copier::recv<FArrayBox>)
// Copier's BoxArray
.def("box_array",&MPMD::Copier::boxArray,
py::return_value_policy::reference_internal)
// Copier's DistributionMapping
.def("distribution_map",&MPMD::Copier::DistributionMap,
py::return_value_policy::reference_internal)
;

}

#endif
7 changes: 6 additions & 1 deletion src/pyAMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ void init_PODVector(py::module &);
void init_Utility(py::module &);
void init_Vector(py::module &);
void init_Version(py::module &);

#ifdef AMREX_USE_MPI
void init_MPMD(py::module &);
#endif

#if AMREX_SPACEDIM == 1
PYBIND11_MODULE(amrex_1d_pybind, m) {
Expand Down Expand Up @@ -108,6 +110,9 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
init_ParticleContainer(m);
init_AmrMesh(m);

#ifdef AMREX_USE_MPI
init_MPMD(m);
#endif
// Wrappers around standalone functions
init_PlotFileUtil(m);
init_Utility(m);
Expand Down
20 changes: 20 additions & 0 deletions tests/test_MPMD/test_1/GNUmakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
AMREX_HOME ?= ../../../../amrex

DEBUG = TRUE

DIM = 3

COMP = gcc

USE_MPI = TRUE

USE_OMP = FALSE
USE_CUDA = FALSE
USE_HIP = FALSE

include $(AMREX_HOME)/Tools/GNUMake/Make.defs

include ./Make.package
include $(AMREX_HOME)/Src/Base/Make.package

include $(AMREX_HOME)/Tools/GNUMake/Make.rules
1 change: 1 addition & 0 deletions tests/test_MPMD/test_1/Make.package
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CEXE_sources += main.cpp
72 changes: 72 additions & 0 deletions tests/test_MPMD/test_1/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

#include <AMReX.H>
#include <AMReX_Print.H>
#include <AMReX_MultiFab.H>
#include <AMReX_PlotFileUtil.H>
#include <mpi.h>
#include <AMReX_MPMD.H>

int main(int argc, char* argv[])
{
// Initialize amrex::MPMD to establish communication across the two apps
MPI_Comm comm = amrex::MPMD::Initialize(argc, argv);
amrex::Initialize(argc,argv,true,comm);
{
amrex::Print() << "Hello world from AMReX version " << amrex::Version() << "\n";
// Number of data components at each grid point in the MultiFab
int ncomp = 2;
// how many grid cells in each direction over the problem domain
int n_cell = 32;
// how many grid cells are allowed in each direction over each box
int max_grid_size = 16;
//BoxArray -- Abstract Domain Setup
// integer vector indicating the lower coordindate bounds
amrex::IntVect dom_lo(0,0,0);
// integer vector indicating the upper coordindate bounds
amrex::IntVect dom_hi(n_cell-1, n_cell-1, n_cell-1);
// box containing the coordinates of this domain
amrex::Box domain(dom_lo, dom_hi);
// will contain a list of boxes describing the problem domain
amrex::BoxArray ba(domain);
// chop the single grid into many small boxes
ba.maxSize(max_grid_size);
// Distribution Mapping
amrex::DistributionMapping dm(ba);
// Create an MPMD Copier that
// sends the BoxArray information to the other (python) application
auto copr = amrex::MPMD::Copier(ba,dm,true);
//Define MuliFab
amrex::MultiFab mf(ba, dm, ncomp, 0);
//Geometry -- Physical Properties for data on our domain
amrex::RealBox real_box ({0., 0., 0.}, {1. , 1., 1.});
amrex::Geometry geom(domain, &real_box);
//Calculate Cell Sizes
amrex::GpuArray<amrex::Real,3> dx = geom.CellSizeArray(); //dx[0] = dx dx[1] = dy dx[2] = dz
//Fill only the first component of the MultiFab
for(amrex::MFIter mfi(mf); mfi.isValid(); ++mfi){
const amrex::Box& bx = mfi.validbox();
const amrex::Array4<amrex::Real>& mf_array = mf.array(mfi);

amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k){

amrex::Real x = (i+0.5) * dx[0];
amrex::Real y = (j+0.5) * dx[1];
amrex::Real z = (k+0.5) * dx[2];
amrex::Real r_squared = ((x-0.5)*(x-0.5)+(y-0.5)*(y-0.5)+(z-0.5)*(z-0.5))/0.01;

mf_array(i,j,k,0) = 1.0 + std::exp(-r_squared);

});
}
// Send ONLY the first populated MultiFab component to the other app
copr.send(mf,0,1);
// Receive ONLY the second MultiFab component from the other app
copr.recv(mf,1,1);
//Plot MultiFab Data
WriteSingleLevelPlotfile("plt_cpp_001", mf, {"comp0","comp1"}, geom, 0., 0);

}
amrex::Finalize();
amrex::MPMD::Finalize();

}
Loading

0 comments on commit d75140f

Please sign in to comment.