Skip to content

Commit

Permalink
Merge pull request #325 from TogetherCrew/feat/301-website-ingestion
Browse files Browse the repository at this point in the history
Adding website ingestion!
  • Loading branch information
amindadgar authored Nov 19, 2024
2 parents 8b859b8 + 6542ae5 commit 023ec51
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 4 deletions.
Empty file.
116 changes: 116 additions & 0 deletions dags/hivemind_etl_helpers/src/db/website/crawlee_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
from typing import Any

from crawlee.playwright_crawler import PlaywrightCrawler, PlaywrightCrawlingContext
from defusedxml import ElementTree as ET


class CrawleeClient:
def __init__(
self,
max_requests: int = 20,
headless: bool = True,
browser_type: str = "chromium",
) -> None:
self.crawler = PlaywrightCrawler(
max_requests_per_crawl=max_requests,
headless=headless,
browser_type=browser_type,
)

# do not persist crawled data to local storage
self.crawler._configuration.persist_storage = False
self.crawler._configuration.write_metadata = False

@self.crawler.router.default_handler
async def request_handler(context: PlaywrightCrawlingContext) -> None:
context.log.info(f"Processing {context.request.url} ...")

inner_text = await context.page.inner_text(selector="body")

if "sitemap.xml" in context.request.url:
links = self._extract_links_from_sitemap(inner_text)
await context.add_requests(requests=list(set(links)))
else:
await context.enqueue_links()

data = {
"url": context.request.url,
"title": await context.page.title(),
"inner_text": inner_text,
}

await context.push_data(data)

def _extract_links_from_sitemap(self, sitemap_content: str) -> list[str]:
"""
Extract URLs from a sitemap XML content.
Parameters
----------
sitemap_content : str
The XML content of the sitemap
Raises
------
ET.ParseError
If the XML content is malformed
Returns
-------
links : list[str]
list of valid URLs extracted from the sitemap
"""
links = []
try:
root = ET.fromstring(sitemap_content)
namespace = {"ns": "http://www.sitemaps.org/schemas/sitemap/0.9"}
for element in root.findall("ns:url/ns:loc", namespace):
url = element.text.strip() if element.text else None
if url and url.startswith(("http://", "https://")):
links.append(url)
except ET.ParseError as e:
raise ValueError(f"Invalid sitemap XML: {str(e)}")

return links

async def crawl(self, links: list[str]) -> list[dict[str, Any]]:
"""
Crawl websites and extract data from all inner links under the domain routes.
Parameters
----------
links : list[str]
List of valid URLs to crawl
Returns
-------
crawled_data : list[dict[str, Any]]
List of dictionaries containing crawled data with keys:
- url: str
- title: str
- inner_text: str
Raises
------
ValueError
If any of the input URLs is invalid (not starting with http or https)
TimeoutError
If the crawl operation times out
"""
# Validate input URLs
valid_links = []
for url in links:
if url and isinstance(url, str) and url.startswith(("http://", "https://")):
valid_links.append(url)
else:
raise ValueError(f"Invalid URL: {url}")

try:
await self.crawler.add_requests(requests=valid_links)
await asyncio.wait_for(self.crawler.run(), timeout=3600) # 1 hour timeout
crawled_data = await self.crawler.get_data()
return crawled_data.items
except asyncio.TimeoutError:
raise TimeoutError("Crawl operation timed out")
1 change: 1 addition & 0 deletions dags/hivemind_etl_helpers/src/utils/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .github import ModulesGitHub
from .mediawiki import ModulesMediaWiki
from .notion import ModulesNotion
from .website import ModulesWebsite
6 changes: 3 additions & 3 deletions dags/hivemind_etl_helpers/src/utils/modules/modules_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_token(self, platform_id: ObjectId, token_type: str) -> str:

def get_platform_metadata(
self, platform_id: ObjectId, metadata_name: str
) -> str | dict:
) -> str | dict | list:
"""
get the userid that belongs to a platform
Expand All @@ -111,8 +111,8 @@ def get_platform_metadata(
Returns
---------
user_id : str
the user id that the platform belongs to
metadata_value : Any
the values that the metadata belongs to
"""
client = MongoSingleton.get_instance().get_client()

Expand Down
63 changes: 63 additions & 0 deletions dags/hivemind_etl_helpers/src/utils/modules/website.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging

from .modules_base import ModulesBase


class ModulesWebsite(ModulesBase):
def __init__(self) -> None:
self.platform_name = "website"
super().__init__()

def get_learning_platforms(
self,
) -> list[dict[str, str | list[str]]]:
"""
Get all the website communities with their page titles.
Returns
---------
community_orgs : list[dict[str, str | list[str]]] = []
a list of website data information
example data output:
```
[{
"community_id": "6579c364f1120850414e0dc5",
"platform_id": "6579c364f1120850414e0dc6",
"urls": ["link1", "link2"],
}]
```
"""
modules = self.query(platform=self.platform_name, projection={"name": 0})
communities_data: list[dict[str, str | list[str]]] = []

