Skip to content

Commit

Permalink
[Android Java] Get rid of forwardOnes
Browse files Browse the repository at this point in the history
Differential Revision: D62327373

Pull Request resolved: #5153
  • Loading branch information
kirklandsign authored Sep 7, 2024
1 parent 8ff79ef commit 32d83b0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 24 deletions.
31 changes: 23 additions & 8 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
jinputs) {
// If no inputs is given, it will run with sample inputs (ones)
if (jinputs->size() == 0) {
if (module_->load_method(method) != Error::Ok) {
return {};
}
auto&& underlying_method = module_->methods_[method].method;
auto&& buf = prepare_input_tensors(*underlying_method);
auto result = underlying_method->execute();
if (result != Error::Ok) {
return {};
}
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
facebook::jni::JArrayClass<JEValue>::newArray(
underlying_method->outputs_size());

for (int i = 0; i < underlying_method->outputs_size(); i++) {
auto jevalue =
JEValue::newJEValueFromEValue(underlying_method->get_output(i));
jresult->setElement(i, *jevalue);
}
return jresult;
}

std::vector<EValue> evalues;
std::vector<TensorPtr> tensors;

Expand Down Expand Up @@ -352,20 +375,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
return jresult;
}

jint forward_ones() {
auto&& load_result = module_->load_method("forward");
auto&& buf = prepare_input_tensors(*(module_->methods_["forward"].method));
auto&& result = module_->methods_["forward"].method->execute();
return (jint)result;
}

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
makeNativeMethod("forward", ExecuTorchJni::forward),
makeNativeMethod("execute", ExecuTorchJni::execute),
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
makeNativeMethod("forwardOnes", ExecuTorchJni::forward_ones),
});
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,12 @@ public static Module load(final String modelPath) {
/**
* Runs the 'forward' method of this module with the specified arguments.
*
* @param inputs arguments for the ExecuTorch module's 'forward' method.
* @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward'
* requires inputs but no inputs are given, the function will not error out, but run 'forward'
* with sample inputs.
* @return return value from the 'forward' method.
*/
public EValue[] forward(EValue... inputs) {
if (inputs.length == 0) {
// forward default args (ones)
mNativePeer.forwardOnes();
// discard the return value
return null;
}
return mNativePeer.forward(inputs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,6 @@ public void resetNative() {
@DoNotStrip
public native EValue[] forward(EValue... inputs);

/**
* Run a "forward" call with the sample inputs (ones) to test a module
*
* @return the outputs of the forward call
* @apiNote This is experimental and test-only API
*/
@DoNotStrip
public native int forwardOnes();

/** Run an arbitrary method on the module */
@DoNotStrip
public native EValue[] execute(String methodName, EValue... inputs);
Expand Down

0 comments on commit 32d83b0

Please sign in to comment.