001/*
002 * Copyright 2015-2020 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.nio.channels.ClosedChannelException;
016import java.nio.channels.SelectableChannel;
017import java.nio.channels.SelectionKey;
018import java.nio.channels.Selector;
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.Iterator;
024import java.util.List;
025import java.util.PriorityQueue;
026import java.util.Queue;
027import java.util.Set;
028import java.util.concurrent.ConcurrentLinkedQueue;
029import java.util.concurrent.ExecutionException;
030import java.util.concurrent.locks.Lock;
031import java.util.concurrent.locks.ReentrantLock;
032import java.util.logging.Level;
033import java.util.logging.Logger;
034
035import org.minidns.MiniDnsFuture;
036import org.minidns.dnsmessage.DnsMessage;
037import org.minidns.dnsqueryresult.DnsQueryResult;
038import org.minidns.source.AbstractDnsDataSource;
039
040public class AsyncNetworkDataSource extends AbstractDnsDataSource {
041
042    protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName());
043
044    private static final int REACTOR_THREAD_COUNT = 1;
045
046    private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>();
047
048    private static final Selector SELECTOR;
049
050    private static final Lock REGISTRATION_LOCK = new ReentrantLock();
051
052    private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>();
053
054    private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT];
055
056    private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() {
057        @Override
058        public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) {
059            if (o1.deadline > o2.deadline) {
060                return 1;
061            } else if (o1.deadline < o2.deadline) {
062                return -1;
063            }
064            return 0;
065        }
066    });
067
068    static {
069        try {
070            SELECTOR = Selector.open();
071        } catch (IOException e) {
072            throw new IllegalStateException(e);
073        }
074
075        for (int i = 0; i < REACTOR_THREAD_COUNT; i++) {
076            Thread reactorThread = new Thread(new Reactor());
077            reactorThread.setDaemon(true);
078            reactorThread.setName("MiniDNS Reactor Thread #" + i);
079            reactorThread.start();
080            REACTOR_THREADS[i] = reactorThread;
081        }
082    }
083
084    @Override
085    public MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) {
086        AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback);
087        INCOMING_REQUESTS.add(asyncDnsRequest);
088        synchronized (DEADLINE_QUEUE) {
089            DEADLINE_QUEUE.add(asyncDnsRequest);
090        }
091        SELECTOR.wakeup();
092        return asyncDnsRequest.getFuture();
093    }
094
095    @Override
096    public DnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException {
097        MiniDnsFuture<DnsQueryResult, IOException> future = queryAsync(message, address, port, null);
098        try {
099            return future.get();
100        } catch (InterruptedException e) {
101            // This should never happen.
102            throw new AssertionError(e);
103        } catch (ExecutionException e) {
104            Throwable wrappedThrowable = e.getCause();
105            if (wrappedThrowable instanceof IOException) {
106                throw (IOException) wrappedThrowable;
107            }
108            // This should never happen.
109            throw new AssertionError(e);
110        }
111    }
112
113    SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException {
114        REGISTRATION_LOCK.lock();
115        try {
116            SELECTOR.wakeup();
117            return channel.register(SELECTOR, ops, attachment);
118        } finally {
119            REGISTRATION_LOCK.unlock();
120        }
121    }
122
123    void finished(AsyncDnsRequest asyncDnsRequest) {
124        synchronized (DEADLINE_QUEUE) {
125            DEADLINE_QUEUE.remove(asyncDnsRequest);
126        }
127    }
128
129    void cancelled(AsyncDnsRequest asyncDnsRequest) {
130        finished(asyncDnsRequest);
131        // Wakeup since the async DNS request was removed from the deadline queue.
132        SELECTOR.wakeup();
133    }
134
135    private static class Reactor implements Runnable {
136        @Override
137        public void run() {
138            while (!Thread.interrupted()) {
139                Collection<SelectionKey> mySelectedKeys = performSelect();
140                handleSelectedKeys(mySelectedKeys);
141
142                handlePendingSelectionKeys();
143
144                handleIncomingRequests();
145            }
146        }
147
148        private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) {
149            for (SelectionKey selectionKey : selectedKeys) {
150                ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment();
151                SelectableChannel channel = selectionKey.channel();
152                channelSelectedHandler.handleChannelSelected(channel, selectionKey);
153            }
154        }
155
156        @SuppressWarnings("LockNotBeforeTry")
157        private static Collection<SelectionKey> performSelect() {
158            AsyncDnsRequest nearestDeadline = null;
159            AsyncDnsRequest nextInQueue;
160
161            synchronized (DEADLINE_QUEUE) {
162                while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) {
163                    if (nextInQueue.wasDeadlineMissedAndFutureNotified()) {
164                        // We notified the future, associated with the AsyncDnsRequest nearestDeadline,
165                        // that the deadline has passed, hence remove it from the queue.
166                        DEADLINE_QUEUE.poll();
167                    } else {
168                        // We found a nearest deadline that has not yet passed, break out of the loop.
169                        nearestDeadline = nextInQueue;
170                        break;
171                    }
172                }
173
174            }
175
176            long selectWait;
177            if (nearestDeadline == null) {
178                // There is no deadline, wait indefinitely in select().
179                selectWait = 0;
180            } else {
181                // There is a deadline in the future, only block in select() until the deadline.
182                selectWait = nextInQueue.deadline - System.currentTimeMillis();
183                if (selectWait < 0) {
184                    // We already have a missed deadline. Do not call select() and handle the tasks which are past their
185                    // deadline.
186                    return Collections.emptyList();
187                }
188            }
189
190            List<SelectionKey> selectedKeys;
191            int newSelectedKeysCount;
192            synchronized (SELECTOR) {
193                // Ensure that a wakeup() in registerWithSelector() gives the corresponding
194                // register() in the same method the chance to actually register the channel. In
195                // other words: This construct ensure that there is never another select()
196                // between a corresponding wakeup() and register() calls.
197                // See also https://stackoverflow.com/a/1112809/194894
198                REGISTRATION_LOCK.lock();
199                REGISTRATION_LOCK.unlock();
200
201                try {
202                    newSelectedKeysCount = SELECTOR.select(selectWait);
203                } catch (IOException e) {
204                    LOGGER.log(Level.WARNING, "IOException while using select()", e);
205                    return Collections.emptyList();
206                }
207
208                if (newSelectedKeysCount == 0) {
209                    return Collections.emptyList();
210                }
211
212                Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys();
213                for (SelectionKey selectionKey : selectedKeySet) {
214                    selectionKey.interestOps(0);
215                }
216
217                selectedKeys = new ArrayList<>(selectedKeySet.size());
218                selectedKeys.addAll(selectedKeySet);
219                selectedKeySet.clear();
220            }
221
222            int selectedKeysCount = selectedKeys.size();
223
224            final Level LOG_LEVEL = Level.FINER;
225            if (LOGGER.isLoggable(LOG_LEVEL)) {
226                LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count "
227                        + selectedKeysCount);
228            }
229
230            int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT;
231            Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount);
232            Iterator<SelectionKey> it = selectedKeys.iterator();
233            for (int i = 0; i < myKeyCount; i++) {
234                SelectionKey selectionKey = it.next();
235                mySelectedKeys.add(selectionKey);
236            }
237            while (it.hasNext()) {
238                // Drain to PENDING_SELECTION_KEYS
239                SelectionKey selectionKey = it.next();
240                PENDING_SELECTION_KEYS.add(selectionKey);
241            }
242            return mySelectedKeys;
243        }
244
245        private static void handlePendingSelectionKeys() {
246            int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size();
247            if (pendingSelectionKeysSize == 0) {
248                return;
249            }
250
251            int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT;
252            Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount);
253            for (int i = 0; i < myKeyCount; i++) {
254                SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll();
255                if (selectionKey == null) {
256                    // We lost a race :)
257                    break;
258                }
259                selectedKeys.add(selectionKey);
260            }
261
262            if (!PENDING_SELECTION_KEYS.isEmpty()) {
263                // There is more work in the pending selection keys queue, wakeup another thread to handle it.
264                SELECTOR.wakeup();
265            }
266
267            handleSelectedKeys(selectedKeys);
268        }
269
270        private static void handleIncomingRequests() {
271            int incomingRequestsSize = INCOMING_REQUESTS.size();
272            if (incomingRequestsSize == 0) {
273                return;
274            }
275
276            int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT;
277            Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount);
278            for (int i = 0; i < myRequestsCount; i++) {
279                AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll();
280                if (asyncDnsRequest == null) {
281                    // We lost a race :)
282                    break;
283                }
284                requests.add(asyncDnsRequest);
285            }
286
287            if (!INCOMING_REQUESTS.isEmpty()) {
288                SELECTOR.wakeup();
289            }
290
291            for (AsyncDnsRequest asyncDnsRequest : requests) {
292                asyncDnsRequest.startHandling();
293            }
294        }
295
296    }
297
298}