001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.iterative;
012
013import static org.minidns.constants.DnsRootServer.getIpv4RootServerById;
014import static org.minidns.constants.DnsRootServer.getIpv6RootServerById;
015import static org.minidns.constants.DnsRootServer.getRandomIpv4RootServer;
016import static org.minidns.constants.DnsRootServer.getRandomIpv6RootServer;
017
018import org.minidns.AbstractDnsClient;
019import org.minidns.DnsCache;
020import org.minidns.dnsmessage.DnsMessage;
021import org.minidns.dnsmessage.Question;
022import org.minidns.dnsname.DnsName;
023import org.minidns.dnsqueryresult.DnsQueryResult;
024import org.minidns.iterative.IterativeClientException.LoopDetected;
025import org.minidns.iterative.IterativeClientException.NotAuthoritativeNorGlueRrFound;
026import org.minidns.record.A;
027import org.minidns.record.AAAA;
028import org.minidns.record.RRWithTarget;
029import org.minidns.record.Record;
030import org.minidns.record.Record.TYPE;
031import org.minidns.record.Data;
032import org.minidns.record.InternetAddressRR;
033import org.minidns.record.NS;
034import org.minidns.util.MultipleIoException;
035
036import java.io.IOException;
037import java.net.Inet4Address;
038import java.net.Inet6Address;
039import java.net.InetAddress;
040import java.net.UnknownHostException;
041import java.util.ArrayList;
042import java.util.Collection;
043import java.util.Collections;
044import java.util.Iterator;
045import java.util.List;
046import java.util.Random;
047import java.util.logging.Level;
048
049public class IterativeDnsClient extends AbstractDnsClient {
050
051    int maxSteps = 128;
052
053    /**
054     * Create a new recursive DNS client using the global default cache.
055     */
056    public IterativeDnsClient() {
057        super();
058    }
059
060    /**
061     * Create a new recursive DNS client with the given DNS cache.
062     *
063     * @param cache The backend DNS cache.
064     */
065    public IterativeDnsClient(DnsCache cache) {
066        super(cache);
067    }
068
069    /**
070     * Recursively query the DNS system for one entry.
071     *
072     * @param queryBuilder The query DNS message builder.
073     * @return The response (or null on timeout/error).
074     * @throws IOException if an IO error occurs.
075     */
076    @Override
077    protected DnsQueryResult query(DnsMessage.Builder queryBuilder) throws IOException {
078        DnsMessage q = queryBuilder.build();
079        ResolutionState resolutionState = new ResolutionState(this);
080        DnsQueryResult result = queryRecursive(resolutionState, q);
081        return result;
082    }
083
084    private static InetAddress[] getTargets(Collection<? extends InternetAddressRR<? extends InetAddress>> primaryTargets,
085            Collection<? extends InternetAddressRR<? extends InetAddress>> secondaryTargets) {
086        InetAddress[] res = new InetAddress[2];
087
088        for (InternetAddressRR<? extends InetAddress> arr : primaryTargets) {
089            if (res[0] == null) {
090                res[0] = arr.getInetAddress();
091                // If secondaryTargets is empty, then try to get the second target out of the set of primaryTargets.
092                if (secondaryTargets.isEmpty()) {
093                    continue;
094                }
095            }
096            if (res[1] == null) {
097                res[1] = arr.getInetAddress();
098            }
099            break;
100        }
101
102        for (InternetAddressRR<? extends InetAddress> arr : secondaryTargets) {
103            if (res[0] == null) {
104                res[0] = arr.getInetAddress();
105                continue;
106            }
107            if (res[1] == null) {
108                res[1] = arr.getInetAddress();
109            }
110            break;
111        }
112
113        return res;
114    }
115
116    private DnsQueryResult queryRecursive(ResolutionState resolutionState, DnsMessage q) throws IOException {
117        InetAddress primaryTarget = null, secondaryTarget = null;
118
119        Question question = q.getQuestion();
120        DnsName parent = question.name.getParent();
121
122        switch (ipVersionSetting) {
123        case v4only:
124            for (A a : getCachedIPv4NameserverAddressesFor(parent)) {
125                if (primaryTarget == null) {
126                    primaryTarget = a.getInetAddress();
127                    continue;
128                }
129                secondaryTarget = a.getInetAddress();
130                break;
131            }
132            break;
133        case v6only:
134            for (AAAA aaaa : getCachedIPv6NameserverAddressesFor(parent)) {
135                if (primaryTarget == null) {
136                    primaryTarget = aaaa.getInetAddress();
137                    continue;
138                }
139                secondaryTarget = aaaa.getInetAddress();
140                break;
141            }
142            break;
143        case v4v6:
144            InetAddress[] v4v6targets = getTargets(getCachedIPv4NameserverAddressesFor(parent), getCachedIPv6NameserverAddressesFor(parent));
145            primaryTarget = v4v6targets[0];
146            secondaryTarget = v4v6targets[1];
147            break;
148        case v6v4:
149            InetAddress[] v6v4targets = getTargets(getCachedIPv6NameserverAddressesFor(parent), getCachedIPv4NameserverAddressesFor(parent));
150            primaryTarget = v6v4targets[0];
151            secondaryTarget = v6v4targets[1];
152            break;
153        default:
154            throw new AssertionError();
155        }
156
157        DnsName authoritativeZone = parent;
158        if (primaryTarget == null) {
159            authoritativeZone = DnsName.ROOT;
160            switch (ipVersionSetting) {
161            case v4only:
162                primaryTarget = getRandomIpv4RootServer(insecureRandom);
163                break;
164            case v6only:
165                primaryTarget = getRandomIpv6RootServer(insecureRandom);
166                break;
167            case v4v6:
168                primaryTarget = getRandomIpv4RootServer(insecureRandom);
169                secondaryTarget = getRandomIpv6RootServer(insecureRandom);
170                break;
171            case v6v4:
172                primaryTarget = getRandomIpv6RootServer(insecureRandom);
173                secondaryTarget = getRandomIpv4RootServer(insecureRandom);
174                break;
175            }
176        }
177
178        List<IOException> ioExceptions = new ArrayList<>();
179
180        try {
181            return queryRecursive(resolutionState, q, primaryTarget, authoritativeZone);
182        } catch (IOException ioException) {
183            abortIfFatal(ioException);
184            ioExceptions.add(ioException);
185        }
186
187        if (secondaryTarget != null) {
188            try {
189                return queryRecursive(resolutionState, q, secondaryTarget, authoritativeZone);
190            } catch (IOException ioException) {
191                ioExceptions.add(ioException);
192            }
193        }
194
195        MultipleIoException.throwIfRequired(ioExceptions);
196        return null;
197    }
198
199    private DnsQueryResult queryRecursive(ResolutionState resolutionState, DnsMessage q, InetAddress address, DnsName authoritativeZone) throws IOException {
200        resolutionState.recurse(address, q);
201
202        DnsQueryResult dnsQueryResult = query(q, address);
203
204        DnsMessage resMessage = dnsQueryResult.response;
205        if (resMessage.authoritativeAnswer) {
206            return dnsQueryResult;
207        }
208
209        if (cache != null) {
210            cache.offer(q, dnsQueryResult, authoritativeZone);
211        }
212
213        List<Record<? extends Data>> authorities = resMessage.copyAuthority();
214
215        List<IOException> ioExceptions = new ArrayList<>();
216
217        // Glued NS first
218        for (Iterator<Record<? extends Data>> iterator = authorities.iterator(); iterator.hasNext(); ) {
219            Record<NS> record = iterator.next().ifPossibleAs(NS.class);
220            if (record == null) {
221                iterator.remove();
222                continue;
223            }
224            DnsName name = record.payloadData.target;
225            IpResultSet gluedNs = searchAdditional(resMessage, name);
226            for (Iterator<InetAddress> addressIterator = gluedNs.addresses.iterator(); addressIterator.hasNext(); ) {
227                InetAddress target = addressIterator.next();
228                DnsQueryResult recursive = null;
229                try {
230                    recursive = queryRecursive(resolutionState, q, target, record.name);
231                } catch (IOException e) {
232                   abortIfFatal(e);
233                   LOGGER.log(Level.FINER, "Exception while recursing", e);
234                   resolutionState.decrementSteps();
235                   ioExceptions.add(e);
236                   if (!addressIterator.hasNext()) {
237                       iterator.remove();
238                   }
239                   continue;
240                }
241                return recursive;
242            }
243        }
244
245        // Try non-glued NS
246        for (Record<? extends Data> record : authorities) {
247            final Question question = q.getQuestion();
248            DnsName name = ((NS) record.payloadData).target;
249
250            // Loop prevention: If this non-glued NS equals the name we question for and if the question is about a A or
251            // AAAA RR, then we should not continue here as it would result in an endless loop.
252            if (question.name.equals(name) && (question.type == TYPE.A || question.type == TYPE.AAAA))
253                continue;
254
255            IpResultSet res = null;
256            try {
257                res = resolveIpRecursive(resolutionState, name);
258            } catch (IOException e) {
259                resolutionState.decrementSteps();
260                ioExceptions.add(e);
261            }
262            if (res == null) {
263                continue;
264            }
265
266            for (InetAddress target : res.addresses) {
267                DnsQueryResult recursive = null;
268                try {
269                    recursive = queryRecursive(resolutionState, q, target, record.name);
270                } catch (IOException e) {
271                    resolutionState.decrementSteps();
272                    ioExceptions.add(e);
273                    continue;
274                }
275                return recursive;
276            }
277        }
278
279        MultipleIoException.throwIfRequired(ioExceptions);
280
281        // Reaching this point means we did not receive an authoritative answer, nor
282        // where we able to find glue records or the IPs of the next nameservers.
283        throw new NotAuthoritativeNorGlueRrFound(q, dnsQueryResult, authoritativeZone);
284    }
285
286    private IpResultSet resolveIpRecursive(ResolutionState resolutionState, DnsName name) throws IOException {
287        IpResultSet.Builder res = newIpResultSetBuilder();
288
289        if (ipVersionSetting.v4) {
290            // TODO Try to retrieve A records for name out from cache.
291            Question question = new Question(name, TYPE.A);
292            final DnsMessage query = getQueryFor(question);
293            DnsQueryResult aDnsQueryResult = queryRecursive(resolutionState, query);
294            // TODO: queryRecurisve() should probably never return null. Verify that and then remove the follwing null check.
295            DnsMessage aMessage = aDnsQueryResult != null ? aDnsQueryResult.response : null;
296            if (aMessage != null) {
297                for (Record<? extends Data> answer : aMessage.answerSection) {
298                    if (answer.isAnswer(question)) {
299                        InetAddress inetAddress = inetAddressFromRecord(name.ace, (A) answer.payloadData);
300                        res.ipv4Addresses.add(inetAddress);
301                    } else if (answer.type == TYPE.CNAME && answer.name.equals(name)) {
302                        return resolveIpRecursive(resolutionState, ((RRWithTarget) answer.payloadData).target);
303                    }
304                }
305            }
306        }
307
308        if (ipVersionSetting.v6) {
309            // TODO Try to retrieve AAAA records for name out from cache.
310            Question question = new Question(name, TYPE.AAAA);
311            final DnsMessage query = getQueryFor(question);
312            DnsQueryResult aDnsQueryResult = queryRecursive(resolutionState, query);
313            // TODO: queryRecurisve() should probably never return null. Verify that and then remove the follwing null check.
314            DnsMessage aMessage = aDnsQueryResult != null ? aDnsQueryResult.response : null;
315            if (aMessage != null) {
316                for (Record<? extends Data> answer : aMessage.answerSection) {
317                    if (answer.isAnswer(question)) {
318                        InetAddress inetAddress = inetAddressFromRecord(name.ace, (AAAA) answer.payloadData);
319                        res.ipv6Addresses.add(inetAddress);
320                    } else if (answer.type == TYPE.CNAME && answer.name.equals(name)) {
321                        return resolveIpRecursive(resolutionState, ((RRWithTarget) answer.payloadData).target);
322                    }
323                }
324            }
325        }
326
327        return res.build();
328    }
329
330    @SuppressWarnings("incomplete-switch")
331    private IpResultSet searchAdditional(DnsMessage message, DnsName name) {
332        IpResultSet.Builder res = newIpResultSetBuilder();
333        for (Record<? extends Data> record : message.additionalSection) {
334            if (!record.name.equals(name)) {
335                continue;
336            }
337            switch (record.type) {
338            case A:
339                res.ipv4Addresses.add(inetAddressFromRecord(name.ace, (A) record.payloadData));
340                break;
341            case AAAA:
342                res.ipv6Addresses.add(inetAddressFromRecord(name.ace, (AAAA) record.payloadData));
343                break;
344            default:
345                break;
346            }
347        }
348        return res.build();
349    }
350
351    private static InetAddress inetAddressFromRecord(String name, A recordPayload) {
352        try {
353            return InetAddress.getByAddress(name, recordPayload.getIp());
354        } catch (UnknownHostException e) {
355            // This will never happen
356            throw new RuntimeException(e);
357        }
358    }
359
360    private static InetAddress inetAddressFromRecord(String name, AAAA recordPayload) {
361        try {
362            return InetAddress.getByAddress(name, recordPayload.getIp());
363        } catch (UnknownHostException e) {
364            // This will never happen
365            throw new RuntimeException(e);
366        }
367    }
368
369    public static List<InetAddress> getRootServer(char rootServerId) {
370        return getRootServer(rootServerId, DEFAULT_IP_VERSION_SETTING);
371    }
372
373    public static List<InetAddress> getRootServer(char rootServerId, IpVersionSetting setting) {
374        Inet4Address ipv4Root = getIpv4RootServerById(rootServerId);
375        Inet6Address ipv6Root = getIpv6RootServerById(rootServerId);
376        List<InetAddress> res = new ArrayList<>(2);
377        switch (setting) {
378        case v4only:
379            if (ipv4Root != null) {
380                res.add(ipv4Root);
381            }
382            break;
383        case v6only:
384            if (ipv6Root != null) {
385                res.add(ipv6Root);
386            }
387            break;
388        case v4v6:
389            if (ipv4Root != null) {
390                res.add(ipv4Root);
391            }
392            if (ipv6Root != null) {
393                res.add(ipv6Root);
394            }
395            break;
396        case v6v4:
397            if (ipv6Root != null) {
398                res.add(ipv6Root);
399            }
400            if (ipv4Root != null) {
401                res.add(ipv4Root);
402            }
403            break;
404        }
405        return res;
406    }
407
408    @Override
409    protected boolean isResponseCacheable(Question q, DnsQueryResult result) {
410        return result.response.authoritativeAnswer;
411    }
412
413    @Override
414    protected DnsMessage.Builder newQuestion(DnsMessage.Builder message) {
415        message.setRecursionDesired(false);
416        message.getEdnsBuilder().setUdpPayloadSize(dataSource.getUdpPayloadSize());
417        return message;
418    }
419
420    private IpResultSet.Builder newIpResultSetBuilder() {
421        return new IpResultSet.Builder(this.insecureRandom);
422    }
423
424    private static final class IpResultSet {
425
426        final List<InetAddress> addresses;
427
428        private IpResultSet(List<InetAddress> ipv4Addresses, List<InetAddress> ipv6Addresses, Random random) {
429            int size;
430            switch (DEFAULT_IP_VERSION_SETTING) {
431            case v4only:
432                size = ipv4Addresses.size();
433                break;
434            case v6only:
435                size = ipv6Addresses.size();
436                break;
437            case v4v6:
438            case v6v4:
439            default:
440                size = ipv4Addresses.size() + ipv6Addresses.size();
441                break;
442            }
443
444            if (size == 0) {
445                // Fast-path in case there were no addresses, which could happen e.g., if the NS records where not
446                // glued.
447                addresses = Collections.emptyList();
448            } else {
449                // Shuffle the addresses first, so that the load is better balanced.
450                if (DEFAULT_IP_VERSION_SETTING.v4) {
451                    Collections.shuffle(ipv4Addresses, random);
452                }
453                if (DEFAULT_IP_VERSION_SETTING.v6) {
454                    Collections.shuffle(ipv6Addresses, random);
455                }
456
457                List<InetAddress> addresses = new ArrayList<>(size);
458
459                // Now add the shuffled addresses to the result list.
460                switch (DEFAULT_IP_VERSION_SETTING) {
461                case v4only:
462                    addresses.addAll(ipv4Addresses);
463                    break;
464                case v6only:
465                    addresses.addAll(ipv6Addresses);
466                    break;
467                case v4v6:
468                    addresses.addAll(ipv4Addresses);
469                    addresses.addAll(ipv6Addresses);
470                    break;
471                case v6v4:
472                    addresses.addAll(ipv6Addresses);
473                    addresses.addAll(ipv4Addresses);
474                    break;
475                }
476
477                this.addresses = Collections.unmodifiableList(addresses);
478            }
479        }
480
481        private static final class Builder {
482            private final Random random;
483            private final List<InetAddress> ipv4Addresses = new ArrayList<>(8);
484            private final List<InetAddress> ipv6Addresses = new ArrayList<>(8);
485
486            private Builder(Random random) {
487                this.random = random;
488            }
489
490            public IpResultSet build() {
491                return new IpResultSet(ipv4Addresses, ipv6Addresses, random);
492            }
493        }
494    }
495
496    protected static void abortIfFatal(IOException ioException) throws IOException {
497        if (ioException instanceof LoopDetected) {
498            throw ioException;
499        }
500    }
501
502}