diff --git a/examples/sp.py b/examples/sp.py index 6409321..d344a06 100755 --- a/examples/sp.py +++ b/examples/sp.py @@ -5,16 +5,9 @@ from tests.idp.base import CERTIFICATE as IDP_CERTIFICATE from tests.sp.base import CERTIFICATE, PRIVATE_KEY - -class ExampleServiceProvider(ServiceProvider): - def get_logout_return_url(self): - return url_for('index', _external=True) - - def get_default_login_return_url(self): - return url_for('index', _external=True) - - -sp = ExampleServiceProvider() +sp = ServiceProvider() +sp.default_login_return_endpoint = 'index' +sp.logout_return_endpoint = 'index' app = Flask(__name__) app.debug = True diff --git a/flask_saml2/sp/sp.py b/flask_saml2/sp/sp.py index 7535958..f57b545 100644 --- a/flask_saml2/sp/sp.py +++ b/flask_saml2/sp/sp.py @@ -34,6 +34,21 @@ class ServiceProvider: #: The name of the blueprint to generate. blueprint_name = 'flask_saml2_sp' + #: Set this to http or https + scheme = 'http' + + #: Set these to your desired endpoints + logout_return_endpoint = None + default_login_return_endpoint = None + acs_redirect_endpoint = None + + """ + Set this value to override the default metadata return value + of :meth: `get_sp_entity_id`. By setting this, you can return + only the entity_id value, rather than the url to the full metadata xml. + """ + entity_id = None + def login_successful( self, auth_data: AuthData, @@ -83,7 +98,10 @@ def get_sp_entity_id(self) -> str: See :func:`get_metadata_url`. """ - return self.get_metadata_url() + if self.entity_id is None: + return self.get_metadata_url() + else: + return self.entity_id def get_sp_certificate(self) -> Optional[X509]: """Get the public certificate for this SP.""" @@ -156,6 +174,8 @@ def get_metadata_url(self) -> str: def get_default_login_return_url(self) -> Optional[str]: """The default URL to redirect users to once the have logged in. """ + if self.default_login_return_endpoint is not None: + return url_for(self.default_login_return_endpoint, _external=True) return None def get_login_return_url(self) -> Optional[str]: @@ -177,6 +197,8 @@ def get_login_return_url(self) -> Optional[str]: def get_logout_return_url(self) -> Optional[str]: """The URL to redirect users to once they have logged out. """ + if self.logout_return_endpoint is not None: + return url_for(self.logout_return_endpoint, _external=True) return None def is_valid_redirect_url(self, url: str) -> str: @@ -306,22 +328,35 @@ def get_metadata_context(self) -> dict: 'contacts': [], } - def create_blueprint(self) -> Blueprint: + def get_scheme(self) -> str: + return self.scheme + + def get_acs_redirect_endpoint(self) -> str: + return self.acs_redirect_endpoint + + # With acs_redirect_url, you can set the url that the Access Consumer Service redirects to upon successful login + # This is unnecessary if you expect a "relay_state" parameter in the SAML request to the ACS + def create_blueprint(self, login_endpoint='/login/', login_idp_endpoint='/login/idp/', + logout_endpoint='/logout/', acs_endpoint='/acs/', sls_endpoint='/sls/', + metadata_endpoint='/metadata.xml', scheme='http') -> Blueprint: + """Create a Flask :class:`flask.Blueprint` for this Service Provider. """ + self.scheme = scheme + idp_bp = Blueprint(self.blueprint_name, 'flask_saml2.sp', template_folder='templates') - idp_bp.add_url_rule('/login/', view_func=Login.as_view( + idp_bp.add_url_rule(login_endpoint, view_func=Login.as_view( 'login', sp=self)) - idp_bp.add_url_rule('/login/idp/', view_func=LoginIdP.as_view( + idp_bp.add_url_rule(login_idp_endpoint, view_func=LoginIdP.as_view( 'login_idp', sp=self)) - idp_bp.add_url_rule('/logout/', view_func=Logout.as_view( + idp_bp.add_url_rule(logout_endpoint, view_func=Logout.as_view( 'logout', sp=self)) - idp_bp.add_url_rule('/acs/', view_func=AssertionConsumer.as_view( + idp_bp.add_url_rule(acs_endpoint, view_func=AssertionConsumer.as_view( 'acs', sp=self)) - idp_bp.add_url_rule('/sls/', view_func=SingleLogout.as_view( + idp_bp.add_url_rule(sls_endpoint, view_func=SingleLogout.as_view( 'sls', sp=self)) - idp_bp.add_url_rule('/metadata.xml', view_func=Metadata.as_view( + idp_bp.add_url_rule(metadata_endpoint, view_func=Metadata.as_view( 'metadata', sp=self)) idp_bp.register_error_handler(CannotHandleAssertion, CannotHandleAssertionView.as_view( diff --git a/flask_saml2/sp/views.py b/flask_saml2/sp/views.py index 1c116b6..622a310 100644 --- a/flask_saml2/sp/views.py +++ b/flask_saml2/sp/views.py @@ -28,7 +28,8 @@ def get(self): handler = self.sp.get_default_idp_handler() login_next = self.sp.get_login_return_url() if handler: - return redirect(url_for('.login_idp', entity_id=handler.entity_id, next=login_next)) + return redirect(url_for('.login_idp', entity_id=handler.entity_id, next=login_next, + _scheme=self.sp.get_scheme(), _external=True)) return self.sp.render_template( 'flask_saml2_sp/choose_idp.html', login_next=login_next, @@ -79,7 +80,10 @@ def do_logout(self, handler): class AssertionConsumer(SAML2View): def post(self): saml_request = request.form['SAMLResponse'] - relay_state = request.form['RelayState'] + if self.sp.get_acs_redirect_endpoint() is None: + relay_state = request.form['RelayState'] + else: + relay_state = self.sp.make_absolute_url(self.sp.get_acs_redirect_endpoint()) for handler in self.sp.get_idp_handlers(): try: