Skip to content

Commit

Permalink
fix issues from classes with __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Jun 25, 2024
1 parent 57fa8d6 commit 7105519
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions src/umlizer/class_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Create graphviz for classes."""
from __future__ import annotations

import ast
import copy
import dataclasses
import glob
import importlib.util
import inspect
import os
import sys
import textwrap
import types

from pathlib import Path
Expand Down Expand Up @@ -117,13 +119,48 @@ def _get_annotations(klass: Type[Any]) -> dict[str, Any]:
return {k: _get_fullname(v) for k, v in annotations.items()}


def _get_classic_class_structure(
klass: Type[Any],
) -> ClassDef:
_methods = _get_methods(klass)
def _get_init_attributes(klass: Type[Any]) -> dict[str, str]:
"""Extract attributes declared in the __init__ method using `self`."""
attributes: dict[str, str] = {}
init_method = klass.__dict__.get('__init__')

klass_anno = _get_annotations(klass)
if not init_method or not isinstance(init_method, types.FunctionType):
return attributes

source_lines, _ = inspect.getsourcelines(init_method)
source_code = textwrap.dedent(''.join(source_lines))
tree = ast.parse(source_code)

for node in ast.walk(tree):
if isinstance(node, ast.AnnAssign):
target = node.target
if (
isinstance(target, ast.Attribute)
and isinstance(target.value, ast.Name)
and target.value.id == 'self'
):
attr_name = target.attr
attr_type = 'Any' # Default type if not explicitly typed

# Try to get the type from the annotation if it exists
if isinstance(node.value, ast.Name):
attr_type = node.annotation.id # type: ignore[attr-defined]
elif isinstance(node.value, ast.Call) and isinstance(
node.value.func, ast.Name
):
attr_type = node.value.func.annotation.id # type: ignore[attr-defined]
elif isinstance(node.value, ast.Constant):
attr_type = type(node.value.value).__name__

attributes[attr_name] = attr_type

return attributes


def _get_classic_class_structure(klass: Type[Any]) -> ClassDef:
"""Get the structure of a classic (non-dataclass) class."""
_methods = _get_methods(klass)
klass_anno = _get_annotations(klass)
fields = {}

for k in list(klass.__dict__.keys()):
Expand All @@ -133,18 +170,8 @@ def _get_classic_class_structure(
fields[k] = getattr(value, '__value__', str(value))

if not fields:
# maybe the attributes are created in the `__init__` method.
try:
obj_tmp = klass()
obj_anno = _get_annotations(obj_tmp)

for k in list(obj_tmp.__dict__.keys()):
if k.startswith('__') or k in _methods:
continue
value = obj_anno.get(k, '')
fields[k] = getattr(value, '__value__', str(value))
except Exception as e:
print(e)
# Extract attributes from the `__init__` method if defined there.
fields = _get_init_attributes(klass)

return ClassDef(
fields=fields,
Expand Down

0 comments on commit 7105519

Please sign in to comment.