Skip to content

Remove unnecessary borrow_parts() methods #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 39 additions & 57 deletions lightning/src/ln/channelmanager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,24 +274,6 @@ pub(super) struct ChannelHolder<ChanSigner: ChannelKeys> {
/// for broadcast messages, where ordering isn't as strict).
pub(super) pending_msg_events: Vec<events::MessageSendEvent>,
}
pub(super) struct MutChannelHolder<'a, ChanSigner: ChannelKeys + 'a> {
pub(super) by_id: &'a mut HashMap<[u8; 32], Channel<ChanSigner>>,
pub(super) short_to_id: &'a mut HashMap<u64, [u8; 32]>,
pub(super) forward_htlcs: &'a mut HashMap<u64, Vec<HTLCForwardInfo>>,
pub(super) claimable_htlcs: &'a mut HashMap<PaymentHash, Vec<(u64, HTLCPreviousHopData)>>,
pub(super) pending_msg_events: &'a mut Vec<events::MessageSendEvent>,
}
impl<ChanSigner: ChannelKeys> ChannelHolder<ChanSigner> {
pub(super) fn borrow_parts(&mut self) -> MutChannelHolder<ChanSigner> {
MutChannelHolder {
by_id: &mut self.by_id,
short_to_id: &mut self.short_to_id,
forward_htlcs: &mut self.forward_htlcs,
claimable_htlcs: &mut self.claimable_htlcs,
pending_msg_events: &mut self.pending_msg_events,
}
}
}

#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
const ERR: () = "You need at least 32 bit pointers (well, usize, but we'll assume they're the same) for ChannelManager::latest_block_height";
Expand Down Expand Up @@ -738,7 +720,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

