diff --git a/rednose/helpers/ekf_sym.py b/rednose/helpers/ekf_sym.py index 9dffe50..2a0b416 100644 --- a/rednose/helpers/ekf_sym.py +++ b/rednose/helpers/ekf_sym.py @@ -391,7 +391,7 @@ def augment(self): assert self.P.shape == (self.dim_err, self.dim_err) def state(self): - return np.array(self.x).flatten() + return np.array(self.x).ravel() def covs(self): return self.P @@ -461,27 +461,26 @@ def predict(self, t): self.normalize_quaternions() self.filter_time = t - def predict_and_update_batch(self, t, kind, z, R, extra_args=[[]], augment=False): # pylint: disable=dangerous-default-value + def predict_and_update_batch(self, t, kind, z, R, xk_km1=None, Pk_km1=None, xk_k=None, Pk_k=None, y=None, extra_args=[[]], augment=False): # pylint: disable=dangerous-default-value # TODO handle rewinding at this level" # rewind if self.filter_time is not None and t < self.filter_time: if len(self.rewind_t) == 0 or t < self.rewind_t[0] or t < self.rewind_t[-1] - self.max_rewind_age: self.logger.error(f"observation too old at {t:.3f} with filter at {self.filter_time:.3f}, ignoring") - return None + return False rewound = self.rewind(t) else: rewound = [] - ret = self._predict_and_update_batch(t, kind, z, R, extra_args, augment) + self._predict_and_update_batch(t, kind, z, R, extra_args, xk_km1, Pk_km1, xk_k, Pk_k, y, augment) # optional fast forward for r in rewound: self._predict_and_update_batch(*r) + return True - return ret - - def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False): + def _predict_and_update_batch(self, t, kind, z, R, extra_args, xk_km1=None, Pk_km1=None, xk_k=None, Pk_k=None, y=None, augment=False): """The main kalman filter function Predicts the state and then updates a batch of observations dim_x: dimensionality of the state space @@ -497,6 +496,8 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False): assert z.shape[0] == R.shape[0] assert z.shape[1] == R.shape[1] assert z.shape[1] == R.shape[2] + assert y is None or z.shape == y.shape + # assert stuff # initialize time if self.filter_time is None: @@ -507,10 +508,13 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False): assert dt >= 0 self.x, self.P = self._predict(self.x, self.P, dt) self.filter_time = t - xk_km1, Pk_km1 = np.copy(self.x).flatten(), np.copy(self.P) + + if xk_km1 is not None: + np.copyto(xk_km1, self.x.ravel()) + if Pk_km1 is not None: + np.copyto(Pk_km1, self.P) # update batch - y = [] for i in range(len(z)): # these are from the user, so we canonicalize them z_i = np.array(z[i], dtype=np.float64, order='F') @@ -519,8 +523,13 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False): # update self.x, self.P, y_i = self._update(self.x, self.P, kind, z_i, R_i, extra_args=extra_args_i) self.normalize_quaternions() - y.append(y_i) - xk_k, Pk_k = np.copy(self.x).flatten(), np.copy(self.P) + if y is not None: + np.copyto(y_i, y[i]) + + if xk_k is not None: + np.copyto(xk_k, self.x.ravel()) + if Pk_k is not None: + np.copyto(Pk_k, self.P) if augment: self.augment() @@ -528,8 +537,6 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False): # checkpoint self.checkpoint((t, kind, z, R, extra_args)) - return xk_km1, xk_k, Pk_km1, Pk_k, t, kind, y, z, extra_args - def _predict_python(self, x, P, dt): x_new = np.zeros(x.shape, dtype=np.float64) self.f(x, dt, x_new)