From f6ccc74797f4a2b8e11644f64098c61756e6e204 Mon Sep 17 00:00:00 2001 From: Dana Katzenelson Date: Mon, 18 Nov 2024 14:35:07 -0800 Subject: [PATCH] Fix broken sql table name validation example --- panther_analysis_tool/validation.py | 10 +++- .../panther_analysis_tool/test_validation.py | 57 ++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/panther_analysis_tool/validation.py b/panther_analysis_tool/validation.py index cf7a9cb6..a2000546 100644 --- a/panther_analysis_tool/validation.py +++ b/panther_analysis_tool/validation.py @@ -45,7 +45,15 @@ def contains_invalid_table_names( logging.info("Failed to parse query %s. Skipping table name validation", analysis_id) return [] tables = nested_lookup("table_reference", parsed_query) - aliases = [alias[0] for alias in nested_lookup("common_table_expression", parsed_query)] + aliases = [] + for alias in nested_lookup("common_table_expression", parsed_query): + if isinstance(alias, list): + aliases.append(alias[0]) + elif isinstance(alias, dict): + dict_table_key = "naked_identifier" + aliases.append({dict_table_key: alias.get(dict_table_key)}) + else: + logging.info("Unrecognized alias type: %s", type(alias)) for table in tables: if table in aliases: continue diff --git a/tests/unit/panther_analysis_tool/test_validation.py b/tests/unit/panther_analysis_tool/test_validation.py index 02c30f27..97bf409f 100644 --- a/tests/unit/panther_analysis_tool/test_validation.py +++ b/tests/unit/panther_analysis_tool/test_validation.py @@ -19,7 +19,7 @@ class TestContainsInvalidTableNames(unittest.TestCase): GROUP BY reported_client_type, user_name HAVING counts >= 3""" - def test_complex_sql(self): + def test_complex_sql_list_pattern(self): sql = """ WITH login_attempts as ( SELECT @@ -67,6 +67,61 @@ def test_complex_sql(self): output = contains_invalid_table_names(analysis_spec, analysis_id, []) self.assertFalse(output) + def test_complex_sql_dict_pattern(self): + sql = """ + WITH date_ranges AS( + -- Generate the last 12 months of date ranges with midnight start and end times + SELECT + DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL), CURRENT_DATE)) AS start_date, + DATE_TRUNC('month', DATEADD('month', -ROW_NUMBER() OVER (ORDER BY NULL) + 1, CURRENT_DATE)) - INTERVAL '1 second' AS end_date + FROM TABLE(GENERATOR(ROWCOUNT => 12)) + ), + table_counts AS ( + -- Query each table and count rows per month + SELECT + 'ATLASSIAN_AUDIT' AS table_name, + COALESCE(COUNT(t.p_parse_time), 0) AS row_count, + d.start_date, + d.end_date + FROM date_ranges AS d + LEFT JOIN panther_logs.public.ATLASSIAN_AUDIT AS t + ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date + GROUP BY d.start_date, d.end_date + + UNION ALL + SELECT + 'AUTH0_EVENTS' AS table_name, + COALESCE(COUNT(t.p_parse_time), 0) AS row_count, + d.start_date, + d.end_date + FROM date_ranges AS d + LEFT JOIN panther_logs.public.AUTH0_EVENTS AS t + ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date + GROUP BY d.start_date, d.end_date + + UNION ALL + SELECT + 'AWS_CLOUDTRAIL' AS table_name, + COALESCE(COUNT(t.p_parse_time), 0) AS row_count, + d.start_date, + d.end_date + FROM date_ranges AS d + LEFT JOIN panther_logs.public.AWS_CLOUDTRAIL AS t + ON t.p_parse_time >= d.start_date AND t.p_parse_time < d.end_date + GROUP BY d.start_date, d.end_date + ) + + -- Final selection of table name, row count, start, and end date + SELECT table_name, row_count, start_date, end_date + FROM table_counts + ORDER BY table_name, start_date; + """ + analysis_spec = {"Query": sql} + analysis_id = "analysis_id_1" + + output = contains_invalid_table_names(analysis_spec, analysis_id, []) + self.assertFalse(output) + def test_simple_sql(self): sql = self.invalid_sql analysis_spec = {"Query": sql}