001/* 002 * Copyright 2015-2018 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.nio.channels.ClosedChannelException; 016import java.nio.channels.SelectableChannel; 017import java.nio.channels.SelectionKey; 018import java.nio.channels.Selector; 019import java.util.ArrayList; 020import java.util.Collection; 021import java.util.Collections; 022import java.util.Comparator; 023import java.util.Iterator; 024import java.util.List; 025import java.util.PriorityQueue; 026import java.util.Queue; 027import java.util.Set; 028import java.util.concurrent.ConcurrentLinkedQueue; 029import java.util.concurrent.ExecutionException; 030import java.util.concurrent.locks.Lock; 031import java.util.concurrent.locks.ReentrantLock; 032import java.util.logging.Level; 033import java.util.logging.Logger; 034 035import org.minidns.MiniDnsFuture; 036import org.minidns.dnsmessage.DnsMessage; 037import org.minidns.source.DnsDataSource; 038 039public class AsyncNetworkDataSource extends DnsDataSource { 040 041 protected static final Logger LOGGER = Logger.getLogger(AsyncNetworkDataSource.class.getName()); 042 043 private static final int REACTOR_THREAD_COUNT = 1; 044 045 private static final Queue<AsyncDnsRequest> INCOMING_REQUESTS = new ConcurrentLinkedQueue<>(); 046 047 private static final Selector SELECTOR; 048 049 private static final Lock REGISTRATION_LOCK = new ReentrantLock(); 050 051 private static final Queue<SelectionKey> PENDING_SELECTION_KEYS = new ConcurrentLinkedQueue<>(); 052 053 private static final Thread[] REACTOR_THREADS = new Thread[REACTOR_THREAD_COUNT]; 054 055 private static final PriorityQueue<AsyncDnsRequest> DEADLINE_QUEUE = new PriorityQueue<>(16, new Comparator<AsyncDnsRequest>() { 056 @Override 057 public int compare(AsyncDnsRequest o1, AsyncDnsRequest o2) { 058 if (o1.deadline > o2.deadline) { 059 return 1; 060 } else if (o1.deadline < o2.deadline) { 061 return -1; 062 } 063 return 0; 064 } 065 }); 066 067 static { 068 try { 069 SELECTOR = Selector.open(); 070 } catch (IOException e) { 071 throw new IllegalStateException(e); 072 } 073 074 for (int i = 0; i < REACTOR_THREAD_COUNT; i++) { 075 Thread reactorThread = new Thread(new Reactor()); 076 reactorThread.setDaemon(true); 077 reactorThread.setName("MiniDNS Reactor Thread #" + i); 078 reactorThread.start(); 079 REACTOR_THREADS[i] = reactorThread; 080 } 081 } 082 083 @Override 084 public MiniDnsFuture<DnsMessage, IOException> queryAsync(DnsMessage message, InetAddress address, int port, OnResponseCallback onResponseCallback) { 085 AsyncDnsRequest asyncDnsRequest = new AsyncDnsRequest(message, address, port, udpPayloadSize, this, onResponseCallback); 086 INCOMING_REQUESTS.add(asyncDnsRequest); 087 synchronized (DEADLINE_QUEUE) { 088 DEADLINE_QUEUE.add(asyncDnsRequest); 089 } 090 SELECTOR.wakeup(); 091 return asyncDnsRequest.getFuture(); 092 } 093 094 @Override 095 public DnsMessage query(DnsMessage message, InetAddress address, int port) throws IOException { 096 MiniDnsFuture<DnsMessage, IOException> future = queryAsync(message, address, port, null); 097 try { 098 return future.get(); 099 } catch (InterruptedException e) { 100 // This should never happen. 101 throw new AssertionError(e); 102 } catch (ExecutionException e) { 103 Throwable wrappedThrowable = e.getCause(); 104 if (wrappedThrowable instanceof IOException) { 105 throw (IOException) wrappedThrowable; 106 } 107 // This should never happen. 108 throw new AssertionError(e); 109 } 110 } 111 112 SelectionKey registerWithSelector(SelectableChannel channel, int ops, Object attachment) throws ClosedChannelException { 113 REGISTRATION_LOCK.lock(); 114 try { 115 SELECTOR.wakeup(); 116 return channel.register(SELECTOR, ops, attachment); 117 } finally { 118 REGISTRATION_LOCK.unlock(); 119 } 120 } 121 122 void finished(AsyncDnsRequest asyncDnsRequest) { 123 synchronized (DEADLINE_QUEUE) { 124 DEADLINE_QUEUE.remove(asyncDnsRequest); 125 } 126 } 127 128 void cancelled(AsyncDnsRequest asyncDnsRequest) { 129 finished(asyncDnsRequest); 130 // Wakeup since the async DNS request was removed from the deadline queue. 131 SELECTOR.wakeup(); 132 } 133 134 private static class Reactor implements Runnable { 135 @Override 136 public void run() { 137 while (!Thread.interrupted()) { 138 Collection<SelectionKey> mySelectedKeys = performSelect(); 139 handleSelectedKeys(mySelectedKeys); 140 141 handlePendingSelectionKeys(); 142 143 handleIncomingRequests(); 144 } 145 } 146 147 private static void handleSelectedKeys(Collection<SelectionKey> selectedKeys) { 148 for (SelectionKey selectionKey : selectedKeys) { 149 ChannelSelectedHandler channelSelectedHandler = (ChannelSelectedHandler) selectionKey.attachment(); 150 SelectableChannel channel = selectionKey.channel(); 151 channelSelectedHandler.handleChannelSelected(channel, selectionKey); 152 } 153 } 154 155 private static Collection<SelectionKey> performSelect() { 156 AsyncDnsRequest nearestDeadline = null; 157 AsyncDnsRequest nextInQueue; 158 159 synchronized (DEADLINE_QUEUE) { 160 while ((nextInQueue = DEADLINE_QUEUE.peek()) != null) { 161 if (nextInQueue.wasDeadlineMissedAndFutureNotified()) { 162 // We notified the future, associated with the AsyncDnsRequest nearestDeadline, 163 // that the deadline has passed, hence remove it from the queue. 164 DEADLINE_QUEUE.poll(); 165 } else { 166 // We found a nearest deadline that has not yet passed, break out of the loop. 167 nearestDeadline = nextInQueue; 168 break; 169 } 170 } 171 172 } 173 174 long selectWait; 175 if (nearestDeadline == null) { 176 // There is no deadline, wait indefinitely in select(). 177 selectWait = 0; 178 } else { 179 // There is a deadline in the future, only block in select() until the deadline. 180 selectWait = nextInQueue.deadline - System.currentTimeMillis(); 181 if (selectWait < 0) { 182 // We already have a missed deadline. Do not call select() and handle the tasks which are past their 183 // deadline. 184 return Collections.emptyList(); 185 } 186 } 187 188 List<SelectionKey> selectedKeys; 189 int newSelectedKeysCount; 190 synchronized (SELECTOR) { 191 // Ensure that a wakeup() in registerWithSelector() gives the corresponding 192 // register() in the same method the chance to actually register the channel. In 193 // other words: This construct ensure that there is never another select() 194 // between a corresponding wakeup() and register() calls. 195 // See also https://stackoverflow.com/a/1112809/194894 196 REGISTRATION_LOCK.lock(); 197 REGISTRATION_LOCK.unlock(); 198 199 try { 200 newSelectedKeysCount = SELECTOR.select(selectWait); 201 } catch (IOException e) { 202 LOGGER.log(Level.WARNING, "IOException while using select()", e); 203 return Collections.emptyList(); 204 } 205 206 if (newSelectedKeysCount == 0) { 207 return Collections.emptyList(); 208 } 209 210 Set<SelectionKey> selectedKeySet = SELECTOR.selectedKeys(); 211 for (SelectionKey selectionKey : selectedKeySet) { 212 selectionKey.interestOps(0); 213 } 214 215 selectedKeys = new ArrayList<>(selectedKeySet.size()); 216 selectedKeys.addAll(selectedKeySet); 217 selectedKeySet.clear(); 218 } 219 220 int selectedKeysCount = selectedKeys.size(); 221 222 final Level LOG_LEVEL = Level.FINER; 223 if (LOGGER.isLoggable(LOG_LEVEL)) { 224 LOGGER.log(LOG_LEVEL, "New selected key count: " + newSelectedKeysCount + ". Total selected key count " 225 + selectedKeysCount); 226 } 227 228 int myKeyCount = selectedKeysCount / REACTOR_THREAD_COUNT; 229 Collection<SelectionKey> mySelectedKeys = new ArrayList<>(myKeyCount); 230 Iterator<SelectionKey> it = selectedKeys.iterator(); 231 for (int i = 0; i < myKeyCount; i++) { 232 SelectionKey selectionKey = it.next(); 233 mySelectedKeys.add(selectionKey); 234 } 235 while (it.hasNext()) { 236 // Drain to PENDING_SELECTION_KEYS 237 SelectionKey selectionKey = it.next(); 238 PENDING_SELECTION_KEYS.add(selectionKey); 239 } 240 return mySelectedKeys; 241 } 242 243 private static void handlePendingSelectionKeys() { 244 int pendingSelectionKeysSize = PENDING_SELECTION_KEYS.size(); 245 if (pendingSelectionKeysSize == 0) { 246 return; 247 } 248 249 int myKeyCount = pendingSelectionKeysSize / REACTOR_THREAD_COUNT; 250 Collection<SelectionKey> selectedKeys = new ArrayList<>(myKeyCount); 251 for (int i = 0; i < myKeyCount; i++) { 252 SelectionKey selectionKey = PENDING_SELECTION_KEYS.poll(); 253 if (selectionKey == null) { 254 // We lost a race :) 255 break; 256 } 257 selectedKeys.add(selectionKey); 258 } 259 260 if (!PENDING_SELECTION_KEYS.isEmpty()) { 261 // There is more work in the pending selection keys queue, wakeup another thread to handle it. 262 SELECTOR.wakeup(); 263 } 264 265 handleSelectedKeys(selectedKeys); 266 } 267 268 private static void handleIncomingRequests() { 269 int incomingRequestsSize = INCOMING_REQUESTS.size(); 270 if (incomingRequestsSize == 0) { 271 return; 272 } 273 274 int myRequestsCount = incomingRequestsSize / REACTOR_THREAD_COUNT; 275 Collection<AsyncDnsRequest> requests = new ArrayList<>(myRequestsCount); 276 for (int i = 0; i < myRequestsCount; i++) { 277 AsyncDnsRequest asyncDnsRequest = INCOMING_REQUESTS.poll(); 278 if (asyncDnsRequest == null) { 279 // We lost a race :) 280 break; 281 } 282 requests.add(asyncDnsRequest); 283 } 284 285 if (!INCOMING_REQUESTS.isEmpty()) { 286 SELECTOR.wakeup(); 287 } 288 289 for (AsyncDnsRequest asyncDnsRequest : requests) { 290 asyncDnsRequest.startHandling(); 291 } 292 } 293 294 } 295 296}