diff --git a/src/pydantic_avro/type_handlers.py b/src/pydantic_avro/type_handlers.py index f0d8679..03bafa1 100644 --- a/src/pydantic_avro/type_handlers.py +++ b/src/pydantic_avro/type_handlers.py @@ -13,80 +13,90 @@ } -def string_type_handler(t: str) -> str: - if t == "string": +def string_type_handler(t: dict) -> str: + if t['type'] == "string": return "str" -def int_type_handler(t: str) -> str: - if t == "int": +def int_type_handler(t: dict) -> str: + if t['type'] == "int": return "int" -def long_type_handler(t: str) -> str: - if t == "long": +def long_type_handler(t: dict) -> str: + if t['type'] == "long": return "int" -def boolean_type_handler(t: str) -> str: - if t == "boolean": +def boolean_type_handler(t: dict) -> str: + if t['type'] == "boolean": return "bool" -def double_type_handler(t: str) -> str: - if t == "double": +def double_type_handler(t: dict) -> str: + if t['type'] == "double": return "float" -def float_type_handler(t: str) -> str: - if t == "float": +def float_type_handler(t: dict) -> str: + if t['type'] == "float": return "float" -def bytes_type_handler(t: str) -> str: - if t == "bytes": +def bytes_type_handler(t: dict) -> str: + if t['type'] == "bytes": return "bytes" -def list_type_handler(t: list) -> str: - if "null" in t and len(t) == 2: - c = t.copy() +def list_type_handler(t: dict) -> str: + l = t['type'] + if "null" in l and len(l) == 2: + c = l.copy() c.remove("null") return f"Optional[{get_pydantic_type(c[0])}]" - if "null" in t: - return f"Optional[Union[{','.join([get_pydantic_type(e) for e in t if e != 'null'])}]]" - return f"Union[{','.join([get_pydantic_type(e) for e in t])}]" + if "null" in l: + return f"Optional[Union[{','.join([get_pydantic_type(e) for e in l if e != 'null'])}]]" + return f"Union[{','.join([get_pydantic_type(e) for e in l])}]" def map_type_handler(t: dict) -> str: + if isinstance(t["type"], dict): + value_type = get_pydantic_type(t["type"].get("values")) + return f"Dict[str, {value_type}]" + value_type = get_pydantic_type(t.get("values")) return f"Dict[str, {value_type}]" def logical_type_handler(t: dict) -> str: + if isinstance(t["type"], dict): + return LOGICAL_TYPES.get(t["type"].get("logicalType")) return LOGICAL_TYPES.get(t.get("logicalType")) - def enum_type_handler(t: dict) -> str: - name = t.get("name") + name = t["type"].get("name") if not ClassRegistry().has_class(name): enum_class = f"class {name}(str, Enum):\n" - for s in t.get("symbols"): + for s in t["type"].get("symbols"): enum_class += f' {s} = "{s}"\n' ClassRegistry().add_class(name, enum_class) return name def array_type_handler(t: dict) -> str: - sub_type = get_pydantic_type(t.get("items")) + if isinstance(t["type"], dict): + sub_type = get_pydantic_type(t["type"].get("items")) + else: + sub_type = get_pydantic_type(t.get("items")) return f"List[{sub_type}]" -def record_type_handler(schema: dict) -> str: - name = schema["name"] +def record_type_handler(t: dict) -> str: + t = t["type"] if isinstance(t["type"], dict) else t + name = t["name"] current = f"class {name}(BaseModel):\n" - - for field in schema["fields"]: + fields = t["fields"] if "fields" in t else t["type"]["fields"] + for field in fields: n = field["name"] t = get_pydantic_type(field) default = field.get("default") @@ -102,7 +112,7 @@ def record_type_handler(schema: dict) -> str: current += f" {n}: {t} = {default}\n" else: current += f" {n}: {t} = {json.dumps(default)}\n" - if len(schema["fields"]) == 0: + if len(fields) == 0: current += " pass\n" ClassRegistry().add_class(name, current) @@ -125,35 +135,31 @@ def record_type_handler(schema: dict) -> str: "record": record_type_handler, } -TYPE_VALUE_IS_DICT = ["record", "enum", "array", "map", "logical"] - -def get_pydantic_type(schema: dict) -> str: - if isinstance(schema, str): - t = schema - else: - t = schema["type"] - - if isinstance(t, str) and ClassRegistry().has_class(t): - return t - - handler = get_handler(t) +def get_pydantic_type(t: str | dict | list) -> str: + if isinstance(t, str): + t = {"type": t} - if handler is None: - raise NotImplementedError(f"Type {t} not supported yet") + if isinstance(t['type'], str) and ClassRegistry().has_class(t["type"]): + return t["type"] - if t in TYPE_VALUE_IS_DICT: - return handler(schema) + return get_handler(t)(t) - return handler(t) -def get_handler(t: str | dict | list) -> callable: +def get_handler(t: dict) -> callable: + h= None + t = t["type"] if isinstance(t, str): - return TYPE_HANDLERS.get(t) + h = TYPE_HANDLERS.get(t) elif isinstance(t, dict) and "logicalType" in t: - return TYPE_HANDLERS.get("logical") + h= TYPE_HANDLERS.get("logical") elif isinstance(t, dict) and "type" in t: - return TYPE_HANDLERS.get(t["type"]) + h= TYPE_HANDLERS.get(t["type"]) elif isinstance(t, list): - return TYPE_HANDLERS.get("list") + h= TYPE_HANDLERS.get("list") + + if h: + return h + + raise NotImplementedError(f"Type {t} not supported yet")