001/*
002 * Copyright 2015-2018 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.nio.channels.ClosedChannelException;
016import java.nio.channels.SelectableChannel;
017import java.nio.channels.SelectionKey;
018import java.nio.channels.Selector;
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.Iterator;
024import java.util.List;
025import java.util.PriorityQueue;
026import java.util.Queue;
027import java.util.Set;
028import java.util.concurrent.ConcurrentLinkedQueue;
029import java.util.concurrent.ExecutionException;
030import java.util.concurrent.locks.Lock;
031import java.util.concurrent.locks.ReentrantLock;
032import java.util.logging.Level;
033import java.util.logging.Logger;
034
035import org.minidns.MiniDnsFuture;
036import org.minidns.dnsmessage.DnsMessage;
037import org.minidns.source.DnsDataSource;
038
039public class AsyncNetworkDataSource extends DnsDataSource {
040
041    protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName());
042
043    private static final int REACTOR_THREAD_COUNT = 1;
044
045    private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>();
046
047    private static final Selector SELECTOR;
048
049    private static final Lock REGISTRATION_LOCK = new ReentrantLock();
050
051    private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>();
052
053    private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT];
054
055    private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() {
056        @Override
057        public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) {
058            if (o1.deadline > o2.deadline) {
059                return 1;
060            } else if (o1.deadline < o2.deadline) {
061                return -1;
062            }
063            return 0;
064        }
065    });
066
067    static {
068        try {
069            SELECTOR = Selector.open();
070        } catch (IOException e) {
071            throw new IllegalStateException(e);
072        }
073
074        for (int i = 0; i < REACTOR_THREAD_COUNT; i++) {
075            Thread reactorThread = new Thread(new Reactor());
076            reactorThread.setDaemon(true);
077            reactorThread.setName("MiniDNS Reactor Thread #" + i);
078            reactorThread.start();
079            REACTOR_THREADS[i] = reactorThread;
080        }
081    }
082
083    @Override
084    public MiniDnsFuture<DnsMessage, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) {
085        AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback);
086        INCOMING_REQUESTS.add(asyncDnsRequest);
087        synchronized (DEADLINE_QUEUE) {
088            DEADLINE_QUEUE.add(asyncDnsRequest);
089        }
090        SELECTOR.wakeup();
091        return asyncDnsRequest.getFuture();
092    }
093
094    @Override
095    public DnsMessage query(DnsMessage message, InetAddress address, int port) throws IOException {
096        MiniDnsFuture<DnsMessage, IOException> future = queryAsync(message, address, port, null);
097        try {
098            return future.get();
099        } catch (InterruptedException e) {
100            // This should never happen.
101            throw new AssertionError(e);
102        } catch (ExecutionException e) {
103            Throwable wrappedThrowable = e.getCause();
104            if (wrappedThrowable instanceof IOException) {
105                throw (IOException) wrappedThrowable;
106            }
107            // This should never happen.
108            throw new AssertionError(e);
109        }
110    }
111
112    SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException {
113        REGISTRATION_LOCK.lock();
114        try {
115            SELECTOR.wakeup();
116            return channel.register(SELECTOR, ops, attachment);
117        } finally {
118            REGISTRATION_LOCK.unlock();
119        }
120    }
121
122    void finished(AsyncDnsRequest asyncDnsRequest) {
123        synchronized (DEADLINE_QUEUE) {
124            DEADLINE_QUEUE.remove(asyncDnsRequest);
125        }
126    }
127
128    void cancelled(AsyncDnsRequest asyncDnsRequest) {
129        finished(asyncDnsRequest);
130        // Wakeup since the async DNS request was removed from the deadline queue.
131        SELECTOR.wakeup();
132    }
133
134    private static class Reactor implements Runnable {
135        @Override
136        public void run() {
137            while (!Thread.interrupted()) {
138                Collection<SelectionKey> mySelectedKeys = performSelect();
139                handleSelectedKeys(mySelectedKeys);
140
141                handlePendingSelectionKeys();
142
143                handleIncomingRequests();
144            }
145        }
146
147        private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) {
148            for (SelectionKey selectionKey : selectedKeys) {
149                ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment();
150                SelectableChannel channel = selectionKey.channel();
151                channelSelectedHandler.handleChannelSelected(channel, selectionKey);
152            }
153        }
154
155        private static Collection<SelectionKey> performSelect() {
156            AsyncDnsRequest nearestDeadline = null;
157            AsyncDnsRequest nextInQueue;
158
159            synchronized (DEADLINE_QUEUE) {
160                while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) {
161                    if (nextInQueue.wasDeadlineMissedAndFutureNotified()) {
162                        // We notified the future, associated with the AsyncDnsRequest nearestDeadline,
163                        // that the deadline has passed, hence remove it from the queue.
164                        DEADLINE_QUEUE.poll();
165                    } else {
166                        // We found a nearest deadline that has not yet passed, break out of the loop.
167                        nearestDeadline = nextInQueue;
168                        break;
169                    }
170                }
171
172            }
173
174            long selectWait;
175            if (nearestDeadline == null) {
176                // There is no deadline, wait indefinitely in select().
177                selectWait = 0;
178            } else {
179                // There is a deadline in the future, only block in select() until the deadline.
180                selectWait = nextInQueue.deadline - System.currentTimeMillis();
181                if (selectWait < 0) {
182                    // We already have a missed deadline. Do not call select() and handle the tasks which are past their
183                    // deadline.
184                    return Collections.emptyList();
185                }
186            }
187
188            List<SelectionKey> selectedKeys;
189            int newSelectedKeysCount;
190            synchronized (SELECTOR) {
191                // Ensure that a wakeup() in registerWithSelector() gives the corresponding
192                // register() in the same method the chance to actually register the channel. In
193                // other words: This construct ensure that there is never another select()
194                // between a corresponding wakeup() and register() calls.
195                // See also https://stackoverflow.com/a/1112809/194894
196                REGISTRATION_LOCK.lock();
197                REGISTRATION_LOCK.unlock();
198
199                try {
200                    newSelectedKeysCount = SELECTOR.select(selectWait);
201                } catch (IOException e) {
202                    LOGGER.log(Level.WARNING, "IOException while using select()", e);
203                    return Collections.emptyList();
204                }
205
206                if (newSelectedKeysCount == 0) {
207                    return Collections.emptyList();
208                }
209
210                Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys();
211                for (SelectionKey selectionKey : selectedKeySet) {
212                    selectionKey.interestOps(0);
213                }
214
215                selectedKeys = new ArrayList<>(selectedKeySet.size());
216                selectedKeys.addAll(selectedKeySet);
217                selectedKeySet.clear();
218            }
219
220            int selectedKeysCount = selectedKeys.size();
221
222            final Level LOG_LEVEL = Level.FINER;
223            if (LOGGER.isLoggable(LOG_LEVEL)) {
224                LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count "
225                        + selectedKeysCount);
226            }
227
228            int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT;
229            Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount);
230            Iterator<SelectionKey> it = selectedKeys.iterator();
231            for (int i = 0; i < myKeyCount; i++) {
232                SelectionKey selectionKey = it.next();
233                mySelectedKeys.add(selectionKey);
234            }
235            while (it.hasNext()) {
236                // Drain to PENDING_SELECTION_KEYS
237                SelectionKey selectionKey = it.next();
238                PENDING_SELECTION_KEYS.add(selectionKey);
239            }
240            return mySelectedKeys;
241        }
242
243        private static void handlePendingSelectionKeys() {
244            int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size();
245            if (pendingSelectionKeysSize == 0) {
246                return;
247            }
248
249            int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT;
250            Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount);
251            for (int i = 0; i < myKeyCount; i++) {
252                SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll();
253                if (selectionKey == null) {
254                    // We lost a race :)
255                    break;
256                }
257                selectedKeys.add(selectionKey);
258            }
259
260            if (!PENDING_SELECTION_KEYS.isEmpty()) {
261                // There is more work in the pending selection keys queue, wakeup another thread to handle it.
262                SELECTOR.wakeup();
263            }
264
265            handleSelectedKeys(selectedKeys);
266        }
267
268        private static void handleIncomingRequests() {
269            int incomingRequestsSize = INCOMING_REQUESTS.size();
270            if (incomingRequestsSize == 0) {
271                return;
272            }
273
274            int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT;
275            Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount);
276            for (int i = 0; i < myRequestsCount; i++) {
277                AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll();
278                if (asyncDnsRequest == null) {
279                    // We lost a race :)
280                    break;
281                }
282                requests.add(asyncDnsRequest);
283            }
284
285            if (!INCOMING_REQUESTS.isEmpty()) {
286                SELECTOR.wakeup();
287            }
288
289            for (AsyncDnsRequest asyncDnsRequest : requests) {
290                asyncDnsRequest.startHandling();
291            }
292        }
293
294    }
295
296}