Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add metadata to QuAM #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 78 additions & 22 deletions quam/core/quam_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def convert_dict_and_list(value, cls_or_obj=None, attr=None):
return value


def sort_quam_components(components: List["QuamComponent"], max_attempts=5) -> List["QuamComponent"]:
def sort_quam_components(
components: List["QuamComponent"], max_attempts=5
) -> List["QuamComponent"]:
"""Sort QuamComponent objects based on their config_settings.

Args:
Expand Down Expand Up @@ -211,7 +213,8 @@ def __set__(self, instance, value):
if "parent" in instance.__dict__ and instance.__dict__["parent"] is not value:
cls = instance.__class__.__name__
raise AttributeError(
f"Cannot overwrite parent attribute of {cls}. " f"To modify {cls}.parent, first set {cls}.parent = None"
f"Cannot overwrite parent attribute of {cls}. "
f"To modify {cls}.parent, first set {cls}.parent = None"
)
instance.__dict__["parent"] = value

Expand Down Expand Up @@ -249,7 +252,10 @@ def __init__(self):
"Please create a subclass and make it a dataclass."
)
else:
raise TypeError(f"Cannot instantiate {self.__class__.__name__}. " "Please make it a dataclass.")
raise TypeError(
f"Cannot instantiate {self.__class__.__name__}. "
"Please make it a dataclass."
)

def _get_attr_names(self) -> List[str]:
"""Get names of all dataclass attributes of this object.
Expand Down Expand Up @@ -280,7 +286,9 @@ def get_attr_name(self, attr_val: Any) -> str:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _attr_val_is_default(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -334,6 +342,17 @@ def _val_matches_attr_annotation(cls, attr: str, val: Any) -> bool:
return isinstance(val, (list, QuamList))
return type(val) == required_type

def get_metadata(self, attr=None):
if isinstance(self, QuamRoot):
return self._metadata.setdefault("#/", {})

reference = self.get_reference(attr=attr)
if reference is None:
raise AttributeError(
"Unable to extract reference path. Parent must be defined for {self}"
)
return self._root._metadata.setdefault(reference, {})

def get_reference(self, attr=None) -> Optional[str]:
"""Get the reference path of this object.

Expand All @@ -346,13 +365,17 @@ def get_reference(self, attr=None) -> Optional[str]:
"""

if self.parent is None:
raise AttributeError("Unable to extract reference path. Parent must be defined for {self}")
raise AttributeError(
"Unable to extract reference path. Parent must be defined for {self}"
)
reference = f"{self.parent.get_reference()}/{self.parent.get_attr_name(self)}"
if attr is not None:
reference = f"{reference}/{attr}"
return reference

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
"""Get all attributes and corresponding values of this object.

Args:
Expand All @@ -376,10 +399,16 @@ def get_attrs(self, follow_references: bool = False, include_defaults: bool = Tr
attrs = {attr: getattr(self, attr) for attr in attr_names}

if not include_defaults:
attrs = {attr: val for attr, val in attrs.items() if not self._attr_val_is_default(attr, val)}
attrs = {
attr: val
for attr, val in attrs.items()
if not self._attr_val_is_default(attr, val)
}
return attrs

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.

Args:
Expand All @@ -396,7 +425,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
`"__class__"` key will be added to the dictionary. This is to ensure
that the object can be reconstructed when loading from a file.
"""
attrs = self.get_attrs(follow_references=follow_references, include_defaults=include_defaults)
attrs = self.get_attrs(
follow_references=follow_references, include_defaults=include_defaults
)
quam_dict = {}
for attr, val in attrs.items():
if isinstance(val, QuamBase):
Expand All @@ -411,7 +442,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_dict[attr] = val
return quam_dict

def iterate_components(self, skip_elems: bool = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: bool = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand Down Expand Up @@ -473,12 +506,15 @@ def _get_referenced_value(self, reference: str) -> Any:

if string_reference.is_absolute_reference(reference) and self._root is None:
warnings.warn(
f"No QuamRoot initialized, cannot retrieve reference {reference}" f" from {self.__class__.__name__}"
f"No QuamRoot initialized, cannot retrieve reference {reference}"
f" from {self.__class__.__name__}"
)
return reference

try:
return string_reference.get_referenced_value(self, reference, root=self._root)
return string_reference.get_referenced_value(
self, reference, root=self._root
)
except ValueError as e:
try:
ref = f"{self.__class__.__name__}: {self.get_reference()}"
Expand Down Expand Up @@ -546,6 +582,7 @@ class QuamRoot(QuamBase):

def __post_init__(self):
QuamBase._root = self
self._metadata: Dict[str, Dict[str, Any]] = {}
super().__post_init__()

def __setattr__(self, name, value):
Expand Down Expand Up @@ -586,7 +623,9 @@ def save(
ignore=ignore,
)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> Dict[str, Any]:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> Dict[str, Any]:
"""Convert this object to a dictionary.

Args:
Expand Down Expand Up @@ -750,7 +789,9 @@ def __getitem__(self, i):
repr = f"{self.__class__.__name__}: {self.get_reference()}"
except Exception:
repr = self.__class__.__name__
raise KeyError(f"Could not get referenced value {elem} from {repr}") from e
raise KeyError(
f"Could not get referenced value {elem} from {repr}"
) from e
return elem

# Overriding methods from UserDict
Expand All @@ -774,7 +815,9 @@ def __repr__(self) -> str:
def _get_attr_names(self):
return list(self.data.keys())

def get_attrs(self, follow_references=False, include_defaults=True) -> Dict[str, Any]:
def get_attrs(
self, follow_references=False, include_defaults=True
) -> Dict[str, Any]:
# TODO implement reference kwargs
return self.data

Expand All @@ -795,7 +838,9 @@ def get_attr_name(self, attr_val: Any) -> Union[str, int]:
return attr_name
else:
raise AttributeError(
"Could not find name corresponding to attribute.\n" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute.\n"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def _val_matches_attr_annotation(self, attr: str, val: Any) -> bool:
Expand Down Expand Up @@ -841,10 +886,13 @@ def get_unreferenced_value(self, attr: str) -> bool:
return self.__dict__["data"][attr]
except KeyError as e:
raise AttributeError(
"Cannot get unreferenced value from attribute {attr} that does not" " exist in {self}"
"Cannot get unreferenced value from attribute {attr} that does not"
" exist in {self}"
) from e

def iterate_components(self, skip_elems: Sequence[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: Sequence[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand Down Expand Up @@ -969,10 +1017,14 @@ def get_attr_name(self, attr_val: Any) -> str:
return str(k)
else:
raise AttributeError(
"Could not find name corresponding to attribute" f"attribute: {attr_val}\n" f"obj: {self}"
"Could not find name corresponding to attribute"
f"attribute: {attr_val}\n"
f"obj: {self}"
)

def to_dict(self, follow_references: bool = False, include_defaults: bool = False) -> list:
def to_dict(
self, follow_references: bool = False, include_defaults: bool = False
) -> list:
"""Convert this object to a list, usually as part of a dictionary representation.

Args:
Expand Down Expand Up @@ -1006,7 +1058,9 @@ def to_dict(self, follow_references: bool = False, include_defaults: bool = Fals
quam_list.append(val)
return quam_list

def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["QuamBase", None, None]:
def iterate_components(
self, skip_elems: List[QuamBase] = None
) -> Generator["QuamBase", None, None]:
"""Iterate over all QuamBase objects in this object, including nested objects.

Args:
Expand All @@ -1027,7 +1081,9 @@ def iterate_components(self, skip_elems: List[QuamBase] = None) -> Generator["Qu
if isinstance(attr_val, QuamBase):
yield from attr_val.iterate_components(skip_elems=skip_elems)

def get_attrs(self, follow_references: bool = False, include_defaults: bool = True) -> Dict[str, Any]:
def get_attrs(
self, follow_references: bool = False, include_defaults: bool = True
) -> Dict[str, Any]:
raise NotImplementedError("QuamList does not have attributes")

def print_summary(self, indent: int = 0):
Expand Down