diff --git a/src/dispatch/incident/severity/service.py b/src/dispatch/incident/severity/service.py index 1ee317e4d0bd..b9de04442a65 100644 --- a/src/dispatch/incident/severity/service.py +++ b/src/dispatch/incident/severity/service.py @@ -106,7 +106,7 @@ def get_all(*, db_session, project_id: int = None) -> List[Optional[IncidentSeve if project_id: return db_session.query(IncidentSeverity).filter(IncidentSeverity.project_id == project_id) - return db_session.query(IncidentSeverity) + return db_session.query(IncidentSeverity).all() def get_all_enabled(*, db_session, project_id: int = None) -> List[Optional[IncidentSeverity]]: diff --git a/tests/conftest.py b/tests/conftest.py index 9f36dc7b37fa..6d5055ce1e4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ IncidentFactory, IncidentPriorityFactory, IncidentRoleFactory, + IncidentSeverityFactory, IncidentTypeFactory, IndividualContactFactory, NotificationFactory, @@ -337,6 +338,11 @@ def incident_priorities(session): return [IncidentPriorityFactory(), IncidentPriorityFactory()] +@pytest.fixture +def incident_severity(session): + return IncidentSeverityFactory() + + @pytest.fixture def incident_role(session): return IncidentRoleFactory() diff --git a/tests/incident_severity/test_incident_severity_service.py b/tests/incident_severity/test_incident_severity_service.py new file mode 100644 index 000000000000..029702e8b87d --- /dev/null +++ b/tests/incident_severity/test_incident_severity_service.py @@ -0,0 +1,153 @@ +def test_get(session, incident_severity): + from dispatch.incident.severity.service import get + + t_incident_severity = get(db_session=session, incident_severity_id=incident_severity.id) + assert t_incident_severity.id == incident_severity.id + + +def test_get_default(session, incident_severity): + from dispatch.incident.severity.service import get_default + + incident_severity.default = True + + t_incident_severity = get_default(db_session=session, project_id=incident_severity.project.id) + assert t_incident_severity.id == incident_severity.id + + +def test_get_default_or_raise__fail(session, incident_severity): + from pydantic.error_wrappers import ValidationError + from dispatch.incident.severity.service import get_default_or_raise + + incident_severity.default = False + validation_error = False + + try: + get_default_or_raise(db_session=session, project_id=incident_severity.project.id) + except ValidationError: + validation_error = True + + assert validation_error + + +def test_get_by_name(session, incident_severity): + from dispatch.incident.severity.service import get_by_name + + assert get_by_name( + db_session=session, project_id=incident_severity.project_id, name=incident_severity.name + ) + + +def get_by_name_or_raise__fail(session, incident_severity): + """Returns the incident severity specified or raises ValidationError.""" + from pydantic.error_wrappers import ValidationError + from dispatch.incident.severity.models import IncidentSeverityRead + from dispatch.incident.severity.service import get_by_name_or_raise + + incident_severity_in = IncidentSeverityRead.from_orm(incident_severity) + incident_severity_in.name += "_fail" + validation_error = False + + try: + get_by_name_or_raise( + db_session=session, + project_id=incident_severity.project.id, + incident_severity_in=incident_severity_in, + ) + except ValidationError: + validation_error = True + + assert validation_error + + +def test_get_by_name_or_default__name(session, incident_severity): + from dispatch.incident.severity.models import IncidentSeverityRead + from dispatch.incident.severity.service import get_by_name_or_default + + incident_severity_in = IncidentSeverityRead.from_orm(incident_severity) + + assert get_by_name_or_default( + db_session=session, + project_id=incident_severity.project_id, + incident_severity_in=incident_severity_in, + ) + + +def test_get_by_name_or_default__default(session, incident_severity): + from dispatch.incident.severity.service import get_by_name_or_default + + incident_severity.default = True + + assert get_by_name_or_default( + db_session=session, + project_id=incident_severity.project_id, + incident_severity_in=None, + ) + + +def test_get_all(session, incident_severity): + from dispatch.incident.severity.service import get_all + + assert get_all( + db_session=session, + project_id=incident_severity.project_id, + ) + + +def test_get_all_enabled(session, incident_severity): + from dispatch.incident.severity.service import get_all_enabled + + incident_severity.enabled = True + + assert get_all_enabled(db_session=session, project_id=incident_severity.project_id) + + +def test_get_all_enabled__empty(session, incident_severity): + from dispatch.incident.severity.service import get_all, get_all_enabled + + for severity in get_all( + db_session=session, + project_id=incident_severity.project_id, + ): + severity.enabled = False + + assert not get_all_enabled(db_session=session, project_id=incident_severity.project_id).all() + + +def test_create(session, incident_severity): + from dispatch.project.models import ProjectRead + from dispatch.incident.severity.models import IncidentSeverityCreate + from dispatch.incident.severity.service import create + + incident_severity_in = IncidentSeverityCreate( + name="new_name", + description="new_description", + color="FFFFFF", + project=ProjectRead.from_orm(incident_severity.project), + ) + + assert create(db_session=session, incident_severity_in=incident_severity_in) + + +def test_update(session, incident_severity): + from dispatch.incident.severity.models import IncidentSeverityUpdate + from dispatch.incident.severity.service import update + + expected_name = incident_severity.name + "_updated" + incident_severity_in = IncidentSeverityUpdate.from_orm(incident_severity) + + incident_severity_in.name = expected_name + + t_incident_severity = update( + db_session=session, + incident_severity=incident_severity, + incident_severity_in=incident_severity_in, + ) + + assert t_incident_severity.name == expected_name + + +def test_delete(session, incident_severity): + from dispatch.incident.severity.service import get, delete + + delete(db_session=session, incident_severity_id=incident_severity.id) + assert not get(db_session=session, incident_severity_id=incident_severity.id) diff --git a/tests/report/test_report_flows.py b/tests/report/test_report_flows.py new file mode 100644 index 000000000000..22c94bebc1c2 --- /dev/null +++ b/tests/report/test_report_flows.py @@ -0,0 +1,17 @@ +def test_create_tactical_report(session, incident, participant): + from dispatch.report.flows import create_tactical_report + from dispatch.report.models import TacticalReportCreate + + participant.incident = incident + + tactical_report_in = TacticalReportCreate( + conditions="sample conditions", actions="sample actions", needs="sample needs" + ) + + assert create_tactical_report( + user_email=participant.individual.email, + incident_id=participant.incident.id, + tactical_report_in=tactical_report_in, + organization_slug=incident.project.organization.slug, + db_session=session, + )