diff --git a/ARSoft.Tools.Net/Dns/DnsClient.cs b/ARSoft.Tools.Net/Dns/DnsClient.cs index 05c94b9..f1f5145 100644 --- a/ARSoft.Tools.Net/Dns/DnsClient.cs +++ b/ARSoft.Tools.Net/Dns/DnsClient.cs @@ -52,15 +52,35 @@ static DnsClient() public DnsClient(IPAddress dnsServer, int queryTimeout) : this(new List { dnsServer }, queryTimeout) {} - /// - /// Provides a new instance with custom dns servers and query timeout + /// + /// Provides a new instance with custom dns server and query timeout /// - /// The IPAddresses of the dns servers to use + /// The IPAddress of the dns server to use /// Query timeout in milliseconds - public DnsClient(List dnsServers, int queryTimeout) - : base(dnsServers, queryTimeout, 53) {} - - protected override int MaximumQueryMessageSize + /// /// The dns server port + public DnsClient(IPAddress dnsServer, int queryTimeout, int port) + : this(new List { dnsServer }, queryTimeout, port) + { } + + /// + /// Provides a new instance with custom dns servers and query timeout + /// + /// The IPAddresses of the dns servers to use + /// Query timeout in milliseconds + public DnsClient(List dnsServers, int queryTimeout) + : this(dnsServers, queryTimeout, 53) {} + + /// + /// Provides a new instance with custom dns servers and query timeout + /// + /// The IPAddresses of the dns servers to use + /// Query timeout in milliseconds + /// The dns server port + public DnsClient(List dnsServers, int queryTimeout, int port) + : base(dnsServers, queryTimeout, port) + { } + + protected override int MaximumQueryMessageSize { get { return 512; } } diff --git a/ARSoft.Tools.Net/Dns/DnsMessageEntryBase.cs b/ARSoft.Tools.Net/Dns/DnsMessageEntryBase.cs index 2f75b4d..b7dadda 100644 --- a/ARSoft.Tools.Net/Dns/DnsMessageEntryBase.cs +++ b/ARSoft.Tools.Net/Dns/DnsMessageEntryBase.cs @@ -31,17 +31,17 @@ public abstract class DnsMessageEntryBase /// /// Domain name /// - public string Name { get; internal set; } + public string Name { get; set; } /// /// Type of the record /// - public RecordType RecordType { get; internal set; } + public RecordType RecordType { get; set; } /// /// Class of the record /// - public RecordClass RecordClass { get; internal set; } + public RecordClass RecordClass { get; set; } internal abstract int MaximumLength { get; } diff --git a/DNSAgent/DnsAgent.cs b/DNSAgent/DnsAgent.cs index f075d78..9774124 100644 --- a/DNSAgent/DnsAgent.cs +++ b/DNSAgent/DnsAgent.cs @@ -262,35 +262,56 @@ await Task.Run(async () => if (!Regex.IsMatch(question.Name, Rules[i].Pattern)) continue; // Domain name matched + + if (Rules[i].NameServer != null) // Name server override + targetNameServer = Rules[i].NameServer; + + if (Rules[i].QueryTimeout != null) // Query timeout override + queryTimeout = Rules[i].QueryTimeout.Value; + + if (Rules[i].CompressionMutation != null) // Compression pointer mutation override + useCompressionMutation = Rules[i].CompressionMutation.Value; + if (Rules[i].Address != null) { IPAddress ip; IPAddress.TryParse(Rules[i].Address, out ip); - if (ip == null) continue; // Invalid rule - - if (question.RecordType == RecordType.A && + if (ip == null) // Invalid IP, may be a domain name + { + var serverEndpoint = Utils.CreateIpEndPoint(targetNameServer, 53); + var dnsClient = new DnsClient(serverEndpoint.Address, queryTimeout, serverEndpoint.Port); + var response = await Task.Factory.FromAsync(dnsClient.BeginResolve, dnsClient.EndResolve, + Rules[i].Address, question.RecordType, question.RecordClass, null); + if (response == null) + { + Logger.Warning($"Remote resolve failed for {Rules[i].Address}."); + return; + } + message.ReturnCode = response.ReturnCode; + foreach (var answerRecord in response.AnswerRecords) + { + answerRecord.Name = question.Name; + message.AnswerRecords.Add(answerRecord); + } + message.IsQuery = false; + } + else + { + if (question.RecordType == RecordType.A && ip.AddressFamily == AddressFamily.InterNetwork) - message.AnswerRecords.Add(new ARecord(question.Name, 600, ip)); - else if (question.RecordType == RecordType.Aaaa && - ip.AddressFamily == AddressFamily.InterNetworkV6) - message.AnswerRecords.Add(new AaaaRecord(question.Name, 600, ip)); - else // Type mismatch - continue; - - message.ReturnCode = ReturnCode.NoError; - message.IsQuery = false; + message.AnswerRecords.Add(new ARecord(question.Name, 600, ip)); + else if (question.RecordType == RecordType.Aaaa && + ip.AddressFamily == AddressFamily.InterNetworkV6) + message.AnswerRecords.Add(new AaaaRecord(question.Name, 600, ip)); + else // Type mismatch + continue; + + message.ReturnCode = ReturnCode.NoError; + message.IsQuery = false; + } } - else - { - if (Rules[i].NameServer != null) // Name server override - targetNameServer = Rules[i].NameServer; - if (Rules[i].QueryTimeout != null) // Query timeout override - queryTimeout = Rules[i].QueryTimeout.Value; - - if (Rules[i].CompressionMutation != null) // Compression pointer mutation override - useCompressionMutation = Rules[i].CompressionMutation.Value; - } + break; } } diff --git a/DNSAgent/Properties/AssemblyInfo.cs b/DNSAgent/Properties/AssemblyInfo.cs index 0b072d1..6364622 100644 --- a/DNSAgent/Properties/AssemblyInfo.cs +++ b/DNSAgent/Properties/AssemblyInfo.cs @@ -35,5 +35,5 @@ // by using the '*' as shown below: // [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.3.*")] +[assembly: AssemblyVersion("1.4.*")] [assembly: AssemblyFileVersion("1.0.0.0")] \ No newline at end of file