diff --git a/src/olm/account/mod.rs b/src/olm/account/mod.rs index 20a97d09..18e7b05e 100644 --- a/src/olm/account/mod.rs +++ b/src/olm/account/mod.rs @@ -129,8 +129,8 @@ impl Account { } /// Sign the given message using our Ed25519 fingerprint key. - pub fn sign(&self, message: &str) -> Ed25519Signature { - self.signing_key.sign(message.as_bytes()) + pub fn sign(&self, message: impl AsRef<[u8]>) -> Ed25519Signature { + self.signing_key.sign(message.as_ref()) } /// Get the maximum number of one-time keys the client should keep on the @@ -1076,7 +1076,7 @@ mod test { #[allow(clippy::redundant_clone)] let signing_key_clone = account_with_expanded_key.signing_key.clone(); signing_key_clone.sign("You met with a terrible fate, haven’t you?".as_bytes()); - account_with_expanded_key.sign("You met with a terrible fate, haven’t you?"); + account_with_expanded_key.sign("You met with a terrible fate, haven’t you?".as_bytes()); Ok(()) } @@ -1146,7 +1146,7 @@ mod test { let vodozemac_pickle = account.to_libolm_pickle(key).unwrap(); let _ = Account::from_libolm_pickle(&vodozemac_pickle, key).unwrap(); - let vodozemac_signature = account.sign(message); + let vodozemac_signature = account.sign(message.as_bytes()); let olm_signature = Ed25519Signature::from_base64(&olm_signature) .expect("We should be able to parse a signature produced by libolm"); account diff --git a/src/olm/messages/mod.rs b/src/olm/messages/mod.rs index 88b70f26..90f8bae2 100644 --- a/src/olm/messages/mod.rs +++ b/src/olm/messages/mod.rs @@ -19,7 +19,7 @@ pub use message::Message; pub use pre_key::PreKeyMessage; use serde::{Deserialize, Serialize}; -use crate::DecodeError; +use crate::{base64_decode, base64_encode, DecodeError}; /// Enum over the different Olm message types. /// @@ -67,9 +67,8 @@ impl Serialize for OlmMessage { where S: serde::Serializer, { - let (message_type, ciphertext) = self.clone().to_parts(); - - let message = MessageSerdeHelper { message_type, ciphertext }; + let (message_type, ciphertext) = self.to_parts(); + let message = MessageSerdeHelper { message_type, ciphertext: base64_encode(ciphertext) }; message.serialize(serializer) } @@ -78,18 +77,19 @@ impl Serialize for OlmMessage { impl<'de> Deserialize<'de> for OlmMessage { fn deserialize>(d: D) -> Result { let value = MessageSerdeHelper::deserialize(d)?; + let ciphertext_bytes = base64_decode(value.ciphertext).map_err(serde::de::Error::custom)?; - OlmMessage::from_parts(value.message_type, &value.ciphertext) + OlmMessage::from_parts(value.message_type, ciphertext_bytes.as_slice()) .map_err(serde::de::Error::custom) } } impl OlmMessage { /// Create a `OlmMessage` from a message type and a ciphertext. - pub fn from_parts(message_type: usize, ciphertext: &str) -> Result { + pub fn from_parts(message_type: usize, ciphertext: &[u8]) -> Result { match message_type { - 0 => Ok(Self::PreKey(PreKeyMessage::try_from(ciphertext)?)), - 1 => Ok(Self::Normal(Message::try_from(ciphertext)?)), + 0 => Ok(Self::PreKey(PreKeyMessage::from_bytes(ciphertext)?)), + 1 => Ok(Self::Normal(Message::from_bytes(ciphertext)?)), m => Err(DecodeError::MessageType(m)), } } @@ -110,14 +110,13 @@ impl OlmMessage { } } - /// Convert the `OlmMessage` into a message type, and base64 encoded message - /// tuple. - pub fn to_parts(self) -> (usize, String) { + /// Convert the `OlmMessage` into a message type, and message bytes tuple. + pub fn to_parts(&self) -> (usize, Vec) { let message_type = self.message_type(); match self { - OlmMessage::Normal(m) => (message_type.into(), m.to_base64()), - OlmMessage::PreKey(m) => (message_type.into(), m.to_base64()), + OlmMessage::Normal(m) => (message_type.into(), m.to_bytes()), + OlmMessage::PreKey(m) => (message_type.into(), m.to_bytes()), } } } @@ -156,8 +155,10 @@ use olm_rs::session::OlmMessage as LibolmMessage; impl From for OlmMessage { fn from(other: LibolmMessage) -> Self { let (message_type, ciphertext) = other.to_tuple(); + let ciphertext_bytes = base64_decode(ciphertext).expect("Can't decode base64"); - Self::from_parts(message_type.into(), &ciphertext).expect("Can't decode a libolm message") + Self::from_parts(message_type.into(), ciphertext_bytes.as_slice()) + .expect("Can't decode a libolm message") } } @@ -247,7 +248,7 @@ mod tests { #[test] fn from_parts() -> Result<()> { - let message = OlmMessage::from_parts(0, PRE_KEY_MESSAGE)?; + let message = OlmMessage::from_parts(0, base64_decode(PRE_KEY_MESSAGE)?.as_slice())?; assert_matches!(message, OlmMessage::PreKey(_)); assert_eq!( message.message_type(), @@ -255,9 +256,13 @@ mod tests { "Expected message to be recognized as a pre-key Olm message." ); assert_eq!(message.message(), PRE_KEY_MESSAGE_CIPHERTEXT); - assert_eq!(message.to_parts(), (0, PRE_KEY_MESSAGE.to_string()), "Roundtrip not identity."); + assert_eq!( + message.to_parts(), + (0, base64_decode(PRE_KEY_MESSAGE)?), + "Roundtrip not identity." + ); - let message = OlmMessage::from_parts(1, MESSAGE)?; + let message = OlmMessage::from_parts(1, base64_decode(MESSAGE)?.as_slice())?; assert_matches!(message, OlmMessage::Normal(_)); assert_eq!( message.message_type(), @@ -265,9 +270,9 @@ mod tests { "Expected message to be recognized as a normal Olm message." ); assert_eq!(message.message(), MESSAGE_CIPHERTEXT); - assert_eq!(message.to_parts(), (1, MESSAGE.to_string()), "Roundtrip not identity."); + assert_eq!(message.to_parts(), (1, base64_decode(MESSAGE)?), "Roundtrip not identity."); - OlmMessage::from_parts(3, PRE_KEY_MESSAGE) + OlmMessage::from_parts(3, base64_decode(PRE_KEY_MESSAGE)?.as_slice()) .expect_err("Unknown message types can't be parsed"); Ok(())