Skip to content

Commit

Permalink
Merge branch 'apache:main' into default_target
Browse files Browse the repository at this point in the history
  • Loading branch information
MNGanesan authored Oct 24, 2024
2 parents 23b9cb6 + d973b33 commit c297b73
Show file tree
Hide file tree
Showing 77 changed files with 3,291 additions and 141 deletions.
3 changes: 2 additions & 1 deletion conda/build-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ channels:
# The packages to install to the environment
dependencies:
- python=3.9
- conda-build
- conda < 24.9.0
- conda-build < 24.9.0
- git
- llvmdev >=11
- numpy
Expand Down
2 changes: 1 addition & 1 deletion conda/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

{% set version = '0.18.dev0' %}
{% set version = '0.19.dev0' %}
{% set pkg_name = 'tvm' %}
{% set cuda_tag = cuda_version | replace('.', '') %} # [cuda]
{% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda]
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.ci_lint
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI.

RUN apt-get update && apt-install-and-clear -y doxygen graphviz curl shellcheck

RUN pip3 install cpplint pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3
RUN pip3 install cpplint==1.6.1 pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3

# Rust env (build early; takes a while)
COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh
Expand Down
2 changes: 1 addition & 1 deletion docker/install/ubuntu2004_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pip3 install --upgrade \
psutil \
pytest \
git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \
pytest-profiling \
pytest-profiling==1.7.0 \
pytest-xdist \
pytest-rerunfailures==10.2 \
requests \
Expand Down
18 changes: 8 additions & 10 deletions docker/install/ubuntu_install_jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@ set -e
set -u
set -o pipefail

JAX_VERSION=0.4.30

# Install jaxlib
# Install jax and jaxlib
if [ "$1" == "cuda" ]; then
pip install -U \
"jax[cuda12]~=${JAX_VERSION}" \
jaxlib~=${JAX_VERSION}
pip3 install --upgrade \
jaxlib~=0.4.9 \
"jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
pip3 install -U \
jax~=${JAX_VERSION} \
jaxlib~=${JAX_VERSION}
pip3 install --upgrade \
jaxlib~=0.4.9 \
"jax[cpu]~=0.4.9"
fi

# Install flax
pip3 install flax~=0.8.5
pip3 install flax~=0.6.9
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pip3 install --upgrade \
psutil \
pytest \
git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \
pytest-profiling \
pytest-profiling!=1.8.0 \
pytest-xdist \
pytest-rerunfailures==10.2 \
requests \
Expand Down
4 changes: 2 additions & 2 deletions docker/install/ubuntu_install_tensorflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ set -u
set -o pipefail

pip3 install \
keras==3.5 \
tensorflow==2.17.0
keras==2.9 \
tensorflow==2.9.1
4 changes: 2 additions & 2 deletions docker/install/ubuntu_install_tensorflow_aarch64.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ apt-install-and-clear -y --no-install-recommends libhdf5-dev
# h5py wheel tries to use the wrong .so file
pip3 install \
numpy==1.23.5 \
keras==3.5 \
tensorflow-aarch64~=2.16.1
keras==2.9 \
tensorflow-aarch64~=2.9.3
40 changes: 20 additions & 20 deletions docker/install/ubuntu_install_tflite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ set -o pipefail
TENSORFLOW_VERSION=$(python3 -c "import tensorflow; print(tensorflow.__version__)" 2> /dev/null)

# Download, build and install flatbuffers
git clone --branch=v24.3.25 --depth=1 --recursive https://github.com/google/flatbuffers.git
pushd flatbuffers
cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess"
ninja install -j8
popd
git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git
cd flatbuffers
cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess"
make install -j8
cd ..

# Install flatbuffers python packages.
pip3 install flatbuffers
Expand All @@ -41,22 +41,22 @@ pip3 install flatbuffers
git clone https://github.com/tensorflow/tensorflow --branch=v${TENSORFLOW_VERSION} --depth 1

mkdir -p /opt/tflite
pushd /opt/tflite
cmake -G Ninja \
-DTFLITE_ENABLE_XNNPACK=OFF \
/tensorflow/tensorflow/lite
cd /opt/tflite
cmake \
-DTFLITE_ENABLE_XNNPACK=OFF \
/tensorflow/tensorflow/lite

cmake --build .
cd -

cmake --build .
popd

# Setup tflite from schema
mkdir tflite
find / -name "schema.fbs"
cp /tensorflow/tensorflow/lite/stablehlo/schema/schema.fbs tflite
pushd tflite
flatc --python schema.fbs
cp tensorflow/tensorflow/lite/schema/schema.fbs tflite
cd tflite
flatc --python schema.fbs

cat <<EOM >setup.py
cat <<EOM >setup.py
import setuptools
setuptools.setup(
Expand All @@ -77,12 +77,12 @@ setuptools.setup(
)
EOM

cat <<EOM >__init__.py
cat <<EOM >__init__.py
name = "tflite"
EOM

# Install tflite over python3
python3 setup.py install
# Install tflite over python3
python3 setup.py install

popd
cd ..
rm -rf tflite
23 changes: 23 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
}; // struct ScatterElementsAttrs

