001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.iterative;
012
013import static org.minidns.constants.DnsRootServer.getIpv4RootServerById;
014import static org.minidns.constants.DnsRootServer.getIpv6RootServerById;
015import static org.minidns.constants.DnsRootServer.getRandomIpv4RootServer;
016import static org.minidns.constants.DnsRootServer.getRandomIpv6RootServer;
017
018import org.minidns.AbstractDnsClient;
019import org.minidns.DnsCache;
020import org.minidns.dnsmessage.DnsMessage;
021import org.minidns.dnsmessage.Question;
022import org.minidns.dnsname.DnsName;
023import org.minidns.dnsqueryresult.DnsQueryResult;
024import org.minidns.iterative.IterativeClientException.LoopDetected;
025import org.minidns.iterative.IterativeClientException.NotAuthoritativeNorGlueRrFound;
026import org.minidns.record.A;
027import org.minidns.record.AAAA;
028import org.minidns.record.RRWithTarget;
029import org.minidns.record.Record;
030import org.minidns.record.Record.TYPE;
031import org.minidns.record.Data;
032import org.minidns.record.InternetAddressRR;
033import org.minidns.record.NS;
034import org.minidns.util.MultipleIoException;
035
036import java.io.IOException;
037import java.net.Inet4Address;
038import java.net.Inet6Address;
039import java.net.InetAddress;
040import java.net.UnknownHostException;
041import java.util.ArrayList;
042import java.util.Collection;
043import java.util.Collections;
044import java.util.Iterator;
045import java.util.LinkedList;
046import java.util.List;
047import java.util.Random;
048import java.util.logging.Level;
049
050public class IterativeDnsClient extends AbstractDnsClient {
051
052    int maxSteps = 128;
053
054    /**
055     * Create a new recursive DNS client using the global default cache.
056     */
057    public IterativeDnsClient() {
058        super();
059    }
060
061    /**
062     * Create a new recursive DNS client with the given DNS cache.
063     *
064     * @param cache The backend DNS cache.
065     */
066    public IterativeDnsClient(DnsCache cache) {
067        super(cache);
068    }
069
070    /**
071     * Recursively query the DNS system for one entry.
072     *
073     * @param queryBuilder The query DNS message builder.
074     * @return The response (or null on timeout/error).
075     * @throws IOException if an IO error occurs.
076     */
077    @Override
078    protected DnsQueryResult query(DnsMessage.Builder queryBuilder) throws IOException {
079        DnsMessage q = queryBuilder.build();
080        ResolutionState resolutionState = new ResolutionState(this);
081        DnsQueryResult result = queryRecursive(resolutionState, q);
082        return result;
083    }
084
085    private static InetAddress[] getTargets(Collection<? extends InternetAddressRR<? extends InetAddress>> primaryTargets,
086            Collection<? extends InternetAddressRR<? extends InetAddress>> secondaryTargets) {
087        InetAddress[] res = new InetAddress[2];
088
089        for (InternetAddressRR<? extends InetAddress> arr : primaryTargets) {
090            if (res[0] == null) {
091                res[0] = arr.getInetAddress();
092                // If secondaryTargets is empty, then try to get the second target out of the set of primaryTargets.
093                if (secondaryTargets.isEmpty()) {
094                    continue;
095                }
096            }
097            if (res[1] == null) {
098                res[1] = arr.getInetAddress();
099            }
100            break;
101        }
102
103        for (InternetAddressRR<? extends InetAddress> arr : secondaryTargets) {
104            if (res[0] == null) {
105                res[0] = arr.getInetAddress();
106                continue;
107            }
108            if (res[1] == null) {
109                res[1] = arr.getInetAddress();
110            }
111            break;
112        }
113
114        return res;
115    }
116
117    private DnsQueryResult queryRecursive(ResolutionState resolutionState, DnsMessage q) throws IOException {
118        InetAddress primaryTarget = null, secondaryTarget = null;
119
120        Question question = q.getQuestion();
121        DnsName parent = question.name.getParent();
122
123        switch (ipVersionSetting) {
124        case v4only:
125            for (A a : getCachedIPv4NameserverAddressesFor(parent)) {
126                if (primaryTarget == null) {
127                    primaryTarget = a.getInetAddress();
128                    continue;
129                }
130                secondaryTarget = a.getInetAddress();
131                break;
132            }
133            break;
134        case v6only:
135            for (AAAA aaaa : getCachedIPv6NameserverAddressesFor(parent)) {
136                if (primaryTarget == null) {
137                    primaryTarget = aaaa.getInetAddress();
138                    continue;
139                }
140                secondaryTarget = aaaa.getInetAddress();
141                break;
142            }
143            break;
144        case v4v6:
145            InetAddress[] v4v6targets = getTargets(getCachedIPv4NameserverAddressesFor(parent), getCachedIPv6NameserverAddressesFor(parent));
146            primaryTarget = v4v6targets[0];
147            secondaryTarget = v4v6targets[1];
148            break;
149        case v6v4:
150            InetAddress[] v6v4targets = getTargets(getCachedIPv6NameserverAddressesFor(parent), getCachedIPv4NameserverAddressesFor(parent));
151            primaryTarget = v6v4targets[0];
152            secondaryTarget = v6v4targets[1];
153            break;
154        default:
155            throw new AssertionError();
156        }
157
158        DnsName authoritativeZone = parent;
159        if (primaryTarget == null) {
160            authoritativeZone = DnsName.ROOT;
161            switch (ipVersionSetting) {
162            case v4only:
163                primaryTarget = getRandomIpv4RootServer(insecureRandom);
164                break;
165            case v6only:
166                primaryTarget = getRandomIpv6RootServer(insecureRandom);
167                break;
168            case v4v6:
169                primaryTarget = getRandomIpv4RootServer(insecureRandom);
170                secondaryTarget = getRandomIpv6RootServer(insecureRandom);
171                break;
172            case v6v4:
173                primaryTarget = getRandomIpv6RootServer(insecureRandom);
174                secondaryTarget = getRandomIpv4RootServer(insecureRandom);
175                break;
176            }
177        }
178
179        List<IOException> ioExceptions = new LinkedList<>();
180
181        try {
182            return queryRecursive(resolutionState, q, primaryTarget, authoritativeZone);
183        } catch (IOException ioException) {
184            abortIfFatal(ioException);
185            ioExceptions.add(ioException);
186        }
187
188        if (secondaryTarget != null) {
189            try {
190                return queryRecursive(resolutionState, q, secondaryTarget, authoritativeZone);
191            } catch (IOException ioException) {
192                ioExceptions.add(ioException);
193            }
194        }
195
196        MultipleIoException.throwIfRequired(ioExceptions);
197        return null;
198    }
199
200    private DnsQueryResult queryRecursive(ResolutionState resolutionState, DnsMessage q, InetAddress address, DnsName authoritativeZone) throws IOException {
201        resolutionState.recurse(address, q);
202
203        DnsQueryResult dnsQueryResult = query(q, address);
204
205        DnsMessage resMessage = dnsQueryResult.response;
206        if (resMessage.authoritativeAnswer) {
207            return dnsQueryResult;
208        }
209
210        if (cache != null) {
211            cache.offer(q, dnsQueryResult, authoritativeZone);
212        }
213
214        List<Record<? extends Data>> authorities = resMessage.copyAuthority();
215
216        List<IOException> ioExceptions = new LinkedList<>();
217
218        // Glued NS first
219        for (Iterator<Record<? extends Data>> iterator = authorities.iterator(); iterator.hasNext(); ) {
220            Record<NS> record = iterator.next().ifPossibleAs(NS.class);
221            if (record == null) {
222                iterator.remove();
223                continue;
224            }
225            DnsName name = record.payloadData.target;
226            IpResultSet gluedNs = searchAdditional(resMessage, name);
227            for (Iterator<InetAddress> addressIterator = gluedNs.addresses.iterator(); addressIterator.hasNext(); ) {
228                InetAddress target = addressIterator.next();
229                DnsQueryResult recursive = null;
230                try {
231                    recursive = queryRecursive(resolutionState, q, target, record.name);
232                } catch (IOException e) {
233                   abortIfFatal(e);
234                   LOGGER.log(Level.FINER, "Exception while recursing", e);
235                   resolutionState.decrementSteps();
236                   ioExceptions.add(e);
237                   if (!addressIterator.hasNext()) {
238                       iterator.remove();
239                   }
240                   continue;
241                }
242                return recursive;
243            }
244        }
245
246        // Try non-glued NS
247        for (Record<? extends Data> record : authorities) {
248            final Question question = q.getQuestion();
249            DnsName name = ((NS) record.payloadData).target;
250
251            // Loop prevention: If this non-glued NS equals the name we question for and if the question is about a A or
252            // AAAA RR, then we should not continue here as it would result in an endless loop.
253            if (question.name.equals(name) && (question.type == TYPE.A || question.type == TYPE.AAAA))
254                continue;
255
256            IpResultSet res = null;
257            try {
258                res = resolveIpRecursive(resolutionState, name);
259            } catch (IOException e) {
260                resolutionState.decrementSteps();
261                ioExceptions.add(e);
262            }
263            if (res == null) {
264                continue;
265            }
266
267            for (InetAddress target : res.addresses) {
268                DnsQueryResult recursive = null;
269                try {
270                    recursive = queryRecursive(resolutionState, q, target, record.name);
271                } catch (IOException e) {
272                    resolutionState.decrementSteps();
273                    ioExceptions.add(e);
274                    continue;
275                }
276                return recursive;
277            }
278        }
279
280        MultipleIoException.throwIfRequired(ioExceptions);
281
282        // Reaching this point means we did not receive an authoritative answer, nor
283        // where we able to find glue records or the IPs of the next nameservers.
284        throw new NotAuthoritativeNorGlueRrFound(q, dnsQueryResult, authoritativeZone);
285    }
286
287    private IpResultSet resolveIpRecursive(ResolutionState resolutionState, DnsName name) throws IOException {
288        IpResultSet.Builder res = newIpResultSetBuilder();
289
290        if (ipVersionSetting.v4) {
291            // TODO Try to retrieve A records for name out from cache.
292            Question question = new Question(name, TYPE.A);
293            final DnsMessage query = getQueryFor(question);
294            DnsQueryResult aDnsQueryResult = queryRecursive(resolutionState, query);
295            // TODO: queryRecurisve() should probably never return null. Verify that and then remove the follwing null check.
296            DnsMessage aMessage = aDnsQueryResult != null ? aDnsQueryResult.response : null;
297            if (aMessage != null) {
298                for (Record<? extends Data> answer : aMessage.answerSection) {
299                    if (answer.isAnswer(question)) {
300                        InetAddress inetAddress = inetAddressFromRecord(name.ace, (A) answer.payloadData);
301                        res.ipv4Addresses.add(inetAddress);
302                    } else if (answer.type == TYPE.CNAME && answer.name.equals(name)) {
303                        return resolveIpRecursive(resolutionState, ((RRWithTarget) answer.payloadData).target);
304                    }
305                }
306            }
307        }
308
309        if (ipVersionSetting.v6) {
310            // TODO Try to retrieve AAAA records for name out from cache.
311            Question question = new Question(name, TYPE.AAAA);
312            final DnsMessage query = getQueryFor(question);
313            DnsQueryResult aDnsQueryResult = queryRecursive(resolutionState, query);
314            // TODO: queryRecurisve() should probably never return null. Verify that and then remove the follwing null check.
315            DnsMessage aMessage = aDnsQueryResult != null ? aDnsQueryResult.response : null;
316            if (aMessage != null) {
317                for (Record<? extends Data> answer : aMessage.answerSection) {
318                    if (answer.isAnswer(question)) {
319                        InetAddress inetAddress = inetAddressFromRecord(name.ace, (AAAA) answer.payloadData);
320                        res.ipv6Addresses.add(inetAddress);
321                    } else if (answer.type == TYPE.CNAME && answer.name.equals(name)) {
322                        return resolveIpRecursive(resolutionState, ((RRWithTarget) answer.payloadData).target);
323                    }
324                }
325            }
326        }
327
328        return res.build();
329    }
330
331    @SuppressWarnings("incomplete-switch")
332    private IpResultSet searchAdditional(DnsMessage message, DnsName name) {
333        IpResultSet.Builder res = newIpResultSetBuilder();
334        for (Record<? extends Data> record : message.additionalSection) {
335            if (!record.name.equals(name)) {
336                continue;
337            }
338            switch (record.type) {
339            case A:
340                res.ipv4Addresses.add(inetAddressFromRecord(name.ace, (A) record.payloadData));
341                break;
342            case AAAA:
343                res.ipv6Addresses.add(inetAddressFromRecord(name.ace, (AAAA) record.payloadData));
344                break;
345            default:
346                break;
347            }
348        }
349        return res.build();
350    }
351
352    private static InetAddress inetAddressFromRecord(String name, A recordPayload) {
353        try {
354            return InetAddress.getByAddress(name, recordPayload.getIp());
355        } catch (UnknownHostException e) {
356            // This will never happen
357            throw new RuntimeException(e);
358        }
359    }
360
361    private static InetAddress inetAddressFromRecord(String name, AAAA recordPayload) {
362        try {
363            return InetAddress.getByAddress(name, recordPayload.getIp());
364        } catch (UnknownHostException e) {
365            // This will never happen
366            throw new RuntimeException(e);
367        }
368    }
369
370    public static List<InetAddress> getRootServer(char rootServerId) {
371        return getRootServer(rootServerId, DEFAULT_IP_VERSION_SETTING);
372    }
373
374    public static List<InetAddress> getRootServer(char rootServerId, IpVersionSetting setting) {
375        Inet4Address ipv4Root = getIpv4RootServerById(rootServerId);
376        Inet6Address ipv6Root = getIpv6RootServerById(rootServerId);
377        List<InetAddress> res = new ArrayList<>(2);
378        switch (setting) {
379        case v4only:
380            if (ipv4Root != null) {
381                res.add(ipv4Root);
382            }
383            break;
384        case v6only:
385            if (ipv6Root != null) {
386                res.add(ipv6Root);
387            }
388            break;
389        case v4v6:
390            if (ipv4Root != null) {
391                res.add(ipv4Root);
392            }
393            if (ipv6Root != null) {
394                res.add(ipv6Root);
395            }
396            break;
397        case v6v4:
398            if (ipv6Root != null) {
399                res.add(ipv6Root);
400            }
401            if (ipv4Root != null) {
402                res.add(ipv4Root);
403            }
404            break;
405        }
406        return res;
407    }
408
409    @Override
410    protected boolean isResponseCacheable(Question q, DnsQueryResult result) {
411        return result.response.authoritativeAnswer;
412    }
413
414    @Override
415    protected DnsMessage.Builder newQuestion(DnsMessage.Builder message) {
416        message.setRecursionDesired(false);
417        message.getEdnsBuilder().setUdpPayloadSize(dataSource.getUdpPayloadSize());
418        return message;
419    }
420
421    private IpResultSet.Builder newIpResultSetBuilder() {
422        return new IpResultSet.Builder(this.insecureRandom);
423    }
424
425    private static final class IpResultSet {
426
427        final List<InetAddress> addresses;
428
429        private IpResultSet(List<InetAddress> ipv4Addresses, List<InetAddress> ipv6Addresses, Random random) {
430            int size;
431            switch (DEFAULT_IP_VERSION_SETTING) {
432            case v4only:
433                size = ipv4Addresses.size();
434                break;
435            case v6only:
436                size = ipv6Addresses.size();
437                break;
438            case v4v6:
439            case v6v4:
440            default:
441                size = ipv4Addresses.size() + ipv6Addresses.size();
442                break;
443            }
444
445            if (size == 0) {
446                // Fast-path in case there were no addresses, which could happen e.g., if the NS records where not
447                // glued.
448                addresses = Collections.emptyList();
449            } else {
450                // Shuffle the addresses first, so that the load is better balanced.
451                if (DEFAULT_IP_VERSION_SETTING.v4) {
452                    Collections.shuffle(ipv4Addresses, random);
453                }
454                if (DEFAULT_IP_VERSION_SETTING.v6) {
455                    Collections.shuffle(ipv6Addresses, random);
456                }
457
458                List<InetAddress> addresses = new ArrayList<>(size);
459
460                // Now add the shuffled addresses to the result list.
461                switch (DEFAULT_IP_VERSION_SETTING) {
462                case v4only:
463                    addresses.addAll(ipv4Addresses);
464                    break;
465                case v6only:
466                    addresses.addAll(ipv6Addresses);
467                    break;
468                case v4v6:
469                    addresses.addAll(ipv4Addresses);
470                    addresses.addAll(ipv6Addresses);
471                    break;
472                case v6v4:
473                    addresses.addAll(ipv6Addresses);
474                    addresses.addAll(ipv4Addresses);
475                    break;
476                }
477
478                this.addresses = Collections.unmodifiableList(addresses);
479            }
480        }
481
482        private static final class Builder {
483            private final Random random;
484            private final List<InetAddress> ipv4Addresses = new ArrayList<>(8);
485            private final List<InetAddress> ipv6Addresses = new ArrayList<>(8);
486
487            private Builder(Random random) {
488                this.random = random;
489            }
490
491            public IpResultSet build() {
492                return new IpResultSet(ipv4Addresses, ipv6Addresses, random);
493            }
494        }
495    }
496
497    protected static void abortIfFatal(IOException ioException) throws IOException {
498        if (ioException instanceof LoopDetected) {
499            throw ioException;
500        }
501    }
502
503}