diff --git a/tests_e2e/test_cors.py b/tests_e2e/test_cors.py index 67630e67..2d977ccf 100644 --- a/tests_e2e/test_cors.py +++ b/tests_e2e/test_cors.py @@ -5,16 +5,71 @@ def test_cors(urls, auth_cookies): origin_host = "https://something.asf.alaska.edu" url = urls.join(urls.METADATA_FILE_CH) - origin_headers = {"origin": origin_host} + request_headers = {"origin": origin_host} - r = requests.get(url, cookies=auth_cookies, headers=origin_headers, allow_redirects=False) + r = requests.get( + url, + cookies=auth_cookies, + headers=request_headers, + allow_redirects=False, + ) headers = dict(r.headers) assert headers.get("Access-Control-Allow-Origin") == origin_host assert headers.get("Access-Control-Allow-Credentials") == "true" - headers = {"origin": "null"} - r = requests.get(url, cookies=auth_cookies, headers=headers, allow_redirects=False) + +def test_cors_origin_null(urls, auth_cookies): + url = urls.join(urls.METADATA_FILE_CH) + request_headers = {"origin": "null"} + r = requests.get( + url, + cookies=auth_cookies, + headers=request_headers, + allow_redirects=False, + ) + headers = dict(r.headers) + + assert headers.get("Access-Control-Allow-Origin") == "null" + + +def test_cors_preflight_options(urls, auth_cookies): + origin_host = "https://something.asf.alaska.edu" + + url = urls.join(urls.METADATA_FILE_CH) + request_headers = { + "Origin": origin_host, + "Access-Control-Request-Method": "GET" + } + + r = requests.options( + url, + cookies=auth_cookies, + headers=request_headers, + allow_redirects=False, + ) + headers = dict(r.headers) + + assert r.status_code == 204 + assert headers.get("Access-Control-Allow-Origin") == origin_host + assert "GET" in headers.get("Access-Control-Allow-Methods") + + +def test_cors_preflight_options_origin_null(urls, auth_cookies): + url = urls.join(urls.METADATA_FILE_CH) + request_headers = { + "Origin": "null", + "Access-Control-Request-Method": "GET" + } + + r = requests.options( + url, + cookies=auth_cookies, + headers=request_headers, + allow_redirects=False, + ) headers = dict(r.headers) + assert r.status_code == 204 assert headers.get("Access-Control-Allow-Origin") == "null" + assert "GET" in headers.get("Access-Control-Allow-Methods") diff --git a/thin_egress_app/app.py b/thin_egress_app/app.py index ee8216a7..88642e1b 100644 --- a/thin_egress_app/app.py +++ b/thin_egress_app/app.py @@ -459,17 +459,33 @@ def add_cors_headers(headers): # send CORS headers if we're configured to use them origin_header = app.current_request.headers.get("origin") - if origin_header is not None: + if is_cors_allowed(): + headers["Access-Control-Allow-Origin"] = origin_header + headers["Access-Control-Allow-Credentials"] = "true" + else: cors_origin = os.getenv("CORS_ORIGIN") - if cors_origin and (origin_header.endswith(cors_origin) or origin_header.lower() == "null"): - headers["Access-Control-Allow-Origin"] = origin_header - headers["Access-Control-Allow-Credentials"] = "true" - else: - log.warning( - "Origin %s is not an approved CORS host: %s", - origin_header, - cors_origin, - ) + log.warning( + "Origin %s is not an approved CORS host: %s", + origin_header, + cors_origin, + ) + + +def is_cors_allowed(): + assert app.current_request is not None + + # send CORS headers if we're configured to use them + origin_header = app.current_request.headers.get("origin") + cors_origin = os.getenv("CORS_ORIGIN") + + return bool( + origin_header + and cors_origin + and ( + origin_header.endswith(cors_origin) + or origin_header.lower() == "null" + ) + ) @with_trace() @@ -942,6 +958,35 @@ def try_download_head(bucket, filename): return make_redirect(presigned_url, {}, 303) +@app.route("/{proxy+}", methods=["OPTIONS"]) +@with_trace(context={}) +def dynamic_url_options(): + allowed_methods = [ + "GET", + "HEAD", + "OPTIONS", + ] + request_method = app.current_request.headers.get( + "Access-Control-Request-Method", + "", + ).strip() + if is_cors_allowed() and request_method in allowed_methods: + headers = { + "Access-Control-Allow-Methods": ", ".join(allowed_methods) + } + add_cors_headers(headers) + return Response( + body="", + headers=headers, + status_code=204, + ) + + return Response( + body="Method Not Allowed", + status_code=405, + ) + + # Attempt to validate HEAD request @app.route("/{proxy+}", methods=["HEAD"]) @with_trace(context={})