001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns;
012
013import org.minidns.MiniDnsFuture.InternalMiniDnsFuture;
014import org.minidns.cache.LruCache;
015import org.minidns.dnsmessage.DnsMessage;
016import org.minidns.dnsmessage.Question;
017import org.minidns.dnsname.DnsName;
018import org.minidns.dnsqueryresult.DnsQueryResult;
019import org.minidns.record.A;
020import org.minidns.record.AAAA;
021import org.minidns.record.Data;
022import org.minidns.record.NS;
023import org.minidns.record.Record;
024import org.minidns.record.Record.CLASS;
025import org.minidns.record.Record.TYPE;
026import org.minidns.source.DnsDataSource;
027import org.minidns.source.NetworkDataSource;
028
029import java.io.IOException;
030import java.net.InetAddress;
031import java.security.NoSuchAlgorithmException;
032import java.security.SecureRandom;
033import java.util.Collections;
034import java.util.HashSet;
035import java.util.Random;
036import java.util.Set;
037import java.util.logging.Level;
038import java.util.logging.Logger;
039
040/**
041 * A minimal DNS client for SRV/A/AAAA/NS and CNAME lookups, with IDN support.
042 * This circumvents the missing javax.naming package on android.
043 */
044public abstract class AbstractDnsClient {
045
046    protected static final LruCache DEFAULT_CACHE = new LruCache();
047
048    protected static final Logger LOGGER = Logger.getLogger(AbstractDnsClient.class.getName());
049
050    /**
051     * This callback is used by the synchronous query() method <b>and</b> by the asynchronous queryAync() method in order to update the
052     * cache. In the asynchronous case, hand this callback into the async call, so that it can get called once the result is available.
053     */
054    private final DnsDataSource.OnResponseCallback onResponseCallback = new DnsDataSource.OnResponseCallback() {
055        @Override
056        public void onResponse(DnsMessage requestMessage, DnsQueryResult responseMessage) {
057            final Question q = requestMessage.getQuestion();
058            if (cache != null && isResponseCacheable(q, responseMessage)) {
059                cache.put(requestMessage.asNormalizedVersion(), responseMessage);
060            }
061        }
062    };
063
064    /**
065     * The internal random class for sequence generation.
066     */
067    protected final Random random;
068
069    protected final Random insecureRandom = new Random();
070
071    /**
072     * The internal DNS cache.
073     */
074    protected final DnsCache cache;
075
076    protected DnsDataSource dataSource = new NetworkDataSource();
077
078    public enum IpVersionSetting {
079
080        v4only(true, false),
081        v6only(false, true),
082        v4v6(true, true),
083        v6v4(true, true),
084        ;
085
086        public final boolean v4;
087        public final boolean v6;
088
089        IpVersionSetting(boolean v4, boolean v6) {
090            this.v4 = v4;
091            this.v6 = v6;
092        }
093
094    }
095
096    protected static IpVersionSetting DEFAULT_IP_VERSION_SETTING = IpVersionSetting.v4v6;
097
098    public static void setDefaultIpVersion(IpVersionSetting preferedIpVersion) {
099        if (preferedIpVersion == null) {
100            throw new IllegalArgumentException();
101        }
102        AbstractDnsClient.DEFAULT_IP_VERSION_SETTING = preferedIpVersion;
103    }
104
105    protected IpVersionSetting ipVersionSetting = DEFAULT_IP_VERSION_SETTING;
106
107    public void setPreferedIpVersion(IpVersionSetting preferedIpVersion) {
108        if (preferedIpVersion == null) {
109            throw new IllegalArgumentException();
110        }
111        ipVersionSetting = preferedIpVersion;
112    }
113
114    public IpVersionSetting getPreferedIpVersion() {
115        return ipVersionSetting;
116    }
117
118    /**
119     * Create a new DNS client with the given DNS cache.
120     *
121     * @param cache The backend DNS cache.
122     */
123    protected AbstractDnsClient(DnsCache cache) {
124        Random random;
125        try {
126            random = SecureRandom.getInstance("SHA1PRNG");
127        } catch (NoSuchAlgorithmException e1) {
128            random = new SecureRandom();
129        }
130        this.random = random;
131        this.cache = cache;
132    }
133
134    /**
135     * Create a new DNS client using the global default cache.
136     */
137    protected AbstractDnsClient() {
138        this(DEFAULT_CACHE);
139    }
140
141    /**
142     * Query the system nameservers for a single entry of any class.
143     *
144     * This can be used to determine the name server version, if name
145     * is version.bind, type is TYPE.TXT and clazz is CLASS.CH.
146     *
147     * @param name  The DNS name to request.
148     * @param type  The DNS type to request (SRV, A, AAAA, ...).
149     * @param clazz The class of the request (usually IN for Internet).
150     * @return The response (or null on timeout/error).
151     * @throws IOException if an IO error occurs.
152     */
153    public final DnsQueryResult query(String name, TYPE type, CLASS clazz) throws IOException {
154        Question q = new Question(name, type, clazz);
155        return query(q);
156    }
157
158    /**
159     * Query the system nameservers for a single entry of the class IN
160     * (which is used for MX, SRV, A, AAAA and most other RRs).
161     *
162     * @param name The DNS name to request.
163     * @param type The DNS type to request (SRV, A, AAAA, ...).
164     * @return The response (or null on timeout/error).
165     * @throws IOException if an IO error occurs.
166     */
167    public final DnsQueryResult query(DnsName name, TYPE type) throws IOException {
168        Question q = new Question(name, type, CLASS.IN);
169        return query(q);
170    }
171
172    /**
173     * Query the system nameservers for a single entry of the class IN
174     * (which is used for MX, SRV, A, AAAA and most other RRs).
175     *
176     * @param name The DNS name to request.
177     * @param type The DNS type to request (SRV, A, AAAA, ...).
178     * @return The response (or null on timeout/error).
179     * @throws IOException if an IO error occurs.
180     */
181    public final DnsQueryResult query(CharSequence name, TYPE type) throws IOException {
182        Question q = new Question(name, type, CLASS.IN);
183        return query(q);
184    }
185
186    public DnsQueryResult query(Question q) throws IOException {
187        DnsMessage.Builder query = buildMessage(q);
188        return query(query);
189    }
190
191    /**
192     * Send a query request to the DNS system.
193     *
194     * @param query The query to send to the server.
195     * @return The response (or null).
196     * @throws IOException if an IO error occurs.
197     */
198    protected abstract DnsQueryResult query(DnsMessage.Builder query) throws IOException;
199
200    public final MiniDnsFuture<DnsQueryResult, IOException> queryAsync(CharSequence name, TYPE type) {
201        Question q = new Question(name, type, CLASS.IN);
202        return queryAsync(q);
203    }
204
205    public final MiniDnsFuture<DnsQueryResult, IOException> queryAsync(Question q) {
206        DnsMessage.Builder query = buildMessage(q);
207        return queryAsync(query);
208    }
209
210    /**
211     * Default implementation of an asynchronous DNS query which just wraps the synchronous case.
212     * <p>
213     * Subclasses override this method to support true asynchronous queries.
214     * </p>
215     *
216     * @param query the query.
217     * @return a future for this query.
218     */
219    protected MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage.Builder query) {
220        InternalMiniDnsFuture<DnsQueryResult, IOException> future = new InternalMiniDnsFuture<>();
221        DnsQueryResult result;
222        try {
223            result = query(query);
224        } catch (IOException e) {
225            future.setException(e);
226            return future;
227        }
228        future.setResult(result);
229        return future;
230    }
231
232    public final DnsQueryResult query(Question q, InetAddress server, int port) throws IOException {
233        DnsMessage query = getQueryFor(q);
234        return query(query, server, port);
235    }
236
237    public final DnsQueryResult query(DnsMessage requestMessage, InetAddress address, int port) throws IOException {
238        // See if we have the answer to this question already cached
239        DnsQueryResult responseMessage = (cache == null) ? null : cache.get(requestMessage);
240        if (responseMessage != null) {
241            return responseMessage;
242        }
243
244        final Question q = requestMessage.getQuestion();
245
246        final Level TRACE_LOG_LEVEL = Level.FINE;
247        LOGGER.log(TRACE_LOG_LEVEL, "Asking {0} on {1} for {2} with:\n{3}", new Object[] { address, port, q, requestMessage });
248
249        try {
250            responseMessage = dataSource.query(requestMessage, address, port);
251        } catch (IOException e) {
252            LOGGER.log(TRACE_LOG_LEVEL, "IOException {0} on {1} while resolving {2}: {3}", new Object[] { address, port, q, e});
253            throw e;
254        }
255
256        LOGGER.log(TRACE_LOG_LEVEL, "Response from {0} on {1} for {2}:\n{3}", new Object[] { address, port, q, responseMessage });
257
258        onResponseCallback.onResponse(requestMessage, responseMessage);
259
260        return responseMessage;
261    }
262
263    public final MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage requestMessage, InetAddress address, int port) {
264        // See if we have the answer to this question already cached
265        DnsQueryResult responseMessage = (cache == null) ? null : cache.get(requestMessage);
266        if (responseMessage != null) {
267            return MiniDnsFuture.from(responseMessage);
268        }
269
270        final Question q = requestMessage.getQuestion();
271
272        final Level TRACE_LOG_LEVEL = Level.FINE;
273        LOGGER.log(TRACE_LOG_LEVEL, "Asynchronusly asking {0} on {1} for {2} with:\n{3}", new Object[] { address, port, q, requestMessage });
274
275        return dataSource.queryAsync(requestMessage, address, port, onResponseCallback);
276    }
277
278    /**
279     * Whether a response from the DNS system should be cached or not.
280     *
281     * @param q          The question the response message should answer.
282     * @param result The DNS query result.
283     * @return True, if the response should be cached, false otherwise.
284     */
285    protected boolean isResponseCacheable(Question q, DnsQueryResult result) {
286        DnsMessage dnsMessage = result.response;
287        for (Record<? extends Data> record : dnsMessage.answerSection) {
288            if (record.isAnswer(q)) {
289                return true;
290            }
291        }
292        return false;
293    }
294
295    /**
296     * Builds a {@link DnsMessage} object carrying the given Question.
297     *
298     * @param question {@link Question} to be put in the DNS request.
299     * @return A {@link DnsMessage} requesting the answer for the given Question.
300     */
301    final DnsMessage.Builder buildMessage(Question question) {
302        DnsMessage.Builder message = DnsMessage.builder();
303        message.setQuestion(question);
304        message.setId(random.nextInt());
305        message = newQuestion(message);
306        return message;
307    }
308
309    protected abstract DnsMessage.Builder newQuestion(DnsMessage.Builder questionMessage);
310
311    /**
312     * Query a nameserver for a single entry.
313     *
314     * @param name    The DNS name to request.
315     * @param type    The DNS type to request (SRV, A, AAAA, ...).
316     * @param clazz   The class of the request (usually IN for Internet).
317     * @param address The DNS server address.
318     * @param port    The DNS server port.
319     * @return The response (or null on timeout / failure).
320     * @throws IOException On IO Errors.
321     */
322    public DnsQueryResult query(String name, TYPE type, CLASS clazz, InetAddress address, int port)
323            throws IOException {
324        Question q = new Question(name, type, clazz);
325        return query(q, address, port);
326    }
327
328    /**
329     * Query a nameserver for a single entry.
330     *
331     * @param name    The DNS name to request.
332     * @param type    The DNS type to request (SRV, A, AAAA, ...).
333     * @param clazz   The class of the request (usually IN for Internet).
334     * @param address The DNS server host.
335     * @return The response (or null on timeout / failure).
336     * @throws IOException On IO Errors.
337     */
338    public DnsQueryResult query(String name, TYPE type, CLASS clazz, InetAddress address)
339            throws IOException {
340        Question q = new Question(name, type, clazz);
341        return query(q, address);
342    }
343
344    /**
345     * Query a nameserver for a single entry of class IN.
346     *
347     * @param name    The DNS name to request.
348     * @param type    The DNS type to request (SRV, A, AAAA, ...).
349     * @param address The DNS server host.
350     * @return The response (or null on timeout / failure).
351     * @throws IOException On IO Errors.
352     */
353    public DnsQueryResult query(String name, TYPE type, InetAddress address)
354            throws IOException {
355        Question q = new Question(name, type, CLASS.IN);
356        return query(q, address);
357    }
358
359    public final DnsQueryResult query(DnsMessage query, InetAddress host) throws IOException {
360        return query(query, host, 53);
361    }
362
363    /**
364     * Query a specific server for one entry.
365     *
366     * @param q       The question section of the DNS query.
367     * @param address The dns server address.
368     * @return The a DNS query result.
369     * @throws IOException On IOErrors.
370     */
371    public DnsQueryResult query(Question q, InetAddress address) throws IOException {
372        return query(q, address, 53);
373    }
374
375    public final MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage query, InetAddress dnsServer) {
376        return queryAsync(query, dnsServer, 53);
377    }
378
379    /**
380     * Returns the currently used {@link DnsDataSource}. See {@link #setDataSource(DnsDataSource)} for details.
381     *
382     * @return The currently used {@link DnsDataSource}
383     */
384    public DnsDataSource getDataSource() {
385        return dataSource;
386    }
387
388    /**
389     * Set a {@link DnsDataSource} to be used by the DnsClient.
390     * The default implementation will direct all queries directly to the Internet.
391     *
392     * This can be used to define a non-default handling for outgoing data. This can be useful to redirect the requests
393     * to a proxy or to modify requests after or responses before they are handled by the DnsClient implementation.
394     *
395     * @param dataSource An implementation of DNSDataSource that shall be used.
396     */
397    public void setDataSource(DnsDataSource dataSource) {
398        if (dataSource == null) {
399            throw new IllegalArgumentException();
400        }
401        this.dataSource = dataSource;
402    }
403
404    /**
405     * Get the cache used by this DNS client.
406     *
407     * @return the cached used by this DNS client or <code>null</code>.
408     */
409    public DnsCache getCache() {
410        return cache;
411    }
412
413    protected DnsMessage getQueryFor(Question q) {
414        DnsMessage.Builder messageBuilder = buildMessage(q);
415        DnsMessage query = messageBuilder.build();
416        return query;
417    }
418
419    private <D extends Data> Set<D> getCachedRecordsFor(DnsName dnsName, TYPE type) {
420        if (cache == null)
421            return Collections.emptySet();
422
423        Question dnsNameNs = new Question(dnsName, type);
424        DnsMessage queryDnsNameNs = getQueryFor(dnsNameNs);
425        DnsQueryResult cachedResult = cache.get(queryDnsNameNs);
426
427        if (cachedResult == null)
428            return Collections.emptySet();
429
430        return cachedResult.response.getAnswersFor(dnsNameNs);
431    }
432
433    public Set<NS> getCachedNameserverRecordsFor(DnsName dnsName) {
434        return getCachedRecordsFor(dnsName, TYPE.NS);
435    }
436
437    public Set<A> getCachedIPv4AddressesFor(DnsName dnsName) {
438        return getCachedRecordsFor(dnsName, TYPE.A);
439    }
440
441    public Set<AAAA> getCachedIPv6AddressesFor(DnsName dnsName) {
442        return getCachedRecordsFor(dnsName, TYPE.AAAA);
443    }
444
445    @SuppressWarnings("unchecked")
446    private <D extends Data> Set<D> getCachedIPNameserverAddressesFor(DnsName dnsName, TYPE type) {
447        Set<NS> nsSet = getCachedNameserverRecordsFor(dnsName);
448        if (nsSet.isEmpty())
449            return Collections.emptySet();
450
451        Set<D> res = new HashSet<>(3 * nsSet.size());
452        for (NS ns : nsSet) {
453            Set<D> addresses;
454            switch (type) {
455            case A:
456                addresses = (Set<D>) getCachedIPv4AddressesFor(ns.target);
457                break;
458            case AAAA:
459                addresses = (Set<D>) getCachedIPv6AddressesFor(ns.target);
460                break;
461            default:
462                throw new AssertionError();
463            }
464            res.addAll(addresses);
465        }
466
467        return res;
468    }
469
470    public Set<A> getCachedIPv4NameserverAddressesFor(DnsName dnsName) {
471        return getCachedIPNameserverAddressesFor(dnsName, TYPE.A);
472    }
473
474    public Set<AAAA> getCachedIPv6NameserverAddressesFor(DnsName dnsName) {
475        return getCachedIPNameserverAddressesFor(dnsName, TYPE.AAAA);
476    }
477}