Skip to content

Commit

Permalink
Merge branch 'main' into feature/standalone-docker
Browse files Browse the repository at this point in the history
  • Loading branch information
dbluhm authored Oct 25, 2022
2 parents 417f435 + 1fb0bca commit 38252c0
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 36 deletions.
133 changes: 110 additions & 23 deletions aries_cloudagent/messaging/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from abc import ABC
from collections import namedtuple
from typing import Mapping, Union
from typing import Mapping, Optional, Type, TypeVar, Union, cast, overload
from typing_extensions import Literal

from marshmallow import Schema, post_dump, pre_load, post_load, ValidationError, EXCLUDE

Expand All @@ -17,7 +18,7 @@
SerDe = namedtuple("SerDe", "ser de")


def resolve_class(the_cls, relative_cls: type = None):
def resolve_class(the_cls, relative_cls: Optional[type] = None) -> type:
"""
Resolve a class.
Expand All @@ -38,6 +39,10 @@ def resolve_class(the_cls, relative_cls: type = None):
elif isinstance(the_cls, str):
default_module = relative_cls and relative_cls.__module__
resolved = ClassLoader.load_class(the_cls, default_module)
else:
raise TypeError(
f"Could not resolve class from {the_cls}; incorrect type {type(the_cls)}"
)
return resolved


Expand All @@ -53,7 +58,10 @@ def resolve_meta_property(obj, prop_name: str, defval=None):
The meta property
"""
cls = obj.__class__
if isinstance(obj, type):
cls = obj
else:
cls = obj.__class__
found = defval
while cls:
Meta = getattr(cls, "Meta", None)
Expand All @@ -70,6 +78,9 @@ class BaseModelError(BaseError):
"""Base exception class for base model errors."""


ModelType = TypeVar("ModelType", bound="BaseModel")


class BaseModel(ABC):
"""Base model that provides convenience methods."""

Expand All @@ -94,18 +105,24 @@ def __init__(self):
)

@classmethod
def _get_schema_class(cls):
def _get_schema_class(cls) -> Type["BaseModelSchema"]:
"""
Get the schema class.
Returns:
The resolved schema class
"""
return resolve_class(cls.Meta.schema_class, cls)
resolved = resolve_class(cls.Meta.schema_class, cls)
if issubclass(resolved, BaseModelSchema):
return resolved

raise TypeError(
f"Resolved class is not a subclass of BaseModelSchema: {resolved}"
)

@property
def Schema(self) -> type:
def Schema(self) -> Type["BaseModelSchema"]:
"""
Accessor for the model's schema class.
Expand All @@ -115,8 +132,49 @@ def Schema(self) -> type:
"""
return self._get_schema_class()

@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
unknown: Optional[str] = None,
) -> ModelType:
"""Convert from JSON representation to a model instance."""
...

@overload
@classmethod
def deserialize(cls, obj, unknown: str = None, none2none: str = False):
def deserialize(
cls: Type[ModelType],
obj,
*,
none2none: Literal[False],
unknown: Optional[str] = None,
) -> ModelType:
"""Convert from JSON representation to a model instance."""
...

@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
none2none: Literal[True],
unknown: Optional[str] = None,
) -> Optional[ModelType]:
"""Convert from JSON representation to a model instance."""
...

