Skip to content

Commit

Permalink
refactoring api-handling (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
ansibleguy committed Mar 12, 2024
1 parent a9e41c6 commit b052bab
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 55 deletions.
136 changes: 90 additions & 46 deletions plugins/module_utils/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_simple_existing, to_digit, get_matching, simplify_translate, is_unset, \
sort_param_lists
from ansible_collections.ansibleguy.opnsense.plugins.module_utils.base.handler import \
exit_bug, exit_debug, ModuleSoftError
exit_bug, ModuleSoftError


class Base:
Expand All @@ -27,6 +27,7 @@ class Base:
ATTR_TRANSLATE = 'FIELDS_TRANSLATE'
ATTR_DIFF_EXCL = 'FIELDS_DIFF_EXCLUDE'
ATTR_VALUE_MAP = 'FIELDS_VALUE_MAPPING'
ATTR_BUILD_COPY = 'FIELDS_BUILD_COPY'
ATTR_VALUE_MAP_RCV = 'FIELDS_VALUE_MAPPING_RCV'
ATTR_FIELD_ALL = 'FIELDS_ALL'
ATTR_FIELD_CH = 'FIELDS_CHANGE'
Expand All @@ -40,6 +41,7 @@ class Base:
ATTR_FIELD_ID = 'FIELD_ID' # field we use for matching
ATTR_FIELD_PK = 'FIELD_PK' # field opnsense uses as primary key
PARAM_MATCH_FIELDS = 'match_fields'
QUERY_MAX_ENTRIES = 1000

REQUIRED_ATTRS = [
ATTR_AK_PATH,
Expand All @@ -53,18 +55,13 @@ class Base:
def __init__(self, instance):
self.i = instance # module-specific object
self.e = {} # existing entry
self.raw = None # to save first raw existing entry - to resolve user input per selection

for attr in self.REQUIRED_ATTRS:
if not hasattr(self.i, attr):
exit_bug(f"Module has no '{attr}' attribute set!")

def search(self, fail_response: bool = False) -> (dict, list):
if fail_response:
# find response keys in initial development
exit_debug(self._api_get(cnf={
**self.i.call_cnf, **{'command': self.i.CMDS['search']}
}))

def search(self) -> (dict, list):
# workaround if 'get' needs to be performed using other api module/controller
cont_get, mod_get = self.i.API_CONT, self.i.API_MOD

Expand All @@ -74,13 +71,49 @@ def search(self, fail_response: bool = False) -> (dict, list):
if hasattr(self.i, self.ATTR_GET_MOD):
mod_get = getattr(self.i, self.ATTR_GET_MOD)

data = self._api_get(cnf={
**self.i.call_cnf,
**{
'module': mod_get,
'controller': cont_get,
self.i.call_cnf['controller'] = cont_get
self.i.call_cnf['module'] = mod_get

if self.i.CMDS['search'].startswith('search'):
# case for api-refactoring: https://github.com/ansibleguy/collection_opnsense/issues/51
if 'detail' not in self.i.CMDS:
exit_bug("To use the 'search' commands you need to also define the related 'detail' (get) command!")

data = []

for base_entry in self._api_post({
**self.i.call_cnf,
'command': self.i.CMDS['search'],
}
'data': {'current': 1, 'rowCount': self.QUERY_MAX_ENTRIES},
})['rows']:
# todo: perform async calls for parallel data fetching
data.append({
**self._search_path_handling(
self._api_get({
**self.i.call_cnf,
'command': self.i.CMDS['detail'],
'params': [base_entry[self.field_pk]]
})
),
**base_entry,
})
if self.raw is None:
self.raw = data[0]

if self.raw is None:
self.raw = self._search_path_handling(
self._api_get({
**self.i.call_cnf,
'command': self.i.CMDS['detail'],
})
)

return data

# legacy api handling (fewer requests needed)
data = self._api_get({
**self.i.call_cnf,
'command': self.i.CMDS['search'],
})

if hasattr(self.i, self.ATTR_GET_ADD):
Expand All @@ -91,7 +124,7 @@ def search(self, fail_response: bool = False) -> (dict, list):
self._search_path_handling(data=data, ak_path=ak_path)
)

return self._search_path_handling(data)
return self._search_path_handling(data)

def _search_path_handling(self, data: dict, ak_path: str = None) -> dict:
# resolving API_KEY_PATH's so data from nested dicts gets extracted as configured
Expand Down Expand Up @@ -137,8 +170,8 @@ def find(self, match_fields: list) -> None:
self.i.exists = True
self.i.r['diff']['before'] = self.build_diff(data=match)

if 'uuid' in match:
self.i.call_cnf['params'] = [match['uuid']]
if self.field_pk in match:
self.i.call_cnf['params'] = [match[self.field_pk]]

def process(self) -> None:
if 'state' in self.i.p and self.i.p['state'] == 'absent':
Expand Down Expand Up @@ -238,14 +271,14 @@ def _update_enabled(self) -> None:

if 'enabled' in existing:
if existing['enabled'] != self.i.p['enabled']:
BOOL_INVERT_FIELDS = []
_bool_invert_fields = []
enable = self.i.p['enabled']
invert = False

if hasattr(self.i, self.ATTR_BOOL_INVERT):
BOOL_INVERT_FIELDS = getattr(self.i, self.ATTR_BOOL_INVERT)
_bool_invert_fields = getattr(self.i, self.ATTR_BOOL_INVERT)

if 'enabled' in BOOL_INVERT_FIELDS:
if 'enabled' in _bool_invert_fields:
invert = True
enable = not enable

Expand Down Expand Up @@ -307,7 +340,7 @@ def _change_enabled_state(self, value: int) -> dict:
return self._api_post(cnf={
**self.i.call_cnf, **{
'command': self.i.CMDS['toggle'],
'params': [getattr(self.i, self.i.EXIST_ATTR)['uuid'], value],
'params': [getattr(self.i, self.i.EXIST_ATTR)[self.field_pk], value],
}
})

Expand Down Expand Up @@ -341,23 +374,19 @@ def build_diff(self, data: dict) -> dict:
if not isinstance(data, dict):
exit_bug('The diff-source object must be of type dict!')

EXCLUDE_FIELDS = []
_exclude_fields = []

if hasattr(self.i, self.ATTR_DIFF_EXCL):
EXCLUDE_FIELDS = getattr(self.i, self.ATTR_DIFF_EXCL)
_exclude_fields = getattr(self.i, self.ATTR_DIFF_EXCL)

self._set_existing()

field_pk = 'uuid'
if hasattr(self.i, self.ATTR_FIELD_PK):
field_pk = getattr(self.i, self.ATTR_FIELD_PK)

diff = {
field_pk: self.e[field_pk] if field_pk in self.e else None
self.field_pk: self.e[self.field_pk] if self.field_pk in self.e else None
}

for field in self.i.FIELDS_ALL:
if field in EXCLUDE_FIELDS:
if field in _exclude_fields:
continue

stringify = True
Expand All @@ -370,7 +399,12 @@ def build_diff(self, data: dict) -> dict:
diff[field] = self.i.p[field]

if isinstance(diff[field], list):
diff[field].sort()
try:
diff[field].sort()

except TypeError:
raise exit_bug(f"Field not defined as 'select_opt_list' type: {diff[field]}")

stringify = False

elif isinstance(diff[field], str) and diff[field].isnumeric:
Expand All @@ -381,8 +415,8 @@ def build_diff(self, data: dict) -> dict:
except (TypeError, ValueError):
pass

elif isinstance(diff[field], dict) and 'uuid' in diff[field]:
diff[field] = diff[field]['uuid']
elif isinstance(diff[field], dict) and self.field_pk in diff[field]:
diff[field] = diff[field][self.field_pk]

elif isinstance(diff[field], (bool, int)):
stringify = False
Expand All @@ -402,9 +436,9 @@ def build_diff(self, data: dict) -> dict:

def build_request(self, ignore_fields: list = None) -> dict:
request = {}
TRANSLATE_FIELDS = {}
TRANSLATE_VALUES = {}
BOOL_INVERT_FIELDS = []
_translate_fields = {}
_translate_values = {}
_bool_invert_fields = []

if ignore_fields is None:
ignore_fields = []
Expand All @@ -413,37 +447,40 @@ def build_request(self, ignore_fields: list = None) -> dict:
self.e = getattr(self.i, self.i.EXIST_ATTR)

if hasattr(self.i, self.ATTR_TRANSLATE):
TRANSLATE_FIELDS = getattr(self.i, self.ATTR_TRANSLATE)

if hasattr(self.i, self.ATTR_BOOL_INVERT):
BOOL_INVERT_FIELDS = getattr(self.i, self.ATTR_BOOL_INVERT)
_translate_fields = getattr(self.i, self.ATTR_TRANSLATE)

if hasattr(self.i, self.ATTR_VALUE_MAP):
TRANSLATE_VALUES = getattr(self.i, self.ATTR_VALUE_MAP)
_translate_values = getattr(self.i, self.ATTR_VALUE_MAP)

if hasattr(self.i, self.ATTR_BOOL_INVERT):
_bool_invert_fields = getattr(self.i, self.ATTR_BOOL_INVERT)

for field in self.i.FIELDS_ALL:
if field in ignore_fields:
continue

opn_field = field
if field in TRANSLATE_FIELDS:
opn_field = TRANSLATE_FIELDS[field]
if field in _translate_fields:
opn_field = _translate_fields[field]

if field in self.i.p:
opn_data = self.i.p[field]

else:
elif field in self.e:
opn_data = self.e[field]

if field in TRANSLATE_VALUES:
else:
opn_data = ''

if field in _translate_values:
try:
opn_data = TRANSLATE_VALUES[field][opn_data]
opn_data = _translate_values[field][opn_data]

except KeyError:
pass

if isinstance(opn_data, bool):
if field in BOOL_INVERT_FIELDS:
if field in _bool_invert_fields:
opn_data = not opn_data

request[opn_field] = to_digit(opn_data)
Expand Down Expand Up @@ -588,6 +625,13 @@ def simplify_existing(self, existing: dict) -> dict:
value_map=value_map,
)

@property
def field_pk(self) -> str:
if hasattr(self.i, self.ATTR_FIELD_PK):
return getattr(self.i, self.ATTR_FIELD_PK)

return 'uuid'

def _call_simple(self) -> Callable:
if hasattr(self.i, 'simplify_existing'):
return self.i.simplify_existing
Expand Down
14 changes: 5 additions & 9 deletions plugins/module_utils/base/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self, m: AnsibleModule, r: dict, s: Session = None):
'controller': self.b.i.API_CONT,
}

