Skip to content
This repository has been archived by the owner on Jun 24, 2022. It is now read-only.

Commit

Permalink
Add support for HttpDNS.
Browse files Browse the repository at this point in the history
Allow to use regex group match result in rules.cfg "Address" field.
Make use of C# 6.0 new grammar to simplify code.
Update README.md.
  • Loading branch information
stackia committed Oct 30, 2015
1 parent 041c711 commit 7c7614e
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 110 deletions.
139 changes: 96 additions & 43 deletions DNSAgent/DnsAgent.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
Expand Down Expand Up @@ -103,7 +104,9 @@ public bool Start()
else
Logger.Error("[Listener] Unexpected exception:\n{0}", e);
}
catch (ObjectDisposedException) {} // Force closing _udpListener will cause this exception
catch (ObjectDisposedException)
{
} // Force closing _udpListener will cause this exception
catch (Exception e)
{
Logger.Error("[Listener] Unexpected exception:\n{0}", e);
Expand Down Expand Up @@ -138,47 +141,45 @@ public bool Start()
if (Options.CacheResponse)
Cache.Update(message.Questions[0], message, Options.CacheAge);
}
catch (ParsingException) {}
catch (ParsingException)
{
}
catch (SocketException e)
{
if (e.SocketErrorCode != SocketError.ConnectionReset)
Logger.Error("[Forwarder.Send] Name server unreachable.");
else
Logger.Error("[Forwarder.Receive] Unexpected socket error:\n{0}", e);
}
catch (ObjectDisposedException) {} // Force closing _udpListener will cause this exception
catch (ObjectDisposedException)
{
} // Force closing _udpListener will cause this exception
catch (Exception e)
{
Logger.Error("[Forwarder] Unexpected exception:\n{0}", e);
}
}
}, _stopTokenSource.Token);

Logger.Info("Listening on {0}...", endPoint);
OnStarted();
return true;
}

