001/*
002 * Copyright 2015-2018 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns;
012
013import org.minidns.MiniDnsFuture.InternalMiniDnsFuture;
014import org.minidns.cache.LruCache;
015import org.minidns.dnsmessage.DnsMessage;
016import org.minidns.dnsmessage.Question;
017import org.minidns.dnsname.DnsName;
018import org.minidns.record.A;
019import org.minidns.record.AAAA;
020import org.minidns.record.Data;
021import org.minidns.record.NS;
022import org.minidns.record.Record;
023import org.minidns.record.Record.CLASS;
024import org.minidns.record.Record.TYPE;
025import org.minidns.source.DnsDataSource;
026import org.minidns.source.NetworkDataSource;
027
028import java.io.IOException;
029import java.net.InetAddress;
030import java.security.NoSuchAlgorithmException;
031import java.security.SecureRandom;
032import java.util.Collections;
033import java.util.HashSet;
034import java.util.Random;
035import java.util.Set;
036import java.util.logging.Level;
037import java.util.logging.Logger;
038
039/**
040 * A minimal DNS client for SRV/A/AAAA/NS and CNAME lookups, with IDN support.
041 * This circumvents the missing javax.naming package on android.
042 */
043public abstract class AbstractDnsClient {
044
045    protected static final LruCache DEFAULT_CACHE = new LruCache();
046
047    protected static final Logger LOGGER = Logger.getLogger(AbstractDnsClient.class.getName());
048
049    /**
050     * This callback is used by the synchronous query() method <b>and</b> by the asynchronous queryAync() method in order to update the
051     * cache. In the asynchronous case, hand this callback into the async call, so that it can get called once the result is available.
052     */
053    private final DnsDataSource.OnResponseCallback onResponseCallback = new DnsDataSource.OnResponseCallback() {
054        @Override
055        public void onResponse(DnsMessage requestMessage, DnsMessage responseMessage) {
056            final Question q = requestMessage.getQuestion();
057            if (cache != null && isResponseCacheable(q, responseMessage)) {
058                cache.put(requestMessage.asNormalizedVersion(), responseMessage);
059            }
060        }
061    };
062
063    /**
064     * The internal random class for sequence generation.
065     */
066    protected final Random random;
067
068    protected final Random insecureRandom = new Random();
069
070    /**
071     * The internal DNS cache.
072     */
073    protected final DnsCache cache;
074
075    protected DnsDataSource dataSource = new NetworkDataSource();
076
077    public enum IpVersionSetting {
078
079        v4only(true, false),
080        v6only(false, true),
081        v4v6(true, true),
082        v6v4(true, true),
083        ;
084
085        public final boolean v4;
086        public final boolean v6;
087
088        IpVersionSetting(boolean v4, boolean v6) {
089            this.v4 = v4;
090            this.v6 = v6;
091        }
092
093    }
094
095    protected static IpVersionSetting DEFAULT_IP_VERSION_SETTING = IpVersionSetting.v4v6;
096
097    public static void setDefaultIpVersion(IpVersionSetting preferedIpVersion) {
098        if (preferedIpVersion == null) {
099            throw new IllegalArgumentException();
100        }
101        AbstractDnsClient.DEFAULT_IP_VERSION_SETTING = preferedIpVersion;
102    }
103
104    protected IpVersionSetting ipVersionSetting = DEFAULT_IP_VERSION_SETTING;
105
106    public void setPreferedIpVersion(IpVersionSetting preferedIpVersion) {
107        if (preferedIpVersion == null) {
108            throw new IllegalArgumentException();
109        }
110        ipVersionSetting = preferedIpVersion;
111    }
112
113    public IpVersionSetting getPreferedIpVersion() {
114        return ipVersionSetting;
115    }
116
117    /**
118     * Create a new DNS client with the given DNS cache.
119     *
120     * @param cache The backend DNS cache.
121     */
122    protected AbstractDnsClient(DnsCache cache) {
123        Random random;
124        try {
125            random = SecureRandom.getInstance("SHA1PRNG");
126        } catch (NoSuchAlgorithmException e1) {
127            random = new SecureRandom();
128        }
129        this.random = random;
130        this.cache = cache;
131    }
132
133    /**
134     * Create a new DNS client using the global default cache.
135     */
136    protected AbstractDnsClient() {
137        this(DEFAULT_CACHE);
138    }
139
140    /**
141     * Query the system nameservers for a single entry of any class.
142     *
143     * This can be used to determine the name server version, if name
144     * is version.bind, type is TYPE.TXT and clazz is CLASS.CH.
145     *
146     * @param name  The DNS name to request.
147     * @param type  The DNS type to request (SRV, A, AAAA, ...).
148     * @param clazz The class of the request (usually IN for Internet).
149     * @return The response (or null on timeout/error).
150     * @throws IOException if an IO error occurs.
151     */
152    public final DnsMessage query(String name, TYPE type, CLASS clazz) throws IOException {
153        Question q = new Question(name, type, clazz);
154        return query(q);
155    }
156
157    /**
158     * Query the system nameservers for a single entry of the class IN
159     * (which is used for MX, SRV, A, AAAA and most other RRs).
160     *
161     * @param name The DNS name to request.
162     * @param type The DNS type to request (SRV, A, AAAA, ...).
163     * @return The response (or null on timeout/error).
164     * @throws IOException if an IO error occurs.
165     */
166    public final DnsMessage query(DnsName name, TYPE type) throws IOException {
167        Question q = new Question(name, type, CLASS.IN);
168        return query(q);
169    }
170
171    /**
172     * Query the system nameservers for a single entry of the class IN
173     * (which is used for MX, SRV, A, AAAA and most other RRs).
174     *
175     * @param name The DNS name to request.
176     * @param type The DNS type to request (SRV, A, AAAA, ...).
177     * @return The response (or null on timeout/error).
178     * @throws IOException if an IO error occurs.
179     */
180    public final DnsMessage query(CharSequence name, TYPE type) throws IOException {
181        Question q = new Question(name, type, CLASS.IN);
182        return query(q);
183    }
184
185    public DnsMessage query(Question q) throws IOException {
186        DnsMessage.Builder query = buildMessage(q);
187        return query(query);
188    }
189
190    /**
191     * Send a query request to the DNS system.
192     *
193     * @param query The query to send to the server.
194     * @return The response (or null).
195     * @throws IOException if an IO error occurs.
196     */
197    protected abstract DnsMessage query(DnsMessage.Builder query) throws IOException;
198
199    public final MiniDnsFuture<DnsMessage, IOException> queryAsync(CharSequence name, TYPE type) {
200        Question q = new Question(name, type, CLASS.IN);
201        return queryAsync(q);
202    }
203
204    public final MiniDnsFuture<DnsMessage, IOException> queryAsync(Question q) {
205        DnsMessage.Builder query = buildMessage(q);
206        return queryAsync(query);
207    }
208
209    /**
210     * Default implementation of an asynchronous DNS query which just wraps the synchronous case.
211     * <p>
212     * Subclasses override this method to support true asynchronous queries.
213     * </p>
214     *
215     * @param query the query.
216     * @return a future for this query.
217     */
218    protected MiniDnsFuture<DnsMessage, IOException> queryAsync(DnsMessage.Builder query) {
219        InternalMiniDnsFuture<DnsMessage, IOException> future = new InternalMiniDnsFuture<>();
220        DnsMessage result;
221        try {
222            result = query(query);
223        } catch (IOException e) {
224            future.setException(e);
225            return future;
226        }
227        future.setResult(result);
228        return future;
229    }
230
231    public final DnsMessage query(Question q, InetAddress server, int port) throws IOException {
232        DnsMessage query = getQueryFor(q);
233        return query(query, server, port);
234    }
235
236    public final DnsMessage query(DnsMessage requestMessage, InetAddress address, int port) throws IOException {
237        // See if we have the answer to this question already cached
238        DnsMessage responseMessage = (cache == null) ? null : cache.get(requestMessage);
239        if (responseMessage != null) {
240            return responseMessage;
241        }
242
243        final Question q = requestMessage.getQuestion();
244
245        final Level TRACE_LOG_LEVEL = Level.FINE;
246        LOGGER.log(TRACE_LOG_LEVEL, "Asking {0} on {1} for {2} with:\n{3}", new Object[] { address, port, q, requestMessage });
247
248        try {
249            responseMessage = dataSource.query(requestMessage, address, port);
250        } catch (IOException e) {
251            LOGGER.log(TRACE_LOG_LEVEL, "IOException {0} on {1} while resolving {2}: {3}", new Object[] { address, port, q, e});
252            throw e;
253        }
254        if (responseMessage != null) {
255            LOGGER.log(TRACE_LOG_LEVEL, "Response from {0} on {1} for {2}:\n{3}", new Object[] { address, port, q, responseMessage });
256        } else {
257            // TODO When should this ever happen?
258            LOGGER.log(Level.SEVERE, "NULL response from " + address + " on " + port + " for " + q);
259        }
260
261        if (responseMessage == null) return null;
262
263        onResponseCallback.onResponse(requestMessage, responseMessage);
264
265        return responseMessage;
266    }
267
268    public final MiniDnsFuture<DnsMessage, IOException> queryAsync(DnsMessage requestMessage, InetAddress address, int port) {
269        // See if we have the answer to this question already cached
270        DnsMessage responseMessage = (cache == null) ? null : cache.get(requestMessage);
271        if (responseMessage != null) {
272            return MiniDnsFuture.from(responseMessage);
273        }
274
275        final Question q = requestMessage.getQuestion();
276
277        final Level TRACE_LOG_LEVEL = Level.FINE;
278        LOGGER.log(TRACE_LOG_LEVEL, "Asynchronusly asking {0} on {1} for {2} with:\n{3}", new Object[] { address, port, q, requestMessage });
279
280        return dataSource.queryAsync(requestMessage, address, port, onResponseCallback);
281    }
282
283    /**
284     * Whether a response from the DNS system should be cached or not.
285     *
286     * @param q          The question the response message should answer.
287     * @param dnsMessage The response message received using the DNS client.
288     * @return True, if the response should be cached, false otherwise.
289     */
290    protected boolean isResponseCacheable(Question q, DnsMessage dnsMessage) {
291        for (Record<? extends Data> record : dnsMessage.answerSection) {
292            if (record.isAnswer(q)) {
293                return true;
294            }
295        }
296        return false;
297    }
298
299    /**
300     * Builds a {@link DnsMessage} object carrying the given Question.
301     *
302     * @param question {@link Question} to be put in the DNS request.
303     * @return A {@link DnsMessage} requesting the answer for the given Question.
304     */
305    final DnsMessage.Builder buildMessage(Question question) {
306        DnsMessage.Builder message = DnsMessage.builder();
307        message.setQuestion(question);
308        message.setId(random.nextInt());
309        message = newQuestion(message);
310        return message;
311    }
312
313    protected abstract DnsMessage.Builder newQuestion(DnsMessage.Builder questionMessage);
314
315    /**
316     * Query a nameserver for a single entry.
317     *
318     * @param name    The DNS name to request.
319     * @param type    The DNS type to request (SRV, A, AAAA, ...).
320     * @param clazz   The class of the request (usually IN for Internet).
321     * @param address The DNS server address.
322     * @param port    The DNS server port.
323     * @return The response (or null on timeout / failure).
324     * @throws IOException On IO Errors.
325     */
326    public DnsMessage query(String name, TYPE type, CLASS clazz, InetAddress address, int port)
327            throws IOException {
328        Question q = new Question(name, type, clazz);
329        return query(q, address, port);
330    }
331
332    /**
333     * Query a nameserver for a single entry.
334     *
335     * @param name    The DNS name to request.
336     * @param type    The DNS type to request (SRV, A, AAAA, ...).
337     * @param clazz   The class of the request (usually IN for Internet).
338     * @param address The DNS server host.
339     * @return The response (or null on timeout / failure).
340     * @throws IOException On IO Errors.
341     */
342    public DnsMessage query(String name, TYPE type, CLASS clazz, InetAddress address)
343            throws IOException {
344        Question q = new Question(name, type, clazz);
345        return query(q, address);
346    }
347
348    /**
349     * Query a nameserver for a single entry of class IN.
350     *
351     * @param name    The DNS name to request.
352     * @param type    The DNS type to request (SRV, A, AAAA, ...).
353     * @param address The DNS server host.
354     * @return The response (or null on timeout / failure).
355     * @throws IOException On IO Errors.
356     */
357    public DnsMessage query(String name, TYPE type, InetAddress address)
358            throws IOException {
359        Question q = new Question(name, type, CLASS.IN);
360        return query(q, address);
361    }
362
363    public final DnsMessage query(DnsMessage query, InetAddress host) throws IOException {
364        return query(query, host, 53);
365    }
366
367    /**
368     * Query a specific server for one entry.
369     *
370     * @param q       The question section of the DNS query.
371     * @param address The dns server address.
372     * @return The response (or null on timeout/error).
373     * @throws IOException On IOErrors.
374     */
375    public DnsMessage query(Question q, InetAddress address) throws IOException {
376        return query(q, address, 53);
377    }
378
379    public final MiniDnsFuture<DnsMessage, IOException> queryAsync(DnsMessage query, InetAddress dnsServer) {
380        return queryAsync(query, dnsServer, 53);
381    }
382
383    /**
384     * Returns the currently used {@link DnsDataSource}. See {@link #setDataSource(DnsDataSource)} for details.
385     *
386     * @return The currently used {@link DnsDataSource}
387     */
388    public DnsDataSource getDataSource() {
389        return dataSource;
390    }
391
392    /**
393     * Set a {@link DnsDataSource} to be used by the DnsClient.
394     * The default implementation will direct all queries directly to the Internet.
395     *
396     * This can be used to define a non-default handling for outgoing data. This can be useful to redirect the requests
397     * to a proxy or to modify requests after or responses before they are handled by the DnsClient implementation.
398     *
399     * @param dataSource An implementation of DNSDataSource that shall be used.
400     */
401    public void setDataSource(DnsDataSource dataSource) {
402        if (dataSource == null) {
403            throw new IllegalArgumentException();
404        }
405        this.dataSource = dataSource;
406    }
407
408    /**
409     * Get the cache used by this DNS client.
410     *
411     * @return the cached used by this DNS client or <code>null</code>.
412     */
413    public DnsCache getCache() {
414        return cache;
415    }
416
417    protected DnsMessage getQueryFor(Question q) {
418        DnsMessage.Builder messageBuilder = buildMessage(q);
419        DnsMessage query = messageBuilder.build();
420        return query;
421    }
422
423    private <D extends Data> Set<D> getCachedRecordsFor(DnsName dnsName, TYPE type) {
424        Question dnsNameNs = new Question(dnsName, type);
425        DnsMessage queryDnsNameNs = getQueryFor(dnsNameNs);
426        DnsMessage cachedResult = cache.get(queryDnsNameNs);
427
428        if (cachedResult == null)
429            return Collections.emptySet();
430
431        return cachedResult.getAnswersFor(dnsNameNs);
432    }
433
434    public Set<NS> getCachedNameserverRecordsFor(DnsName dnsName) {
435        return getCachedRecordsFor(dnsName, TYPE.NS);
436    }
437
438    public Set<A> getCachedIPv4AddressesFor(DnsName dnsName) {
439        return getCachedRecordsFor(dnsName, TYPE.A);
440    }
441
442    public Set<AAAA> getCachedIPv6AddressesFor(DnsName dnsName) {
443        return getCachedRecordsFor(dnsName, TYPE.AAAA);
444    }
445
446    @SuppressWarnings("unchecked")
447    private <D extends Data> Set<D> getCachedIPNameserverAddressesFor(DnsName dnsName, TYPE type) {
448        Set<NS> nsSet = getCachedNameserverRecordsFor(dnsName);
449        if (nsSet.isEmpty())
450            return Collections.emptySet();
451
452        Set<D> res = new HashSet<>(3 * nsSet.size());
453        for (NS ns : nsSet) {
454            Set<D> addresses;
455            switch (type) {
456            case A:
457                addresses = (Set<D>) getCachedIPv4AddressesFor(ns.target);
458                break;
459            case AAAA:
460                addresses = (Set<D>) getCachedIPv6AddressesFor(ns.target);
461                break;
462            default:
463                throw new AssertionError();
464            }
465            res.addAll(addresses);
466        }
467
468        return res;
469    }
470
471    public Set<A> getCachedIPv4NameserverAddressesFor(DnsName dnsName) {
472        return getCachedIPNameserverAddressesFor(dnsName, TYPE.A);
473    }
474
475    public Set<AAAA> getCachedIPv6NameserverAddressesFor(DnsName dnsName) {
476        return getCachedIPNameserverAddressesFor(dnsName, TYPE.AAAA);
477    }
478}