|
| 1 | +extern crate bytes; |
| 2 | +extern crate tokio; |
| 3 | +extern crate tokio_codec; |
| 4 | +extern crate futures; |
| 5 | +extern crate lightning; |
| 6 | +extern crate secp256k1; |
| 7 | + |
| 8 | +use bytes::BufMut; |
| 9 | + |
| 10 | +use futures::future; |
| 11 | +use futures::future::Future; |
| 12 | +use futures::{AsyncSink, Stream, Sink}; |
| 13 | +use futures::sync::mpsc; |
| 14 | + |
| 15 | +use secp256k1::key::PublicKey; |
| 16 | + |
| 17 | +use tokio::timer::Delay; |
| 18 | +use tokio::net::TcpStream; |
| 19 | + |
| 20 | +use lightning::ln::peer_handler; |
| 21 | +use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait; |
| 22 | + |
| 23 | +use std::mem; |
| 24 | +use std::net::SocketAddr; |
| 25 | +use std::sync::{Arc, Mutex}; |
| 26 | +use std::sync::atomic::{AtomicU64, Ordering}; |
| 27 | +use std::time::{Duration, Instant}; |
| 28 | +use std::vec::Vec; |
| 29 | +use std::hash::Hash; |
| 30 | + |
| 31 | +static ID_COUNTER: AtomicU64 = AtomicU64::new(0); |
| 32 | + |
| 33 | +/// A connection to a remote peer. Can be constructed either as a remote connection using |
| 34 | +/// Connection::setup_outbound o |
| 35 | +pub struct Connection { |
| 36 | + writer: Option<mpsc::Sender<bytes::Bytes>>, |
| 37 | + event_notify: mpsc::Sender<()>, |
| 38 | + pending_read: Vec<u8>, |
| 39 | + read_blocker: Option<futures::sync::oneshot::Sender<Result<(), ()>>>, |
| 40 | + read_paused: bool, |
| 41 | + need_disconnect: bool, |
| 42 | + id: u64, |
| 43 | +} |
| 44 | +impl Connection { |
| 45 | + fn schedule_read(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>, us: Arc<Mutex<Self>>, reader: futures::stream::SplitStream<tokio_codec::Framed<TcpStream, tokio_codec::BytesCodec>>) { |
| 46 | + let us_ref = us.clone(); |
| 47 | + let us_close_ref = us.clone(); |
| 48 | + let peer_manager_ref = peer_manager.clone(); |
| 49 | + tokio::spawn(reader.for_each(move |b| { |
| 50 | + let pending_read = b.to_vec(); |
| 51 | + { |
| 52 | + let mut lock = us_ref.lock().unwrap(); |
| 53 | + assert!(lock.pending_read.is_empty()); |
| 54 | + if lock.read_paused { |
| 55 | + lock.pending_read = pending_read; |
| 56 | + let (sender, blocker) = futures::sync::oneshot::channel(); |
| 57 | + lock.read_blocker = Some(sender); |
| 58 | + return future::Either::A(blocker.then(|_| { Ok(()) })); |
| 59 | + } |
| 60 | + } |
| 61 | + //TODO: There's a race where we don't meet the requirements of disconnect_socket if its |
| 62 | + //called right here, after we release the us_ref lock in the scope above, but before we |
| 63 | + //call read_event! |
| 64 | + match peer_manager.read_event(&mut SocketDescriptor::new(us_ref.clone(), peer_manager.clone()), pending_read) { |
| 65 | + Ok(pause_read) => { |
| 66 | + if pause_read { |
| 67 | + let mut lock = us_ref.lock().unwrap(); |
| 68 | + lock.read_paused = true; |
| 69 | + } |
| 70 | + }, |
| 71 | + Err(e) => { |
| 72 | + us_ref.lock().unwrap().need_disconnect = false; |
| 73 | + return future::Either::B(future::result(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + if let Err(e) = us_ref.lock().unwrap().event_notify.try_send(()) { |
| 78 | + // Ignore full errors as we just need them to poll after this point, so if the user |
| 79 | + // hasn't received the last send yet, it doesn't matter. |
| 80 | + assert!(e.is_full()); |
| 81 | + } |
| 82 | + |
| 83 | + future::Either::B(future::result(Ok(()))) |
| 84 | + }).then(move |_| { |
| 85 | + if us_close_ref.lock().unwrap().need_disconnect { |
| 86 | + peer_manager_ref.disconnect_event(&SocketDescriptor::new(us_close_ref, peer_manager_ref.clone())); |
| 87 | + println!("Peer disconnected!"); |
| 88 | + } else { |
| 89 | + println!("We disconnected peer!"); |
| 90 | + } |
| 91 | + Ok(()) |
| 92 | + })); |
| 93 | + } |
| 94 | + |
| 95 | + fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (futures::stream::SplitStream<tokio_codec::Framed<TcpStream, tokio_codec::BytesCodec>>, Arc<Mutex<Self>>) { |
| 96 | + let (writer, reader) = tokio_codec::Framed::new(stream, tokio_codec::BytesCodec::new()).split(); |
| 97 | + let (send_sink, send_stream) = mpsc::channel(3); |
| 98 | + tokio::spawn(writer.send_all(send_stream.map_err(|_| -> std::io::Error { |
| 99 | + unreachable!(); |
| 100 | + })).then(|_| { |
| 101 | + future::result(Ok(())) |
| 102 | + })); |
| 103 | + let us = Arc::new(Mutex::new(Self { writer: Some(send_sink), event_notify, pending_read: Vec::new(), read_blocker: None, read_paused: false, need_disconnect: true, id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) })); |
| 104 | + |
| 105 | + (reader, us) |
| 106 | + } |
| 107 | + |
| 108 | + /// Process incoming messages and feed outgoing messages on the provided socket generated by |
| 109 | + /// accepting an incoming connection (by scheduling futures with tokio::spawn). |
| 110 | + /// |
| 111 | + /// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on |
| 112 | + /// ChannelManager and ChannelMonitor objects. |
| 113 | + pub fn setup_inbound(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>, event_notify: mpsc::Sender<()>, stream: TcpStream) { |
| 114 | + let (reader, us) = Self::new(event_notify, stream); |
| 115 | + |
| 116 | + if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(us.clone(), peer_manager.clone())) { |
| 117 | + Self::schedule_read(peer_manager, us, reader); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + /// Process incoming messages and feed outgoing messages on the provided socket generated by |
| 122 | + /// making an outbound connection which is expected to be accepted by a peer with the given |
| 123 | + /// public key (by scheduling futures with tokio::spawn). |
| 124 | + /// |
| 125 | + /// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on |
| 126 | + /// ChannelManager and ChannelMonitor objects. |
| 127 | + pub fn setup_outbound(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) { |
| 128 | + let (reader, us) = Self::new(event_notify, stream); |
| 129 | + |
| 130 | + if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone(), peer_manager.clone())) { |
| 131 | + if SocketDescriptor::new(us.clone(), peer_manager.clone()).send_data(&initial_send, 0, true) == initial_send.len() { |
| 132 | + Self::schedule_read(peer_manager, us, reader); |
| 133 | + } else { |
| 134 | + println!("Failed to write first full message to socket!"); |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + /// Process incoming messages and feed outgoing messages on a new connection made to the given |
| 140 | + /// socket address which is expected to be accepted by a peer with the given public key (by |
| 141 | + /// scheduling futures with tokio::spawn). |
| 142 | + /// |
| 143 | + /// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on |
| 144 | + /// ChannelManager and ChannelMonitor objects. |
| 145 | + pub fn connect_outbound(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, addr: SocketAddr) { |
| 146 | + let connect_timeout = Delay::new(Instant::now() + Duration::from_secs(10)).then(|_| { |
| 147 | + future::err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout reached")) |
| 148 | + }); |
| 149 | + tokio::spawn(TcpStream::connect(&addr).select(connect_timeout) |
| 150 | + .and_then(move |stream| { |
| 151 | + Connection::setup_outbound(peer_manager, event_notify, their_node_id, stream.0); |
| 152 | + future::ok(()) |
| 153 | + }).or_else(|_| { |
| 154 | + //TODO: return errors somehow |
| 155 | + future::ok(()) |
| 156 | + })); |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +#[derive(Clone)] |
| 161 | +pub struct SocketDescriptor { |
| 162 | + conn: Arc<Mutex<Connection>>, |
| 163 | + id: u64, |
| 164 | + peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>, |
| 165 | +} |
| 166 | +impl SocketDescriptor { |
| 167 | + fn new(conn: Arc<Mutex<Connection>>, peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor>>) -> Self { |
| 168 | + let id = conn.lock().unwrap().id; |
| 169 | + Self { conn, id, peer_manager } |
| 170 | + } |
| 171 | +} |
| 172 | +impl peer_handler::SocketDescriptor for SocketDescriptor { |
| 173 | + fn send_data(&mut self, data: &Vec<u8>, write_offset: usize, resume_read: bool) -> usize { |
| 174 | + macro_rules! schedule_read { |
| 175 | + ($us_ref: expr) => { |
| 176 | + tokio::spawn(future::lazy(move || -> Result<(), ()> { |
| 177 | + let mut read_data = Vec::new(); |
| 178 | + { |
| 179 | + let mut us = $us_ref.conn.lock().unwrap(); |
| 180 | + mem::swap(&mut read_data, &mut us.pending_read); |
| 181 | + } |
| 182 | + if !read_data.is_empty() { |
| 183 | + let mut us_clone = $us_ref.clone(); |
| 184 | + match $us_ref.peer_manager.read_event(&mut us_clone, read_data) { |
| 185 | + Ok(pause_read) => { |
| 186 | + if pause_read { return Ok(()); } |
| 187 | + }, |
| 188 | + Err(_) => { |
| 189 | + //TODO: Not actually sure how to do this |
| 190 | + return Ok(()); |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | + let mut us = $us_ref.conn.lock().unwrap(); |
| 195 | + if let Some(sender) = us.read_blocker.take() { |
| 196 | + sender.send(Ok(())).unwrap(); |
| 197 | + } |
| 198 | + us.read_paused = false; |
| 199 | + if let Err(e) = us.event_notify.try_send(()) { |
| 200 | + // Ignore full errors as we just need them to poll after this point, so if the user |
| 201 | + // hasn't received the last send yet, it doesn't matter. |
| 202 | + assert!(e.is_full()); |
| 203 | + } |
| 204 | + Ok(()) |
| 205 | + })); |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + let mut us = self.conn.lock().unwrap(); |
| 210 | + if resume_read { |
| 211 | + let us_ref = self.clone(); |
| 212 | + schedule_read!(us_ref); |
| 213 | + } |
| 214 | + if data.len() == write_offset { return 0; } |
| 215 | + if us.writer.is_none() { |
| 216 | + us.read_paused = true; |
| 217 | + return 0; |
| 218 | + } |
| 219 | + |
| 220 | + let mut bytes = bytes::BytesMut::with_capacity(data.len() - write_offset); |
| 221 | + bytes.put(&data[write_offset..]); |
| 222 | + let write_res = us.writer.as_mut().unwrap().start_send(bytes.freeze()); |
| 223 | + match write_res { |
| 224 | + Ok(res) => { |
| 225 | + match res { |
| 226 | + AsyncSink::Ready => { |
| 227 | + data.len() - write_offset |
| 228 | + }, |
| 229 | + AsyncSink::NotReady(_) => { |
| 230 | + us.read_paused = true; |
| 231 | + let us_ref = self.clone(); |
| 232 | + tokio::spawn(us.writer.take().unwrap().flush().then(move |writer_res| -> Result<(), ()> { |
| 233 | + if let Ok(writer) = writer_res { |
| 234 | + { |
| 235 | + let mut us = us_ref.conn.lock().unwrap(); |
| 236 | + us.writer = Some(writer); |
| 237 | + } |
| 238 | + schedule_read!(us_ref); |
| 239 | + } // we'll fire the disconnect event on the socket reader end |
| 240 | + Ok(()) |
| 241 | + })); |
| 242 | + 0 |
| 243 | + } |
| 244 | + } |
| 245 | + }, |
| 246 | + Err(_) => { |
| 247 | + // We'll fire the disconnected event on the socket reader end |
| 248 | + 0 |
| 249 | + }, |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + fn disconnect_socket(&mut self) { |
| 254 | + let mut us = self.conn.lock().unwrap(); |
| 255 | + us.need_disconnect = true; |
| 256 | + us.read_paused = true; |
| 257 | + } |
| 258 | +} |
| 259 | +impl Eq for SocketDescriptor {} |
| 260 | +impl PartialEq for SocketDescriptor { |
| 261 | + fn eq(&self, o: &Self) -> bool { |
| 262 | + self.id == o.id |
| 263 | + } |
| 264 | +} |
| 265 | +impl Hash for SocketDescriptor { |
| 266 | + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
| 267 | + self.id.hash(state); |
| 268 | + } |
| 269 | +} |
| 270 | + |
0 commit comments