Skip to content

Commit 6a13e46

Browse files
committed
Require PartialEq for wire::Message in cfg(test)
...and implement wire::Type for `()` for `feature = "_test_utils"`.
1 parent d73cbf2 commit 6a13e46

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

lightning/src/ln/wire.rs

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,26 @@ pub trait CustomMessageReader {
2828
fn read<R: io::Read>(&self, message_type: u16, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
2929
}
3030

31+
// TestEq is a dummy trait which requires PartialEq when built in testing, and otherwise is
32+
// blanket-implemented for all types.
33+
34+
#[cfg(test)]
35+
pub trait TestEq : PartialEq {}
36+
#[cfg(test)]
37+
impl<T: PartialEq> TestEq for T {}
38+
39+
#[cfg(not(test))]
40+
pub(crate) trait TestEq {}
41+
#[cfg(not(test))]
42+
impl<T> TestEq for T {}
43+
44+
3145
/// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each
3246
/// variant contains a message from [`msgs`] or otherwise the message type if unknown.
3347
#[allow(missing_docs)]
3448
#[derive(Debug)]
35-
pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
49+
#[cfg_attr(test, derive(PartialEq))]
50+
pub(crate) enum Message<T> where T: core::fmt::Debug + Type + TestEq {
3651
Init(msgs::Init),
3752
Error(msgs::ErrorMessage),
3853
Warning(msgs::WarningMessage),
@@ -69,7 +84,7 @@ pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
6984
Custom(T),
7085
}
7186

72-
impl<T> Message<T> where T: core::fmt::Debug + Type {
87+
impl<T> Message<T> where T: core::fmt::Debug + Type + TestEq {
7388
/// Returns the type that was used to decode the message payload.
7489
pub fn type_id(&self) -> u16 {
7590
match self {
@@ -252,6 +267,7 @@ mod encode {
252267

253268
pub(crate) use self::encode::Encode;
254269

270+
#[cfg(not(test))]
255271
/// Defines a type identifier for sending messages over the wire.
256272
///
257273
/// Messages implementing this trait specify a type and must be [`Writeable`].
@@ -260,10 +276,24 @@ pub trait Type: core::fmt::Debug + Writeable {
260276
fn type_id(&self) -> u16;
261277
}
262278

279+
#[cfg(test)]
280+
pub trait Type: core::fmt::Debug + Writeable + PartialEq {
281+
fn type_id(&self) -> u16;
282+
}
283+
284+
#[cfg(any(feature = "_test_utils", fuzzing, test))]
285+
impl Type for () {
286+
fn type_id(&self) -> u16 { unreachable!(); }
287+
}
288+
289+
#[cfg(test)]
290+
impl<T: core::fmt::Debug + Writeable + PartialEq> Type for T where T: Encode {
291+
fn type_id(&self) -> u16 { T::TYPE }
292+
}
293+
294+
#[cfg(not(test))]
263295
impl<T: core::fmt::Debug + Writeable> Type for T where T: Encode {
264-
fn type_id(&self) -> u16 {
265-
T::TYPE
266-
}
296+
fn type_id(&self) -> u16 { T::TYPE }
267297
}
268298

269299
impl Encode for msgs::Init {
@@ -471,10 +501,6 @@ mod tests {
471501
}
472502
}
473503

474-
impl Type for () {
475-
fn type_id(&self) -> u16 { unreachable!(); }
476-
}
477-
478504
#[test]
479505
fn is_even_message_type() {
480506
let message = Message::<()>::Unknown(42);

0 commit comments

Comments
 (0)