Skip to content

Commit

Permalink
Sync from upstream TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
TFLM-bot committed Oct 18, 2024
1 parent e86d97b commit eb5606e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tensorflow/lite/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "tensorflow/lite/array.h"

#include "tensorflow/lite/c/common.h"

namespace tflite {
namespace array_internal {

Expand Down
9 changes: 7 additions & 2 deletions tensorflow/lite/kernels/internal/reference/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const float* scaling_factors,
const int32_t* input_offset, int32_t* row_sums,
const RuntimeShape& output_shape, float* output_data,
bool* compute_row_sums) {
bool* compute_row_sums,
const float* per_channel_scales) {
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
Expand Down Expand Up @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
int32_t row_sum = woff_ptr2[i];
total -= row_sum * batch_offset;
int idx = lhs_rows * j + i;
out_ptr[idx] += batch_scaling_factor * total;
float scale = batch_scaling_factor;
if (per_channel_scales) {
scale *= per_channel_scales[i];
}
out_ptr[idx] += scale * total;
}
}
}
Expand Down

0 comments on commit eb5606e

Please sign in to comment.