Skip to content

Commit

Permalink
Add the changes to handle upsert to remove document and context manag… (
Browse files Browse the repository at this point in the history
#177)

- Use context manager for batch ingestion in
[weaviate](https://weaviate.io/developers/weaviate/manage-data/import).
- Fix the logic to remove successfully upserted documents.

closes: [#174](#174)
closes: [#173](#173)
  • Loading branch information
sunank200 authored Nov 28, 2023
1 parent 515f338 commit 87cce4f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 71 deletions.
4 changes: 2 additions & 2 deletions airflow/dags/ingestion/ask-astro-load.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def extract_astro_blogs():
task(ask_astro_weaviate_hook.ingest_data, retries=10)
.partial(
class_name=WEAVIATE_CLASS,
existing="upsert",
existing="skip",
doc_key="docLink",
batch_params={"batch_size": 1000},
verbose=True,
Expand All @@ -276,7 +276,7 @@ def extract_astro_blogs():
_import_baseline = task(ask_astro_weaviate_hook.import_baseline, trigger_rule="none_failed")(
seed_baseline_url=seed_baseline_url,
class_name=WEAVIATE_CLASS,
existing="upsert",
existing="error",
doc_key="docLink",
uuid_column="id",
vector_column="vector",
Expand Down
190 changes: 121 additions & 69 deletions airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AskAstroWeaviateHook(WeaviateHook):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_errors = []
self.logger = logging.getLogger("airflow.task")
self.client = self.get_client()

Expand Down Expand Up @@ -211,93 +212,101 @@ def batch_ingest(
vector_column: str | None = None,
batch_params: dict = {},
verbose: bool = False,
tenant: str | None = None,
) -> (list, Any):
"""
Processes the DataFrame and batches the data for ingestion into Weaviate.
:param df: DataFrame containing the data to be ingested.
:param class_name: The name of the class in Weaviate to which data will be ingested.
:param uuid_column: Name of the column containing the UUID.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert' or 'error').
:param vector_column: Name of the column containing the vector data.
:param batch_params: Parameters for batch configuration.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert').
:param verbose: Whether to print verbose output.
:param verbose: Whether to log verbose output.
:param tenant: The tenant to which the object will be added.
"""
batch = self.client.batch.configure(**batch_params)
batch_errors = []

for row_id, row in df.iterrows():
data_object = row.to_dict()
uuid = data_object.pop(uuid_column)
vector = data_object.pop(vector_column, None)

try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name) is True:
if existing == "skip":
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
continue
elif existing == "replace":
# Default for weaviate is replace existing
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")

except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
continue

try:
added_row = batch.add_data_object(
class_name=class_name, uuid=uuid, data_object=data_object, vector=vector
)
if verbose is True:
self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")

except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
# configuration for context manager for __exit__ method to callback on errors for weaviate batch ingestion.
if not batch_params.get("callback"):
batch_params.update({"callback": self.process_batch_errors})

self.client.batch.configure(**batch_params)

with self.client.batch as batch:
for row_id, row in df.iterrows():
data_object = row.to_dict()
uuid = data_object.pop(uuid_column)
vector = data_object.pop(vector_column, None)

try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if existing == "error":
raise AirflowException(f"Ingest of UUID {uuid} failed. Object exists.")

if existing == "skip":
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
continue
elif existing == "replace":
# Default for weaviate is replace existing
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")
except AirflowException as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
break
except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
continue

results = batch.create_objects()
try:
added_row = batch.add_data_object(
class_name=class_name, uuid=uuid, data_object=data_object, vector=vector, tenant=tenant
)
if verbose is True:
self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")

if len(results) > 0:
batch_errors += self.process_batch_errors(results=results, verbose=verbose)
except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
self.batch_errors.append({"uuid": uuid, "result": {"errors": str(e)}})

return batch_errors
return self.batch_errors

def process_batch_errors(self, results: list, verbose: bool) -> list:
def process_batch_errors(self, results: list, verbose: bool = True) -> None:
"""
Processes the results from batch operation and collects any errors.
:param results: Results from the batch operation.
:param verbose: Flag to enable verbose logging.
"""
errors = []
for item in results:
if "errors" in item["result"]:
item_error = {"uuid": item["id"], "errors": item["result"]["errors"]}
if verbose:
self.logger.info(
f"Error occurred in batch process for {item['id']} with error {item['result']['errors']}"
)
errors.append(item_error)
return errors
self.batch_errors.append(item_error)

def handle_upsert_rollback(
self, objects_to_upsert: pd.DataFrame, batch_errors: list, class_name: str, verbose: bool
) -> list:
self, objects_to_upsert: pd.DataFrame, class_name: str, verbose: bool, tenant: str | None = None
) -> tuple[list, set]:
"""
Handles rollback of inserts in case of errors during upsert operation.
:param objects_to_upsert: Dictionary of objects to upsert.
:param class_name: Name of the class in Weaviate.
:param verbose: Flag to enable verbose logging.
:param tenant: The tenant to which the object will be added.
"""
rollback_errors = []

error_uuids = {error["uuid"] for error in batch_errors}
error_uuids = {error["uuid"] for error in self.batch_errors}

objects_to_upsert["rollback_doc"] = objects_to_upsert.objects_to_insert.apply(
lambda x: any(error_uuids.intersection(x))
Expand All @@ -315,30 +324,48 @@ def handle_upsert_rollback(

for uuid in rollback_objects:
try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant):
self.logger.info(f"Removing id {uuid} for rollback.")
self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL")
self.client.data_object.delete(
uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL"
)
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion during rollback.")
except Exception as e:
rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
if verbose:
self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}")

for uuid in delete_objects:
return rollback_errors, delete_objects

def handle_successful_upsert(
self, objects_to_remove: list, class_name: str, verbose: bool, tenant: str | None = None
) -> list:
"""
Handles removal of previous objects after successful upsert.
:param objects_to_remove: If there were errors rollback will generate a list of successfully inserted objects.
If not set, assume all objects inserted successfully and delete all objects_to_upsert['objects_to_delete']
:param class_name: Name of the class in Weaviate.
:param verbose: Flag to enable verbose logging.
:param tenant: The tenant to which the object will be added.
"""
deletion_errors = []
for uuid in objects_to_remove:
try:
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if self.client.data_object.exists(uuid=uuid, class_name=class_name, tenant=tenant):
if verbose:
self.logger.info(f"Deleting id {uuid} for successful upsert.")
self.client.data_object.delete(uuid=uuid, class_name=class_name)
self.client.data_object.delete(
uuid=uuid, class_name=class_name, tenant=tenant, consistency_level="ALL"
)
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.")
except Exception as e:
rollback_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
deletion_errors.append({"uuid": uuid, "result": {"errors": str(e)}})
if verbose:
self.logger.info(f"Error in rolling back id {uuid}. Error: {str(e)}")

return rollback_errors
return deletion_errors

def ingest_data(
self,
Expand All @@ -350,6 +377,7 @@ def ingest_data(
vector_column: str = None,
batch_params: dict = None,
verbose: bool = True,
tenant: str | None = None,
) -> list:
"""
Ingests data into Weaviate, handling upserts and rollbacks, and returns a list of objects that failed to import.
Expand All @@ -367,11 +395,14 @@ def ingest_data(
:param vector_column: Column with embedding vectors for pre-embedded data.
:param batch_params: Additional parameters for Weaviate batch configuration.
:param verbose: Flag to enable verbose output during the ingestion process.
:param tenant: The tenant to which the object will be added.
"""

global objects_to_upsert
if existing not in ["skip", "replace", "upsert"]:
raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'")
if existing not in ["skip", "replace", "upsert", "error"]:
raise AirflowException(
"Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert', 'error'."
)

df = pd.concat(dfs, ignore_index=True)

Expand All @@ -380,7 +411,7 @@ def ingest_data(
df=df, class_name=class_name, vector_column=vector_column, uuid_column=uuid_column
)

if existing == "upsert":
if existing == "upsert" or existing == "skip":
objects_to_upsert = self.identify_upsert_targets(
df=df, class_name=class_name, doc_key=doc_key, uuid_column=uuid_column
)
Expand All @@ -392,28 +423,49 @@ def ingest_data(

self.logger.info(f"Passing {len(df)} objects for ingest.")

batch_errors = self.batch_ingest(
self.batch_ingest(
df=df,
class_name=class_name,
uuid_column=uuid_column,
vector_column=vector_column,
batch_params=batch_params,
existing=existing,
verbose=verbose,
tenant=tenant,
)

if existing == "upsert" and batch_errors:
self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.")
rollback_errors = self.handle_upsert_rollback(
objects_to_upsert=objects_to_upsert, batch_errors=batch_errors, class_name=class_name, verbose=verbose
)
if existing == "upsert":
if self.batch_errors:
self.logger.warning("Error during upsert. Rolling back all inserts for docs with errors.")
rollback_errors, objects_to_remove = self.handle_upsert_rollback(
objects_to_upsert=objects_to_upsert, class_name=class_name, verbose=verbose
)

deletion_errors = self.handle_successful_upsert(
objects_to_remove=objects_to_remove, class_name=class_name, verbose=verbose
)

if len(rollback_errors) > 0:
self.logger.error("Errors encountered during rollback.")
raise AirflowException("Errors encountered during rollback.")
rollback_errors += deletion_errors

if rollback_errors:
self.logger.error("Errors encountered during rollback.")
self.logger.error("\n".join(rollback_errors))
raise AirflowException("Errors encountered during rollback.")
else:
removal_errors = self.handle_successful_upsert(
objects_to_remove={item for sublist in objects_to_upsert.objects_to_delete for item in sublist},
class_name=class_name,
verbose=verbose,
tenant=tenant,
)
if removal_errors:
self.logger.error("Errors encountered during removal.")
self.logger.error("\n".join(removal_errors))
raise AirflowException("Errors encountered during removal.")

if batch_errors:
if self.batch_errors:
self.logger.error("Errors encountered during ingest.")
self.logger.error("\n".join(self.batch_errors))
raise AirflowException("Errors encountered during ingest.")

def _query_objects(self, value: Any, doc_key: str, class_name: str, uuid_column: str) -> set:
Expand Down

0 comments on commit 87cce4f

Please sign in to comment.