let (mut failed_htlcs, chan_option) = {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(channel_id.clone()) {
hash_map::Entry::Occupied(mut chan_entry) => {
let (shutdown_msg, failed_htlcs) = chan_entry.get_mut().get_shutdown()?;
Expand Down Expand Up @@ -795,7 +777,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

let mut chan = {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
if let Some(chan) = channel_state.by_id.remove(channel_id) {
if let Some(short_id) = chan.get_short_channel_id() {
channel_state.short_to_id.remove(&short_id);
Expand Down Expand Up @@ -1127,7 +1109,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
Some(id) => id.clone(),
};

let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
if let hash_map::Entry::Occupied(mut chan) = channel_state.by_id.entry(id) {
match {
if chan.get().get_their_node_id() != route.hops.first().unwrap().pubkey {
Expand Down Expand Up @@ -1275,7 +1257,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
let mut handle_errors = Vec::new();
{
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

for (short_chan_id, mut pending_forwards) in channel_state.forward_htlcs.drain() {
if short_chan_id != 0 {
Expand Down Expand Up @@ -1473,8 +1455,8 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
pub fn timer_chan_freshness_every_min(&self) {
let _ = self.total_consistency_lock.read().unwrap();
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
for (_, chan) in channel_state.by_id {
let channel_state = &mut *channel_state_lock;
for (_, chan) in channel_state.by_id.iter_mut() {
if chan.is_disabled_staged() && !chan.is_live() {
if let Ok(update) = self.get_channel_update(&chan) {
channel_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
Expand Down Expand Up @@ -1657,7 +1639,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
},
HTLCSource::PreviousHopData(HTLCPreviousHopData { short_channel_id, htlc_id, .. }) => {
//TODO: Delay the claimed_funds relaying just like we do outbound relay!
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

let chan_id = match channel_state.short_to_id.get(&short_channel_id) {
Some(chan_id) => chan_id.clone(),
Expand Down Expand Up @@ -1729,9 +1711,9 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

{
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let short_to_id = channel_state.short_to_id;
let pending_msg_events = channel_state.pending_msg_events;
let channel_state = &mut *channel_lock;
let short_to_id = &mut channel_state.short_to_id;
let pending_msg_events = &mut channel_state.pending_msg_events;
channel_state.by_id.retain(|_, channel| {
if channel.is_awaiting_monitor_update() {
let chan_monitor = channel.channel_monitor().clone();
Expand Down Expand Up @@ -1836,7 +1818,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
let channel = Channel::new_from_req(&*self.fee_estimator, &self.keys_manager, their_node_id.clone(), their_features, msg, 0, Arc::clone(&self.logger), &self.default_configuration)
.map_err(|e| MsgHandleErrInternal::from_chan_no_close(e, msg.temporary_channel_id))?;
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(channel.channel_id()) {
hash_map::Entry::Occupied(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("temporary_channel_id collision!", msg.temporary_channel_id.clone())),
hash_map::Entry::Vacant(entry) => {
Expand All @@ -1853,7 +1835,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_accept_channel(&self, their_node_id: &PublicKey, their_features: InitFeatures, msg: &msgs::AcceptChannel) -> Result<(), MsgHandleErrInternal> {
let (value, output_script, user_id) = {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.temporary_channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -1878,7 +1860,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_funding_created(&self, their_node_id: &PublicKey, msg: &msgs::FundingCreated) -> Result<(), MsgHandleErrInternal> {
let ((funding_msg, monitor_update), mut chan) = {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.temporary_channel_id.clone()) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand Down Expand Up @@ -1910,7 +1892,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
}
}
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(funding_msg.channel_id) {
hash_map::Entry::Occupied(_) => {
return Err(MsgHandleErrInternal::send_err_msg_no_close("Already had channel with the new channel_id", funding_msg.channel_id))
Expand All @@ -1929,7 +1911,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_funding_signed(&self, their_node_id: &PublicKey, msg: &msgs::FundingSigned) -> Result<(), MsgHandleErrInternal> {
let (funding_txo, user_id) = {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -1954,7 +1936,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_funding_locked(&self, their_node_id: &PublicKey, msg: &msgs::FundingLocked) -> Result<(), MsgHandleErrInternal> {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand Down Expand Up @@ -1985,7 +1967,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_shutdown(&self, their_node_id: &PublicKey, msg: &msgs::Shutdown) -> Result<(), MsgHandleErrInternal> {
let (mut dropped_htlcs, chan_option) = {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

match channel_state.by_id.entry(msg.channel_id.clone()) {
hash_map::Entry::Occupied(mut chan_entry) => {
Expand Down Expand Up @@ -2032,7 +2014,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_closing_signed(&self, their_node_id: &PublicKey, msg: &msgs::ClosingSigned) -> Result<(), MsgHandleErrInternal> {
let (tx, chan_option) = {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(msg.channel_id.clone()) {
hash_map::Entry::Occupied(mut chan_entry) => {
if chan_entry.get().get_their_node_id() != *their_node_id {
Expand Down Expand Up @@ -2086,7 +2068,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
//but we should prevent it anyway.

let (mut pending_forward_info, mut channel_state_lock) = self.decode_update_add_htlc_onion(msg);
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
Expand Down Expand Up @@ -2135,7 +2117,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_update_fulfill_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) -> Result<(), MsgHandleErrInternal> {
let mut channel_lock = self.channel_state.lock().unwrap();
let htlc_source = {
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -2152,7 +2134,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result<(), MsgHandleErrInternal> {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -2167,7 +2149,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) -> Result<(), MsgHandleErrInternal> {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -2185,7 +2167,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_commitment_signed(&self, their_node_id: &PublicKey, msg: &msgs::CommitmentSigned) -> Result<(), MsgHandleErrInternal> {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand Down Expand Up @@ -2261,7 +2243,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
fn internal_revoke_and_ack(&self, their_node_id: &PublicKey, msg: &msgs::RevokeAndACK) -> Result<(), MsgHandleErrInternal> {
let (pending_forwards, mut pending_failures, short_channel_id) = {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand Down Expand Up @@ -2305,7 +2287,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_update_fee(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFee) -> Result<(), MsgHandleErrInternal> {
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let channel_state = &mut *channel_lock;
match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
if chan.get().get_their_node_id() != *their_node_id {
Expand All @@ -2320,7 +2302,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_announcement_signatures(&self, their_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) -> Result<(), MsgHandleErrInternal> {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
Expand Down Expand Up @@ -2362,7 +2344,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {

fn internal_channel_reestablish(&self, their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) -> Result<(), MsgHandleErrInternal> {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

match channel_state.by_id.entry(msg.channel_id) {
hash_map::Entry::Occupied(mut chan) => {
Expand Down Expand Up @@ -2440,7 +2422,7 @@ impl<ChanSigner: ChannelKeys> ChannelManager<ChanSigner> {
let mut channel_state_lock = self.channel_state.lock().unwrap();
let their_node_id;
let err: Result<(), _> = loop {
let channel_state = channel_state_lock.borrow_parts();
let channel_state = &mut *channel_state_lock;

match channel_state.by_id.entry(channel_id) {
hash_map::Entry::Vacant(_) => return Err(APIError::APIMisuseError{err: "Failed to find corresponding channel"}),
Expand Down Expand Up @@ -2543,9 +2525,9 @@ impl<ChanSigner: ChannelKeys> ChainListener for ChannelManager<ChanSigner> {
let mut failed_channels = Vec::new();
{
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let short_to_id = channel_state.short_to_id;
let pending_msg_events = channel_state.pending_msg_events;
let channel_state = &mut *channel_lock;
let short_to_id = &mut channel_state.short_to_id;
let pending_msg_events = &mut channel_state.pending_msg_events;
channel_state.by_id.retain(|_, channel| {
let chan_res = channel.block_connected(header, height, txn_matched, indexes_of_txn_matched);
if let Ok(Some(funding_locked)) = chan_res {
Expand Down Expand Up @@ -2621,9 +2603,9 @@ impl<ChanSigner: ChannelKeys> ChainListener for ChannelManager<ChanSigner> {
let mut failed_channels = Vec::new();
{
let mut channel_lock = self.channel_state.lock().unwrap();
let channel_state = channel_lock.borrow_parts();
let short_to_id = channel_state.short_to_id;
let pending_msg_events = channel_state.pending_msg_events;
let channel_state = &mut *channel_lock;
let short_to_id = &mut channel_state.short_to_id;
let pending_msg_events = &mut channel_state.pending_msg_events;
channel_state.by_id.retain(|_, v| {
if v.block_disconnected(header) {
if let Some(short_id) = v.get_short_channel_id() {
Expand Down Expand Up @@ -2800,9 +2782,9 @@ impl<ChanSigner: ChannelKeys> ChannelMessageHandler for ChannelManager<ChanSigne
let mut failed_payments = Vec::new();
{
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let short_to_id = channel_state.short_to_id;
let pending_msg_events = channel_state.pending_msg_events;
let channel_state = &mut *channel_state_lock;
let short_to_id = &mut channel_state.short_to_id;
let pending_msg_events = &mut channel_state.pending_msg_events;
if no_connection_possible {
log_debug!(self, "Failing all channels with {} due to no_connection_possible", log_pubkey!(their_node_id));
channel_state.by_id.retain(|_, chan| {
Expand Down Expand Up @@ -2876,8 +2858,8 @@ impl<ChanSigner: ChannelKeys> ChannelMessageHandler for ChannelManager<ChanSigne

let _ = self.total_consistency_lock.read().unwrap();
let mut channel_state_lock = self.channel_state.lock().unwrap();
let channel_state = channel_state_lock.borrow_parts();
let pending_msg_events = channel_state.pending_msg_events;
let channel_state = &mut *channel_state_lock;
let pending_msg_events = &mut channel_state.pending_msg_events;
channel_state.by_id.retain(|_, chan| {
if chan.get_their_node_id() == *their_node_id {
if !chan.have_received_message() {
Expand Down
4 changes: 2 additions & 2 deletions lightning/src/ln/functional_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5019,7 +5019,7 @@ fn test_onion_failure() {
}, || {}, true, Some(17), None);

run_onion_failure_test("final_incorrect_cltv_expiry", 1, &nodes, &route, &payment_hash, |_| {}, || {
for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().borrow_parts().forward_htlcs.iter_mut() {
for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
for f in pending_forwards.iter_mut() {
match f {
&mut HTLCForwardInfo::AddHTLC { ref mut forward_info, .. } =>
Expand All @@ -5032,7 +5032,7 @@ fn test_onion_failure() {

run_onion_failure_test("final_incorrect_htlc_amount", 1, &nodes, &route, &payment_hash, |_| {}, || {
// violate amt_to_forward > msg.amount_msat
for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().borrow_parts().forward_htlcs.iter_mut() {
for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
for f in pending_forwards.iter_mut() {
match f {
&mut HTLCForwardInfo::AddHTLC { ref mut forward_info, .. } =>
Expand Down
Loading