From 00b28610db1a7fc7c49367a2ef9e8534c4d5db53 Mon Sep 17 00:00:00 2001 From: Jimmy Jia Date: Mon, 19 Nov 2018 10:38:14 -0500 Subject: [PATCH] Inspect violated constraints --- flask_resty/view.py | 46 ++++++++++++++++++++++++++++++--------- tests/test_view_errors.py | 16 +++++++++++++- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/flask_resty/view.py b/flask_resty/view.py index a1089d3..386d8ee 100644 --- a/flask_resty/view.py +++ b/flask_resty/view.py @@ -3,6 +3,7 @@ import flask from flask.views import MethodView from marshmallow import fields +import sqlalchemy as sa from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Load from sqlalchemy.orm.exc import NoResultFound @@ -415,20 +416,45 @@ def commit(self): def resolve_integrity_error(self, error): original_error = error.orig - if ( - hasattr(original_error, 'pgcode') and - original_error.pgcode in ( + if hasattr(original_error, 'pgcode'): + if original_error.pgcode in ( '23502', # not_null_violation '23514', # check_violation - ) - ): - # Using the psycopg2 error code, we can tell that this was not from - # an integrity error that was not a conflict. This means there was - # a schema bug, so we emit an interal server error instead. - return error + ): + # Using the psycopg2 error code, we can tell that this was not + # from an integrity error that was not a conflict. This means + # there was a schema bug, so we emit an internal server error + # instead. + return error + else: + diag = original_error.diag + + schema_name = diag.schema_name + table_name = diag.table_name + constraint_name = diag.constraint_name + + insp = sa.inspect(self.session.bind) + unique_constraints = insp.get_unique_constraints( + table_name, + schema=schema_name, + ) + + for unique_constraint in unique_constraints: + if unique_constraint['name'] == constraint_name: + column_names = unique_constraint['column_names'] + break + else: + column_names = () + else: + column_names = () flask.current_app.logger.exception("handled integrity error") - return ApiError(409, {'code': 'invalid_data.conflict'}) + + error = {'code': 'invalid_data.conflict'} + if column_names: + error['source'] = {'pointer': '/data/{}'.format(column_names[-1])} + + return ApiError(409, error) def set_item_response_meta(self, item): super(ModelView, self).set_item_response_meta(item) diff --git a/tests/test_view_errors.py b/tests/test_view_errors.py index 1309045..9380c28 100644 --- a/tests/test_view_errors.py +++ b/tests/test_view_errors.py @@ -230,10 +230,24 @@ def test_integrity_error_conflict(client, path): }]) +@pytest.mark.parametrize('path', ('/widgets', '/widgets_flush')) +def test_integrity_error_conflict_with_source(db, client, path): + if db.engine.driver != 'psycopg2': + pytest.xfail("IntegrityError source detection requires psycopg2") + + response = client.post(path, data={ + 'name': "Foo", + }) + assert_response(response, 409, [{ + 'code': 'invalid_data.conflict', + 'source': {'pointer': '/data/name'}, + }]) + + @pytest.mark.parametrize('path', ('/widgets', '/widgets_flush')) def test_integrity_error_uncaught(db, app, client, path): if db.engine.driver != 'psycopg2': - pytest.xfail("IntegrityError cause detection only works with psycopg2") + pytest.xfail("IntegrityError cause detection requires psycopg2") app.testing = False