001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source.async;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.nio.channels.ClosedChannelException;
016import java.nio.channels.SelectableChannel;
017import java.nio.channels.SelectionKey;
018import java.nio.channels.Selector;
019import java.util.ArrayList;
020import java.util.Collection;
021import java.util.Collections;
022import java.util.Comparator;
023import java.util.Iterator;
024import java.util.List;
025import java.util.PriorityQueue;
026import java.util.Queue;
027import java.util.Set;
028import java.util.concurrent.ConcurrentLinkedQueue;
029import java.util.concurrent.ExecutionException;
030import java.util.concurrent.locks.Lock;
031import java.util.concurrent.locks.ReentrantLock;
032import java.util.logging.Level;
033import java.util.logging.Logger;
034
035import org.minidns.MiniDnsFuture;
036import org.minidns.dnsmessage.DnsMessage;
037import org.minidns.dnsqueryresult.DnsQueryResult;
038import org.minidns.source.AbstractDnsDataSource;
039
040/**
041 * A DNS data sources that resolves requests via the network asynchronously.
042 */
043public class AsyncNetworkDataSource extends AbstractDnsDataSource {
044
045    /**
046     * The logger of this data source.
047     */
048    protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName());
049
050    private static final int REACTOR_THREAD_COUNT = 1;
051
052    private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>();
053
054    private static final Selector SELECTOR;
055
056    private static final Lock REGISTRATION_LOCK = new ReentrantLock();
057
058    private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>();
059
060    private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT];
061
062    private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() {
063        @Override
064        public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) {
065            if (o1.deadline > o2.deadline) {
066                return 1;
067            } else if (o1.deadline < o2.deadline) {
068                return -1;
069            }
070            return 0;
071        }
072    });
073
074    static {
075        try {
076            SELECTOR = Selector.open();
077        } catch (IOException e) {
078            throw new IllegalStateException(e);
079        }
080
081        for (int i = 0; i < REACTOR_THREAD_COUNT; i++) {
082            Thread reactorThread = new Thread(new Reactor());
083            reactorThread.setDaemon(true);
084            reactorThread.setName("MiniDNS Reactor Thread #" + i);
085            reactorThread.start();
086            REACTOR_THREADS[i] = reactorThread;
087        }
088    }
089
090    @Override
091    public MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) {
092        AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback);
093        INCOMING_REQUESTS.add(asyncDnsRequest);
094        synchronized (DEADLINE_QUEUE) {
095            DEADLINE_QUEUE.add(asyncDnsRequest);
096        }
097        SELECTOR.wakeup();
098        return asyncDnsRequest.getFuture();
099    }
100
101    @Override
102    public DnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException {
103        MiniDnsFuture<DnsQueryResult, IOException> future = queryAsync(message, address, port, null);
104        try {
105            return future.get();
106        } catch (InterruptedException e) {
107            // This should never happen.
108            throw new AssertionError(e);
109        } catch (ExecutionException e) {
110            Throwable wrappedThrowable = e.getCause();
111            if (wrappedThrowable instanceof IOException) {
112                throw (IOException) wrappedThrowable;
113            }
114            // This should never happen.
115            throw new AssertionError(e);
116        }
117    }
118
119    SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException {
120        REGISTRATION_LOCK.lock();
121        try {
122            SELECTOR.wakeup();
123            return channel.register(SELECTOR, ops, attachment);
124        } finally {
125            REGISTRATION_LOCK.unlock();
126        }
127    }
128
129    void finished(AsyncDnsRequest asyncDnsRequest) {
130        synchronized (DEADLINE_QUEUE) {
131            DEADLINE_QUEUE.remove(asyncDnsRequest);
132        }
133    }
134
135    void cancelled(AsyncDnsRequest asyncDnsRequest) {
136        finished(asyncDnsRequest);
137        // Wakeup since the async DNS request was removed from the deadline queue.
138        SELECTOR.wakeup();
139    }
140
141    private static class Reactor implements Runnable {
142        @Override
143        public void run() {
144            while (!Thread.interrupted()) {
145                Collection<SelectionKey> mySelectedKeys = performSelect();
146                handleSelectedKeys(mySelectedKeys);
147
148                handlePendingSelectionKeys();
149
150                handleIncomingRequests();
151            }
152        }
153
154        private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) {
155            for (SelectionKey selectionKey : selectedKeys) {
156                ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment();
157                SelectableChannel channel = selectionKey.channel();
158                channelSelectedHandler.handleChannelSelected(channel, selectionKey);
159            }
160        }
161
162        @SuppressWarnings("LockNotBeforeTry")
163        private static Collection<SelectionKey> performSelect() {
164            AsyncDnsRequest nearestDeadline = null;
165            AsyncDnsRequest nextInQueue;
166
167            synchronized (DEADLINE_QUEUE) {
168                while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) {
169                    if (nextInQueue.wasDeadlineMissedAndFutureNotified()) {
170                        // We notified the future, associated with the AsyncDnsRequest nearestDeadline,
171                        // that the deadline has passed, hence remove it from the queue.
172                        DEADLINE_QUEUE.poll();
173                    } else {
174                        // We found a nearest deadline that has not yet passed, break out of the loop.
175                        nearestDeadline = nextInQueue;
176                        break;
177                    }
178                }
179
180            }
181
182            long selectWait;
183            if (nearestDeadline == null) {
184                // There is no deadline, wait indefinitely in select().
185                selectWait = 0;
186            } else {
187                // There is a deadline in the future, only block in select() until the deadline.
188                selectWait = nextInQueue.deadline - System.currentTimeMillis();
189                if (selectWait < 0) {
190                    // We already have a missed deadline. Do not call select() and handle the tasks which are past their
191                    // deadline.
192                    return Collections.emptyList();
193                }
194            }
195
196            List<SelectionKey> selectedKeys;
197            int newSelectedKeysCount;
198            synchronized (SELECTOR) {
199                // Ensure that a wakeup() in registerWithSelector() gives the corresponding
200                // register() in the same method the chance to actually register the channel. In
201                // other words: This construct ensure that there is never another select()
202                // between a corresponding wakeup() and register() calls.
203                // See also https://stackoverflow.com/a/1112809/194894
204                REGISTRATION_LOCK.lock();
205                REGISTRATION_LOCK.unlock();
206
207                try {
208                    newSelectedKeysCount = SELECTOR.select(selectWait);
209                } catch (IOException e) {
210                    LOGGER.log(Level.WARNING, "IOException while using select()", e);
211                    return Collections.emptyList();
212                }
213
214                if (newSelectedKeysCount == 0) {
215                    return Collections.emptyList();
216                }
217
218                Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys();
219                for (SelectionKey selectionKey : selectedKeySet) {
220                    selectionKey.interestOps(0);
221                }
222
223                selectedKeys = new ArrayList<>(selectedKeySet.size());
224                selectedKeys.addAll(selectedKeySet);
225                selectedKeySet.clear();
226            }
227
228            int selectedKeysCount = selectedKeys.size();
229
230            final Level LOG_LEVEL = Level.FINER;
231            if (LOGGER.isLoggable(LOG_LEVEL)) {
232                LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count "
233                        + selectedKeysCount);
234            }
235
236            int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT;
237            Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount);
238            Iterator<SelectionKey> it = selectedKeys.iterator();
239            for (int i = 0; i < myKeyCount; i++) {
240                SelectionKey selectionKey = it.next();
241                mySelectedKeys.add(selectionKey);
242            }
243            while (it.hasNext()) {
244                // Drain to PENDING_SELECTION_KEYS
245                SelectionKey selectionKey = it.next();
246                PENDING_SELECTION_KEYS.add(selectionKey);
247            }
248            return mySelectedKeys;
249        }
250
251        private static void handlePendingSelectionKeys() {
252            int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size();
253            if (pendingSelectionKeysSize == 0) {
254                return;
255            }
256
257            int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT;
258            Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount);
259            for (int i = 0; i < myKeyCount; i++) {
260                SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll();
261                if (selectionKey == null) {
262                    // We lost a race :)
263                    break;
264                }
265                selectedKeys.add(selectionKey);
266            }
267
268            if (!PENDING_SELECTION_KEYS.isEmpty()) {
269                // There is more work in the pending selection keys queue, wakeup another thread to handle it.
270                SELECTOR.wakeup();
271            }
272
273            handleSelectedKeys(selectedKeys);
274        }
275
276        private static void handleIncomingRequests() {
277            int incomingRequestsSize = INCOMING_REQUESTS.size();
278            if (incomingRequestsSize == 0) {
279                return;
280            }
281
282            int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT;
283            // The division could result in myRequestCount being zero despite pending incoming
284            // requests. Therefore, ensure this thread tries to get at least one incoming
285            // request by invoking poll(). Otherwise, we might end up in a busy loop
286            // where myRequestCount is zero, and this thread invokes a selector.wakeup() below
287            // because incomingRequestsSize is not empty, but the woken-up reactor thread
288            // will end up with myRequestCount being zero again, restarting the busy-loop cycle.
289            if (myRequestsCount == 0) myRequestsCount = 1;
290            Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount);
291            for (int i = 0; i < myRequestsCount; i++) {
292                AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll();
293                if (asyncDnsRequest == null) {
294                    // We lost a race :)
295                    break;
296                }
297                requests.add(asyncDnsRequest);
298            }
299
300            if (!INCOMING_REQUESTS.isEmpty()) {
301                SELECTOR.wakeup();
302            }
303
304            for (AsyncDnsRequest asyncDnsRequest : requests) {
305                asyncDnsRequest.startHandling();
306            }
307        }
308
309    }
310
311}