def _search_call(self, fail_response: bool = False) -> list:
return self.b.search(fail_response=fail_response)
def _search_call(self) -> list:
return self.b.search()

def _base_check(self, match_fields: list = None):
if match_fields is None:
Expand All @@ -43,7 +43,7 @@ def _base_check(self, match_fields: list = None):
if match_fields is not None:
self.b.find(match_fields=match_fields)
if self.exists:
self.call_cnf['params'] = [getattr(self, self.EXIST_ATTR)['uuid']]
self.call_cnf['params'] = [getattr(self, self.EXIST_ATTR)[self.b.field_pk]]

if self.p['state'] == 'present':
self.r['diff']['after'] = self.b.build_diff(data=self.p)
Expand Down Expand Up @@ -115,12 +115,8 @@ def check(self) -> None:
self.settings = self._search_call()
self._build_diff()

def _search_call(self, fail_response: bool = False) -> dict:
return self.b.simplify_existing(
self.b.search(
fail_response=fail_response
)
)
def _search_call(self) -> dict:
return self.b.simplify_existing(self.b.search())

def get_existing(self) -> dict:
return self._search_call()
Expand Down
13 changes: 13 additions & 0 deletions plugins/module_utils/helper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ def get_selected_opt_list(data: (dict, list)) -> (str, None):
return get_selected_value(data)


def get_selected_opt_list_idx(data: list) -> int:
idx = 0
for values in data:
if is_true(values['selected']):
return idx

idx += 1

return 0

def get_selected_list(data: dict, remove_empty: bool = False) -> list:
if isinstance(data, list):
# if function is re-applied
Expand Down Expand Up @@ -412,6 +422,9 @@ def simplify_translate(
elif t == 'select_opt_list':
simple[f] = get_selected_opt_list(simple[f])

elif t == 'select_opt_list_idx':
simple[f] = get_selected_opt_list_idx(simple[f])

for f, vmap in value_map.items():
try:
for pretty_value, opn_value in vmap.items():
Expand Down

0 comments on commit b052bab

Please sign in to comment.