forked from apache/airflow
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AIP-72: Add "update TI state" endpoint for Execution API (apache#43602)
Part of apache#43586 This PR adds a new endpoint `/execution/{task_instance_id}/state` that will allow Updating the State of the TI from the worker. Some of the interesting changes / TILs were: (hat tip to @ashb for this) To streamline the data exchange between workers and the Task Execution API, this PR adds minified schemas for Task Instance updates i.e. focuses solely on the fields necessary for specific state transitions, reducing payload size and validations. Since our TaskInstance model is huge this also keeps it clean to focus on only those fields that matter for this case. The endpoint added in this PR also leverages Pydantic’s [discriminated unions](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions) to handle varying payload structures for each target state. This allows a single endpoint to receive different payloads (with different validations). For example: - `TIEnterRunningPayload`: Requires fields such as hostname, unixname, pid, and start_date to mark a task as RUNNING. - `TITerminalStatePayload`: Supports terminal states like SUCCESS, FAILED, SKIPPED, - `TITargetStatePayload`: Allows for other non-terminal, non-running states that a task may transition to. This is better so we don't have invalid payloads for example adding a start_date when a task is marked as SUCCESS, it doesn't make sense and it might be an error from the client! ![Nov-04-2024 20-00-26](https://github.com/user-attachments/assets/07c1a197-0238-4c1a-9783-f23dd74a8d3e) `fastapi` allows importing a handy `status` module from starlette which has status code and the reason in its name. Reference: https://fastapi.tiangolo.com/reference/status/ Example: `status.HTTP_204_NO_CONTENT` and `status.HTTP_409_CONFLICT` explain a lot more than just a "204 code" which doesn't tell much. I plan to change our current integers on public API to these in coming days. For now, I have assumed that we/the user don't care about `end_date` for `REMOVED` & `UPSTREAM_FAILED` status since they should be handled by the scheduler and shouldn't even show up on the worker. For `SKIPPED` state, since there are 2 scenarios: 1) A user can run the task and raise a `AirflowSkipException` 2) a task skipped on scheduler itself! For (1), we could set an end date, but (2) doesn't have it. - [ ] Pass a [RFC 9457](https://datatracker.ietf.org/doc/html/rfc9457) compliant error message in "detail" field of `HTTPException` to provide more information about the error - [ ] Add a separate heartbeat endpoint to track the TI’s active state. - [ ] Replace handling of `SQLAlchemyError` with FastAPI's [Custom Exception handling](https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers) across the Execution API endpoints. That way we don't need duplicate code across multiple endpoints. - [ ] Replace `None` state on TaskInstance with a `Created` state. ([link](https://github.com/orgs/apache/projects/405/views/1?pane=issue&itemId=85900878)) - [ ] Remove redundant code that also set's task type once we remove DB access from the worker. This is assuming that the Webserver or the new FastAPI endpoints don't use this endpoint.
- Loading branch information
Showing
10 changed files
with
540 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
from pydantic import AfterValidator, AwareDatetime | ||
from typing_extensions import Annotated | ||
|
||
from airflow.utils import timezone | ||
|
||
UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))] | ||
"""UTCDateTime is a datetime with timezone information""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
airflow/api_fastapi/execution_api/routes/task_instance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from uuid import UUID | ||
|
||
from fastapi import Body, Depends, HTTPException, status | ||
from sqlalchemy import update | ||
from sqlalchemy.exc import NoResultFound, SQLAlchemyError | ||
from sqlalchemy.orm import Session | ||
from sqlalchemy.sql import select | ||
from typing_extensions import Annotated | ||
|
||
from airflow.api_fastapi.common.db.common import get_session | ||
from airflow.api_fastapi.common.router import AirflowRouter | ||
from airflow.api_fastapi.execution_api import schemas | ||
from airflow.models.taskinstance import TaskInstance as TI | ||
from airflow.utils.state import State | ||
|
||
# TODO: Add dependency on JWT token | ||
ti_router = AirflowRouter( | ||
prefix="/task_instance", | ||
tags=["Task Instance"], | ||
) | ||
|
||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@ti_router.patch( | ||
"/{task_instance_id}/state", | ||
status_code=status.HTTP_204_NO_CONTENT, | ||
# TODO: Add Operation ID to control the function name in the OpenAPI spec | ||
# TODO: Do we need to use create_openapi_http_exception_doc here? | ||
responses={ | ||
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, | ||
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, | ||
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, | ||
}, | ||
) | ||
async def ti_update_state( | ||
task_instance_id: UUID, | ||
ti_patch_payload: Annotated[schemas.TIStateUpdate, Body()], | ||
session: Annotated[Session, Depends(get_session)], | ||
): | ||
""" | ||
Update the state of a TaskInstance. | ||
Not all state transitions are valid, and transitioning to some states required extra information to be | ||
passed along. (Check our the schemas for details, the rendered docs might not reflect this accurately) | ||
""" | ||
# We only use UUID above for validation purposes | ||
ti_id_str = str(task_instance_id) | ||
|
||
old = select(TI.state).where(TI.id == ti_id_str).with_for_update() | ||
try: | ||
(previous_state,) = session.execute(old).one() | ||
except NoResultFound: | ||
log.error("Task Instance %s not found", ti_id_str) | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail={ | ||
"reason": "not_found", | ||
"message": "Task Instance not found", | ||
}, | ||
) | ||
|
||
# We exclude_unset to avoid updating fields that are not set in the payload | ||
data = ti_patch_payload.model_dump(exclude_unset=True) | ||
|
||
query = update(TI).where(TI.id == ti_id_str).values(data) | ||
|
||
if isinstance(ti_patch_payload, schemas.TIEnterRunningPayload): | ||
if previous_state != State.QUEUED: | ||
log.warning( | ||
"Can not start Task Instance ('%s') in invalid state: %s", | ||
ti_id_str, | ||
previous_state, | ||
) | ||
|
||
# TODO: Pass a RFC 9457 compliant error message in "detail" field | ||
# https://datatracker.ietf.org/doc/html/rfc9457 | ||
# to provide more information about the error | ||
# FastAPI will automatically convert this to a JSON response | ||
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 | ||
raise HTTPException( | ||
status_code=status.HTTP_409_CONFLICT, | ||
detail={ | ||
"reason": "invalid_state", | ||
"message": "TI was not in a state where it could be marked as running", | ||
"previous_state": previous_state, | ||
}, | ||
) | ||
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname) | ||
# Ensure there is no end date set. | ||
query = query.values( | ||
end_date=None, | ||
hostname=ti_patch_payload.hostname, | ||
unixname=ti_patch_payload.unixname, | ||
pid=ti_patch_payload.pid, | ||
state=State.RUNNING, | ||
) | ||
elif isinstance(ti_patch_payload, schemas.TITerminalStatePayload): | ||
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) | ||
|
||
# TODO: Replace this with FastAPI's Custom Exception handling: | ||
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers | ||
try: | ||
result = session.execute(query) | ||
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount) | ||
except SQLAlchemyError as e: | ||
log.error("Error updating Task Instance state: %s", e) | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Annotated, Literal, Union | ||
|
||
from pydantic import ( | ||
BaseModel, | ||
ConfigDict, | ||
Discriminator, | ||
Field, | ||
Tag, | ||
WithJsonSchema, | ||
) | ||
|
||
from airflow.api_fastapi.common.types import UtcDateTime | ||
from airflow.utils.state import State, TaskInstanceState as TIState | ||
|
||
|
||
class TIEnterRunningPayload(BaseModel): | ||
"""Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.""" | ||
|
||
model_config = ConfigDict(from_attributes=True) | ||
|
||
state: Annotated[ | ||
Literal[TIState.RUNNING], | ||
# Specify a default in the schema, but not in code, so Pydantic marks it as required. | ||
WithJsonSchema({"enum": [TIState.RUNNING], "default": TIState.RUNNING}), | ||
] | ||
hostname: str | ||
"""Hostname where this task has started""" | ||
unixname: str | ||
"""Local username of the process where this task has started""" | ||
pid: int | ||
"""Process Identifier on `hostname`""" | ||
start_date: UtcDateTime | ||
"""When the task started executing""" | ||
|
||
|
||
class TITerminalStatePayload(BaseModel): | ||
"""Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).""" | ||
|
||
state: Annotated[ | ||
Literal[TIState.SUCCESS, TIState.FAILED, TIState.SKIPPED], | ||
Field(title="TerminalState"), | ||
WithJsonSchema({"enum": list(State.ran_and_finished_states)}), | ||
] | ||
|
||
end_date: UtcDateTime | ||
"""When the task completed executing""" | ||
|
||
|
||
class TITargetStatePayload(BaseModel): | ||
"""Schema for updating TaskInstance to a target state, excluding terminal and running states.""" | ||
|
||
state: Annotated[ | ||
TIState, | ||
# For the OpenAPI schema generation, | ||
# make sure we do not include RUNNING as a valid state here | ||
WithJsonSchema( | ||
{ | ||
"enum": [ | ||
state for state in TIState if state not in (State.ran_and_finished_states | {State.NONE}) | ||
] | ||
} | ||
), | ||
] | ||
|
||
|
||
def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: | ||
""" | ||
Determine the discriminator key for TaskInstance state transitions. | ||
This function serves as a discriminator for the TIStateUpdate union schema, | ||
categorizing the payload based on the ``state`` attribute in the input data. | ||
It returns a key that directs FastAPI to the appropriate subclass (schema) | ||
based on the requested state. | ||
""" | ||
if isinstance(v, dict): | ||
state = v.get("state") | ||
else: | ||
state = getattr(v, "state", None) | ||
if state == TIState.RUNNING: | ||
return str(state) | ||
elif state in State.ran_and_finished_states: | ||
return "_terminal_" | ||
return "_other_" | ||
|
||
|
||
# It is called "_terminal_" to avoid future conflicts if we added an actual state named "terminal" | ||
# and "_other_" is a catch-all for all other states that are not covered by the other schemas. | ||
TIStateUpdate = Annotated[ | ||
Union[ | ||
Annotated[TIEnterRunningPayload, Tag("running")], | ||
Annotated[TITerminalStatePayload, Tag("_terminal_")], | ||
Annotated[TITargetStatePayload, Tag("_other_")], | ||
], | ||
Discriminator(ti_state_discriminator), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
|
||
from airflow.api_fastapi.app import cached_app | ||
|
||
|
||
@pytest.fixture | ||
def client(): | ||
return TestClient(cached_app(apps="execution")) |
Oops, something went wrong.