public void Stop()
{
if (_stopTokenSource != null)
_stopTokenSource.Cancel();

if (_udpListener != null)
_udpListener.Close();

if (_udpForwarder != null)
_udpForwarder.Close();
_stopTokenSource?.Cancel();
_udpListener?.Close();
_udpForwarder?.Close();

try
{
if (_listeningTask != null)
_listeningTask.Wait();

if (_forwardingTask != null)
_forwardingTask.Wait();
_listeningTask?.Wait();
_forwardingTask?.Wait();
}
catch (AggregateException)
{
}
catch (AggregateException) {}

_stopTokenSource = null;
_udpListener = null;
Expand Down Expand Up @@ -248,6 +249,7 @@ await Task.Run(async () =>
}
var targetNameServer = Options.DefaultNameServer;
var useHttpQuery = Options.UseHttpQuery;
var queryTimeout = Options.QueryTimeout;
var useCompressionMutation = Options.CompressionMutation;
Expand All @@ -257,13 +259,17 @@ await Task.Run(async () =>
{
for (var i = Rules.Count - 1; i >= 0; i--)
{
if (!Regex.IsMatch(question.Name, Rules[i].Pattern)) continue;
var match = Regex.Match(question.Name, Rules[i].Pattern);
if (!match.Success) continue;
// Domain name matched
if (Rules[i].NameServer != null) // Name server override
targetNameServer = Rules[i].NameServer;
if (Rules[i].UseHttpQuery != null) // HTTP query override
useHttpQuery = Rules[i].UseHttpQuery.Value;
if (Rules[i].QueryTimeout != null) // Query timeout override
queryTimeout = Rules[i].QueryTimeout.Value;
Expand All @@ -276,27 +282,39 @@ await Task.Run(async () =>
IPAddress.TryParse(Rules[i].Address, out ip);
if (ip == null) // Invalid IP, may be a domain name
{
var serverEndpoint = Utils.CreateIpEndPoint(targetNameServer, 53);
var dnsClient = new DnsClient(serverEndpoint.Address, queryTimeout, serverEndpoint.Port);
var response = await Task<DnsMessage>.Factory.FromAsync(dnsClient.BeginResolve, dnsClient.EndResolve,
Rules[i].Address, question.RecordType, question.RecordClass, null);
if (response == null)
var address = string.Format(Rules[i].Address, match.Groups.Cast<object>().ToArray());
if (question.RecordType == RecordType.A && useHttpQuery)
{
Logger.Warning($"Remote resolve failed for {Rules[i].Address}.");
return;
await ResolveWithHttp(targetNameServer, address, queryTimeout, message);
}
message.ReturnCode = response.ReturnCode;
foreach (var answerRecord in response.AnswerRecords)
else
{
answerRecord.Name = question.Name;
message.AnswerRecords.Add(answerRecord);
var serverEndpoint = Utils.CreateIpEndPoint(targetNameServer, 53);
var dnsClient = new DnsClient(serverEndpoint.Address, queryTimeout,
serverEndpoint.Port);
var response =
await
Task<DnsMessage>.Factory.FromAsync(dnsClient.BeginResolve,
dnsClient.EndResolve,
address, question.RecordType, question.RecordClass, null);
if (response == null)
{
Logger.Warning($"Remote resolve failed for {address}.");
return;
}
foreach (var answerRecord in response.AnswerRecords)
{
answerRecord.Name = question.Name;
message.AnswerRecords.Add(answerRecord);
}
message.ReturnCode = response.ReturnCode;
message.IsQuery = false;
}
message.IsQuery = false;
}
else
{
if (question.RecordType == RecordType.A &&
ip.AddressFamily == AddressFamily.InterNetwork)
ip.AddressFamily == AddressFamily.InterNetwork)
message.AnswerRecords.Add(new ARecord(question.Name, 600, ip));
else if (question.RecordType == RecordType.Aaaa &&
ip.AddressFamily == AddressFamily.InterNetworkV6)
Expand Down Expand Up @@ -325,7 +343,6 @@ await Task.Run(async () =>
// message.AnswerRecords.AddRange(dnsResponse.Where(
// ip => ip.AddressFamily == AddressFamily.InterNetwork).Select(
// ip => new ARecord(question.Name, 0, ip)));
// }
// else if (question.RecordType == RecordType.Aaaa)
// {
// message.AnswerRecords.AddRange(dnsResponse.Where(
Expand All @@ -336,12 +353,16 @@ await Task.Run(async () =>
// message.IsQuery = false;
//}
if (message.IsQuery && question.RecordType == RecordType.A && useHttpQuery)
{
await ResolveWithHttp(targetNameServer, question.Name, queryTimeout, message);
}
if (message.IsQuery)
{
// Use internal forwarder to forward query to another name server
await
ForwardMessage(message, udpMessage, Utils.CreateIpEndPoint(targetNameServer, 53),
queryTimeout, useCompressionMutation);
await ForwardMessage(message, udpMessage, Utils.CreateIpEndPoint(targetNameServer, 53),
queryTimeout, useCompressionMutation);
}
else
{
Expand All @@ -359,7 +380,9 @@ await Task.Run(async () =>
}
}
}
catch (ParsingException) {}
catch (ParsingException)
{
}
catch (SocketException e)
{
Logger.Error("[Listener.Send] Unexpected socket error:\n{0}", e);
Expand Down Expand Up @@ -415,12 +438,13 @@ private async Task ForwardMessage(DnsMessage message, UdpReceiveResult originalU
out ignoreTokenSource);

var warningText = message.Questions.Count > 0
? string.Format("{0} (Type {1})", message.Questions[0].Name,
message.Questions[0].RecordType)
: string.Format("Transaction #{0}", message.TransactionID);
? $"{message.Questions[0].Name} (Type {message.Questions[0].RecordType})"
: $"Transaction #{message.TransactionID}";
Logger.Warning("Query timeout for: {0}", warningText);
}
catch (TaskCanceledException) {}
catch (TaskCanceledException)
{
}
}
catch (InfiniteForwardingException e)
{
Expand All @@ -447,18 +471,47 @@ private async Task ForwardMessage(DnsMessage message, UdpReceiveResult originalU
await _udpListener.SendAsync(responseBuffer, responseBuffer.Length, originalUdpMessage.RemoteEndPoint);
}

private static async Task ResolveWithHttp(string targetNameServer, string domainName, int timeout, DnsMessage message)
{
var request = WebRequest.Create($"http://{targetNameServer}/d?dn={domainName}&ttl=1");
request.Timeout = timeout;
var stream = (await request.GetResponseAsync()).GetResponseStream();
if (stream == null)
throw new Exception("Invalid HTTP response stream.");
using (var reader = new StreamReader(stream))
{
var result = await reader.ReadToEndAsync();
if (string.IsNullOrEmpty(result))
{
message.ReturnCode = ReturnCode.NxDomain;
message.IsQuery = false;
}
else
{
var parts = result.Split(',');
var ips = parts[0].Split(';');
foreach (var ip in ips)
{
message.AnswerRecords.Add(new ARecord(domainName, int.Parse(parts[1]), IPAddress.Parse(ip)));
}
message.ReturnCode = ReturnCode.NoError;
message.IsQuery = false;
}
}
}

#region Event Invokers

protected virtual void OnStarted()
{
var handler = Started;
if (handler != null) handler();
handler?.Invoke();
}

protected virtual void OnStopped()
{
var handler = Stopped;
if (handler != null) handler();
handler?.Invoke();
}

#endregion
Expand Down
6 changes: 3 additions & 3 deletions DNSAgent/DnsAgent.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
</PropertyGroup>
<ItemGroup>
<Reference Include="Microsoft.CSharp" />
<Reference Include="Newtonsoft.Json, Version=6.0.0.0, Culture=neutral, PublicKeyToken=30ad4fe6b2a6aeed, processorArchitecture=MSIL">
<SpecificVersion>False</SpecificVersion>
<HintPath>..\packages\Newtonsoft.Json.6.0.8\lib\net45\Newtonsoft.Json.dll</HintPath>
<Reference Include="Newtonsoft.Json, Version=7.0.0.0, Culture=neutral, PublicKeyToken=30ad4fe6b2a6aeed, processorArchitecture=MSIL">
<HintPath>..\packages\Newtonsoft.Json.7.0.1\lib\net45\Newtonsoft.Json.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System" />
<Reference Include="System.Configuration.Install" />
Expand Down
5 changes: 1 addition & 4 deletions DNSAgent/DnsMessageCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ public DnsCacheMessageEntry(DnsMessage message, int timeToLive)
public DnsMessage Message { get; set; }
public DateTime ExpireTime { get; set; }

public bool IsExpired
{
get { return DateTime.Now > ExpireTime; }
}
public bool IsExpired => DateTime.Now > ExpireTime;
}

internal class DnsMessageCache :
Expand Down
33 changes: 13 additions & 20 deletions DNSAgent/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,53 @@ namespace DNSAgent
{
internal class Options
{
public Options()
{
HideOnStart = false;
ListenOn = "127.0.0.1";
DefaultNameServer = "8.8.8.8";
QueryTimeout = 4000;
CompressionMutation = false;
CacheResponse = true;
CacheAge = 0;
NetworkWhitelist = null;
}

/// <summary>
/// Set to true to automatically hide the window on start.
/// </summary>
public bool HideOnStart { get; set; }
public bool HideOnStart { get; set; } = false;

/// <summary>
/// IP and port that DNSAgent will listen on. 0.0.0.0:53 for all interfaces and 127.0.0.1:53 for localhost. Of course
/// you can use other ports.
/// </summary>
public string ListenOn { get; set; }
public string ListenOn { get; set; } = "127.0.0.1";

/// <summary>
/// Querys that don't match any rules will be send to this server.
/// </summary>
public string DefaultNameServer { get; set; }
public string DefaultNameServer { get; set; } = "8.8.8.8";

/// <summary>
/// Whether to use DNSPod HttpDNS protocol to query the name server for A record.
/// </summary>
public bool UseHttpQuery { get; set; } = false;

/// <summary>
/// Timeout for a query, in milliseconds. This may be overridden by rules.cfg for a specific domain name.
/// </summary>
public int QueryTimeout { get; set; }
public int QueryTimeout { get; set; } = 4000;

/// <summary>
/// Whether to enable compression pointer mutation to query the default name servers. This may avoid MITM attack in
/// some network environments.
/// </summary>
public bool CompressionMutation { get; set; }
public bool CompressionMutation { get; set; } = false;

/// <summary>
/// Whether to enable caching of responses.
/// </summary>
public bool CacheResponse { get; set; }
public bool CacheResponse { get; set; } = true;

/// <summary>
/// How long will the cached response live. If a DNS record's TTL is longer than this value, it will be used instead of
/// this. Set to 0 to use the original TTL.
/// </summary>
public int CacheAge { get; set; }
public int CacheAge { get; set; } = 0;

/// <summary>
/// Source network whitelist. Only IPs from these network are accepted. Set to null to accept all IP (disable
/// whitelist), empty to deny all IP.
/// </summary>
public List<string> NetworkWhitelist { get; set; }
public List<string> NetworkWhitelist { get; set; } = null;
}
}
Loading

0 comments on commit 7c7614e

Please sign in to comment.