diff --git a/tools/web-fuzzing-introspection/app/webapp/routes.py b/tools/web-fuzzing-introspection/app/webapp/routes.py index 6550c608..ddf798c4 100644 --- a/tools/web-fuzzing-introspection/app/webapp/routes.py +++ b/tools/web-fuzzing-introspection/app/webapp/routes.py @@ -15,6 +15,7 @@ import os import random +import re import json import signal from typing import Dict, List, Optional @@ -2399,18 +2400,13 @@ def extract_project_tests(project_name, refined_list.append(test_file) tests_file_list = refined_list - # If filtering is to be done, do thsi now. + # If filtering is to be done, do this now. if try_ignore_irrelevant: - has_repo_match = False - for test_file in tests_file_list: - if f'/src/{project_name.lower()}' in test_file.lower(): - has_repo_match = True - if has_repo_match: - tmp = [] - for test_file in tests_file_list: - if f'/src/{project_name.lower()}' in test_file.lower(): - tmp.append(test_file) - return tmp + target_project = get_project_with_name(project_name) + tests_file_list = _ignore_irrelevant_tests(tests_file_list, + target_project, + project_name) + return tests_file_list @@ -2430,21 +2426,110 @@ def _light_project_tests(project_name, try_ignore_irrelevant=True): continue returner_list.append(test_file) - # If filtering is to be done, do thsi now. + # If filtering is to be done, do this now. if try_ignore_irrelevant: - has_repo_match = False - for test_file in returner_list: - if f'/src/{project_name.lower()}' in test_file.lower(): - has_repo_match = True - if has_repo_match: - tmp = [] - for test_file in returner_list: - if f'/src/{project_name.lower()}' in test_file.lower(): - tmp.append(test_file) - return tmp + returner_list = _ignore_irrelevant_tests(returner_list, target_project, + project_name) + return returner_list +def _ignore_irrelevant_tests(tests_file_list, project, project_name): + """Helper function to ignore irrelevant tests""" + repo_match = [] + + for test_file in tests_file_list: + if f'/src/{project_name.lower()}' in test_file.lower(): + repo_match.append(test_file) + + # Extra filtering for Java project + # This is to filter irrelevant java test/example sources that does + # not call any public classes of the project + if project and project.language == 'java': + result_list = [] + + # Determine a list of relevant import statements + target_list = data_storage.get_functions_by_project( + project_name) + data_storage.get_constructors_by_project( + project_name) + public_classes = list( + function_helper.get_public_class_list(target_list)) + relevant_import = _determine_relevant_imports(public_classes) + + # Determine if the test files are relevant by checking if the relevant + # import statements exist in the test files + for test_file in repo_match: + if _contains_relevant_import(test_file, relevant_import, + project_name): + result_list.append(test_file) + + return result_list + + return repo_match + + +def _determine_relevant_imports(public_classes): + """Helper function to determine list of relevant imports""" + + # Group classes by package + package_map = {} + for public_class in public_classes: + if '.' in public_class: + package, class_name = public_class.rsplit('.', 1) + else: + package = '' + class_name = public_class + + if package not in package_map: + package_map[package] = [] + + package_map[package].append(class_name) + + # Generate separate import statements + imports = set() + for package, classes in package_map.items(): + # Import with no package + if not package: + for class_name in classes: + imports.add(f"import {class_name};") + continue + + # Specific imports + for class_name in classes: + imports.add(f"import {package}.{class_name};") + + # Wildcard imports + levels = package.split('.') + for i in range(1, len(levels) + 1): + imports.add(f"import {'.'.join(levels[:i])}.*;") + + # Combine separate and wildcard imports + return list(imports) + + +def _contains_relevant_import(test_file, imports, project_name): + """Helper function to determine if imports exists in test_file.""" + + import_pattern = re.compile(r'^\s*import\s+([\w.*]+);\s*(//.*)?$', + re.MULTILINE) + + # Extract source from test file path + datestr = get_latest_introspector_date(project_name) + source = extract_lines_from_source_code(project_name, datestr, test_file, + 0, 10000) + + if not source: + # Always return true if we failed to extract the test source + return True + + # Extract all import statements from the source code + source_imports = set() + for match in import_pattern.finditer(source): + source_imports.add(' '.join(match.group(0).strip().split())) + + return bool(source_imports.intersection(imports)) + + @api_blueprint.route('/api/project-tests') @api_blueprint.arguments(ProjectSchema, location='query') def project_tests(args):