Skip to content

Commit

Permalink
Fix broken sql table name validation example
Browse files Browse the repository at this point in the history
  • Loading branch information
dekatzenel committed Nov 18, 2024
1 parent c71ea3b commit 801a611
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
16 changes: 14 additions & 2 deletions panther_analysis_tool/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def contains_invalid_table_names(
invalid_table_names = []
query = lookup_snowflake_query(analysis_spec)
if query is not None:
parsed_query = {}
try:
parsed_query = parse(query, "snowflake")
except Exception: # pylint: disable=broad-except
Expand All @@ -45,7 +44,7 @@ def contains_invalid_table_names(
logging.info("Failed to parse query %s. Skipping table name validation", analysis_id)
return []
tables = nested_lookup("table_reference", parsed_query)
aliases = [alias[0] for alias in nested_lookup("common_table_expression", parsed_query)]
aliases = get_aliases(parsed_query)
for table in tables:
if table in aliases:
continue
Expand Down Expand Up @@ -77,6 +76,19 @@ def contains_invalid_table_names(
return invalid_table_names


def get_aliases(parsed_query: dict[str, Any]) -> list[Any]:
aliases = []
for alias in nested_lookup("common_table_expression", parsed_query):
if isinstance(alias, list):
aliases.append(alias[0])
elif isinstance(alias, dict):
dict_table_key = "naked_identifier"
aliases.append({dict_table_key: alias.get(dict_table_key)})
else:
logging.info("Unrecognized alias type: %s", type(alias))
return aliases


def lookup_snowflake_query(analysis_spec: Any) -> Optional[str]:
query_keys = ["Query", "SnowflakeQuery"]
for key in query_keys:
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/panther_analysis_tool/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestContainsInvalidTableNames(unittest.TestCase):
GROUP BY reported_client_type, user_name
HAVING counts >= 3"""

def test_complex_sql(self):
def test_complex_sql_list_pattern(self):
sql = """
WITH login_attempts as (
SELECT
Expand Down Expand Up @@ -67,6 +67,61 @@ def test_complex_sql(self):
output = contains_invalid_table_names(analysis_spec, analysis_id, [])
self.assertFalse(output)

def test_complex_sql_dict_pattern(self):
sql = """
WITH date_ranges AS(
-- Generate the last 12 months of date ranges with midnight start and end times
SELECT
DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL), CURRENT_DATE)) AS start_date,
DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL) + 1, CURRENT_DATE)) - INTERVAL '1 second' AS end_date
FROM TABLE(GENERATOR(ROWCOUNT => 12))
),
table_counts AS (
-- Query each table and count rows per month
SELECT
'ATLASSIAN_AUDIT' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.ATLASSIAN_AUDIT AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
UNION ALL
SELECT
'AUTH0_EVENTS' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.AUTH0_EVENTS AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
UNION ALL
SELECT
'AWS_CLOUDTRAIL' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.AWS_CLOUDTRAIL AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
)
-- Final selection of table name, row count, start, and end date
SELECT table_name, row_count, start_date, end_date
FROM table_counts
ORDER BY table_name, start_date;
"""
analysis_spec = {"Query": sql}
analysis_id = "analysis_id_1"

output = contains_invalid_table_names(analysis_spec, analysis_id, [])
self.assertFalse(output)

def test_simple_sql(self):
sql = self.invalid_sql
analysis_spec = {"Query": sql}
Expand Down

0 comments on commit 801a611

Please sign in to comment.