Skip to content

Commit

Permalink
Fix and extend cloud backup support (#5464)
Browse files Browse the repository at this point in the history
* Fix and extend cloud backup support

* Clean up task for cloud backup and remove by location

* Args to kwargs on backup methods

* Fix backup remove error test and typing clean up
  • Loading branch information
mdegat01 authored Dec 5, 2024
1 parent 9b52fee commit 6e32144
Show file tree
Hide file tree
Showing 22 changed files with 587 additions and 335 deletions.
130 changes: 98 additions & 32 deletions supervisor/api/backups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Backups RESTful API."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
import errno
Expand All @@ -14,7 +16,7 @@
import voluptuous as vol

from ..backups.backup import Backup
from ..backups.const import LOCATION_CLOUD_BACKUP
from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE
from ..backups.validate import ALL_FOLDERS, FOLDER_HOMEASSISTANT, days_until_stale
from ..const import (
ATTR_ADDONS,
Expand All @@ -23,7 +25,7 @@
ATTR_CONTENT,
ATTR_DATE,
ATTR_DAYS_UNTIL_STALE,
ATTR_FILENAME,
ATTR_EXTRA,
ATTR_FOLDERS,
ATTR_HOMEASSISTANT,
ATTR_HOMEASSISTANT_EXCLUDE_DATABASE,
Expand All @@ -48,7 +50,12 @@
from ..jobs import JobSchedulerOptions
from ..mounts.const import MountUsage
from ..resolution.const import UnhealthyReason
from .const import ATTR_BACKGROUND, ATTR_LOCATIONS, CONTENT_TYPE_TAR
from .const import (
ATTR_ADDITIONAL_LOCATIONS,
ATTR_BACKGROUND,
ATTR_LOCATIONS,
CONTENT_TYPE_TAR,
)
from .utils import api_process, api_validate

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -60,6 +67,14 @@
# Remove: 2022.08
_ALL_FOLDERS = ALL_FOLDERS + [FOLDER_HOMEASSISTANT]


def _ensure_list(item: Any) -> list:
"""Ensure value is a list."""
if not isinstance(item, list):
return [item]
return item


# pylint: disable=no-value-for-parameter
SCHEMA_RESTORE_FULL = vol.Schema(
{
Expand All @@ -81,9 +96,12 @@
vol.Optional(ATTR_NAME): str,
vol.Optional(ATTR_PASSWORD): vol.Maybe(str),
vol.Optional(ATTR_COMPRESSED): vol.Maybe(vol.Boolean()),
vol.Optional(ATTR_LOCATION): vol.Maybe(str),
vol.Optional(ATTR_LOCATION): vol.All(
_ensure_list, [vol.Maybe(str)], vol.Unique()
),
vol.Optional(ATTR_HOMEASSISTANT_EXCLUDE_DATABASE): vol.Boolean(),
vol.Optional(ATTR_BACKGROUND, default=False): vol.Boolean(),
vol.Optional(ATTR_EXTRA): dict,
}
)

Expand All @@ -106,12 +124,6 @@
vol.Optional(ATTR_TIMEOUT): vol.All(int, vol.Range(min=1)),
}
)
SCHEMA_RELOAD = vol.Schema(
{
vol.Inclusive(ATTR_LOCATION, "file"): vol.Maybe(str),
vol.Inclusive(ATTR_FILENAME, "file"): vol.Match(RE_BACKUP_FILENAME),
}
)


class APIBackups(CoreSysAttributes):
Expand Down Expand Up @@ -177,13 +189,10 @@ async def options(self, request):
self.sys_backups.save_data()

@api_process
async def reload(self, request: web.Request):
async def reload(self, _):
"""Reload backup list."""
body = await api_validate(SCHEMA_RELOAD, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
backup = self._location_to_mount(body)

return await asyncio.shield(self.sys_backups.reload(**backup))
await asyncio.shield(self.sys_backups.reload())
return True

@api_process
async def backup_info(self, request):
Expand Down Expand Up @@ -217,27 +226,35 @@ async def backup_info(self, request):
ATTR_REPOSITORIES: backup.repositories,
ATTR_FOLDERS: backup.folders,
ATTR_HOMEASSISTANT_EXCLUDE_DATABASE: backup.homeassistant_exclude_database,
ATTR_EXTRA: backup.extra,
}

def _location_to_mount(self, body: dict[str, Any]) -> dict[str, Any]:
"""Change location field to mount if necessary."""
if not body.get(ATTR_LOCATION) or body[ATTR_LOCATION] == LOCATION_CLOUD_BACKUP:
return body
def _location_to_mount(self, location: str | None) -> LOCATION_TYPE:
"""Convert a single location to a mount if possible."""
if not location or location == LOCATION_CLOUD_BACKUP:
return location

body[ATTR_LOCATION] = self.sys_mounts.get(body[ATTR_LOCATION])
if body[ATTR_LOCATION].usage != MountUsage.BACKUP:
mount = self.sys_mounts.get(location)
if mount.usage != MountUsage.BACKUP:
raise APIError(
f"Mount {body[ATTR_LOCATION].name} is not used for backups, cannot backup to there"
f"Mount {mount.name} is not used for backups, cannot backup to there"
)

