Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for displaying pandas DataFrame as an interactive table #1373

Merged
merged 9 commits into from
Oct 16, 2024
Merged
1 change: 1 addition & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chainlit.element import (
Audio,
Component,
Dataframe,
File,
Image,
Pdf,
Expand Down
1 change: 0 additions & 1 deletion backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import aiofiles
import aiohttp

from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
Expand Down
30 changes: 29 additions & 1 deletion backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@
}

ElementType = Literal[
"image", "text", "pdf", "tasklist", "audio", "video", "file", "plotly", "component"
"image",
"text",
"pdf",
"tasklist",
"audio",
"video",
"file",
"plotly",
"dataframe",
"component",
]
ElementDisplay = Literal["inline", "side", "page"]
ElementSize = Literal["small", "medium", "large"]
Expand Down Expand Up @@ -358,6 +367,25 @@ def __post_init__(self) -> None:
super().__post_init__()


@dataclass
class Dataframe(Element):
"""Useful to send a pandas DataFrame to the UI."""

type: ClassVar[ElementType] = "dataframe"
size: ElementSize = "large"
data: Any = None # The type is Any because it is checked in __post_init__.

def __post_init__(self) -> None:
"""Ensures the data is a pandas DataFrame and converts it to JSON."""
from pandas import DataFrame

if not isinstance(self.data, DataFrame):
raise TypeError("data must be a pandas.DataFrame")

self.content = self.data.to_json(orient="split", date_format="iso")
super().__post_init__()


@dataclass
class Component(Element):
"""Useful to send a custom component to the UI."""
Expand Down
28 changes: 27 additions & 1 deletion backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ slack_bolt = "^1.18.1"
discord = "^2.3.2"
botbuilder-core = "^4.15.0"
aiosqlite = "^0.20.0"
pandas = "^2.2.2"
moto = "^5.0.14"

[tool.poetry.group.dev.dependencies]
Expand All @@ -94,6 +95,7 @@ mypy = "^1.7.1"
types-requests = "^2.31.0.2"
types-aiofiles = "^23.1.0.5"
mypy-boto3-dynamodb = "^1.34.113"
pandas-stubs = { version = "^2.2.2", python = ">=3.9" }

[tool.mypy]
python_version = "3.9"
Expand All @@ -120,7 +122,6 @@ ignore_missing_imports = true




[tool.poetry.group.custom-data]
optional = true

Expand Down
3 changes: 1 addition & 2 deletions backend/tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from unittest.mock import AsyncMock

import pytest
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.user import User

Expand Down
36 changes: 23 additions & 13 deletions backend/tests/data/test_sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from pathlib import Path

import pytest
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.element import Text
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine

from chainlit import User
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.element import Text


@pytest.fixture
Expand All @@ -23,18 +23,21 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
# Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer
async with engine.begin() as conn:
await conn.execute(
text("""
text(
"""
CREATE TABLE users (
"id" UUID PRIMARY KEY,
"identifier" TEXT NOT NULL UNIQUE,
"metadata" JSONB NOT NULL,
"createdAt" TEXT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS threads (
"id" UUID PRIMARY KEY,
"createdAt" TEXT,
Expand All @@ -45,11 +48,13 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"metadata" JSONB,
FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS steps (
"id" UUID PRIMARY KEY,
"name" TEXT NOT NULL,
Expand All @@ -72,11 +77,13 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"language" TEXT,
"indent" INT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS elements (
"id" UUID PRIMARY KEY,
"threadId" UUID,
Expand All @@ -92,19 +99,22 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
"forId" UUID,
"mime" TEXT
);
""")
"""
)
)

await conn.execute(
text("""
text(
"""
CREATE TABLE IF NOT EXISTS feedbacks (
"id" UUID PRIMARY KEY,
"forId" UUID NOT NULL,
"threadId" UUID NOT NULL,
"value" INT NOT NULL,
"comment" TEXT
);
""")
"""
)
)

# Create SQLAlchemyDataLayer instance
Expand Down
68 changes: 68 additions & 0 deletions cypress/e2e/dataframe/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pandas as pd

import chainlit as cl


@cl.on_chat_start
async def start():
# Create a sample DataFrame with more than 10 rows to test pagination functionality
data = {
"Name": [
"Alice",
"David",
"Charlie",
"Bob",
"Eva",
"Grace",
"Hannah",
"Jack",
"Frank",
"Kara",
"Liam",
"Ivy",
"Mia",
"Noah",
"Olivia",
],
"Age": [25, 40, 35, 30, 45, 55, 60, 70, 50, 75, 80, 65, 85, 90, 95],
"City": [
"New York",
"Houston",
"Chicago",
"Los Angeles",
"Phoenix",
"San Antonio",
"San Diego",
"San Jose",
"Philadelphia",
"Austin",
"Fort Worth",
"Dallas",
"Jacksonville",
"Columbus",
"Charlotte",
],
"Salary": [
70000,
100000,
90000,
80000,
110000,
130000,
140000,
160000,
120000,
170000,
180000,
150000,
190000,
200000,
210000,
],
}

df = pd.DataFrame(data)

elements = [cl.Dataframe(data=df, display="inline", name="Dataframe")]

await cl.Message(content="This message has a Dataframe", elements=elements).send()
41 changes: 41 additions & 0 deletions cypress/e2e/dataframe/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { runTestServer } from '../../support/testUtils';

describe('dataframe', () => {
before(() => {
runTestServer();
});

it('should be able to display an inline dataframe', () => {
// Check if the DataFrame is rendered within the first step
cy.get('.step').should('have.length', 1);
cy.get('.step').first().find('.MuiDataGrid-main').should('have.length', 1);

// Click the sort button in the "Age" column header to sort in ascending order
cy.get('.MuiDataGrid-columnHeader[aria-label="Age"]')
.find('button')
.first()
.click({ force: true });
// Verify the first row's "Age" cell contains '25' after sorting
cy.get('.MuiDataGrid-row')
.first()
.find('.MuiDataGrid-cell[data-field="Age"] .MuiDataGrid-cellContent')
.should('have.text', '25');

// Click the "Next page" button in the pagination controls
cy.get('.MuiTablePagination-actions').find('button').eq(1).click();
// Verify that the next page contains exactly 5 rows
cy.get('.MuiDataGrid-row').should('have.length', 5);

// Click the input to open the dropdown
cy.get('.MuiTablePagination-select').click();
// Select the option with the value '50' from the dropdown list
cy.get('ul.MuiMenu-list li').contains('50').click();
// Scroll to the bottom of the virtual scroller in the MUI DataGrid
cy.get('.MuiDataGrid-virtualScroller').scrollTo('bottom');
// Check that tha last name is Olivia
cy.get('.MuiDataGrid-row')
.last()
.find('.MuiDataGrid-cell[data-field="Name"] .MuiDataGrid-cellContent')
.should('have.text', 'Olivia');
});
});
1 change: 1 addition & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"@mui/icons-material": "^5.14.9",
"@mui/lab": "^5.0.0-alpha.122",
"@mui/material": "^5.14.10",
"@mui/x-data-grid": "^6.20.4",
"formik": "^2.4.3",
"highlight.js": "^11.9.0",
"i18next": "^23.7.16",
Expand Down
Loading
Loading