Skip to content

Commit

Permalink
Simplify code a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Tom committed May 25, 2024
1 parent 58a8d7c commit 2db1604
Showing 1 changed file with 57 additions and 51 deletions.
108 changes: 57 additions & 51 deletions src/pydantic_avro/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,90 @@
}


def string_type_handler(t: str) -> str:
if t == "string":
def string_type_handler(t: dict) -> str:
if t['type'] == "string":
return "str"


def int_type_handler(t: str) -> str:
if t == "int":
def int_type_handler(t: dict) -> str:
if t['type'] == "int":
return "int"


def long_type_handler(t: str) -> str:
if t == "long":
def long_type_handler(t: dict) -> str:
if t['type'] == "long":
return "int"


def boolean_type_handler(t: str) -> str:
if t == "boolean":
def boolean_type_handler(t: dict) -> str:
if t['type'] == "boolean":
return "bool"


def double_type_handler(t: str) -> str:
if t == "double":
def double_type_handler(t: dict) -> str:
if t['type'] == "double":
return "float"


def float_type_handler(t: str) -> str:
if t == "float":
def float_type_handler(t: dict) -> str:
if t['type'] == "float":
return "float"


def bytes_type_handler(t: str) -> str:
if t == "bytes":
def bytes_type_handler(t: dict) -> str:
if t['type'] == "bytes":
return "bytes"


def list_type_handler(t: list) -> str:
if "null" in t and len(t) == 2:
c = t.copy()
def list_type_handler(t: dict) -> str:
l = t['type']
if "null" in l and len(l) == 2:
c = l.copy()
c.remove("null")
return f"Optional[{get_pydantic_type(c[0])}]"
if "null" in t:
return f"Optional[Union[{','.join([get_pydantic_type(e) for e in t if e != 'null'])}]]"
return f"Union[{','.join([get_pydantic_type(e) for e in t])}]"
if "null" in l:
return f"Optional[Union[{','.join([get_pydantic_type(e) for e in l if e != 'null'])}]]"
return f"Union[{','.join([get_pydantic_type(e) for e in l])}]"


def map_type_handler(t: dict) -> str:
if isinstance(t["type"], dict):
value_type = get_pydantic_type(t["type"].get("values"))
return f"Dict[str, {value_type}]"

value_type = get_pydantic_type(t.get("values"))
return f"Dict[str, {value_type}]"


def logical_type_handler(t: dict) -> str:
if isinstance(t["type"], dict):
return LOGICAL_TYPES.get(t["type"].get("logicalType"))
return LOGICAL_TYPES.get(t.get("logicalType"))


def enum_type_handler(t: dict) -> str:
name = t.get("name")
name = t["type"].get("name")
if not ClassRegistry().has_class(name):
enum_class = f"class {name}(str, Enum):\n"
for s in t.get("symbols"):
for s in t["type"].get("symbols"):
enum_class += f' {s} = "{s}"\n'
ClassRegistry().add_class(name, enum_class)
return name


def array_type_handler(t: dict) -> str:
sub_type = get_pydantic_type(t.get("items"))
if isinstance(t["type"], dict):
sub_type = get_pydantic_type(t["type"].get("items"))
else:
sub_type = get_pydantic_type(t.get("items"))
return f"List[{sub_type}]"


def record_type_handler(schema: dict) -> str:
name = schema["name"]
def record_type_handler(t: dict) -> str:
t = t["type"] if isinstance(t["type"], dict) else t
name = t["name"]
current = f"class {name}(BaseModel):\n"

for field in schema["fields"]:
fields = t["fields"] if "fields" in t else t["type"]["fields"]
for field in fields:
n = field["name"]
t = get_pydantic_type(field)
default = field.get("default")
Expand All @@ -102,7 +112,7 @@ def record_type_handler(schema: dict) -> str:
current += f" {n}: {t} = {default}\n"
else:
current += f" {n}: {t} = {json.dumps(default)}\n"
if len(schema["fields"]) == 0:
if len(fields) == 0:
current += " pass\n"

ClassRegistry().add_class(name, current)
Expand All @@ -125,35 +135,31 @@ def record_type_handler(schema: dict) -> str:
"record": record_type_handler,
}

TYPE_VALUE_IS_DICT = ["record", "enum", "array", "map", "logical"]


def get_pydantic_type(schema: dict) -> str:
if isinstance(schema, str):
t = schema
else:
t = schema["type"]

if isinstance(t, str) and ClassRegistry().has_class(t):
return t

handler = get_handler(t)
def get_pydantic_type(t: str | dict | list) -> str:
if isinstance(t, str):
t = {"type": t}

if handler is None:
raise NotImplementedError(f"Type {t} not supported yet")
if isinstance(t['type'], str) and ClassRegistry().has_class(t["type"]):
return t["type"]

if t in TYPE_VALUE_IS_DICT:
return handler(schema)
return get_handler(t)(t)

return handler(t)


def get_handler(t: str | dict | list) -> callable:
def get_handler(t: dict) -> callable:
h= None
t = t["type"]
if isinstance(t, str):
return TYPE_HANDLERS.get(t)
h = TYPE_HANDLERS.get(t)
elif isinstance(t, dict) and "logicalType" in t:
return TYPE_HANDLERS.get("logical")
h= TYPE_HANDLERS.get("logical")
elif isinstance(t, dict) and "type" in t:
return TYPE_HANDLERS.get(t["type"])
h= TYPE_HANDLERS.get(t["type"])
elif isinstance(t, list):
return TYPE_HANDLERS.get("list")
h= TYPE_HANDLERS.get("list")

if h:
return h

raise NotImplementedError(f"Type {t} not supported yet")

0 comments on commit 2db1604

Please sign in to comment.