Skip to content

Commit

Permalink
Split next_dagruns_to_examine function into two (apache#42386)
Browse files Browse the repository at this point in the history
The behavior is different enough to merit two different functions.  In fact I noticed that we actually are using a bad index hint for the QUEUED case. And this becomes more apparent with introduction of backfill handling into scheduler, which is forthcoming.
  • Loading branch information
dstandish authored Sep 24, 2024
1 parent 4c8c72f commit 9ec8753
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 35 deletions.
10 changes: 3 additions & 7 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,9 +1218,10 @@ def _do_scheduling(self, session: Session) -> int:

self._start_queued_dagruns(session)
guard.commit()
dag_runs = self._get_next_dagruns_to_examine(DagRunState.RUNNING, session)

# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
dag_runs = DagRun.get_running_dag_runs_to_examine(session=session)

callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)

Expand Down Expand Up @@ -1274,11 +1275,6 @@ def _do_scheduling(self, session: Session) -> int:

return num_queued_tis

@retry_db_transaction
def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session) -> Query:
"""Get Next DagRuns to Examine with retries."""
return DagRun.next_dagruns_to_examine(state, session)

@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError."""
Expand Down Expand Up @@ -1512,7 +1508,7 @@ def _should_update_dag_next_dagruns(
def _start_queued_dagruns(self, session: Session) -> None:
"""Find DagRuns in queued state and decide moving them to running state."""
# added all() to save runtime, otherwise query is executed more than once
dag_runs: Collection[DagRun] = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session).all()
dag_runs: Collection[DagRun] = DagRun.get_queued_dag_runs_to_set_running(session).all()

active_runs_of_dags = Counter(
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
Expand Down
84 changes: 59 additions & 25 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from airflow.utils.dates import datetime_to_nano
from airflow.utils.helpers import chunks, is_container, prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
Expand Down Expand Up @@ -388,55 +389,88 @@ def active_runs_of_dags(
return dict(iter(session.execute(query)))

@classmethod
def next_dagruns_to_examine(
cls,
state: DagRunState,
session: Session,
max_number: int | None = None,
) -> Query:
@retry_db_transaction
def get_running_dag_runs_to_examine(cls, session: Session) -> Query:
"""
Return the next DagRuns that the scheduler should attempt to schedule.
This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.
:meta private:
"""
from airflow.models.dag import DagModel

if max_number is None:
max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

# TODO: Bake this query, it is run _A lot_
query = (
select(cls)
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
.where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.where(cls.state == DagRunState.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)
)
if state == DagRunState.QUEUED:
# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()

if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.where(DagRun.execution_date <= func.now())

return session.scalars(
with_row_locks(
query.limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE),
of=cls,
session=session,
skip_locked=True,
)
query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where(
func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
)

@classmethod
@retry_db_transaction
def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query:
"""
Return the next queued DagRuns that the scheduler should attempt to schedule.
This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.
:meta private:
"""
from airflow.models.dag import DagModel

# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()
)
query = (
select(cls)
.where(cls.state == DagRunState.QUEUED, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id)
.where(func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs)
.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)
query = query.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)

if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.where(DagRun.execution_date <= func.now())

return session.scalars(
with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True)
with_row_locks(
query.limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE),
of=cls,
session=session,
skip_locked=True,
)
)

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6036,7 +6036,9 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d
self.job_runner.processor_agent = mock_agent

with assert_queries_count(expected_query_count, margin=15):
with mock.patch.object(DagRun, "next_dagruns_to_examine") as mock_dagruns:
with mock.patch.object(
DagRun, DagRun.get_running_dag_runs_to_examine.__name__
) as mock_dagruns:
query = MagicMock()
query.all.return_value = dagruns
mock_dagruns.return_value = query
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,18 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state):
**triggered_by_kwargs,
)

runs = DagRun.next_dagruns_to_examine(state, session).all()
if state == DagRunState.RUNNING:
func = DagRun.get_running_dag_runs_to_examine
else:
func = DagRun.get_queued_dag_runs_to_set_running
runs = func(session).all()

assert runs == [dr]

orm_dag.is_paused = True
session.flush()

runs = DagRun.next_dagruns_to_examine(state, session).all()
runs = func(session).all()
assert runs == []

@mock.patch.object(Stats, "timing")
Expand Down

0 comments on commit 9ec8753

Please sign in to comment.