@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
unknown: Optional[str] = None,
none2none: bool = False,
) -> Optional[ModelType]:
"""
Convert from JSON representation to a model instance.
Expand All @@ -132,18 +190,45 @@ def deserialize(cls, obj, unknown: str = None, none2none: str = False):
if obj is None and none2none:
return None

schema = cls._get_schema_class()(unknown=unknown or EXCLUDE)
schema_cls = cls._get_schema_class()
schema = schema_cls(
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
)

try:
return schema.loads(obj) if isinstance(obj, str) else schema.load(obj)
return cast(
ModelType,
schema.loads(obj) if isinstance(obj, str) else schema.load(obj),
)
except (AttributeError, ValidationError) as err:
LOGGER.exception(f"{cls.__name__} message validation error:")
raise BaseModelError(f"{cls.__name__} schema validation failed") from err

@overload
def serialize(
self,
*,
as_string: Literal[True],
unknown: Optional[str] = None,
) -> str:
"""Create a JSON-compatible dict representation of the model instance."""
...

@overload
def serialize(
self,
as_string=False,
unknown: str = None,
*,
unknown: Optional[str] = None,
) -> dict:
"""Create a JSON-compatible dict representation of the model instance."""
...

def serialize(
self,
*,
as_string: bool = False,
unknown: Optional[str] = None,
) -> Union[str, dict]:
"""
Create a JSON-compatible dict representation of the model instance.
Expand All @@ -154,7 +239,10 @@ def serialize(
A dict representation of this model, or a JSON string if as_string is True
"""
schema = self.Schema(unknown=unknown or EXCLUDE)
schema_cls = self._get_schema_class()
schema = schema_cls(
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
)
try:
return (
schema.dumps(self, separators=(",", ":"))
Expand All @@ -168,18 +256,17 @@ def serialize(
) from err

@classmethod
def serde(cls, obj: Union["BaseModel", Mapping]) -> SerDe:
def serde(cls, obj: Union["BaseModel", Mapping]) -> Optional[SerDe]:
"""Return serialized, deserialized representations of input object."""
if obj is None:
return None

return (
SerDe(obj.serialize(), obj)
if isinstance(obj, BaseModel)
else None
if obj is None
else SerDe(obj, cls.deserialize(obj))
)
if isinstance(obj, BaseModel):
return SerDe(obj.serialize(), obj)

return SerDe(obj, cls.deserialize(obj))

def validate(self, unknown: str = None):
def validate(self, unknown: Optional[str] = None):
"""Validate a constructed model."""
schema = self.Schema(unknown=unknown)
errors = schema.validate(self.serialize())
Expand All @@ -191,7 +278,7 @@ def validate(self, unknown: str = None):
def from_json(
cls,
json_repr: Union[str, bytes],
unknown: str = None,
unknown: Optional[str] = None,
):
"""
Parse a JSON string into a model instance.
Expand All @@ -218,7 +305,7 @@ def to_json(self, unknown: str = None) -> str:
A JSON representation of this message
"""
return json.dumps(self.serialize(unknown=unknown or EXCLUDE))
return json.dumps(self.serialize(unknown=unknown))

def __repr__(self) -> str:
"""
Expand Down
74 changes: 64 additions & 10 deletions aries_cloudagent/messaging/models/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
import json

from asynctest import TestCase as AsyncTestCase, mock as async_mock

from marshmallow import EXCLUDE, fields, validates_schema, ValidationError

from ....cache.base import BaseCache
from ....config.injection_context import InjectionContext
from ....storage.base import BaseStorage, StorageRecord

from ...responder import BaseResponder, MockResponder
from ...util import time_now
from marshmallow import EXCLUDE, INCLUDE, fields, validates_schema, ValidationError

from ..base import BaseModel, BaseModelError, BaseModelSchema

Expand All @@ -35,6 +26,48 @@ def validate_fields(self, data, **kwargs):
raise ValidationError("")


class ModelImplWithUnknown(BaseModel):
class Meta:
schema_class = "SchemaImplWithUnknown"

def __init__(self, *, attr=None, **kwargs):
self.attr = attr
self.extra = kwargs


class SchemaImplWithUnknown(BaseModelSchema):
class Meta:
model_class = ModelImplWithUnknown
unknown = INCLUDE

attr = fields.String(required=True)

@validates_schema
def validate_fields(self, data, **kwargs):
if data["attr"] != "succeeds":
raise ValidationError("")


class ModelImplWithoutUnknown(BaseModel):
class Meta:
schema_class = "SchemaImplWithoutUnknown"

def __init__(self, *, attr=None):
self.attr = attr


class SchemaImplWithoutUnknown(BaseModelSchema):
class Meta:
model_class = ModelImplWithoutUnknown

attr = fields.String(required=True)

@validates_schema
def validate_fields(self, data, **kwargs):
if data["attr"] != "succeeds":
raise ValidationError("")


class TestBase(AsyncTestCase):
def test_model_validate_fails(self):
model = ModelImpl(attr="string")
Expand Down Expand Up @@ -63,3 +96,24 @@ def test_from_json_x(self):
data = "{}{}"
with self.assertRaises(BaseModelError):
ModelImpl.from_json(data)

def test_model_with_unknown(self):
model = ModelImplWithUnknown(attr="succeeds")
model = model.validate()
assert model.attr == "succeeds"

model = ModelImplWithUnknown.deserialize(
{"attr": "succeeds", "another": "value"}
)
assert model.extra
assert model.extra["another"] == "value"
assert model.attr == "succeeds"

def test_model_without_unknown_default_exclude(self):
model = ModelImplWithoutUnknown(attr="succeeds")
model = model.validate()
assert model.attr == "succeeds"

assert ModelImplWithoutUnknown.deserialize(
{"attr": "succeeds", "another": "value"}
)
7 changes: 5 additions & 2 deletions aries_cloudagent/utils/classloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from importlib import import_module
from importlib.util import find_spec, resolve_name
from types import ModuleType
from typing import Sequence, Type
from typing import Optional, Sequence, Type

from ..core.error import BaseError

Expand Down Expand Up @@ -75,7 +75,10 @@ def load_module(cls, mod_path: str, package: str = None) -> ModuleType:

@classmethod
def load_class(
cls, class_name: str, default_module: str = None, package: str = None
cls,
class_name: str,
default_module: Optional[str] = None,
package: Optional[str] = None,
):
"""
Resolve a complete class path (ie. typing.Dict) to the class itself.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ markupsafe==2.0.1
marshmallow==3.5.1
msgpack~=1.0
prompt_toolkit~=2.0.9
pynacl~=1.4.0
pynacl~=1.5.0
requests~=2.25.0
packaging~=20.4
pyld~=2.0.3
Expand Down

0 comments on commit 38252c0

Please sign in to comment.