Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Correct msgpack deserialization of ResourceSlot (#2754) #2756

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2754.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correct `msgpack` deserialization of `ResourceSlot`.
7 changes: 6 additions & 1 deletion src/ai/backend/common/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import msgpack as _msgpack
import temporenc

from .types import BinarySize
from .types import BinarySize, ResourceSlot

__all__ = ("packb", "unpackb")

Expand All @@ -27,6 +27,7 @@ class ExtTypes(enum.IntEnum):
POSIX_PATH = 4
PURE_POSIX_PATH = 5
ENUM = 6
RESOURCE_SLOT = 8
BACKENDAI_BINARY_SIZE = 16


Expand All @@ -46,6 +47,8 @@ def _default(obj: object) -> Any:
return _msgpack.ExtType(ExtTypes.POSIX_PATH, os.fsencode(obj))
case PurePosixPath():
return _msgpack.ExtType(ExtTypes.PURE_POSIX_PATH, os.fsencode(obj))
case ResourceSlot():
return _msgpack.ExtType(ExtTypes.RESOURCE_SLOT, pickle.dumps(obj, protocol=5))
case enum.Enum():
return _msgpack.ExtType(ExtTypes.ENUM, pickle.dumps(obj, protocol=5))
raise TypeError(f"Unknown type: {obj!r} ({type(obj)})")
Expand All @@ -65,6 +68,8 @@ def _ext_hook(code: int, data: bytes) -> Any:
return PurePosixPath(os.fsdecode(data))
case ExtTypes.ENUM:
return pickle.loads(data)
case ExtTypes.RESOURCE_SLOT:
return pickle.loads(data)
case ExtTypes.BACKENDAI_BINARY_SIZE:
return pickle.loads(data)
return _msgpack.ExtType(code, data)
Expand Down
19 changes: 18 additions & 1 deletion tests/common/test_msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dateutil.tz import gettz, tzutc

from ai.backend.common import msgpack
from ai.backend.common.types import BinarySize, SlotTypes
from ai.backend.common.types import BinarySize, ResourceSlot, SlotTypes


def test_msgpack_with_unicode():
Expand Down Expand Up @@ -125,3 +125,20 @@ def test_msgpack_posixpath():
unpacked = msgpack.unpackb(packed)
assert isinstance(unpacked["path"], PosixPath)
assert unpacked["path"] == path


def test_msgpack_resource_slot():
resource_slot = ResourceSlot({"cpu": 1, "mem": 1024})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot

resource_slot = ResourceSlot({"cpu": 2, "mem": Decimal(1024**5)})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot

resource_slot = ResourceSlot({"cpu": 3, "mem": "1125899906842624"})
packed = msgpack.packb(resource_slot)
unpacked = msgpack.unpackb(packed)
assert unpacked == resource_slot
Loading