001/* 002 * Copyright 2015-2020 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.nio.channels.ClosedChannelException; 016import java.nio.channels.SelectableChannel; 017import java.nio.channels.SelectionKey; 018import java.nio.channels.Selector; 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.Comparator; 023import java.util.Iterator; 024import java.util.List; 025import java.util.PriorityQueue; 026import java.util.Queue; 027import java.util.Set; 028import java.util.concurrent.ConcurrentLinkedQueue; 029import java.util.concurrent.ExecutionException; 030import java.util.concurrent.locks.Lock; 031import java.util.concurrent.locks.ReentrantLock; 032import java.util.logging.Level; 033import java.util.logging.Logger; 034 035import org.minidns.MiniDnsFuture; 036import org.minidns.dnsmessage.DnsMessage; 037import org.minidns.dnsqueryresult.DnsQueryResult; 038import org.minidns.source.AbstractDnsDataSource; 039 040public class AsyncNetworkDataSource extends AbstractDnsDataSource { 041 042 protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName()); 043 044 private static final int REACTOR_THREAD_COUNT = 1; 045 046 private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>(); 047 048 private static final Selector SELECTOR; 049 050 private static final Lock REGISTRATION_LOCK = new ReentrantLock(); 051 052 private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>(); 053 054 private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT]; 055 056 private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() { 057 @Override 058 public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) { 059 if (o1.deadline > o2.deadline) { 060 return 1; 061 } else if (o1.deadline < o2.deadline) { 062 return -1; 063 } 064 return 0; 065 } 066 }); 067 068 static { 069 try { 070 SELECTOR = Selector.open(); 071 } catch (IOException e) { 072 throw new IllegalStateException(e); 073 } 074 075 for (int i = 0; i < REACTOR_THREAD_COUNT; i++) { 076 Thread reactorThread = new Thread(new Reactor()); 077 reactorThread.setDaemon(true); 078 reactorThread.setName("MiniDNS Reactor Thread #" + i); 079 reactorThread.start(); 080 REACTOR_THREADS[i] = reactorThread; 081 } 082 } 083 084 @Override 085 public MiniDnsFuture<DnsQueryResult, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) { 086 AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback); 087 INCOMING_REQUESTS.add(asyncDnsRequest); 088 synchronized (DEADLINE_QUEUE) { 089 DEADLINE_QUEUE.add(asyncDnsRequest); 090 } 091 SELECTOR.wakeup(); 092 return asyncDnsRequest.getFuture(); 093 } 094 095 @Override 096 public DnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException { 097 MiniDnsFuture<DnsQueryResult, IOException> future = queryAsync(message, address, port, null); 098 try { 099 return future.get(); 100 } catch (InterruptedException e) { 101 // This should never happen. 102 throw new AssertionError(e); 103 } catch (ExecutionException e) { 104 Throwable wrappedThrowable = e.getCause(); 105 if (wrappedThrowable instanceof IOException) { 106 throw (IOException) wrappedThrowable; 107 } 108 // This should never happen. 109 throw new AssertionError(e); 110 } 111 } 112 113 SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException { 114 REGISTRATION_LOCK.lock(); 115 try { 116 SELECTOR.wakeup(); 117 return channel.register(SELECTOR, ops, attachment); 118 } finally { 119 REGISTRATION_LOCK.unlock(); 120 } 121 } 122 123 void finished(AsyncDnsRequest asyncDnsRequest) { 124 synchronized (DEADLINE_QUEUE) { 125 DEADLINE_QUEUE.remove(asyncDnsRequest); 126 } 127 } 128 129 void cancelled(AsyncDnsRequest asyncDnsRequest) { 130 finished(asyncDnsRequest); 131 // Wakeup since the async DNS request was removed from the deadline queue. 132 SELECTOR.wakeup(); 133 } 134 135 private static class Reactor implements Runnable { 136 @Override 137 public void run() { 138 while (!Thread.interrupted()) { 139 Collection<SelectionKey> mySelectedKeys = performSelect(); 140 handleSelectedKeys(mySelectedKeys); 141 142 handlePendingSelectionKeys(); 143 144 handleIncomingRequests(); 145 } 146 } 147 148 private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) { 149 for (SelectionKey selectionKey : selectedKeys) { 150 ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment(); 151 SelectableChannel channel = selectionKey.channel(); 152 channelSelectedHandler.handleChannelSelected(channel, selectionKey); 153 } 154 } 155 156 @SuppressWarnings("LockNotBeforeTry") 157 private static Collection<SelectionKey> performSelect() { 158 AsyncDnsRequest nearestDeadline = null; 159 AsyncDnsRequest nextInQueue; 160 161 synchronized (DEADLINE_QUEUE) { 162 while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) { 163 if (nextInQueue.wasDeadlineMissedAndFutureNotified()) { 164 // We notified the future, associated with the AsyncDnsRequest nearestDeadline, 165 // that the deadline has passed, hence remove it from the queue. 166 DEADLINE_QUEUE.poll(); 167 } else { 168 // We found a nearest deadline that has not yet passed, break out of the loop. 169 nearestDeadline = nextInQueue; 170 break; 171 } 172 } 173 174 } 175 176 long selectWait; 177 if (nearestDeadline == null) { 178 // There is no deadline, wait indefinitely in select(). 179 selectWait = 0; 180 } else { 181 // There is a deadline in the future, only block in select() until the deadline. 182 selectWait = nextInQueue.deadline - System.currentTimeMillis(); 183 if (selectWait < 0) { 184 // We already have a missed deadline. Do not call select() and handle the tasks which are past their 185 // deadline. 186 return Collections.emptyList(); 187 } 188 } 189 190 List<SelectionKey> selectedKeys; 191 int newSelectedKeysCount; 192 synchronized (SELECTOR) { 193 // Ensure that a wakeup() in registerWithSelector() gives the corresponding 194 // register() in the same method the chance to actually register the channel. In 195 // other words: This construct ensure that there is never another select() 196 // between a corresponding wakeup() and register() calls. 197 // See also https://stackoverflow.com/a/1112809/194894 198 REGISTRATION_LOCK.lock(); 199 REGISTRATION_LOCK.unlock(); 200 201 try { 202 newSelectedKeysCount = SELECTOR.select(selectWait); 203 } catch (IOException e) { 204 LOGGER.log(Level.WARNING, "IOException while using select()", e); 205 return Collections.emptyList(); 206 } 207 208 if (newSelectedKeysCount == 0) { 209 return Collections.emptyList(); 210 } 211 212 Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys(); 213 for (SelectionKey selectionKey : selectedKeySet) { 214 selectionKey.interestOps(0); 215 } 216 217 selectedKeys = new ArrayList<>(selectedKeySet.size()); 218 selectedKeys.addAll(selectedKeySet); 219 selectedKeySet.clear(); 220 } 221 222 int selectedKeysCount = selectedKeys.size(); 223 224 final Level LOG_LEVEL = Level.FINER; 225 if (LOGGER.isLoggable(LOG_LEVEL)) { 226 LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count " 227 + selectedKeysCount); 228 } 229 230 int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT; 231 Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount); 232 Iterator<SelectionKey> it = selectedKeys.iterator(); 233 for (int i = 0; i < myKeyCount; i++) { 234 SelectionKey selectionKey = it.next(); 235 mySelectedKeys.add(selectionKey); 236 } 237 while (it.hasNext()) { 238 // Drain to PENDING_SELECTION_KEYS 239 SelectionKey selectionKey = it.next(); 240 PENDING_SELECTION_KEYS.add(selectionKey); 241 } 242 return mySelectedKeys; 243 } 244 245 private static void handlePendingSelectionKeys() { 246 int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size(); 247 if (pendingSelectionKeysSize == 0) { 248 return; 249 } 250 251 int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT; 252 Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount); 253 for (int i = 0; i < myKeyCount; i++) { 254 SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll(); 255 if (selectionKey == null) { 256 // We lost a race :) 257 break; 258 } 259 selectedKeys.add(selectionKey); 260 } 261 262 if (!PENDING_SELECTION_KEYS.isEmpty()) { 263 // There is more work in the pending selection keys queue, wakeup another thread to handle it. 264 SELECTOR.wakeup(); 265 } 266 267 handleSelectedKeys(selectedKeys); 268 } 269 270 private static void handleIncomingRequests() { 271 int incomingRequestsSize = INCOMING_REQUESTS.size(); 272 if (incomingRequestsSize == 0) { 273 return; 274 } 275 276 int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT; 277 Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount); 278 for (int i = 0; i < myRequestsCount; i++) { 279 AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll(); 280 if (asyncDnsRequest == null) { 281 // We lost a race :) 282 break; 283 } 284 requests.add(asyncDnsRequest); 285 } 286 287 if (!INCOMING_REQUESTS.isEmpty()) { 288 SELECTOR.wakeup(); 289 } 290 291 for (AsyncDnsRequest asyncDnsRequest : requests) { 292 asyncDnsRequest.startHandling(); 293 } 294 } 295 296 } 297 298}