-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #325 from TogetherCrew/feat/301-website-ingestion
Adding website ingestion!
- Loading branch information
Showing
9 changed files
with
419 additions
and
4 deletions.
There are no files selected for viewing
Empty file.
116 changes: 116 additions & 0 deletions
116
dags/hivemind_etl_helpers/src/db/website/crawlee_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import asyncio | ||
from typing import Any | ||
|
||
from crawlee.playwright_crawler import PlaywrightCrawler, PlaywrightCrawlingContext | ||
from defusedxml import ElementTree as ET | ||
|
||
|
||
class CrawleeClient: | ||
def __init__( | ||
self, | ||
max_requests: int = 20, | ||
headless: bool = True, | ||
browser_type: str = "chromium", | ||
) -> None: | ||
self.crawler = PlaywrightCrawler( | ||
max_requests_per_crawl=max_requests, | ||
headless=headless, | ||
browser_type=browser_type, | ||
) | ||
|
||
# do not persist crawled data to local storage | ||
self.crawler._configuration.persist_storage = False | ||
self.crawler._configuration.write_metadata = False | ||
|
||
@self.crawler.router.default_handler | ||
async def request_handler(context: PlaywrightCrawlingContext) -> None: | ||
context.log.info(f"Processing {context.request.url} ...") | ||
|
||
inner_text = await context.page.inner_text(selector="body") | ||
|
||
if "sitemap.xml" in context.request.url: | ||
links = self._extract_links_from_sitemap(inner_text) | ||
await context.add_requests(requests=list(set(links))) | ||
else: | ||
await context.enqueue_links() | ||
|
||
data = { | ||
"url": context.request.url, | ||
"title": await context.page.title(), | ||
"inner_text": inner_text, | ||
} | ||
|
||
await context.push_data(data) | ||
|
||
def _extract_links_from_sitemap(self, sitemap_content: str) -> list[str]: | ||
""" | ||
Extract URLs from a sitemap XML content. | ||
Parameters | ||
---------- | ||
sitemap_content : str | ||
The XML content of the sitemap | ||
Raises | ||
------ | ||
ET.ParseError | ||
If the XML content is malformed | ||
Returns | ||
------- | ||
links : list[str] | ||
list of valid URLs extracted from the sitemap | ||
""" | ||
links = [] | ||
try: | ||
root = ET.fromstring(sitemap_content) | ||
namespace = {"ns": "http://www.sitemaps.org/schemas/sitemap/0.9"} | ||
for element in root.findall("ns:url/ns:loc", namespace): | ||
url = element.text.strip() if element.text else None | ||
if url and url.startswith(("http://", "https://")): | ||
links.append(url) | ||
except ET.ParseError as e: | ||
raise ValueError(f"Invalid sitemap XML: {str(e)}") | ||
|
||
return links | ||
|
||
async def crawl(self, links: list[str]) -> list[dict[str, Any]]: | ||
""" | ||
Crawl websites and extract data from all inner links under the domain routes. | ||
Parameters | ||
---------- | ||
links : list[str] | ||
List of valid URLs to crawl | ||
Returns | ||
------- | ||
crawled_data : list[dict[str, Any]] | ||
List of dictionaries containing crawled data with keys: | ||
- url: str | ||
- title: str | ||
- inner_text: str | ||
Raises | ||
------ | ||
ValueError | ||
If any of the input URLs is invalid (not starting with http or https) | ||
TimeoutError | ||
If the crawl operation times out | ||
""" | ||
# Validate input URLs | ||
valid_links = [] | ||
for url in links: | ||
if url and isinstance(url, str) and url.startswith(("http://", "https://")): | ||
valid_links.append(url) | ||
else: | ||
raise ValueError(f"Invalid URL: {url}") | ||
|
||
try: | ||
await self.crawler.add_requests(requests=valid_links) | ||
await asyncio.wait_for(self.crawler.run(), timeout=3600) # 1 hour timeout | ||
crawled_data = await self.crawler.get_data() | ||
return crawled_data.items | ||
except asyncio.TimeoutError: | ||
raise TimeoutError("Crawl operation timed out") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import logging | ||
|
||
from .modules_base import ModulesBase | ||
|
||
|
||
class ModulesWebsite(ModulesBase): | ||
def __init__(self) -> None: | ||
self.platform_name = "website" | ||
super().__init__() | ||
|
||
def get_learning_platforms( | ||
self, | ||
) -> list[dict[str, str | list[str]]]: | ||
""" | ||
Get all the website communities with their page titles. | ||
Returns | ||
--------- | ||
community_orgs : list[dict[str, str | list[str]]] = [] | ||
a list of website data information | ||
example data output: | ||
``` | ||
[{ | ||
"community_id": "6579c364f1120850414e0dc5", | ||
"platform_id": "6579c364f1120850414e0dc6", | ||
"urls": ["link1", "link2"], | ||
}] | ||
``` | ||
""" | ||
modules = self.query(platform=self.platform_name, projection={"name": 0}) | ||
communities_data: list[dict[str, str | list[str]]] = [] | ||
|
||
for module in modules: | ||
community = module["community"] | ||
|
||
# each platform of the community | ||
for platform in module["options"]["platforms"]: | ||
if platform["name"] != self.platform_name: | ||
continue | ||
|
||
platform_id = platform["platform"] | ||
|
||
try: | ||
website_links = self.get_platform_metadata( | ||
platform_id=platform_id, | ||
metadata_name="resources", | ||
) | ||
|
||
communities_data.append( | ||
{ | ||
"community_id": str(community), | ||
"platform_id": platform_id, | ||
"urls": website_links, | ||
} | ||
) | ||
except Exception as exp: | ||
logging.error( | ||
"Exception while fetching website modules " | ||
f"for platform: {platform_id} | exception: {exp}" | ||
) | ||
|
||
return communities_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from unittest import IsolatedAsyncioTestCase | ||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
from dotenv import load_dotenv | ||
from hivemind_etl_helpers.website_etl import WebsiteETL | ||
from llama_index.core import Document | ||
|
||
|
||
class TestWebsiteETL(IsolatedAsyncioTestCase): | ||
def setUp(self): | ||
""" | ||
Setup for the test cases. Initializes a WebsiteETL instance with mocked dependencies. | ||
""" | ||
load_dotenv() | ||
self.community_id = "test_community" | ||
self.website_etl = WebsiteETL(self.community_id) | ||
self.website_etl.crawlee_client = AsyncMock() | ||
self.website_etl.ingestion_pipeline = MagicMock() | ||
|
||
async def test_extract(self): | ||
""" | ||
Test the extract method. | ||
""" | ||
urls = ["https://example.com"] | ||
mocked_data = [ | ||
{ | ||
"url": "https://example.com", | ||
"inner_text": "Example text", | ||
"title": "Example", | ||
} | ||
] | ||
self.website_etl.crawlee_client.crawl.return_value = mocked_data | ||
|
||
extracted_data = await self.website_etl.extract(urls) | ||
|
||
self.assertEqual(extracted_data, mocked_data) | ||
self.website_etl.crawlee_client.crawl.assert_awaited_once_with(urls) | ||
|
||
def test_transform(self): | ||
""" | ||
Test the transform method. | ||
""" | ||
raw_data = [ | ||
{ | ||
"url": "https://example.com", | ||
"inner_text": "Example text", | ||
"title": "Example", | ||
} | ||
] | ||
expected_documents = [ | ||
Document( | ||
doc_id="https://example.com", | ||
text="Example text", | ||
metadata={"title": "Example", "url": "https://example.com"}, | ||
) | ||
] | ||
|
||
documents = self.website_etl.transform(raw_data) | ||
|
||
self.assertEqual(len(documents), len(expected_documents)) | ||
self.assertEqual(documents[0].doc_id, expected_documents[0].doc_id) | ||
self.assertEqual(documents[0].text, expected_documents[0].text) | ||
self.assertEqual(documents[0].metadata, expected_documents[0].metadata) | ||
|
||
def test_load(self): | ||
""" | ||
Test the load method. | ||
""" | ||
documents = [ | ||
Document( | ||
doc_id="https://example.com", | ||
text="Example text", | ||
metadata={"title": "Example", "url": "https://example.com"}, | ||
) | ||
] | ||
|
||
self.website_etl.load(documents) | ||
|
||
self.website_etl.ingestion_pipeline.run_pipeline.assert_called_once_with( | ||
docs=documents | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from typing import Any | ||
|
||
from hivemind_etl_helpers.ingestion_pipeline import CustomIngestionPipeline | ||
from hivemind_etl_helpers.src.db.website.crawlee_client import CrawleeClient | ||
from llama_index.core import Document | ||
|
||
|
||
class WebsiteETL: | ||
def __init__( | ||
self, | ||
community_id: str, | ||
) -> None: | ||
""" | ||
Parameters | ||
----------- | ||
community_id : str | ||
the community to save its data | ||
""" | ||
self.community_id = community_id | ||
collection_name = "website" | ||
|
||
# preparing the data extractor and ingestion pipelines | ||
self.crawlee_client = CrawleeClient() | ||
self.ingestion_pipeline = CustomIngestionPipeline( | ||
self.community_id, collection_name=collection_name | ||
) | ||
|
||
async def extract( | ||
self, | ||
urls: list[str], | ||
) -> list[dict[str, Any]]: | ||
""" | ||
Extract given urls | ||
Parameters | ||
----------- | ||
urls : list[str] | ||
a list of urls | ||
Returns | ||
--------- | ||
extracted_data : list[dict[str, Any]] | ||
The crawled data from urls | ||
""" | ||
if not urls: | ||
raise ValueError("No URLs provided for crawling") | ||
extracted_data = await self.crawlee_client.crawl(urls) | ||
|
||
if not extracted_data: | ||
raise ValueError(f"No data extracted from URLs: {urls}") | ||
|
||
return extracted_data | ||
|
||
def transform(self, raw_data: list[dict[str, Any]]) -> list[Document]: | ||
""" | ||
transform raw data to llama-index documents | ||
Parameters | ||
------------ | ||
raw_data : list[dict[str, Any]] | ||
crawled data | ||
Returns | ||
--------- | ||
documents : list[llama_index.Document] | ||
list of llama-index documents | ||
""" | ||
documents: list[Document] = [] | ||
|
||
for data in raw_data: | ||
doc_id = data["url"] | ||
doc = Document( | ||
doc_id=doc_id, | ||
text=data["inner_text"], | ||
metadata={ | ||
"title": data["title"], | ||
"url": data["url"], | ||
}, | ||
) | ||
documents.append(doc) | ||
|
||
return documents | ||
|
||
def load(self, documents: list[Document]) -> None: | ||
""" | ||
load the documents into the vector db | ||
Parameters | ||
------------- | ||
documents: list[llama_index.Document] | ||
the llama-index documents to be ingested | ||
""" | ||
# loading data into db | ||
self.ingestion_pipeline.run_pipeline(docs=documents) |
Oops, something went wrong.