Skip to content

Commit

Permalink
fix: extend limit on number of regular arrays to concatenate (#3312)
Browse files Browse the repository at this point in the history
* fix limit on number of regular arrays

* style: pre-commit fixes

* sort specializations

* sort one more

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ianna and pre-commit-ci[bot] authored Nov 22, 2024
1 parent 7887b96 commit 4eebc08
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
39 changes: 39 additions & 0 deletions awkward-cpp/src/cpu-kernels/awkward_UnionArray_regular_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,45 @@ ERROR awkward_UnionArray_regular_index(
}
return success();
}
ERROR awkward_UnionArray64_32_regular_index(
int32_t* toindex,
int32_t* current,
int64_t size,
const int64_t* fromtags,
int64_t length) {
return awkward_UnionArray_regular_index<int64_t, int32_t>(
toindex,
current,
size,
fromtags,
length);
}
ERROR awkward_UnionArray64_U32_regular_index(
uint32_t* toindex,
uint32_t* current,
int64_t size,
const int64_t* fromtags,
int64_t length) {
return awkward_UnionArray_regular_index<int64_t, uint32_t>(
toindex,
current,
size,
fromtags,
length);
}
ERROR awkward_UnionArray64_64_regular_index(
int64_t* toindex,
int64_t* current,
int64_t size,
const int64_t* fromtags,
int64_t length) {
return awkward_UnionArray_regular_index<int64_t, int64_t>(
toindex,
current,
size,
fromtags,
length);
}
ERROR awkward_UnionArray8_32_regular_index(
int32_t* toindex,
int32_t* current,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ ERROR awkward_UnionArray_regular_index_getsize(
*size = *size + 1;
return success();
}
ERROR awkward_UnionArray64_regular_index_getsize(
int64_t* size,
const int64_t* fromtags,
int64_t length) {
return awkward_UnionArray_regular_index_getsize<int64_t>(
size,
fromtags,
length);
}
ERROR awkward_UnionArray8_regular_index_getsize(
int64_t* size,
const int8_t* fromtags,
Expand Down
26 changes: 26 additions & 0 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3458,6 +3458,27 @@ kernels:

- name: awkward_UnionArray_regular_index
specializations:
- name: awkward_UnionArray64_32_regular_index
args:
- {name: toindex, type: "List[int32_t]", dir: out}
- {name: current, type: "List[int32_t]", dir: out}
- {name: size, type: "int64_t", dir: in, role: default}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray64_64_regular_index
args:
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: current, type: "List[int64_t]", dir: out}
- {name: size, type: "int64_t", dir: in, role: default}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray64_U32_regular_index
args:
- {name: toindex, type: "List[uint32_t]", dir: out}
- {name: current, type: "List[uint32_t]", dir: out}
- {name: size, type: "int64_t", dir: in, role: default}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_regular_index
args:
- {name: toindex, type: "List[int32_t]", dir: out}
Expand Down Expand Up @@ -3493,6 +3514,11 @@ kernels:

- name: awkward_UnionArray_regular_index_getsize
specializations:
- name: awkward_UnionArray64_regular_index_getsize
args:
- {name: size, type: "List[int64_t]", dir: out}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_regular_index_getsize
args:
- {name: size, type: "List[int64_t]", dir: out}
Expand Down
11 changes: 11 additions & 0 deletions tests/test_3312_concatenate_regular_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

import numpy as np

import awkward as ak


def test():
ak.concatenate([ak.Array([i])[:, np.newaxis] for i in range(127)], axis=1)

ak.concatenate([ak.Array([i])[:, np.newaxis] for i in range(128)], axis=1)

0 comments on commit 4eebc08

Please sign in to comment.