Skip to content

Commit

Permalink
AIP-72: Add "update TI state" endpoint for Execution API (apache#43602)
Browse files Browse the repository at this point in the history
Part of apache#43586

This PR adds a new endpoint `/execution/{task_instance_id}/state` that will allow Updating the State of the TI from the worker.

Some of the interesting changes / TILs were:

(hat tip to @ashb for this)

To streamline the data exchange between workers and the Task Execution API, this PR adds minified schemas for Task Instance updates i.e. focuses solely on the fields necessary for specific state transitions, reducing payload size and validations. Since our TaskInstance model is huge this also keeps it clean to focus on only those fields that matter for this case.

The endpoint added in this PR also leverages Pydantic’s [discriminated unions](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions) to handle varying payload structures for each target state. This allows a single endpoint to receive different payloads (with different validations). For example:

- `TIEnterRunningPayload`: Requires fields such as hostname, unixname, pid, and start_date to mark a task as RUNNING.
- `TITerminalStatePayload`: Supports terminal states like SUCCESS, FAILED, SKIPPED,
- `TITargetStatePayload`: Allows for other non-terminal, non-running states that a task may transition to.

This is better so we don't have invalid payloads for example adding a start_date when a task is marked as SUCCESS, it doesn't make sense and it might be an error from the client!

![Nov-04-2024 20-00-26](https://github.com/user-attachments/assets/07c1a197-0238-4c1a-9783-f23dd74a8d3e)

`fastapi` allows importing a handy `status` module from starlette which has status code and the reason in its name. Reference: https://fastapi.tiangolo.com/reference/status/
Example:

`status.HTTP_204_NO_CONTENT` and `status.HTTP_409_CONFLICT` explain a lot more than just a "204 code" which doesn't tell much. I plan to change our current integers on public API to these in coming days.

For now, I have assumed that we/the user don't care about `end_date` for `REMOVED` & `UPSTREAM_FAILED` status since they should be handled by the scheduler and shouldn't even show up on the worker. For `SKIPPED` state, since there are 2 scenarios: 1) A user can run the task and raise a `AirflowSkipException` 2) a task skipped on scheduler itself! For (1), we could set an end date, but (2) doesn't have it.

- [ ] Pass a [RFC 9457](https://datatracker.ietf.org/doc/html/rfc9457) compliant error message in "detail" field of `HTTPException` to provide more information about the error
- [ ] Add a separate heartbeat endpoint to track the TI’s active state.
- [ ] Replace handling of `SQLAlchemyError` with FastAPI's [Custom Exception handling](https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers) across the Execution API endpoints. That way we don't need duplicate code across multiple endpoints.
- [ ] Replace `None` state on TaskInstance with a `Created` state. ([link](https://github.com/orgs/apache/projects/405/views/1?pane=issue&itemId=85900878))
- [ ] Remove redundant code that also set's task type once we remove DB access from the worker. This is assuming that the Webserver or the new FastAPI endpoints don't use this endpoint.
  • Loading branch information
kaxil authored Nov 5, 2024
1 parent 9ede38a commit 3939d13
Show file tree
Hide file tree
Showing 10 changed files with 540 additions and 1 deletion.
25 changes: 25 additions & 0 deletions airflow/api_fastapi/common/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pydantic import AfterValidator, AwareDatetime
from typing_extensions import Annotated

from airflow.utils import timezone

UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))]
"""UTCDateTime is a datetime with timezone information"""
1 change: 1 addition & 0 deletions airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
"""Create FastAPI app for task execution API."""
from airflow.api_fastapi.execution_api.routes import execution_api_router

# TODO: Add versioning to the API
task_exec_api_app = FastAPI(
title="Airflow Task Execution API",
description="The private Airflow Task Execution API.",
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes.health import health_router
from airflow.api_fastapi.execution_api.routes.task_instance import ti_router

execution_api_router = AirflowRouter()
execution_api_router.include_router(health_router)
execution_api_router.include_router(ti_router)
2 changes: 1 addition & 1 deletion airflow/api_fastapi/execution_api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from airflow.api_fastapi.common.router import AirflowRouter

health_router = AirflowRouter(tags=["Task SDK"])
health_router = AirflowRouter(tags=["Health"])


@health_router.get("/health")
Expand Down
131 changes: 131 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from uuid import UUID

from fastapi import Body, Depends, HTTPException, status
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.sql import select
from typing_extensions import Annotated

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import schemas
from airflow.models.taskinstance import TaskInstance as TI
from airflow.utils.state import State

# TODO: Add dependency on JWT token
ti_router = AirflowRouter(
prefix="/task_instance",
tags=["Task Instance"],
)


log = logging.getLogger(__name__)


@ti_router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add Operation ID to control the function name in the OpenAPI spec
# TODO: Do we need to use create_openapi_http_exception_doc here?
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
},
)
async def ti_update_state(
task_instance_id: UUID,
ti_patch_payload: Annotated[schemas.TIStateUpdate, Body()],
session: Annotated[Session, Depends(get_session)],
):
"""
Update the state of a TaskInstance.
Not all state transitions are valid, and transitioning to some states required extra information to be
passed along. (Check our the schemas for details, the rendered docs might not reflect this accurately)
"""
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state,) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, schemas.TIEnterRunningPayload):
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
previous_state,
)