for module in modules:
community = module["community"]

# each platform of the community
for platform in module["options"]["platforms"]:
if platform["name"] != self.platform_name:
continue

platform_id = platform["platform"]

try:
website_links = self.get_platform_metadata(
platform_id=platform_id,
metadata_name="resources",
)

communities_data.append(
{
"community_id": str(community),
"platform_id": platform_id,
"urls": website_links,
}
)
except Exception as exp:
logging.error(
"Exception while fetching website modules "
f"for platform: {platform_id} | exception: {exp}"
)

return communities_data
81 changes: 81 additions & 0 deletions dags/hivemind_etl_helpers/tests/unit/test_website_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock

from dotenv import load_dotenv
from hivemind_etl_helpers.website_etl import WebsiteETL
from llama_index.core import Document


class TestWebsiteETL(IsolatedAsyncioTestCase):
def setUp(self):
"""
Setup for the test cases. Initializes a WebsiteETL instance with mocked dependencies.
"""
load_dotenv()
self.community_id = "test_community"
self.website_etl = WebsiteETL(self.community_id)
self.website_etl.crawlee_client = AsyncMock()
self.website_etl.ingestion_pipeline = MagicMock()

async def test_extract(self):
"""
Test the extract method.
"""
urls = ["https://example.com"]
mocked_data = [
{
"url": "https://example.com",
"inner_text": "Example text",
"title": "Example",
}
]
self.website_etl.crawlee_client.crawl.return_value = mocked_data

extracted_data = await self.website_etl.extract(urls)

self.assertEqual(extracted_data, mocked_data)
self.website_etl.crawlee_client.crawl.assert_awaited_once_with(urls)

def test_transform(self):
"""
Test the transform method.
"""
raw_data = [
{
"url": "https://example.com",
"inner_text": "Example text",
"title": "Example",
}
]
expected_documents = [
Document(
doc_id="https://example.com",
text="Example text",
metadata={"title": "Example", "url": "https://example.com"},
)
]

documents = self.website_etl.transform(raw_data)

self.assertEqual(len(documents), len(expected_documents))
self.assertEqual(documents[0].doc_id, expected_documents[0].doc_id)
self.assertEqual(documents[0].text, expected_documents[0].text)
self.assertEqual(documents[0].metadata, expected_documents[0].metadata)

def test_load(self):
"""
Test the load method.
"""
documents = [
Document(
doc_id="https://example.com",
text="Example text",
metadata={"title": "Example", "url": "https://example.com"},
)
]

self.website_etl.load(documents)

self.website_etl.ingestion_pipeline.run_pipeline.assert_called_once_with(
docs=documents
)
94 changes: 94 additions & 0 deletions dags/hivemind_etl_helpers/website_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Any

from hivemind_etl_helpers.ingestion_pipeline import CustomIngestionPipeline
from hivemind_etl_helpers.src.db.website.crawlee_client import CrawleeClient
from llama_index.core import Document


class WebsiteETL:
def __init__(
self,
community_id: str,
) -> None:
"""
Parameters
-----------
community_id : str
the community to save its data
"""
self.community_id = community_id
collection_name = "website"

# preparing the data extractor and ingestion pipelines
self.crawlee_client = CrawleeClient()
self.ingestion_pipeline = CustomIngestionPipeline(
self.community_id, collection_name=collection_name
)

async def extract(
self,
urls: list[str],
) -> list[dict[str, Any]]:
"""
Extract given urls
Parameters
-----------
urls : list[str]
a list of urls
Returns
---------
extracted_data : list[dict[str, Any]]
The crawled data from urls
"""
if not urls:
raise ValueError("No URLs provided for crawling")
extracted_data = await self.crawlee_client.crawl(urls)

if not extracted_data:
raise ValueError(f"No data extracted from URLs: {urls}")

return extracted_data

def transform(self, raw_data: list[dict[str, Any]]) -> list[Document]:
"""
transform raw data to llama-index documents
Parameters
------------
raw_data : list[dict[str, Any]]
crawled data
Returns
---------
documents : list[llama_index.Document]
list of llama-index documents
"""
documents: list[Document] = []

for data in raw_data:
doc_id = data["url"]
doc = Document(
doc_id=doc_id,
text=data["inner_text"],
metadata={
"title": data["title"],
"url": data["url"],
},
)
documents.append(doc)

return documents

def load(self, documents: list[Document]) -> None:
"""
load the documents into the vector db
Parameters
-------------
documents: list[llama_index.Document]
the llama-index documents to be ingested
"""
# loading data into db
self.ingestion_pipeline.run_pipeline(docs=documents)
Loading

0 comments on commit 023ec51

Please sign in to comment.