return mount

def _location_field_to_mount(self, body: dict[str, Any]) -> dict[str, Any]:
"""Change location field to mount if necessary."""
body[ATTR_LOCATION] = self._location_to_mount(body.get(ATTR_LOCATION))
return body

def _validate_cloud_backup_location(
self, request: web.Request, location: str | None
self, request: web.Request, location: list[str | None] | str | None
) -> None:
"""Cloud backup location is only available to Home Assistant."""
if not isinstance(location, list):
location = [location]
if (
location == LOCATION_CLOUD_BACKUP
LOCATION_CLOUD_BACKUP in location
and request.get(REQUEST_FROM) != self.sys_homeassistant
):
raise APIForbidden(
Expand Down Expand Up @@ -278,10 +295,22 @@ async def release_on_freeze(new_state: CoreState):
async def backup_full(self, request: web.Request):
"""Create full backup."""
body = await api_validate(SCHEMA_BACKUP_FULL, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
locations: list[LOCATION_TYPE] | None = None

if ATTR_LOCATION in body:
location_names: list[str | None] = body.pop(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)

locations = [
self._location_to_mount(location) for location in location_names
]
body[ATTR_LOCATION] = locations.pop(0)
if locations:
body[ATTR_ADDITIONAL_LOCATIONS] = locations

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_full, **self._location_to_mount(body)
self.sys_backups.do_backup_full, **body
)

if background and not backup_task.done():
Expand All @@ -299,10 +328,22 @@ async def backup_full(self, request: web.Request):
async def backup_partial(self, request: web.Request):
"""Create a partial backup."""
body = await api_validate(SCHEMA_BACKUP_PARTIAL, request)
self._validate_cloud_backup_location(request, body.get(ATTR_LOCATION))
locations: list[LOCATION_TYPE] | None = None

if ATTR_LOCATION in body:
location_names: list[str | None] = body.pop(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)

locations = [
self._location_to_mount(location) for location in location_names
]
body[ATTR_LOCATION] = locations.pop(0)
if locations:
body[ATTR_ADDITIONAL_LOCATIONS] = locations

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_partial, **self._location_to_mount(body)
self.sys_backups.do_backup_partial, **body
)

if background and not backup_task.done():
Expand Down Expand Up @@ -370,9 +411,11 @@ async def remove(self, request: web.Request):
self._validate_cloud_backup_location(request, backup.location)
return self.sys_backups.remove(backup)

@api_process
async def download(self, request: web.Request):
"""Download a backup file."""
backup = self._extract_slug(request)
self._validate_cloud_backup_location(request, backup.location)

_LOGGER.info("Downloading backup %s", backup.slug)
response = web.FileResponse(backup.tarfile)
Expand All @@ -385,7 +428,23 @@ async def download(self, request: web.Request):
@api_process
async def upload(self, request: web.Request):
"""Upload a backup file."""
with TemporaryDirectory(dir=str(self.sys_config.path_tmp)) as temp_dir:
location: LOCATION_TYPE = None
locations: list[LOCATION_TYPE] | None = None
tmp_path = self.sys_config.path_tmp
if ATTR_LOCATION in request.query:
location_names: list[str] = request.query.getall(ATTR_LOCATION)
self._validate_cloud_backup_location(request, location_names)
# Convert empty string to None if necessary
locations = [
self._location_to_mount(location) if location else None
for location in location_names
]
location = locations.pop(0)

if location and location != LOCATION_CLOUD_BACKUP:
tmp_path = location.local_where

with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir:
tar_file = Path(temp_dir, "backup.tar")
reader = await request.multipart()
contents = await reader.next()
Expand All @@ -398,15 +457,22 @@ async def upload(self, request: web.Request):
backup.write(chunk)

except OSError as err:
if err.errno == errno.EBADMSG:
if err.errno == errno.EBADMSG and location in {
LOCATION_CLOUD_BACKUP,
None,
}:
self.sys_resolution.unhealthy = UnhealthyReason.OSERROR_BAD_MESSAGE
_LOGGER.error("Can't write new backup file: %s", err)
return False

except asyncio.CancelledError:
return False

backup = await asyncio.shield(self.sys_backups.import_backup(tar_file))
backup = await asyncio.shield(
self.sys_backups.import_backup(
tar_file, location=location, additional_locations=locations
)
)

if backup:
return {ATTR_SLUG: backup.slug}
Expand Down
1 change: 1 addition & 0 deletions supervisor/api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

COOKIE_INGRESS = "ingress_session"

ATTR_ADDITIONAL_LOCATIONS = "additional_locations"
ATTR_AGENT_VERSION = "agent_version"
ATTR_APPARMOR_VERSION = "apparmor_version"
ATTR_ATTRIBUTES = "attributes"
Expand Down
Loading

0 comments on commit 6e32144

Please sign in to comment.