# TODO: Pass a RFC 9457 compliant error message in "detail" field
# https://datatracker.ietf.org/doc/html/rfc9457
# to provide more information about the error
# FastAPI will automatically convert this to a JSON response
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
},
)
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname)
# Ensure there is no end date set.
query = query.values(
end_date=None,
hostname=ti_patch_payload.hostname,
unixname=ti_patch_payload.unixname,
pid=ti_patch_payload.pid,
state=State.RUNNING,
)
elif isinstance(ti_patch_payload, schemas.TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)

# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)
114 changes: 114 additions & 0 deletions airflow/api_fastapi/execution_api/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Annotated, Literal, Union

from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
Field,
Tag,
WithJsonSchema,
)

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.utils.state import State, TaskInstanceState as TIState


class TIEnterRunningPayload(BaseModel):
"""Schema for updating TaskInstance to 'RUNNING' state with minimal required fields."""

model_config = ConfigDict(from_attributes=True)

state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema({"enum": [TIState.RUNNING], "default": TIState.RUNNING}),
]
hostname: str
"""Hostname where this task has started"""
unixname: str
"""Local username of the process where this task has started"""
pid: int
"""Process Identifier on `hostname`"""
start_date: UtcDateTime
"""When the task started executing"""


class TITerminalStatePayload(BaseModel):
"""Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED)."""

state: Annotated[
Literal[TIState.SUCCESS, TIState.FAILED, TIState.SKIPPED],
Field(title="TerminalState"),
WithJsonSchema({"enum": list(State.ran_and_finished_states)}),
]

end_date: UtcDateTime
"""When the task completed executing"""


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""

state: Annotated[
TIState,
# For the OpenAPI schema generation,
# make sure we do not include RUNNING as a valid state here
WithJsonSchema(
{
"enum": [
state for state in TIState if state not in (State.ran_and_finished_states | {State.NONE})
]
}
),
]


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Determine the discriminator key for TaskInstance state transitions.
This function serves as a discriminator for the TIStateUpdate union schema,
categorizing the payload based on the ``state`` attribute in the input data.
It returns a key that directs FastAPI to the appropriate subclass (schema)
based on the requested state.
"""
if isinstance(v, dict):
state = v.get("state")
else:
state = getattr(v, "state", None)
if state == TIState.RUNNING:
return str(state)
elif state in State.ran_and_finished_states:
return "_terminal_"
return "_other_"


# It is called "_terminal_" to avoid future conflicts if we added an actual state named "terminal"
# and "_other_" is a catch-all for all other states that are not covered by the other schemas.
TIStateUpdate = Annotated[
Union[
Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
],
Discriminator(ti_state_discriminator),
]
36 changes: 36 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
UniqueConstraint,
and_,
delete,
extract,
false,
func,
inspect,
Expand Down Expand Up @@ -151,7 +152,9 @@
from pathlib import PurePath
from types import TracebackType

from sqlalchemy.engine import Connection as SAConnection, Engine
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Update
from sqlalchemy.sql.elements import BooleanClauseList
from sqlalchemy.sql.expression import ColumnOperators

Expand Down Expand Up @@ -3843,6 +3846,39 @@ def clear_db_references(self, session: Session):
)
)

@classmethod
def duration_expression_update(
cls, end_date: datetime, query: Update, bind: Engine | SAConnection
) -> Update:
"""Return a SQL expression for calculating the duration of this TI, based on the start and end date columns."""
# TODO: Compare it with self._set_duration method

if bind.dialect.name == "sqlite":
return query.values(
{
"end_date": end_date,
"duration": (func.julianday(end_date) - func.julianday(cls.start_date)) * 86400,
}
)
elif bind.dialect.name == "postgresql":
return query.values(
{
"end_date": end_date,
"duration": extract("EPOCH", end_date - cls.start_date),
}
)

return query.values(
{
"end_date": end_date,
"duration": (
func.timestampdiff(text("MICROSECOND"), cls.start_date, end_date)
# Turn microseconds into floating point seconds.
/ 1_000_000
),
}
)


def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
Expand Down
9 changes: 9 additions & 0 deletions airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,12 @@ def color_fg(cls, state):
A list of states indicating that a task can be adopted or reset by a scheduler job
if it was queued by another scheduler job that is not running anymore.
"""

ran_and_finished_states = frozenset(
[TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED]
)
"""
A list of states indicating that a task has run and finished. This excludes states like
removed and upstream_failed. Skipped is included because a user can raise a
AirflowSkipException in a task and it will be marked as skipped.
"""
27 changes: 27 additions & 0 deletions tests/api_fastapi/execution_api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import cached_app


@pytest.fixture
def client():
return TestClient(cached_app(apps="execution"))
Loading

0 comments on commit 3939d13

Please sign in to comment.