001/*
002 * Copyright 2015-2018 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.net.InetSocketAddress;
016import java.net.SocketAddress;
017import java.nio.ByteBuffer;
018import java.nio.channels.ClosedChannelException;
019import java.nio.channels.DatagramChannel;
020import java.nio.channels.SelectableChannel;
021import java.nio.channels.SelectionKey;
022import java.nio.channels.SocketChannel;
023import java.util.ArrayList;
024import java.util.List;
025import java.util.concurrent.Future;
026import java.util.logging.Level;
027import java.util.logging.Logger;
028
029import org.minidns.MiniDnsException;
030import org.minidns.MiniDnsFuture;
031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture;
032import org.minidns.dnsmessage.DnsMessage;
033import org.minidns.source.DnsDataSource.OnResponseCallback;
034import org.minidns.source.DnsDataSource.QueryMode;
035import org.minidns.util.MultipleIoException;
036
037public class AsyncDnsRequest {
038
039    private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName());
040
041    private final InternalMiniDnsFuture<DnsMessage, IOException> future = new InternalMiniDnsFuture<DnsMessage, IOException>() {
042        @Override
043        public boolean cancel(boolean mayInterruptIfRunning) {
044            boolean res = super.cancel(mayInterruptIfRunning);
045            cancelAsyncDnsRequest();
046            return res;
047        }
048    };
049
050    private final DnsMessage request;
051
052    private final int udpPayloadSize;
053
054    private final SocketAddress socketAddress;
055
056    private final AsyncNetworkDataSource asyncNds;
057
058    private final OnResponseCallback onResponseCallback;
059
060    private final boolean skipUdp;
061
062    private ByteBuffer writeBuffer;
063
064    private List<IOException> exceptions;
065
066    private SelectionKey selectionKey;
067
068    final long deadline;
069
070    /**
071     * Creates a new AsyncDnsRequest instance.
072     *
073     * @param request the DNS message of the request.
074     * @param inetAddress The IP address of the DNS server to ask.
075     * @param port The port of the DNS server to ask.
076     * @param udpPayloadSize The configured UDP payload size.
077     * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests.
078     * @param onResponseCallback the optional callback when a response was received.
079     */
080    AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) {
081        this.request = request;
082        this.udpPayloadSize = udpPayloadSize;
083        this.asyncNds = asyncNds;
084        this.onResponseCallback = onResponseCallback;
085
086        final QueryMode queryMode = asyncNds.getQueryMode();
087        switch (queryMode) {
088        case dontCare:
089        case udpTcp:
090            skipUdp = false;
091            break;
092        case tcp:
093            skipUdp = true;
094            break;
095        default:
096            throw new IllegalStateException("Unsupported query mode: " + queryMode);
097
098        }
099        deadline = System.currentTimeMillis() + asyncNds.getTimeout();
100        socketAddress = new InetSocketAddress(inetAddress, port);
101    }
102
103    private void ensureWriteBufferIsInitialized() {
104        if (writeBuffer != null) {
105            if (!writeBuffer.hasRemaining()) {
106                writeBuffer.rewind();
107            }
108            return;
109        }
110        writeBuffer = request.getInByteBuffer();
111    }
112
113    private synchronized void cancelAsyncDnsRequest() {
114        if (selectionKey != null) {
115            selectionKey.cancel();
116        }
117        asyncNds.cancelled(this);
118    }
119
120    private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler)
121            throws ClosedChannelException {
122        if (future.isCancelled()) {
123            return;
124        }
125        selectionKey = asyncNds.registerWithSelector(channel, ops, handler);
126    }
127
128    private void addException(IOException e) {
129        if (exceptions == null) {
130            exceptions = new ArrayList<>(4);
131        }
132        exceptions.add(e);
133    }
134
135    private final void gotResult(DnsMessage result) {
136        if (onResponseCallback != null) {
137            onResponseCallback.onResponse(request, result);
138        }
139        asyncNds.finished(this);
140        future.setResult(result);
141    }
142
143    MiniDnsFuture<DnsMessage, IOException> getFuture() {
144        return future;
145    }
146
147    boolean wasDeadlineMissedAndFutureNotified() {
148        if (System.currentTimeMillis() < deadline) {
149            return false;
150        }
151
152        future.setException(new IOException("Timeout"));
153        return true;
154    }
155
156    void startHandling() {
157        if (!skipUdp) {
158            startUdpRequest();
159        } else {
160            startTcpRequest();
161        }
162    }
163
164    private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) {
165        LOGGER.log(Level.SEVERE, errorMessage, exception);
166        addException(exception);
167
168        if (selectionKey != null) {
169            selectionKey.cancel();
170        }
171
172        try {
173            datagramChannel.close();
174        } catch (IOException e) {
175            LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e);
176            addException(e);
177        }
178
179        startTcpRequest();
180    }
181
182    private void startUdpRequest() {
183        if (future.isCancelled()) {
184            return;
185        }
186
187        DatagramChannel datagramChannel;
188        try {
189            datagramChannel = DatagramChannel.open();
190        } catch (IOException e) {
191            LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e);
192            addException(e);
193            startTcpRequest();
194            return;
195        }
196
197        try {
198            datagramChannel.configureBlocking(false);
199        } catch (IOException e) {
200            abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e);
201            return;
202        }
203
204        try {
205            datagramChannel.connect(socketAddress);
206        } catch (IOException e) {
207            abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel", e);
208            return;
209        }
210
211        try {
212            registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future));
213        } catch (ClosedChannelException e) {
214            abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
215            return;
216        }
217    }
218
219    class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler {
220
221        UdpWritableChannelSelectedHandler(Future<?> future) {
222            super(future);
223        }
224
225        @Override
226        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
227            DatagramChannel datagramChannel = (DatagramChannel) channel;
228
229            ensureWriteBufferIsInitialized();
230
231            try {
232                datagramChannel.write(writeBuffer);
233            } catch (IOException e) {
234                abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e);
235                return;
236            }
237
238            if (writeBuffer.hasRemaining()) {
239                try {
240                    registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this);
241                } catch (ClosedChannelException e) {
242                    abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e);
243                }
244                return;
245            }
246
247            try {
248                registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future));
249            } catch (ClosedChannelException e) {
250                abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e);
251                return;
252            }
253        }
254
255    }
256
257    class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler {
258
259        UdpReadableChannelSelectedHandler(Future<?> future) {
260            super(future);
261        }
262
263        final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize);
264
265        @Override
266        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
267            DatagramChannel datagramChannel = (DatagramChannel) channel;
268
269            try {
270                datagramChannel.read(byteBuffer);
271            } catch (IOException e) {
272                abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e);
273                return;
274            }
275
276            selectionKey.cancel();
277            try {
278                datagramChannel.close();
279            } catch (IOException e) {
280                LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e);
281                addException(e);
282            }
283
284            DnsMessage response;
285            try {
286                response = new DnsMessage(byteBuffer.array());
287            } catch (IOException e) {
288                abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e);
289                return;
290            }
291
292            if (response.id != request.id) {
293                addException(new MiniDnsException.IdMismatch(request, response));
294                startTcpRequest();
295                return;
296            }
297
298            if (response.truncated) {
299                startTcpRequest();
300                return;
301            }
302
303            gotResult(response);
304        }
305    }
306
307    private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) {
308        if (exception == null) {
309            exception = new IOException(errorMessage);
310        }
311        LOGGER.log(Level.SEVERE, errorMessage, exception);
312        addException(exception);
313
314        if (selectionKey != null) {
315            selectionKey.cancel();
316        }
317
318        if (socketChannel != null && socketChannel.isOpen()) {
319            try {
320                socketChannel.close();
321            } catch (IOException e) {
322                LOGGER.log(Level.SEVERE, "Exception closing socket channel", e);
323                addException(e);
324            }
325        }
326
327        future.setException(MultipleIoException.toIOException(exceptions));
328    }
329
330    private void startTcpRequest() {
331        SocketChannel socketChannel = null;
332        try {
333            socketChannel = SocketChannel.open();
334        } catch (IOException e) {
335            abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e);
336            return;
337        }
338
339        try {
340            socketChannel.configureBlocking(false);
341        } catch (IOException e) {
342            abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e);
343            return;
344        }
345
346        try {
347            registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future));
348        } catch (ClosedChannelException e) {
349            abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e);
350            return;
351        }
352
353        try {
354            socketChannel.connect(socketAddress);
355        } catch (IOException e) {
356            abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel", e);
357            return;
358        }
359    }
360
361    class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler {
362
363        TcpConnectedChannelSelectedHandler(Future<?> future) {
364            super(future);
365        }
366
367        @Override
368        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
369            SocketChannel socketChannel = (SocketChannel) channel;
370
371            boolean connected;
372            try {
373                connected = socketChannel.finishConnect();
374            } catch (IOException e) {
375                abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e);
376                return;
377            }
378
379            assert connected;
380
381            try {
382                registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future));
383            } catch (ClosedChannelException e) {
384                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
385                return;
386            }
387        }
388
389    }
390
391    class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler {
392
393        TcpWritableChannelSelectedHandler(Future<?> future) {
394            super(future);
395        }
396
397        /**
398         * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message.
399         */
400        private ByteBuffer[] writeBuffers;
401
402        @Override
403        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
404            SocketChannel socketChannel = (SocketChannel) channel;
405
406            if (writeBuffers == null) {
407                ensureWriteBufferIsInitialized();
408
409                ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
410                int messageLength = writeBuffer.capacity();
411                assert messageLength <= Short.MAX_VALUE;
412                messageLengthByteBuffer.putShort((short) (messageLength & 0xffff));
413                messageLengthByteBuffer.rewind();
414
415                writeBuffers = new ByteBuffer[2];
416                writeBuffers[0] = messageLengthByteBuffer;
417                writeBuffers[1] = writeBuffer;
418            }
419
420            try {
421                socketChannel.write(writeBuffers);
422            } catch (IOException e) {
423                abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e);
424                return;
425            }
426
427            if (moreToWrite()) {
428                try {
429                    registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this);
430                } catch (ClosedChannelException e) {
431                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e);
432                }
433                return;
434            }
435
436            try {
437                registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future));
438            } catch (ClosedChannelException e) {
439                abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
440                return;
441            }
442        }
443
444        private boolean moreToWrite() {
445            for (int i = 0; i < writeBuffers.length; i++) {
446                if (writeBuffers[i].hasRemaining()) {
447                    return true;
448                }
449            }
450            return false;
451        }
452    }
453
454    class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler {
455
456        TcpReadableChannelSelectedHandler(Future<?> future) {
457            super(future);
458        }
459
460        final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2);
461
462        ByteBuffer byteBuffer;
463
464        @Override
465        public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) {
466            SocketChannel socketChannel = (SocketChannel) channel;
467
468            int bytesRead;
469            if (byteBuffer == null) {
470                try {
471                    bytesRead = socketChannel.read(messageLengthByteBuffer);
472                } catch (IOException e) {
473                    abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e);
474                    return;
475                }
476
477                if (bytesRead < 0) {
478                    abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
479                    return;
480                }
481
482                if (messageLengthByteBuffer.hasRemaining()) {
483                    try {
484                        registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
485                    } catch (ClosedChannelException e) {
486                        abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
487                    }
488                    return;
489                }
490
491                messageLengthByteBuffer.rewind();
492                short messageLengthSignedShort = messageLengthByteBuffer.getShort();
493                int messageLength = messageLengthSignedShort & 0xffff;
494                byteBuffer = ByteBuffer.allocate(messageLength);
495            }
496
497            try {
498                bytesRead = socketChannel.read(byteBuffer);
499            } catch (IOException e) {
500                throw new Error(e);
501            }
502
503            if (bytesRead < 0) {
504                abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null);
505                return;
506            }
507
508            if (byteBuffer.hasRemaining()) {
509                try {
510                    registerWithSelector(socketChannel, SelectionKey.OP_READ, this);
511                } catch (ClosedChannelException e) {
512                    abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e);
513                }
514                return;
515            }
516
517            selectionKey.cancel();
518            try {
519                socketChannel.close();
520            } catch (IOException e) {
521                addException(e);
522            }
523
524            DnsMessage response;
525            try {
526                response = new DnsMessage(byteBuffer.array());
527            } catch (IOException e) {
528                abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e);
529                return;
530            }
531
532            if (request.id != response.id) {
533                MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response);
534                addException(idMismatchException);
535                AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions));
536                return;
537            }
538
539            gotResult(response);
540        }
541
542    }
543
544}