001/* 002 * Copyright 2015-2024 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.nio.channels.ClosedChannelException; 016import java.nio.channels.SelectableChannel; 017import java.nio.channels.SelectionKey; 018import java.nio.channels.Selector; 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.Comparator; 023import java.util.Iterator; 024import java.util.List; 025import java.util.PriorityQueue; 026import java.util.Queue; 027import java.util.Set; 028import java.util.concurrent.ConcurrentLinkedQueue; 029import java.util.concurrent.ExecutionException; 030import java.util.concurrent.locks.Lock; 031import java.util.concurrent.locks.ReentrantLock; 032import java.util.logging.Level; 033import java.util.logging.Logger; 034 035import org.minidns.MiniDnsFuture; 036import org.minidns.dnsmessage.DnsMessage; 037import org.minidns.dnsqueryresult.DnsQueryResult; 038import org.minidns.source.AbstractDnsDataSource; 039 040/** 041 * A DNS data sources that resolves requests via the network asynchronously. 042 */ 043public class AsyncNetworkDataSource extends AbstractDnsDataSource { 044 045 /** 046 * The logger of this data source. 047 */ 048 protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName()); 049 050 private static final int REACTOR_THREAD_COUNT = 1; 051 052 private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>(); 053 054 private static final Selector SELECTOR; 055 056 private static final Lock REGISTRATION_LOCK = new ReentrantLock(); 057 058 private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>(); 059 060 private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT]; 061 062 private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() { 063 @Override 064 public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) { 065 if (o1.deadline > o2.deadline) { 066 return 1; 067 } else if (o1.deadline < o2.deadline) { 068 return -1; 069 } 070 return 0; 071 } 072 }); 073 074 static { 075 try { 076 SELECTOR = Selector.open(); 077 } catch (IOException e) { 078 throw new IllegalStateException(e); 079 } 080 081 for (int i = 0; i < REACTOR_THREAD_COUNT; i++) { 082 Thread reactorThread = new Thread(new Reactor()); 083 reactorThread.setDaemon(true); 084 reactorThread.setName("MiniDNS Reactor Thread #" + i); 085 reactorThread.start(); 086 REACTOR_THREADS[i] = reactorThread; 087 } 088 } 089 090 @Override 091 public MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) { 092 AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback); 093 INCOMING_REQUESTS.add(asyncDnsRequest); 094 synchronized (DEADLINE_QUEUE) { 095 DEADLINE_QUEUE.add(asyncDnsRequest); 096 } 097 SELECTOR.wakeup(); 098 return asyncDnsRequest.getFuture(); 099 } 100 101 @Override 102 public DnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException { 103 MiniDnsFuture<DnsQueryResult, IOException> future = queryAsync(message, address, port, null); 104 try { 105 return future.get(); 106 } catch (InterruptedException e) { 107 // This should never happen. 108 throw new AssertionError(e); 109 } catch (ExecutionException e) { 110 Throwable wrappedThrowable = e.getCause(); 111 if (wrappedThrowable instanceof IOException) { 112 throw (IOException) wrappedThrowable; 113 } 114 // This should never happen. 115 throw new AssertionError(e); 116 } 117 } 118 119 SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException { 120 REGISTRATION_LOCK.lock(); 121 try { 122 SELECTOR.wakeup(); 123 return channel.register(SELECTOR, ops, attachment); 124 } finally { 125 REGISTRATION_LOCK.unlock(); 126 } 127 } 128 129 void finished(AsyncDnsRequest asyncDnsRequest) { 130 synchronized (DEADLINE_QUEUE) { 131 DEADLINE_QUEUE.remove(asyncDnsRequest); 132 } 133 } 134 135 void cancelled(AsyncDnsRequest asyncDnsRequest) { 136 finished(asyncDnsRequest); 137 // Wakeup since the async DNS request was removed from the deadline queue. 138 SELECTOR.wakeup(); 139 } 140 141 private static final class Reactor implements Runnable { 142 @Override 143 public void run() { 144 while (!Thread.interrupted()) { 145 Collection<SelectionKey> mySelectedKeys = performSelect(); 146 handleSelectedKeys(mySelectedKeys); 147 148 handlePendingSelectionKeys(); 149 150 handleIncomingRequests(); 151 } 152 } 153 154 private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) { 155 for (SelectionKey selectionKey : selectedKeys) { 156 ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment(); 157 SelectableChannel channel = selectionKey.channel(); 158 channelSelectedHandler.handleChannelSelected(channel, selectionKey); 159 } 160 } 161 162 @SuppressWarnings({"LockNotBeforeTry", "MixedMutabilityReturnType"}) 163 private static Collection<SelectionKey> performSelect() { 164 AsyncDnsRequest nearestDeadline = null; 165 AsyncDnsRequest nextInQueue; 166 167 synchronized (DEADLINE_QUEUE) { 168 while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) { 169 if (nextInQueue.wasDeadlineMissedAndFutureNotified()) { 170 // We notified the future, associated with the AsyncDnsRequest nearestDeadline, 171 // that the deadline has passed, hence remove it from the queue. 172 DEADLINE_QUEUE.poll(); 173 } else { 174 // We found a nearest deadline that has not yet passed, break out of the loop. 175 nearestDeadline = nextInQueue; 176 break; 177 } 178 } 179 180 } 181 182 long selectWait; 183 if (nearestDeadline == null) { 184 // There is no deadline, wait indefinitely in select(). 185 selectWait = 0; 186 } else { 187 // There is a deadline in the future, only block in select() until the deadline. 188 selectWait = nextInQueue.deadline - System.currentTimeMillis(); 189 if (selectWait < 0) { 190 // We already have a missed deadline. Do not call select() and handle the tasks which are past their 191 // deadline. 192 return Collections.emptyList(); 193 } 194 } 195 196 List<SelectionKey> selectedKeys; 197 int newSelectedKeysCount; 198 synchronized (SELECTOR) { 199 // Ensure that a wakeup() in registerWithSelector() gives the corresponding 200 // register() in the same method the chance to actually register the channel. In 201 // other words: This construct ensure that there is never another select() 202 // between a corresponding wakeup() and register() calls. 203 // See also https://stackoverflow.com/a/1112809/194894 204 REGISTRATION_LOCK.lock(); 205 REGISTRATION_LOCK.unlock(); 206 207 try { 208 newSelectedKeysCount = SELECTOR.select(selectWait); 209 } catch (IOException e) { 210 LOGGER.log(Level.WARNING, "IOException while using select()", e); 211 return Collections.emptyList(); 212 } 213 214 if (newSelectedKeysCount == 0) { 215 return Collections.emptyList(); 216 } 217 218 Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys(); 219 for (SelectionKey selectionKey : selectedKeySet) { 220 selectionKey.interestOps(0); 221 } 222 223 selectedKeys = new ArrayList<>(selectedKeySet.size()); 224 selectedKeys.addAll(selectedKeySet); 225 selectedKeySet.clear(); 226 } 227 228 int selectedKeysCount = selectedKeys.size(); 229 230 final Level LOG_LEVEL = Level.FINER; 231 if (LOGGER.isLoggable(LOG_LEVEL)) { 232 LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count " 233 + selectedKeysCount); 234 } 235 236 int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT; 237 Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount); 238 Iterator<SelectionKey> it = selectedKeys.iterator(); 239 for (int i = 0; i < myKeyCount; i++) { 240 SelectionKey selectionKey = it.next(); 241 mySelectedKeys.add(selectionKey); 242 } 243 while (it.hasNext()) { 244 // Drain to PENDING_SELECTION_KEYS 245 SelectionKey selectionKey = it.next(); 246 PENDING_SELECTION_KEYS.add(selectionKey); 247 } 248 return mySelectedKeys; 249 } 250 251 private static void handlePendingSelectionKeys() { 252 int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size(); 253 if (pendingSelectionKeysSize == 0) { 254 return; 255 } 256 257 int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT; 258 Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount); 259 for (int i = 0; i < myKeyCount; i++) { 260 SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll(); 261 if (selectionKey == null) { 262 // We lost a race :) 263 break; 264 } 265 selectedKeys.add(selectionKey); 266 } 267 268 if (!PENDING_SELECTION_KEYS.isEmpty()) { 269 // There is more work in the pending selection keys queue, wakeup another thread to handle it. 270 SELECTOR.wakeup(); 271 } 272 273 handleSelectedKeys(selectedKeys); 274 } 275 276 private static void handleIncomingRequests() { 277 int incomingRequestsSize = INCOMING_REQUESTS.size(); 278 if (incomingRequestsSize == 0) { 279 return; 280 } 281 282 int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT; 283 // The division could result in myRequestCount being zero despite pending incoming 284 // requests. Therefore, ensure this thread tries to get at least one incoming 285 // request by invoking poll(). Otherwise, we might end up in a busy loop 286 // where myRequestCount is zero, and this thread invokes a selector.wakeup() below 287 // because incomingRequestsSize is not empty, but the woken-up reactor thread 288 // will end up with myRequestCount being zero again, restarting the busy-loop cycle. 289 if (myRequestsCount == 0) myRequestsCount = 1; 290 Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount); 291 for (int i = 0; i < myRequestsCount; i++) { 292 AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll(); 293 if (asyncDnsRequest == null) { 294 // We lost a race :) 295 break; 296 } 297 requests.add(asyncDnsRequest); 298 } 299 300 if (!INCOMING_REQUESTS.isEmpty()) { 301 SELECTOR.wakeup(); 302 } 303 304 for (AsyncDnsRequest asyncDnsRequest : requests) { 305 asyncDnsRequest.startHandling(); 306 } 307 } 308 309 } 310 311}