diff --git a/ardere/aws.py b/ardere/aws.py index e72ec20..38a59d5 100644 --- a/ardere/aws.py +++ b/ardere/aws.py @@ -636,6 +636,27 @@ def all_services_ready(self, steps): results = executer.map(self.service_ready, steps) return all(results) + def service_done(self, step): + # type: (Dict[str, Any]) -> bool + """Query a service to return whether its fully drained and back to + INACTIVE""" + service_name = step["name"] + response = self._ecs_client.describe_services( + cluster=self._ecs_name, + services=[service_name] + ) + + service = response["services"][0] + return service["status"] == "INACTIVE" + + def all_services_done(self, steps): + # type: (List[Dict[str, Any]]) -> bool + """Queries all service ARN's in the plan to see if they're fully + DRAINED and now INACTIVE""" + with ThreadPoolExecutor(max_workers=8) as executer: + results = executer.map(self.service_done, steps) + return all(results) + def stop_finished_service(self, start_time, step): # type: (start_time, Dict[str, Any]) -> None """Stops a service if it needs to shutdown""" diff --git a/ardere/step_functions.py b/ardere/step_functions.py index 445231e..94a2730 100644 --- a/ardere/step_functions.py +++ b/ardere/step_functions.py @@ -345,9 +345,7 @@ def check_for_cluster_done(self): return self.event def cleanup_cluster(self): - """Shutdown all ECS services and deregister all task definitions - - """ + """Shutdown all ECS services and deregister all task definitions""" self.ecs.shutdown_plan(self.event["steps"]) # Attempt to remove the S3 object @@ -363,34 +361,8 @@ def cleanup_cluster(self): return self.event def check_drained(self): - """Ensure that all services are shut down before allowing restart - - """ - client = self.boto.client('ecs') - actives = client.list_container_instances( - cluster=self.event["ecs_name"], - maxResults=1, - status="ACTIVE", - ).get('containerInstanceArns', []) - # filter out metric servers - if self.event["metrics_options"]["enabled"]: - metrics = self.ecs.locate_metrics_service() - if metrics: - metrics_arn = metrics.get("serviceArn") - try: - actives.remove(metrics_arn) - except ValueError: - pass - if len(actives): - raise UndrainedInstancesException( - "Still active: {}.".format(actives)) - draining = len( - client.list_container_instances( - cluster=self.event["ecs_name"], - maxResults=1, - status="DRAINING", - ).get('containerInstanceArns', [])) - if draining: - raise UndrainedInstancesException( - "Still draining: {}.".format(draining)) - return self.event + """Ensure that all services are shut down before allowing restart""" + if self.ecs.all_services_done(self.event["steps"]): + return self.event + else: + raise UndrainedInstancesException("Services still draining") diff --git a/tests/test_aws.py b/tests/test_aws.py index 6c5d2ff..d9b1bb2 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -244,6 +244,38 @@ def test_all_services_ready(self): ecs.all_services_ready(ecs._plan["steps"]) ecs.service_ready.assert_called() + def test_service_done_true(self): + ecs = self._make_FUT() + step = ecs._plan["steps"][0] + + ecs._ecs_client.describe_services.return_value = { + "services": [{ + "status": "INACTIVE" + }] + } + + result = ecs.service_done(step) + eq_(result, True) + + def test_service_not_known(self): + ecs = self._make_FUT() + step = ecs._plan["steps"][0] + + ecs._ecs_client.describe_services.return_value = { + "services": [{ + "status": "DRAINING" + }] + } + + result = ecs.service_done(step) + eq_(result, False) + + def test_all_services_done(self): + ecs = self._make_FUT() + ecs.service_done = mock.Mock() + ecs.all_services_done(ecs._plan["steps"]) + ecs.service_done.assert_called() + def test_stop_finished_service_stopped(self): ecs = self._make_FUT() ecs._ecs_client.update_service = mock.Mock() diff --git a/tests/test_step_functions.py b/tests/test_step_functions.py index 669fc91..0f51cce 100644 --- a/tests/test_step_functions.py +++ b/tests/test_step_functions.py @@ -293,62 +293,14 @@ def test_cleanup_cluster_error(self): self.runner.cleanup_cluster() mock_s3.Object.assert_called() - def test_drain_check_active(self): - from ardere.exceptions import UndrainedInstancesException - - mock_client = mock.Mock() - mock_client.list_container_instances.return_value = { - 'containerInstanceArns': [ - 'Some-Arn-01234567890', - 'Metric-Arn-01234567890', - ], - "nextToken": "token-8675309" - } - self.mock_boto.client.return_value = mock_client - assert_raises(UndrainedInstancesException, - self.runner.check_drained) - def test_drain_check_draining(self): from ardere.exceptions import UndrainedInstancesException - - mock_client = mock.Mock() - mock_client.list_container_instances.side_effect = [ - {}, - { - 'containerInstanceArns': [ - 'Some-Arn-01234567890', - ], - "nextToken": "token-8675309" - } - ] - self.mock_boto.client.return_value = mock_client + self.mock_ecs.all_services_done.return_value = True + self.runner.check_drained() + self.mock_ecs.all_services_done.return_value = False assert_raises(UndrainedInstancesException, self.runner.check_drained) - def test_drain_check(self): - # Include a "metrics" instance to show that we ignore it. - self.plan["metrics_options"] = dict(enabled=True) - self.mock_ecs.locate_metrics_service.return_value = { - "deployments": [{ - "desiredCount": 1, - "runningCount": 1 - }], - "serviceArn": "Metric-Arn-01234567890" - } - - mock_client = mock.Mock() - mock_client.list_container_instances.side_effect = [ - { # Actives - 'containerInstanceArns': [ - 'Metric-Arn-01234567890', - ], - "nextToken": "token-8675309" - }, - {} # Draining - ] - self.mock_boto.client.return_value = mock_client - self.runner.check_drained() - class TestValidation(unittest.TestCase): def _make_FUT(self):