diff --git a/src/time_delta.rs b/src/time_delta.rs index ddab0d57a7..6cf28eecde 100644 --- a/src/time_delta.rs +++ b/src/time_delta.rs @@ -372,6 +372,40 @@ impl TimeDelta { TimeDelta::new(secs, nanos as u32) } + /// Multiply a `TimeDelta` with a i32, returning `None` if overflow occurred. + #[must_use] + pub const fn checked_mul(&self, rhs: i32) -> Option { + // Multiply nanoseconds as i64, because it cannot overflow that way. + let total_nanos = self.nanos as i64 * rhs as i64; + let (extra_secs, nanos) = div_mod_floor_64(total_nanos, NANOS_PER_SEC as i64); + // Multiply seconds as i128 to prevent overflow + let secs: i128 = self.secs as i128 * rhs as i128 + extra_secs as i128; + if secs <= i64::MIN as i128 || secs >= i64::MAX as i128 { + return None; + }; + Some(TimeDelta { secs: secs as i64, nanos: nanos as i32 }) + } + + /// Divide a `TimeDelta` with a i32, returning `None` if dividing by 0. + #[must_use] + pub const fn checked_div(&self, rhs: i32) -> Option { + if rhs == 0 { + return None; + } + let secs = self.secs / rhs as i64; + let carry = self.secs % rhs as i64; + let extra_nanos = carry * NANOS_PER_SEC as i64 / rhs as i64; + let nanos = self.nanos / rhs + extra_nanos as i32; + + let (secs, nanos) = match nanos { + i32::MIN..=-1 => (secs - 1, nanos + NANOS_PER_SEC), + NANOS_PER_SEC..=i32::MAX => (secs + 1, nanos - NANOS_PER_SEC), + _ => (secs, nanos), + }; + + Some(TimeDelta { secs, nanos }) + } + /// Returns the `TimeDelta` as an absolute (non-negative) value. #[inline] pub const fn abs(&self) -> TimeDelta { @@ -489,11 +523,7 @@ impl Mul for TimeDelta { type Output = TimeDelta; fn mul(self, rhs: i32) -> TimeDelta { - // Multiply nanoseconds as i64, because it cannot overflow that way. - let total_nanos = self.nanos as i64 * rhs as i64; - let (extra_secs, nanos) = div_mod_floor_64(total_nanos, NANOS_PER_SEC as i64); - let secs = self.secs * rhs as i64 + extra_secs; - TimeDelta { secs, nanos: nanos as i32 } + self.checked_mul(rhs).expect("`TimeDelta * i32` overflowed") } } @@ -501,19 +531,7 @@ impl Div for TimeDelta { type Output = TimeDelta; fn div(self, rhs: i32) -> TimeDelta { - let mut secs = self.secs / rhs as i64; - let carry = self.secs - secs * rhs as i64; - let extra_nanos = carry * NANOS_PER_SEC as i64 / rhs as i64; - let mut nanos = self.nanos / rhs + extra_nanos as i32; - if nanos >= NANOS_PER_SEC { - nanos -= NANOS_PER_SEC; - secs += 1; - } - if nanos < 0 { - nanos += NANOS_PER_SEC; - secs -= 1; - } - TimeDelta { secs, nanos } + self.checked_div(rhs).expect("`i32` is zero") } } @@ -1034,6 +1052,7 @@ mod tests { #[test] fn test_duration_checked_ops() { let milliseconds = |ms| TimeDelta::try_milliseconds(ms).unwrap(); + let seconds = |s| TimeDelta::try_seconds(s).unwrap(); assert_eq!( milliseconds(i64::MAX).checked_add(&milliseconds(0)), @@ -1056,6 +1075,10 @@ mod tests { ); assert!(milliseconds(-i64::MAX).checked_sub(&milliseconds(1)).is_none()); assert!(milliseconds(-i64::MAX).checked_sub(&TimeDelta::nanoseconds(1)).is_none()); + + assert!(seconds(i64::MAX / 1000).checked_mul(2000).is_none()); + assert!(seconds(i64::MIN / 1000).checked_mul(2000).is_none()); + assert!(seconds(1).checked_div(0).is_none()); } #[test]