001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.net.InetSocketAddress;
016import java.nio.ByteBuffer;
017import java.nio.channels.Channel;
018import java.nio.channels.ClosedChannelException;
019import java.nio.channels.DatagramChannel;
020import java.nio.channels.SelectableChannel;
021import java.nio.channels.SelectionKey;
022import java.nio.channels.SocketChannel;
023import java.util.ArrayList;
024import java.util.List;
025import java.util.concurrent.Future;
026import java.util.logging.Level;
027import java.util.logging.Logger;
028
029import org.minidns.MiniDnsException;
030import org.minidns.MiniDnsFuture;
031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture;
032import org.minidns.dnsmessage.DnsMessage;
033import org.minidns.dnsqueryresult.DnsQueryResult;
034import org.minidns.dnsqueryresult.DnsQueryResult.QueryMethod;
035import org.minidns.dnsqueryresult.StandardDnsQueryResult;
036import org.minidns.source.DnsDataSource.OnResponseCallback;
037import org.minidns.source.AbstractDnsDataSource.QueryMode;
038import org.minidns.util.MultipleIoException;
039
040/**
041 * A DNS request that is performed asynchronously.
042 */
043public class AsyncDnsRequest {
044
045    private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName());
046
047    private final InternalMiniDnsFuture<DnsQueryResult, IOException> future = new InternalMiniDnsFuture<DnsQueryResult, IOException>() {
048        @SuppressWarnings("UnsynchronizedOverridesSynchronized")
049        @Override
050        public boolean cancel(boolean mayInterruptIfRunning) {
051            boolean res = super.cancel(mayInterruptIfRunning);
052            cancelAsyncDnsRequest();
053            return res;
054        }
055    };
056
057    private final DnsMessage request;
058
059    private final int udpPayloadSize;
060
061    private final InetSocketAddress socketAddress;
062
063    private final AsyncNetworkDataSource asyncNds;
064
065    private final OnResponseCallback onResponseCallback;
066
067    private final boolean skipUdp;
068
069    private ByteBuffer writeBuffer;
070
071    private List<IOException> exceptions;
072
073    private SelectionKey selectionKey;
074
075    final long deadline;
076
077    /**
078     * Creates a new AsyncDnsRequest instance.
079     *
080     * @param request the DNS message of the request.
081     * @param inetAddress The IP address of the DNS server to ask.
082     * @param port The port of the DNS server to ask.
083     * @param udpPayloadSize The configured UDP payload size.
084     * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests.
085     * @param onResponseCallback the optional callback when a response was received.
086     */
087    AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) {
088        this.request = request;
089        this.udpPayloadSize = udpPayloadSize;
090        this.asyncNds = asyncNds;
091        this.onResponseCallback = onResponseCallback;
092
093        final QueryMode queryMode = asyncNds.getQueryMode();
094        switch (queryMode) {
095        case dontCare:
096        case udpTcp:
097            skipUdp = false;
098            break;
099        case tcp:
100            skipUdp = true;
101            break;
102        default:
103            throw new IllegalStateException("Unsupported query mode: " + queryMode);
104
105        }
106        deadline = System.currentTimeMillis() + asyncNds.getTimeout();
107        socketAddress = new InetSocketAddress(inetAddress, port);
108    }
109
110    private void ensureWriteBufferIsInitialized() {
111        if (writeBuffer != null) {
112            if (!writeBuffer.hasRemaining()) {
113                writeBuffer.rewind();
114            }
115            return;
116        }
117        writeBuffer = request.getInByteBuffer();
118    }
119
120    private synchronized void cancelAsyncDnsRequest() {
121        if (selectionKey != null) {
122            selectionKey.cancel();
123        }
124        asyncNds.cancelled(this);
125    }
126
127    private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler)
128            throws ClosedChannelException {
129        if (future.isCancelled()) {
130            return;
131        }
132        selectionKey = asyncNds.registerWithSelector(channel, ops, handler);
133    }
134
135    private void addException(IOException e) {
136        if (exceptions == null) {
137            exceptions = new ArrayList<>(4);
138        }
139        exceptions.add(e);
140    }
141
142    private void gotResult(DnsQueryResult result) {
143        if (onResponseCallback != null) {
144            onResponseCallback.onResponse(request, result);
145        }
146        asyncNds.finished(this);
147        future.setResult(result);
148    }
149
150    MiniDnsFuture<DnsQueryResult, IOException> getFuture() {
151        return future;
152    }
153
154    boolean wasDeadlineMissedAndFutureNotified() {
155        if (System.currentTimeMillis() < deadline) {
156            return false;
157        }
158
159        future.setException(new IOException("Timeout"));
160        return true;
161    }
162
163    void startHandling() {
164        if (!skipUdp) {
165            startUdpRequest();
166        } else {
167            startTcpRequest();
168        }
169    }
170
171    private void abortRequestAndCleanup(Channel channel, String errorMessage, IOException exception) {
172        if (exception == null) {
173            // TODO: Can this case be removed? Is 'exception' ever null?
174            LOGGER.info("Exception was null in abortRequestAndCleanup()");
175            exception = new IOException(errorMessage);
176        }
177        LOGGER.log(Level.SEVERE, "Error connecting " + channel + ": " + errorMessage, exception);
178        addException(exception);
179
180        if (selectionKey != null) {
181            selectionKey.cancel();
182        }
183
184        if (channel != null && channel.isOpen()) {
185            try {
186                channel.close();
187            } catch (IOException e) {
188                LOGGER.log(Level.SEVERE, "Exception closing socket channel", e);
189                addException(e);
190            }
191        }
192    }
193
194    private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) {
195        abortRequestAndCleanup(datagramChannel, errorMessage, exception);
196        startTcpRequest();
197    }
198
199    private void startUdpRequest() {
200        if (future.isCancelled()) {
201            return;
202        }
203
204        DatagramChannel datagramChannel;
205        try {
206            datagramChannel = DatagramChannel.open();
207        } catch (IOException e) {
208            LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e);
209            addException(e);
210            startTcpRequest();
211            return;
212        }
213
214        try {
215            datagramChannel.configureBlocking(false);
216        } catch (IOException e) {
217            abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e);
218            return;
219        }
220
221        try {
222            datagramChannel.connect(socketAddress);
223        } catch (IOException e) {
224            abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel to " + socketAddress, e);
225            return;
226        }
227
228        try {
229            registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future));
230        } catch (ClosedChannelException e) {
231            abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
232            return;
233        }
234    }
235
236    class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler {
237
238        UdpWritableChannelSelectedHandler(Future<?> future) {
239            super(future);
240        }
241
242        @Override
243        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
244            DatagramChannel datagramChannel = (DatagramChannel) channel;
245
246            ensureWriteBufferIsInitialized();
247
248            try {
249                datagramChannel.write(writeBuffer);
250            } catch (IOException e) {
251                abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e);
252                return;
253            }
254
255            if (writeBuffer.hasRemaining()) {
256                try {
257                    registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this);
258                } catch (ClosedChannelException e) {
259                    abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
260                }
261                return;
262            }
263
264            try {
265                registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future));
266            } catch (ClosedChannelException e) {
267                abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e);
268                return;
269            }
270        }
271
272    }
273
274    class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler {
275
276        UdpReadableChannelSelectedHandler(Future<?> future) {
277            super(future);
278        }
279
280        final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize);
281
282        @Override
283        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
284            DatagramChannel datagramChannel = (DatagramChannel) channel;
285
286            try {
287                datagramChannel.read(byteBuffer);
288            } catch (IOException e) {
289                abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e);
290                return;
291            }
292
293            selectionKey.cancel();
294            try {
295                datagramChannel.close();
296            } catch (IOException e) {
297                LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e);
298                addException(e);
299            }
300
301            DnsMessage response;
302            try {
303                response = new DnsMessage(byteBuffer.array());
304            } catch (IOException e) {
305                abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e);
306                return;
307            }
308
309            if (response.id != request.id) {
310                addException(new MiniDnsException.IdMismatch(request, response));
311                startTcpRequest();
312                return;
313            }
314
315            if (response.truncated) {
316                startTcpRequest();
317                return;
318            }
319
320            DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(),
321                    QueryMethod.asyncUdp, request, response);
322            gotResult(result);
323        }
324    }
325
326    private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) {
327        abortRequestAndCleanup(socketChannel, errorMessage, exception);
328        future.setException(MultipleIoException.toIOException(exceptions));
329    }
330
331    private void startTcpRequest() {
332        SocketChannel socketChannel = null;
333        try {
334            socketChannel = SocketChannel.open();
335        } catch (IOException e) {
336            abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e);
337            return;
338        }
339
340        try {
341            socketChannel.configureBlocking(false);
342        } catch (IOException e) {
343            abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e);
344            return;
345        }
346
347        try {
348            registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future));
349        } catch (ClosedChannelException e) {
350            abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e);
351            return;
352        }
353
354        try {
355            socketChannel.connect(socketAddress);
356        } catch (IOException e) {
357            abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel to " + socketAddress, e);
358            return;
359        }
360    }
361
362    class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler {
363
364        TcpConnectedChannelSelectedHandler(Future<?> future) {
365            super(future);
366        }
367
368        @Override
369        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
370            SocketChannel socketChannel = (SocketChannel) channel;
371
372            boolean connected;
373            try {
374                connected = socketChannel.finishConnect();
375            } catch (IOException e) {
376                abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e);
377                return;
378            }
379
380            assert connected;
381
382            try {
383                registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future));
384            } catch (ClosedChannelException e) {
385                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
386                return;
387            }
388        }
389
390    }
391
392    class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler {
393
394        TcpWritableChannelSelectedHandler(Future<?> future) {
395            super(future);
396        }
397
398        /**
399         * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message.
400         */
401        private ByteBuffer[] writeBuffers;
402
403        @Override
404        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
405            SocketChannel socketChannel = (SocketChannel) channel;
406
407            if (writeBuffers == null) {
408                ensureWriteBufferIsInitialized();
409
410                ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
411                int messageLength = writeBuffer.capacity();
412                assert messageLength <= Short.MAX_VALUE;
413                messageLengthByteBuffer.putShort((short) (messageLength & 0xffff));
414                messageLengthByteBuffer.rewind();
415
416                writeBuffers = new ByteBuffer[2];
417                writeBuffers[0] = messageLengthByteBuffer;
418                writeBuffers[1] = writeBuffer;
419            }
420
421            try {
422                socketChannel.write(writeBuffers);
423            } catch (IOException e) {
424                abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e);
425                return;
426            }
427
428            if (moreToWrite()) {
429                try {
430                    registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this);
431                } catch (ClosedChannelException e) {
432                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
433                }
434                return;
435            }
436
437            try {
438                registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future));
439            } catch (ClosedChannelException e) {
440                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
441                return;
442            }
443        }
444
445        private boolean moreToWrite() {
446            for (int i = 0; i < writeBuffers.length; i++) {
447                if (writeBuffers[i].hasRemaining()) {
448                    return true;
449                }
450            }
451            return false;
452        }
453    }
454
455    class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler {
456
457        TcpReadableChannelSelectedHandler(Future<?> future) {
458            super(future);
459        }
460
461        final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
462
463        ByteBuffer byteBuffer;
464
465        @Override
466        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
467            SocketChannel socketChannel = (SocketChannel) channel;
468
469            int bytesRead;
470            if (byteBuffer == null) {
471                try {
472                    bytesRead = socketChannel.read(messageLengthByteBuffer);
473                } catch (IOException e) {
474                    abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e);
475                    return;
476                }
477
478                if (bytesRead < 0) {
479                    abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
480                    return;
481                }
482
483                if (messageLengthByteBuffer.hasRemaining()) {
484                    try {
485                        registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
486                    } catch (ClosedChannelException e) {
487                        abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
488                    }
489                    return;
490                }
491
492                messageLengthByteBuffer.rewind();
493                short messageLengthSignedShort = messageLengthByteBuffer.getShort();
494                int messageLength = messageLengthSignedShort & 0xffff;
495                byteBuffer = ByteBuffer.allocate(messageLength);
496            }
497
498            try {
499                bytesRead = socketChannel.read(byteBuffer);
500            } catch (IOException e) {
501                throw new Error(e);
502            }
503
504            if (bytesRead < 0) {
505                abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
506                return;
507            }
508
509            if (byteBuffer.hasRemaining()) {
510                try {
511                    registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
512                } catch (ClosedChannelException e) {
513                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
514                }
515                return;
516            }
517
518            selectionKey.cancel();
519            try {
520                socketChannel.close();
521            } catch (IOException e) {
522                addException(e);
523            }
524
525            DnsMessage response;
526            try {
527                response = new DnsMessage(byteBuffer.array());
528            } catch (IOException e) {
529                abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e);
530                return;
531            }
532
533            if (request.id != response.id) {
534                MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response);
535                addException(idMismatchException);
536                AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions));
537                return;
538            }
539
540            DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(),
541                    QueryMethod.asyncTcp, request, response);
542            gotResult(result);
543        }
544
545    }
546
547}