Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable patchwise training and prediction #135

Open
wants to merge 124 commits into
base: main
Choose a base branch
from

Conversation

davidwilby
Copy link
Collaborator

Hey @tom-andersson - at long last, the long-awaited patchwise training and prediction feature that @nilsleh and @MartinSJRogers have been working on.

This PR adds patching capabilities to DeepSensor during training and inference.

Training

Optional args patching_strategy, patch_size, stride and num_samples_per_date are added to TaskLoader.__call__.

There are two available patching strategies: random_window and sliding_window. The random_window option randomly selects points in the x1 and x2 extent as the centroid of the patch. The number of patches is defined by the num_samples_per_date argument. The sliding_window function starts in the top left of the dataset and convolves from left to right and top to bottom over the data using the user-defined patch_size and stride.

TaskLoader.__call__ now contains additional conditional logic depending upon the patching strategy selected. If no patching strategy is selected, task_generator() runs exactly as before. If random_window (sliding_window) is selected the bounding boxes for the patches are generated using the sample_random_window() (sample_sliding_window()) methods. The bounding boxes are appended to the list bboxes, and passed to task_generator().

Within task_generator() after the sampling strategies are applied, the data is spatially sliced using each bbox in bboxes using the self.spatial_slice_variable() function.

When using a patching strategy, TaskLoader produces a list of tasks per date, rather than an individual task per date. A small change has been made to Task's summarise_str method to avoid an error when printing patched Tasks and to output more meaningful information.

Inference

To run patchwise predictions, a new method has been created in model.py called predict_patch(). This method iterates through and applies the pre-exisiting predict() method to each patched task. The predict() method has not been changed. Within each iteration, prior to running predict() for each patch, the bounding box of each patch is unnormalized, so the X_t of each patch can be passed to the predict() function. The patchwise predictions are stored in the list preds for subsequent stitching.

It is only possible to use the sliding_window patching function during inference, and the stride and patch size are defined when the user generates the test tasks within the task_loader() call. The data_processor must also be passed to predict_patch() method to enable unnormalisation of the coordinates of the bboxes in model.py.

Once the list of patchwise predictions are generated, stitch_clipped_predictions() is used to form a prediction at the original X_t extent. Currently, functionality is provided to subset or clip each patchwise prediction so there is no overlap between adjacent patches and then merge the patches using xr.combine_by_coords(). The modular nature of the code means there is scope for additional stitching strategies to be added after this PR, for example applying a weighting function to overlapping predictions. To ensure the patches are clipped by the correct amount, get_patch_overlap() calculates the overlap between adjacent patches. stitch_clipped_predictions() also contains code to handle patches at the edge or bottom of the dataset, where the overlap may be different.

The output from predict_patch() is the identical DeepSensor object produced in model.predict(), hence DeepSensor’s plotting functionality can subsequently be used in the same way.

Documentation and Testing

New notebook(s) are added illustrating the usage of both patchwise training and prediction.

New tests are added to verify the new behaviour.

Limitations

  • Patchwise prediction does not currently support predicting at more than one timestamp - calling predict_patch with more than one date raises a NotImplementedError.
  • predict_patch is a new, distinct function due to all the pre-processing it needs to do, the patchwise behaviour may be better served as an option in predict - let me know what you think.
  • Patched tasks don't exactly follow the proportions from patch_size, e.g. for a 'square' patch patch_size=(0.5,0.5) the exact dimensions won't be exactly square, this is accounted for in stitching of patches, but is slightly inelegant at the moment so we may want to come back and find a more refined solution in the future.
  • In test_model.test_patchwise_prediction I've temporarily commented-out the asserts checking for correct prediction shape, these fail with test datasets for now, but with real datasets the shapes are correct, see the patchwise_training_and_prediction.ipynb notebook.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Collaborator

@tom-andersson tom-andersson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very exciting! Great to see this finally out for review! Thanks for kicking this off with a description of the feature, a new documentation page, and some unit tests.

As it's a very large PR, we'll probably have to go through a few review cycles. With that in mind, I've just skimmed and left a few high-level comments with the assumption that there will be some more iteration and tidying before I take another closer look.

Before sending back for review, please:

  1. Fix the failing unit tets. I think there is a type hint error.
  2. Generate the documentation locally and check they make sense. See https://github.com/alan-turing-institute/deepsensor/blob/main/CONTRIBUTING.md#contributing-to-documentation

progress_bar: int = 0,
verbose: bool = False,
) -> Prediction:
"""Predict on a regular grid or at off-grid locations.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring to explain the patching procedure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

325de6d

I'm also trying out using autodoc's versionadded directive here, in yesterday's community meeting users were keen to have new features highlighted in the docs. What do you think?

image

        .. versionadded:: 0.4.3
            :py:func:`predict_patchwise()` method.

## start with first patch top left hand corner at x1_min, x2_min
patch_list = []

# Todo: simplify these elif statements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do this in this PR? Would prefer something flatter and more modular rather than heavily indented elif statements.

deepsensor/model/model.py Outdated Show resolved Hide resolved
Comment on lines +1026 to +1038
"""
Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.

At end of row (when patch_x2_index = data_x2_index), to calculate the number of pixels to remove from left hand side of patch:
If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
to get the number of pixels to remove from left hand side of patch.

If x2 is descending. Subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
to get the number of pixels to remove from left hand side of patch.

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably make more sense in the method docstring as part of a general description of the method, at the lowest indentation level.

]
return (x1_index, x2_index)

def stitch_clipped_predictions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably make more sense to put this method in deepsensor.model.pred, since it does not use self, and just operates on Prediction objects. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I had wondered about where to separate out some of the methods used for patching, maybe the right compromise is to move this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking more on this, there are a number of methods specifically for stitching patches together, constituting a few hundred lines of code. A few options for organising are: moving into a specific module as functions; moving into the Prediction class as static methods; or moving into a child class e.g. PatchwisePrediction(Prediction); moving to the pred module as functions outside of the Prediction class. Which would you prefer?

)

## Cast prediction into DeepSensor.Prediction object.
# TODO make this into seperate method.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to do this in this PR, and add it to deepsensor.model.pred?

deepsensor/model/model.py Outdated Show resolved Hide resolved
Comment on lines +1115 to +1118
combined = {
var_name: xr.combine_by_coords(patches, compat="no_conflicts")
for var_name, patches in patches_clipped.items()
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expose an argument for the method used to combine patches (currently only "remove_overlap" supported, or whatever you think is more appropriate. This will make it more clear how to add new combining methods (like weighted averaging) in future.

deepsensor/model/model.py Outdated Show resolved Hide resolved

# gridded predictions
assert [isinstance(ds, xr.Dataset) for ds in pred.values()]
# TODO come back to this, for artificial datasets here, shapes of predictions don't match inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if we get to the bottom of this and uncomment the test before submitting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants