001/*
002 * Copyright 2015-2020 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.net.InetSocketAddress;
016import java.nio.ByteBuffer;
017import java.nio.channels.Channel;
018import java.nio.channels.ClosedChannelException;
019import java.nio.channels.DatagramChannel;
020import java.nio.channels.SelectableChannel;
021import java.nio.channels.SelectionKey;
022import java.nio.channels.SocketChannel;
023import java.util.ArrayList;
024import java.util.List;
025import java.util.concurrent.Future;
026import java.util.logging.Level;
027import java.util.logging.Logger;
028
029import org.minidns.MiniDnsException;
030import org.minidns.MiniDnsFuture;
031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture;
032import org.minidns.dnsmessage.DnsMessage;
033import org.minidns.dnsqueryresult.DnsQueryResult;
034import org.minidns.dnsqueryresult.DnsQueryResult.QueryMethod;
035import org.minidns.dnsqueryresult.StandardDnsQueryResult;
036import org.minidns.source.DnsDataSource.OnResponseCallback;
037import org.minidns.source.AbstractDnsDataSource.QueryMode;
038import org.minidns.util.MultipleIoException;
039
040public class AsyncDnsRequest {
041
042    private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName());
043
044    private final InternalMiniDnsFuture<DnsQueryResult, IOException> future = new InternalMiniDnsFuture<DnsQueryResult, IOException>() {
045        @SuppressWarnings("UnsynchronizedOverridesSynchronized")
046        @Override
047        public boolean cancel(boolean mayInterruptIfRunning) {
048            boolean res = super.cancel(mayInterruptIfRunning);
049            cancelAsyncDnsRequest();
050            return res;
051        }
052    };
053
054    private final DnsMessage request;
055
056    private final int udpPayloadSize;
057
058    private final InetSocketAddress socketAddress;
059
060    private final AsyncNetworkDataSource asyncNds;
061
062    private final OnResponseCallback onResponseCallback;
063
064    private final boolean skipUdp;
065
066    private ByteBuffer writeBuffer;
067
068    private List<IOException> exceptions;
069
070    private SelectionKey selectionKey;
071
072    final long deadline;
073
074    /**
075     * Creates a new AsyncDnsRequest instance.
076     *
077     * @param request the DNS message of the request.
078     * @param inetAddress The IP address of the DNS server to ask.
079     * @param port The port of the DNS server to ask.
080     * @param udpPayloadSize The configured UDP payload size.
081     * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests.
082     * @param onResponseCallback the optional callback when a response was received.
083     */
084    AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) {
085        this.request = request;
086        this.udpPayloadSize = udpPayloadSize;
087        this.asyncNds = asyncNds;
088        this.onResponseCallback = onResponseCallback;
089
090        final QueryMode queryMode = asyncNds.getQueryMode();
091        switch (queryMode) {
092        case dontCare:
093        case udpTcp:
094            skipUdp = false;
095            break;
096        case tcp:
097            skipUdp = true;
098            break;
099        default:
100            throw new IllegalStateException("Unsupported query mode: " + queryMode);
101
102        }
103        deadline = System.currentTimeMillis() + asyncNds.getTimeout();
104        socketAddress = new InetSocketAddress(inetAddress, port);
105    }
106
107    private void ensureWriteBufferIsInitialized() {
108        if (writeBuffer != null) {
109            if (!writeBuffer.hasRemaining()) {
110                writeBuffer.rewind();
111            }
112            return;
113        }
114        writeBuffer = request.getInByteBuffer();
115    }
116
117    private synchronized void cancelAsyncDnsRequest() {
118        if (selectionKey != null) {
119            selectionKey.cancel();
120        }
121        asyncNds.cancelled(this);
122    }
123
124    private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler)
125            throws ClosedChannelException {
126        if (future.isCancelled()) {
127            return;
128        }
129        selectionKey = asyncNds.registerWithSelector(channel, ops, handler);
130    }
131
132    private void addException(IOException e) {
133        if (exceptions == null) {
134            exceptions = new ArrayList<>(4);
135        }
136        exceptions.add(e);
137    }
138
139    private void gotResult(DnsQueryResult result) {
140        if (onResponseCallback != null) {
141            onResponseCallback.onResponse(request, result);
142        }
143        asyncNds.finished(this);
144        future.setResult(result);
145    }
146
147    MiniDnsFuture<DnsQueryResult, IOException> getFuture() {
148        return future;
149    }
150
151    boolean wasDeadlineMissedAndFutureNotified() {
152        if (System.currentTimeMillis() < deadline) {
153            return false;
154        }
155
156        future.setException(new IOException("Timeout"));
157        return true;
158    }
159
160    void startHandling() {
161        if (!skipUdp) {
162            startUdpRequest();
163        } else {
164            startTcpRequest();
165        }
166    }
167
168    private void abortRequestAndCleanup(Channel channel, String errorMessage, IOException exception) {
169        if (exception == null) {
170            // TODO: Can this case be removed? Is 'exception' ever null?
171            LOGGER.info("Exception was null in abortRequestAndCleanup()");
172            exception = new IOException(errorMessage);
173        }
174        LOGGER.log(Level.SEVERE, "Error connecting " + channel + ": " + errorMessage, exception);
175        addException(exception);
176
177        if (selectionKey != null) {
178            selectionKey.cancel();
179        }
180
181        if (channel != null && channel.isOpen()) {
182            try {
183                channel.close();
184            } catch (IOException e) {
185                LOGGER.log(Level.SEVERE, "Exception closing socket channel", e);
186                addException(e);
187            }
188        }
189    }
190
191    private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) {
192        abortRequestAndCleanup(datagramChannel, errorMessage, exception);
193        startTcpRequest();
194    }
195
196    private void startUdpRequest() {
197        if (future.isCancelled()) {
198            return;
199        }
200
201        DatagramChannel datagramChannel;
202        try {
203            datagramChannel = DatagramChannel.open();
204        } catch (IOException e) {
205            LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e);
206            addException(e);
207            startTcpRequest();
208            return;
209        }
210
211        try {
212            datagramChannel.configureBlocking(false);
213        } catch (IOException e) {
214            abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e);
215            return;
216        }
217
218        try {
219            datagramChannel.connect(socketAddress);
220        } catch (IOException e) {
221            abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel to " + socketAddress, e);
222            return;
223        }
224
225        try {
226            registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future));
227        } catch (ClosedChannelException e) {
228            abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
229            return;
230        }
231    }
232
233    class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler {
234
235        UdpWritableChannelSelectedHandler(Future<?> future) {
236            super(future);
237        }
238
239        @Override
240        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
241            DatagramChannel datagramChannel = (DatagramChannel) channel;
242
243            ensureWriteBufferIsInitialized();
244
245            try {
246                datagramChannel.write(writeBuffer);
247            } catch (IOException e) {
248                abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e);
249                return;
250            }
251
252            if (writeBuffer.hasRemaining()) {
253                try {
254                    registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this);
255                } catch (ClosedChannelException e) {
256                    abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
257                }
258                return;
259            }
260
261            try {
262                registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future));
263            } catch (ClosedChannelException e) {
264                abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e);
265                return;
266            }
267        }
268
269    }
270
271    class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler {
272
273        UdpReadableChannelSelectedHandler(Future<?> future) {
274            super(future);
275        }
276
277        final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize);
278
279        @Override
280        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
281            DatagramChannel datagramChannel = (DatagramChannel) channel;
282
283            try {
284                datagramChannel.read(byteBuffer);
285            } catch (IOException e) {
286                abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e);
287                return;
288            }
289
290            selectionKey.cancel();
291            try {
292                datagramChannel.close();
293            } catch (IOException e) {
294                LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e);
295                addException(e);
296            }
297
298            DnsMessage response;
299            try {
300                response = new DnsMessage(byteBuffer.array());
301            } catch (IOException e) {
302                abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e);
303                return;
304            }
305
306            if (response.id != request.id) {
307                addException(new MiniDnsException.IdMismatch(request, response));
308                startTcpRequest();
309                return;
310            }
311
312            if (response.truncated) {
313                startTcpRequest();
314                return;
315            }
316
317            DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(),
318                    QueryMethod.asyncUdp, request, response);
319            gotResult(result);
320        }
321    }
322
323    private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) {
324        abortRequestAndCleanup(socketChannel, errorMessage, exception);
325        future.setException(MultipleIoException.toIOException(exceptions));
326    }
327
328    private void startTcpRequest() {
329        SocketChannel socketChannel = null;
330        try {
331            socketChannel = SocketChannel.open();
332        } catch (IOException e) {
333            abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e);
334            return;
335        }
336
337        try {
338            socketChannel.configureBlocking(false);
339        } catch (IOException e) {
340            abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e);
341            return;
342        }
343
344        try {
345            registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future));
346        } catch (ClosedChannelException e) {
347            abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e);
348            return;
349        }
350
351        try {
352            socketChannel.connect(socketAddress);
353        } catch (IOException e) {
354            abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel to " + socketAddress, e);
355            return;
356        }
357    }
358
359    class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler {
360
361        TcpConnectedChannelSelectedHandler(Future<?> future) {
362            super(future);
363        }
364
365        @Override
366        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
367            SocketChannel socketChannel = (SocketChannel) channel;
368
369            boolean connected;
370            try {
371                connected = socketChannel.finishConnect();
372            } catch (IOException e) {
373                abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e);
374                return;
375            }
376
377            assert connected;
378
379            try {
380                registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future));
381            } catch (ClosedChannelException e) {
382                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
383                return;
384            }
385        }
386
387    }
388
389    class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler {
390
391        TcpWritableChannelSelectedHandler(Future<?> future) {
392            super(future);
393        }
394
395        /**
396         * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message.
397         */
398        private ByteBuffer[] writeBuffers;
399
400        @Override
401        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
402            SocketChannel socketChannel = (SocketChannel) channel;
403
404            if (writeBuffers == null) {
405                ensureWriteBufferIsInitialized();
406
407                ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
408                int messageLength = writeBuffer.capacity();
409                assert messageLength <= Short.MAX_VALUE;
410                messageLengthByteBuffer.putShort((short) (messageLength & 0xffff));
411                messageLengthByteBuffer.rewind();
412
413                writeBuffers = new ByteBuffer[2];
414                writeBuffers[0] = messageLengthByteBuffer;
415                writeBuffers[1] = writeBuffer;
416            }
417
418            try {
419                socketChannel.write(writeBuffers);
420            } catch (IOException e) {
421                abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e);
422                return;
423            }
424
425            if (moreToWrite()) {
426                try {
427                    registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this);
428                } catch (ClosedChannelException e) {
429                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
430                }
431                return;
432            }
433
434            try {
435                registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future));
436            } catch (ClosedChannelException e) {
437                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
438                return;
439            }
440        }
441
442        private boolean moreToWrite() {
443            for (int i = 0; i < writeBuffers.length; i++) {
444                if (writeBuffers[i].hasRemaining()) {
445                    return true;
446                }
447            }
448            return false;
449        }
450    }
451
452    class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler {
453
454        TcpReadableChannelSelectedHandler(Future<?> future) {
455            super(future);
456        }
457
458        final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
459
460        ByteBuffer byteBuffer;
461
462        @Override
463        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
464            SocketChannel socketChannel = (SocketChannel) channel;
465
466            int bytesRead;
467            if (byteBuffer == null) {
468                try {
469                    bytesRead = socketChannel.read(messageLengthByteBuffer);
470                } catch (IOException e) {
471                    abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e);
472                    return;
473                }
474
475                if (bytesRead < 0) {
476                    abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
477                    return;
478                }
479
480                if (messageLengthByteBuffer.hasRemaining()) {
481                    try {
482                        registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
483                    } catch (ClosedChannelException e) {
484                        abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
485                    }
486                    return;
487                }
488
489                messageLengthByteBuffer.rewind();
490                short messageLengthSignedShort = messageLengthByteBuffer.getShort();
491                int messageLength = messageLengthSignedShort & 0xffff;
492                byteBuffer = ByteBuffer.allocate(messageLength);
493            }
494
495            try {
496                bytesRead = socketChannel.read(byteBuffer);
497            } catch (IOException e) {
498                throw new Error(e);
499            }
500
501            if (bytesRead < 0) {
502                abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
503                return;
504            }
505
506            if (byteBuffer.hasRemaining()) {
507                try {
508                    registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
509                } catch (ClosedChannelException e) {
510                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
511                }
512                return;
513            }
514
515            selectionKey.cancel();
516            try {
517                socketChannel.close();
518            } catch (IOException e) {
519                addException(e);
520            }
521
522            DnsMessage response;
523            try {
524                response = new DnsMessage(byteBuffer.array());
525            } catch (IOException e) {
526                abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e);
527                return;
528            }
529
530            if (request.id != response.id) {
531                MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response);
532                addException(idMismatchException);
533                AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions));
534                return;
535            }
536
537            DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(),
538                    QueryMethod.asyncTcp, request, response);
539            gotResult(result);
540        }
541
542    }
543
544}