Skip to content

Commit

Permalink
fix: Correct msgpack deserialization of ResourceSlot (#2754)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <joongi@lablup.com>
Backported-from: main (24.09)
Backported-to: 24.03
Backport-of: 2754
  • Loading branch information
jopemachine and achimnol committed Aug 22, 2024
1 parent 186e1d1 commit 9deece7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
1 change: 1 addition & 0 deletions changes/2754.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correct `msgpack` deserialization of `ResourceSlot`.
7 changes: 6 additions & 1 deletion src/ai/backend/common/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import msgpack as _msgpack
import temporenc

from .types import BinarySize
from .types import BinarySize, ResourceSlot

__all__ = ("packb", "unpackb")

Expand All @@ -27,6 +27,7 @@ class ExtTypes(enum.IntEnum):
POSIX_PATH = 4
PURE_POSIX_PATH = 5
ENUM = 6
RESOURCE_SLOT = 8
BACKENDAI_BINARY_SIZE = 16


Expand All @@ -46,6 +47,8 @@ def _default(obj: object) -> Any:
return _msgpack.ExtType(ExtTypes.POSIX_PATH, os.fsencode(obj))
case PurePosixPath():
return _msgpack.ExtType(ExtTypes.PURE_POSIX_PATH, os.fsencode(obj))
case ResourceSlot():
return _msgpack.ExtType(ExtTypes.RESOURCE_SLOT, pickle.dumps(obj, protocol=5))
case enum.Enum():
return _msgpack.ExtType(ExtTypes.ENUM, pickle.dumps(obj, protocol=5))
raise TypeError(f"Unknown type: {obj!r} ({type(obj)})")
Expand All @@ -65,6 +68,8 @@ def _ext_hook(code: int, data: bytes) -> Any:
return PurePosixPath(os.fsdecode(data))
case ExtTypes.ENUM:
return pickle.loads(data)
case ExtTypes.RESOURCE_SLOT:
return pickle.loads(data)
case ExtTypes.BACKENDAI_BINARY_SIZE:
return pickle.loads(data)
return _msgpack.ExtType(code, data)
Expand Down
19 changes: 18 additions & 1 deletion tests/common/test_msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dateutil.tz import gettz, tzutc

from ai.backend.common import msgpack
from ai.backend.common.types import BinarySize, SlotTypes
from ai.backend.common.types import BinarySize, ResourceSlot, SlotTypes


def test_msgpack_with_unicode():
Expand Down Expand Up @@ -125,3 +125,20 @@ def test_msgpack_posixpath():
unpacked = msgpack.unpackb(packed)
assert isinstance(unpacked["path"], PosixPath)
assert unpacked["path"] == path


def test_msgpack_resource_slot():
resource_slot = ResourceSlot({"cpu": 1, "mem": 1024})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot

resource_slot = ResourceSlot({"cpu": 2, "mem": Decimal(1024**5)})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot

resource_slot = ResourceSlot({"cpu": 3, "mem": "1125899906842624"})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot

0 comments on commit 9deece7

Please sign in to comment.