Skip to content

Commit

Permalink
Fix broken sql table name validation example
Browse files Browse the repository at this point in the history
  • Loading branch information
dekatzenel committed Nov 18, 2024
1 parent c71ea3b commit f6ccc74
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
10 changes: 9 additions & 1 deletion panther_analysis_tool/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def contains_invalid_table_names(
logging.info("Failed to parse query %s. Skipping table name validation", analysis_id)
return []
tables = nested_lookup("table_reference", parsed_query)
aliases = [alias[0] for alias in nested_lookup("common_table_expression", parsed_query)]
aliases = []
for alias in nested_lookup("common_table_expression", parsed_query):
if isinstance(alias, list):
aliases.append(alias[0])
elif isinstance(alias, dict):
dict_table_key = "naked_identifier"
aliases.append({dict_table_key: alias.get(dict_table_key)})
else:
logging.info("Unrecognized alias type: %s", type(alias))
for table in tables:
if table in aliases:
continue
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/panther_analysis_tool/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestContainsInvalidTableNames(unittest.TestCase):
GROUP BY reported_client_type, user_name
HAVING counts >= 3"""

def test_complex_sql(self):
def test_complex_sql_list_pattern(self):
sql = """
WITH login_attempts as (
SELECT
Expand Down Expand Up @@ -67,6 +67,61 @@ def test_complex_sql(self):
output = contains_invalid_table_names(analysis_spec, analysis_id, [])
self.assertFalse(output)

def test_complex_sql_dict_pattern(self):
sql = """
WITH date_ranges AS(
-- Generate the last 12 months of date ranges with midnight start and end times
SELECT
DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL), CURRENT_DATE)) AS start_date,
DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL) + 1, CURRENT_DATE)) - INTERVAL '1 second' AS end_date
FROM TABLE(GENERATOR(ROWCOUNT => 12))
),
table_counts AS (
-- Query each table and count rows per month
SELECT
'ATLASSIAN_AUDIT' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.ATLASSIAN_AUDIT AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
UNION ALL
SELECT
'AUTH0_EVENTS' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.AUTH0_EVENTS AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
UNION ALL
SELECT
'AWS_CLOUDTRAIL' AS table_name,
COALESCE(COUNT(t.p_parse_time), 0) AS row_count,
d.start_date,
d.end_date
FROM date_ranges AS d
LEFT JOIN panther_logs.public.AWS_CLOUDTRAIL AS t
ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date
GROUP BY d.start_date, d.end_date
)
-- Final selection of table name, row count, start, and end date
SELECT table_name, row_count, start_date, end_date
FROM table_counts
ORDER BY table_name, start_date;
"""
analysis_spec = {"Query": sql}
analysis_id = "analysis_id_1"

output = contains_invalid_table_names(analysis_spec, analysis_id, [])
self.assertFalse(output)

def test_simple_sql(self):
sql = self.invalid_sql
analysis_spec = {"Query": sql}
Expand Down

0 comments on commit f6ccc74

Please sign in to comment.