-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
npy: first stab at a n-dim array with support for ragged-arrays
Fixes #20. Signed-off-by: Sebastien Binet <binet@cern.ch>
- Loading branch information
Showing
15 changed files
with
3,291 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY= | ||
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= | ||
github.com/nlpodyssey/gopickle v0.2.1-0.20231124153821-2139434d2287 h1:cqTk2IOiApRFn/e3YoAVQDgvOc6yU1zihhrH1w14WTg= | ||
github.com/nlpodyssey/gopickle v0.2.1-0.20231124153821-2139434d2287/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0= | ||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= | ||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= | ||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= | ||
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= | ||
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
// Copyright 2023 The npyio Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package npy | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
py "github.com/nlpodyssey/gopickle/types" | ||
) | ||
|
||
// Array is a multidimensional, homogeneous array of fixed-size items. | ||
type Array struct { | ||
descr ArrayDescr | ||
shape []int | ||
strides []int | ||
fortran bool | ||
|
||
data any | ||
} | ||
|
||
var ( | ||
_ py.Callable = (*Array)(nil) | ||
_ py.PyNewable = (*Array)(nil) | ||
_ py.PyStateSettable = (*Array)(nil) | ||
) | ||
|
||
func (*Array) Call(args ...any) (any, error) { | ||
switch sz := len(args); { | ||
case sz < 1, sz > 3: | ||
return nil, fmt.Errorf("invalid tuple length (got=%d)", sz) | ||
} | ||
|
||
return &Array{}, nil | ||
} | ||
|
||
func (*Array) PyNew(args ...any) (any, error) { | ||
var ( | ||
subtype = args[0] | ||
descr = args[1].(*ArrayDescr) | ||
shape = args[2].([]int) | ||
strides = args[3].([]int) | ||
data = args[4].([]byte) | ||
flags = args[5].(int) | ||
) | ||
|
||
return newArray(subtype, *descr, shape, strides, data, flags) | ||
} | ||
|
||
func newArray(subtype any, descr ArrayDescr, shape, strides []int, data []byte, flags int) (*Array, error) { | ||
switch subtype := subtype.(type) { | ||
case *Array: | ||
// ok. | ||
default: | ||
return nil, fmt.Errorf("subtyping ndarray with %T is not (yet?) supported", subtype) | ||
} | ||
|
||
arr := &Array{ | ||
descr: descr, | ||
shape: shape, | ||
strides: strides, | ||
data: data, | ||
} | ||
return arr, nil | ||
} | ||
|
||
func (arr *Array) PySetState(arg any) error { | ||
tuple, ok := arg.(*py.Tuple) | ||
if !ok { | ||
return fmt.Errorf("invalid argument type %T", arg) | ||
} | ||
|
||
var ( | ||
vers = 0 | ||
shape py.Tuple | ||
raw any | ||
) | ||
switch tuple.Len() { | ||
case 5: | ||
err := parseTuple(tuple, &vers, &shape, &arr.descr, &arr.fortran, nil) | ||
if err != nil { | ||
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err) | ||
} | ||
raw = tuple.Get(4) | ||
case 4: | ||
err := parseTuple(tuple, &shape, &arr.descr, &arr.fortran, nil) | ||
if err != nil { | ||
return fmt.Errorf("could not parse ndarray.__setstate__ tuple: %w", err) | ||
} | ||
raw = tuple.Get(3) | ||
default: | ||
return fmt.Errorf("invalid length (%d) for ndarray.__setstate__ tuple", tuple.Len()) | ||
} | ||
|
||
arr.shape = nil | ||
for i := range shape { | ||
v, ok := shape.Get(i).(int) | ||
if !ok { | ||
return fmt.Errorf("invalid shape[%d]: got=%T, want=int", i, shape.Get(i)) | ||
} | ||
arr.shape = append(arr.shape, v) | ||
} | ||
|
||
err := arr.setupStrides() | ||
if err != nil { | ||
return fmt.Errorf("ndarray.__setstate__ could not infer strides: %w", err) | ||
} | ||
|
||
switch raw := raw.(type) { | ||
case *py.List: | ||
arr.data = raw | ||
|
||
case []byte: | ||
data, err := arr.descr.unmarshal(raw, arr.shape) | ||
if err != nil { | ||
return fmt.Errorf("ndarray.__setstate__ could not unmarshal raw data: %w", err) | ||
} | ||
arr.data = data | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (arr *Array) setupStrides() error { | ||
// TODO(sbinet): complete implementation. | ||
// see: _array_fill_strides in numpy/_core/multiarray/ctors.c | ||
|
||
if arr.shape == nil { | ||
arr.strides = nil | ||
return nil | ||
} | ||
|
||
strides := make([]int, len(arr.shape)) | ||
// FIXME(sbinet): handle non-contiguous arrays | ||
// FIXME(sbinet): handle FORTRAN arrays | ||
|
||
var ( | ||
// notCFContig bool | ||
noDim bool // a dimension != 1 was found | ||
) | ||
|
||
// check if array is both FORTRAN- and C-contiguous | ||
for _, dim := range arr.shape { | ||
if dim != 1 { | ||
if noDim { | ||
// notCFContig = true | ||
break | ||
} | ||
noDim = true | ||
} | ||
} | ||
|
||
itemsize := arr.descr.itemsize() | ||
switch { | ||
case arr.fortran: | ||
for i, dim := range arr.shape { | ||
strides[i] = itemsize | ||
switch { | ||
case dim != 0: | ||
itemsize *= dim | ||
default: | ||
// notCFContig = false | ||
} | ||
} | ||
|
||
default: | ||
for i := len(arr.shape) - 1; i >= 0; i-- { | ||
dim := arr.shape[i] | ||
strides[i] = itemsize | ||
switch { | ||
case dim != 0: | ||
itemsize *= dim | ||
default: | ||
// notCFContig = false | ||
} | ||
} | ||
} | ||
|
||
arr.strides = strides | ||
return nil | ||
} | ||
|
||
// Descr returns the array's data type descriptor. | ||
func (arr Array) Descr() ArrayDescr { | ||
return arr.descr | ||
} | ||
|
||
// Shape returns the array's shape. | ||
func (arr Array) Shape() []int { | ||
return arr.shape | ||
} | ||
|
||
// Strides returns the array's strides in bytes. | ||
func (arr Array) Strides() []int { | ||
return arr.strides | ||
} | ||
|
||
// Fortran returns whether the array's data is stored in FORTRAN-order | ||
// (ie: column-major) instead of C-order (ie: row-major.) | ||
func (arr Array) Fortran() bool { | ||
return arr.fortran | ||
} | ||
|
||
// Data returns the array's underlying data. | ||
func (arr Array) Data() any { | ||
return arr.data | ||
} | ||
|
||
func (arr Array) String() string { | ||
o := new(strings.Builder) | ||
fmt.Fprintf(o, "Array{descr: %v, ", arr.descr) | ||
switch arr.shape { | ||
case nil: | ||
fmt.Fprintf(o, "shape: nil, ") | ||
default: | ||
fmt.Fprintf(o, "shape: %v, ", arr.shape) | ||
} | ||
switch arr.strides { | ||
case nil: | ||
fmt.Fprintf(o, "strides: nil, ") | ||
default: | ||
fmt.Fprintf(o, "strides: %v, ", arr.strides) | ||
} | ||
fmt.Fprintf(o, "fortran: %v, data: %+v}", | ||
arr.fortran, | ||
arr.data, | ||
) | ||
return o.String() | ||
} |
Oops, something went wrong.