Skip to content

Do DNS resolution before connection attempt #893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/RabbitMQ.Client/client/api/ITcpClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Threading.Tasks;

Expand All @@ -17,6 +18,7 @@ public interface ITcpClient : IDisposable
Socket Client { get; }

Task ConnectAsync(string host, int port);
Task ConnectAsync(IPAddress host, int port);

NetworkStream GetStream();

Expand Down
35 changes: 26 additions & 9 deletions projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,38 @@ public SocketFrameHandler(AmqpTcpEndpoint endpoint,
_channelReader = channel.Reader;
_channelWriter = channel.Writer;

if (ShouldTryIPv6(endpoint))
// Resolve the hostname to know if it's even possible to even try IPv6
IPAddress[] adds = Dns.GetHostAddresses(endpoint.HostName);
IPAddress ipv6 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetworkV6);

if (ipv6 == default(IPAddress))
{
if (endpoint.AddressFamily == AddressFamily.InterNetworkV6)
{
throw new ConnectFailureException("Connection failed", new ArgumentException($"No IPv6 address could be resolved for {endpoint.HostName}"));
}
}
else if (ShouldTryIPv6(endpoint))
{
try
{
_socket = ConnectUsingIPv6(endpoint, socketFactory, connectionTimeout);
_socket = ConnectUsingIPv6(new IPEndPoint(ipv6, endpoint.Port), socketFactory, connectionTimeout);
}
catch (ConnectFailureException)
{
// We resolved to a ipv6 address and tried it but it still didn't connect, try IPv4
_socket = null;
}
}

if (_socket == null && endpoint.AddressFamily != AddressFamily.InterNetworkV6)
if (_socket == null)
{
_socket = ConnectUsingIPv4(endpoint, socketFactory, connectionTimeout);
IPAddress ipv4 = TcpClientAdapterHelper.GetMatchingHost(adds, AddressFamily.InterNetwork);
if (ipv4 == default(IPAddress))
{
throw new ConnectFailureException("Connection failed", new ArgumentException($"No ip address could be resolved for {endpoint.HostName}"));
}
_socket = ConnectUsingIPv4(new IPEndPoint(ipv4, endpoint.Port), socketFactory, connectionTimeout);
}

Stream netstream = _socket.GetStream();
Expand Down Expand Up @@ -276,21 +293,21 @@ private static bool ShouldTryIPv6(AmqpTcpEndpoint endpoint)
return Socket.OSSupportsIPv6 && endpoint.AddressFamily != AddressFamily.InterNetwork;
}

private ITcpClient ConnectUsingIPv6(AmqpTcpEndpoint endpoint,
private ITcpClient ConnectUsingIPv6(IPEndPoint endpoint,
Func<AddressFamily, ITcpClient> socketFactory,
TimeSpan timeout)
{
return ConnectUsingAddressFamily(endpoint, socketFactory, timeout, AddressFamily.InterNetworkV6);
}

private ITcpClient ConnectUsingIPv4(AmqpTcpEndpoint endpoint,
private ITcpClient ConnectUsingIPv4(IPEndPoint endpoint,
Func<AddressFamily, ITcpClient> socketFactory,
TimeSpan timeout)
{
return ConnectUsingAddressFamily(endpoint, socketFactory, timeout, AddressFamily.InterNetwork);
}

private ITcpClient ConnectUsingAddressFamily(AmqpTcpEndpoint endpoint,
private ITcpClient ConnectUsingAddressFamily(IPEndPoint endpoint,
Func<AddressFamily, ITcpClient> socketFactory,
TimeSpan timeout, AddressFamily family)
{
Expand All @@ -307,11 +324,11 @@ private ITcpClient ConnectUsingAddressFamily(AmqpTcpEndpoint endpoint,
}
}

private void ConnectOrFail(ITcpClient socket, AmqpTcpEndpoint endpoint, TimeSpan timeout)
private void ConnectOrFail(ITcpClient socket, IPEndPoint endpoint, TimeSpan timeout)
{
try
{
socket.ConnectAsync(endpoint.HostName, endpoint.Port)
socket.ConnectAsync(endpoint.Address, endpoint.Port)
.TimeoutAfter(timeout)
.ConfigureAwait(false)
// this ensures exceptions aren't wrapped in an AggregateException
Expand Down
6 changes: 6 additions & 0 deletions projects/RabbitMQ.Client/client/impl/TcpClientAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public virtual async Task ConnectAsync(string host, int port)
throw new ArgumentException($"No ip address could be resolved for {host}");
}

await ConnectAsync(ep, port);
}

public virtual async Task ConnectAsync(IPAddress ep, int port)
{
AssertSocket();
#if NET461
await Task.Run(() => _sock.Connect(ep, port)).ConfigureAwait(false);
#else
Expand Down
1 change: 1 addition & 0 deletions projects/Unit/APIApproval.Approve.verified.txt
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ namespace RabbitMQ.Client
bool Connected { get; }
System.TimeSpan ReceiveTimeout { get; set; }
void Close();
System.Threading.Tasks.Task ConnectAsync(System.Net.IPAddress host, int port);
System.Threading.Tasks.Task ConnectAsync(string host, int port);
System.Net.Sockets.NetworkStream GetStream();
}
Expand Down