diff --git a/providers/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/src/airflow/providers/amazon/aws/operators/ecs.py index 51dde9f75a30..be4da8d6f559 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/src/airflow/providers/amazon/aws/operators/ecs.py @@ -598,10 +598,10 @@ def _start_task(self): if self.capacity_provider_strategy: run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy - if self.volume_configurations is not None: - run_opts["volumeConfigurations"] = self.volume_configurations elif self.launch_type: run_opts["launchType"] = self.launch_type + if self.volume_configurations is not None: + run_opts["volumeConfigurations"] = self.volume_configurations if self.platform_version is not None: run_opts["platformVersion"] = self.platform_version if self.group is not None: diff --git a/providers/tests/amazon/aws/operators/test_ecs.py b/providers/tests/amazon/aws/operators/test_ecs.py index ed900acb7364..2129228c0a5e 100644 --- a/providers/tests/amazon/aws/operators/test_ecs.py +++ b/providers/tests/amazon/aws/operators/test_ecs.py @@ -192,13 +192,14 @@ def test_template_fields_overrides(self): ) @pytest.mark.parametrize( - "launch_type, capacity_provider_strategy,platform_version,tags,expected_args", + "launch_type, capacity_provider_strategy,platform_version,tags,volume_configurations,expected_args", [ [ "EC2", None, None, None, + None, {"launchType": "EC2"}, ], [ @@ -206,6 +207,7 @@ def test_template_fields_overrides(self): None, None, None, + None, {"launchType": "EXTERNAL"}, ], [ @@ -213,6 +215,7 @@ def test_template_fields_overrides(self): None, "LATEST", None, + None, {"launchType": "FARGATE", "platformVersion": "LATEST"}, ], [ @@ -220,6 +223,7 @@ def test_template_fields_overrides(self): None, None, {"testTagKey": "testTagValue"}, + None, {"launchType": "EC2", "tags": [{"key": "testTagKey", "value": "testTagValue"}]}, ], [ @@ -227,6 +231,7 @@ def test_template_fields_overrides(self): None, None, {"testTagKey": "testTagValue"}, + None, {"tags": [{"key": "testTagKey", "value": "testTagValue"}]}, ], [ @@ -234,6 +239,7 @@ def test_template_fields_overrides(self): {"capacityProvider": "FARGATE_SPOT"}, "LATEST", None, + None, { "capacityProviderStrategy": {"capacityProvider": "FARGATE_SPOT"}, "platformVersion": "LATEST", @@ -244,6 +250,7 @@ def test_template_fields_overrides(self): {"capacityProvider": "FARGATE_SPOT", "weight": 123, "base": 123}, "LATEST", None, + None, { "capacityProviderStrategy": { "capacityProvider": "FARGATE_SPOT", @@ -258,11 +265,70 @@ def test_template_fields_overrides(self): {"capacityProvider": "FARGATE_SPOT"}, "LATEST", None, + None, { "capacityProviderStrategy": {"capacityProvider": "FARGATE_SPOT"}, "platformVersion": "LATEST", }, ], + [ + "FARGATE", + None, + None, + None, + [ + { + "name": "ebs-volume", + "managedEBSVolume": { + "volumeType": "gp3", + "sizeInGiB": 10, + }, + "roleArn": "arn:aws:iam:1111222333:role/ecsInfrastructureRole", + } + ], + { + "launchType": "FARGATE", + "volumeConfigurations": [ + { + "name": "ebs-volume", + "managedEBSVolume": { + "volumeType": "gp3", + "sizeInGiB": 10, + }, + "roleArn": "arn:aws:iam:1111222333:role/ecsInfrastructureRole", + } + ], + }, + ], + [ + None, + {"capacityProvider": "FARGATE_SPOT"}, + None, + None, + [ + { + "name": "ebs-volume", + "managedEBSVolume": { + "volumeType": "gp3", + "sizeInGiB": 10, + }, + "roleArn": "arn:aws:iam:1111222333:role/ecsInfrastructureRole", + } + ], + { + "capacityProviderStrategy": {"capacityProvider": "FARGATE_SPOT"}, + "volumeConfigurations": [ + { + "name": "ebs-volume", + "managedEBSVolume": { + "volumeType": "gp3", + "sizeInGiB": 10, + }, + "roleArn": "arn:aws:iam:1111222333:role/ecsInfrastructureRole", + } + ], + }, + ], ], ) @mock.patch.object(EcsRunTaskOperator, "xcom_push") @@ -279,6 +345,7 @@ def test_execute_without_failures( capacity_provider_strategy, platform_version, tags, + volume_configurations, expected_args, ): self.set_up_operator( @@ -286,6 +353,7 @@ def test_execute_without_failures( capacity_provider_strategy=capacity_provider_strategy, platform_version=platform_version, tags=tags, + volume_configurations=volume_configurations, ) client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES