Skip to content

Commit

Permalink
Fix: Corrected aload function to be asynchronous
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounhak committed Nov 25, 2024
1 parent 6ed2d38 commit b62b5d1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def _check_parser(parser: str) -> None:
"`parser` must be one of " + ", ".join(valid_parsers) + "."
)

def scrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
async def scrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
"""Fetch all urls, then return soups for all results."""
from bs4 import BeautifulSoup

results = asyncio.run(self.fetch_all(urls))
results = await self.fetch_all(urls)
final_results = []
for i, result in enumerate(results):
url = urls[i]
Expand Down Expand Up @@ -331,10 +331,10 @@ def lazy_load(self) -> Iterator[Document]:
metadata = _build_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)

def aload(self) -> List[Document]: # type: ignore
async def aload(self) -> List[Document]: # type: ignore
"""Load text from the urls in web_path async into Documents."""

results = self.scrape_all(self.web_paths)
results = await self.scrape_all(self.web_paths)
docs = []
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_lazy_load(mock_get: Any) -> None:

@pytest.mark.requires("bs4")
@patch("aiohttp.ClientSession.get")
def test_aload(mock_get: Any) -> None:
async def test_aload(mock_get: Any) -> None:
async def mock_text() -> str:
return "<html><body><p>Test content</p></body></html>"

Expand All @@ -76,7 +76,7 @@ async def mock_text() -> str:
web_paths=["https://www.example.com"],
header_template={"User-Agent": "test-user-agent"},
)
results = loader.aload()
results = await loader.aload()
assert len(results) == 1
assert results[0].page_content == "Test content"
mock_get.assert_called_with(
Expand Down

0 comments on commit b62b5d1

Please sign in to comment.