From a466e367486997d4b4d9725416a15583f2f8ed6b Mon Sep 17 00:00:00 2001 From: Isabel Zimmerman <54685329+isabelizimm@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:34:01 -0500 Subject: [PATCH] fix: pinned `bool` columns show up as `None` (#225) * dont generalize falsiness * add test --- ...vetiver_model.py => test_vetiver_model.py} | 35 +++++++++++++++++++ vetiver/utils.py | 4 ++- 2 files changed, 38 insertions(+), 1 deletion(-) rename vetiver/tests/{test_build_vetiver_model.py => test_vetiver_model.py} (86%) diff --git a/vetiver/tests/test_build_vetiver_model.py b/vetiver/tests/test_vetiver_model.py similarity index 86% rename from vetiver/tests/test_build_vetiver_model.py rename to vetiver/tests/test_vetiver_model.py index 7f61739..ba84541 100644 --- a/vetiver/tests/test_build_vetiver_model.py +++ b/vetiver/tests/test_vetiver_model.py @@ -110,6 +110,41 @@ def test_vetiver_model_dict_like_prototype(prototype_data): assert json_schema == expected +@pytest.mark.skipif( + pydantic.__version__.startswith("1"), reason="only run for pydantic v2" +) +@pytest.mark.parametrize( + "prototype_data,expected", + [ + ( + {"B": 0, "C": False, "D": None}, + { + "properties": { + "B": {"example": 0, "title": "B", "type": "integer"}, + "C": {"example": False, "title": "C", "type": "boolean"}, + "D": {"example": None, "title": "D", "type": "null"}, + }, + "required": ["B", "C", "D"], + "title": "prototype", + "type": "object", + }, + ) + ], +) +def test_falsy_prototypes(prototype_data, expected): + v = VetiverModel( + model=model, + prototype_data=prototype_data, + model_name="model", + versioned=None, + description=None, + metadata=None, + ) + + assert isinstance(v.prototype.construct(), pydantic.BaseModel) + assert v.prototype.model_json_schema() == expected + + @pytest.mark.parametrize("prototype_data", [MockPrototype(B=4, C=0, D=0), None]) def test_vetiver_model_prototypes(prototype_data): v = VetiverModel( diff --git a/vetiver/utils.py b/vetiver/utils.py index 923dd79..c727cd5 100644 --- a/vetiver/utils.py +++ b/vetiver/utils.py @@ -64,6 +64,8 @@ def serialize_prototype(prototype): serialized_schema = dict() for key, value in schema.items(): - serialized_schema[key] = value.get("example") or value.get("default") + example = value.get("example", None) + default = value.get("default", None) + serialized_schema[key] = example if example is not None else default return json.dumps(serialized_schema)