/*! \brief Attributes used in scatter_nd operators */
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String reduction;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Accumulation mode of the ScatterND, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
}; // struct ScatterNDAttrs

/*! \brief Attributes used in one_hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;

TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
}
}; // struct OneHotAttrs

} // namespace relax
} // namespace tvm

Expand Down
21 changes: 21 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_war
*/
TVM_DLL Pass RealizeVDevice();

/*!
* \brief Attach layout free buffers to the tir::PrimFunc.
*
* This pass is used to attach layout free buffers to the tir::PrimFunc according to
* the function usage in the relax function. Currently, the layout free buffers are the model
* weights and relax constants.
*
* \note We recommend applying CanonicalizeBindings before this pass.
* \return The Pass.
*/
TVM_DLL Pass AttachAttrLayoutFreeBuffers();

/*!
* \brief Split the layout rewrite preproc block to a separate tir::PrimFunc.
*
* This pass is used in the prepack weight after meta_schedule tuning.
*
* \return The Pass.
*/
TVM_DLL Pass SplitLayoutRewritePreproc();

/*!
* \brief Lift transformation of the parameters of a function.
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
#endif

// TVM version
#define TVM_VERSION "0.18.dev0"
#define TVM_VERSION "0.19.dev0"

// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object {
*/
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;

/*!
* \brief Annotate the buffer access of a block
* \param block_rv The block to be annotated
* \param buffer_index The index of the buffer in block's read or write region
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The index map that defines the new read or write region
*/
virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution";
/*! \brief Mark that a block is disallowed in auto inline. */
constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";

/*! \brief Mark that a block has an explicitly specified read region.
* This is used to override the default read region inference in TIR.
*/
constexpr const char* explicit_read_region = "explicit_read_region";

/*! \brief Mark that a block has an explicitly specified write region.
* This is used to override the default write region inference in TIR.
*/
constexpr const char* explicit_write_region = "explicit_write_region";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
12 changes: 12 additions & 0 deletions jvm/core/src/main/java/org/apache/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ public Function pushArg(byte[] arg) {
return this;
}

/**
* Push argument to the function.
* @param arg Device.
* @return this
*/
public Function pushArg(Device arg) {
Base._LIB.tvmFuncPushArgDevice(arg);
return this;
}

/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
Expand Down Expand Up @@ -255,6 +265,8 @@ private static void pushArgToStack(Object arg) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
} else if (arg instanceof Device) {
Base._LIB.tvmFuncPushArgDevice((Device) arg);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
Expand Down
2 changes: 2 additions & 0 deletions jvm/core/src/main/java/org/apache/tvm/LibInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class LibInfo {

native void tvmFuncPushArgHandle(long arg, int argType);

native void tvmFuncPushArgDevice(Device device);

native int tvmFuncListGlobalNames(List<String> funcNames);

native int tvmFuncFree(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class GraphModule {
private Function fdebugGetOutput;
private Function floadParams;

GraphModule(Module module, Device dev) {
public GraphModule(Module module, Device dev) {
this.module = module;
this.device = dev;
fsetInput = module.getFunction("set_input");
Expand Down
21 changes: 21 additions & 0 deletions jvm/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,25 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) {
return NULL;
}

// Helper function to pack two int32_t values into an int64_t
inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) {
int64_t result;
int32_t* parts = reinterpret_cast<int32_t*>(&result);

// Lambda function to check endianness
const auto isLittleEndian = []() -> bool {
uint32_t i = 1;
return *reinterpret_cast<char*>(&i) == 1;
};

if (isLittleEndian()) {
parts[0] = device_type;
parts[1] = device_id;
} else {
parts[1] = device_type;
parts[0] = device_id;
}
return result;
}

#endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_
15 changes: 15 additions & 0 deletions jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv*
e->tvmFuncArgTypes.push_back(static_cast<int>(argType));
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj,
jobject arg) {
jclass deviceClass = env->FindClass("org/apache/tvm/Device");
jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I");
jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I");
jint deviceType = env->GetIntField(arg, deviceTypeField);
jint deviceId = env->GetIntField(arg, deviceIdField);

TVMValue value;
value.v_int64 = deviceToInt64(deviceType, deviceId);
TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kDLDevice);
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj,
jbyteArray arg) {
jbyteArray garg = reinterpret_cast<jbyteArray>(env->NewGlobalRef(arg));
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,4 @@ def find_include_path(name=None, search_path=None, optional=False):
# We use the version of the incoming release for code
# that is under development.
# The following line is set by tvm/python/update_version.py
__version__ = "0.18.dev0"
__version__ = "0.19.dev0"
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .modules import (
GELU,
Conv1D,
Conv2D,
Conv3D,
ConvTranspose1D,
Embedding,
GroupNorm,
Expand Down
Loading

0 comments on commit c297b73

Please sign in to comment.