001/* 002 * Copyright 2015-2018 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.net.InetSocketAddress; 016import java.net.SocketAddress; 017import java.nio.ByteBuffer; 018import java.nio.channels.ClosedChannelException; 019import java.nio.channels.DatagramChannel; 020import java.nio.channels.SelectableChannel; 021import java.nio.channels.SelectionKey; 022import java.nio.channels.SocketChannel; 023import java.util.ArrayList; 024import java.util.List; 025import java.util.concurrent.Future; 026import java.util.logging.Level; 027import java.util.logging.Logger; 028 029import org.minidns.MiniDnsException; 030import org.minidns.MiniDnsFuture; 031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture; 032import org.minidns.dnsmessage.DnsMessage; 033import org.minidns.source.DnsDataSource.OnResponseCallback; 034import org.minidns.source.DnsDataSource.QueryMode; 035import org.minidns.util.MultipleIoException; 036 037public class AsyncDnsRequest { 038 039 private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName()); 040 041 private final InternalMiniDnsFuture<DnsMessage, IOException> future = new InternalMiniDnsFuture<DnsMessage, IOException>() { 042 @Override 043 public boolean cancel(boolean mayInterruptIfRunning) { 044 boolean res = super.cancel(mayInterruptIfRunning); 045 cancelAsyncDnsRequest(); 046 return res; 047 } 048 }; 049 050 private final DnsMessage request; 051 052 private final int udpPayloadSize; 053 054 private final SocketAddress socketAddress; 055 056 private final AsyncNetworkDataSource asyncNds; 057 058 private final OnResponseCallback onResponseCallback; 059 060 private final boolean skipUdp; 061 062 private ByteBuffer writeBuffer; 063 064 private List<IOException> exceptions; 065 066 private SelectionKey selectionKey; 067 068 final long deadline; 069 070 /** 071 * Creates a new AsyncDnsRequest instance. 072 * 073 * @param request the DNS message of the request. 074 * @param inetAddress The IP address of the DNS server to ask. 075 * @param port The port of the DNS server to ask. 076 * @param udpPayloadSize The configured UDP payload size. 077 * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests. 078 * @param onResponseCallback the optional callback when a response was received. 079 */ 080 AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) { 081 this.request = request; 082 this.udpPayloadSize = udpPayloadSize; 083 this.asyncNds = asyncNds; 084 this.onResponseCallback = onResponseCallback; 085 086 final QueryMode queryMode = asyncNds.getQueryMode(); 087 switch (queryMode) { 088 case dontCare: 089 case udpTcp: 090 skipUdp = false; 091 break; 092 case tcp: 093 skipUdp = true; 094 break; 095 default: 096 throw new IllegalStateException("Unsupported query mode: " + queryMode); 097 098 } 099 deadline = System.currentTimeMillis() + asyncNds.getTimeout(); 100 socketAddress = new InetSocketAddress(inetAddress, port); 101 } 102 103 private void ensureWriteBufferIsInitialized() { 104 if (writeBuffer != null) { 105 if (!writeBuffer.hasRemaining()) { 106 writeBuffer.rewind(); 107 } 108 return; 109 } 110 writeBuffer = request.getInByteBuffer(); 111 } 112 113 private synchronized void cancelAsyncDnsRequest() { 114 if (selectionKey != null) { 115 selectionKey.cancel(); 116 } 117 asyncNds.cancelled(this); 118 } 119 120 private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler) 121 throws ClosedChannelException { 122 if (future.isCancelled()) { 123 return; 124 } 125 selectionKey = asyncNds.registerWithSelector(channel, ops, handler); 126 } 127 128 private void addException(IOException e) { 129 if (exceptions == null) { 130 exceptions = new ArrayList<>(4); 131 } 132 exceptions.add(e); 133 } 134 135 private final void gotResult(DnsMessage result) { 136 if (onResponseCallback != null) { 137 onResponseCallback.onResponse(request, result); 138 } 139 asyncNds.finished(this); 140 future.setResult(result); 141 } 142 143 MiniDnsFuture<DnsMessage, IOException> getFuture() { 144 return future; 145 } 146 147 boolean wasDeadlineMissedAndFutureNotified() { 148 if (System.currentTimeMillis() < deadline) { 149 return false; 150 } 151 152 future.setException(new IOException("Timeout")); 153 return true; 154 } 155 156 void startHandling() { 157 if (!skipUdp) { 158 startUdpRequest(); 159 } else { 160 startTcpRequest(); 161 } 162 } 163 164 private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) { 165 LOGGER.log(Level.SEVERE, errorMessage, exception); 166 addException(exception); 167 168 if (selectionKey != null) { 169 selectionKey.cancel(); 170 } 171 172 try { 173 datagramChannel.close(); 174 } catch (IOException e) { 175 LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e); 176 addException(e); 177 } 178 179 startTcpRequest(); 180 } 181 182 private void startUdpRequest() { 183 if (future.isCancelled()) { 184 return; 185 } 186 187 DatagramChannel datagramChannel; 188 try { 189 datagramChannel = DatagramChannel.open(); 190 } catch (IOException e) { 191 LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e); 192 addException(e); 193 startTcpRequest(); 194 return; 195 } 196 197 try { 198 datagramChannel.configureBlocking(false); 199 } catch (IOException e) { 200 abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e); 201 return; 202 } 203 204 try { 205 datagramChannel.connect(socketAddress); 206 } catch (IOException e) { 207 abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel", e); 208 return; 209 } 210 211 try { 212 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future)); 213 } catch (ClosedChannelException e) { 214 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 215 return; 216 } 217 } 218 219 class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler { 220 221 UdpWritableChannelSelectedHandler(Future<?> future) { 222 super(future); 223 } 224 225 @Override 226 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 227 DatagramChannel datagramChannel = (DatagramChannel) channel; 228 229 ensureWriteBufferIsInitialized(); 230 231 try { 232 datagramChannel.write(writeBuffer); 233 } catch (IOException e) { 234 abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e); 235 return; 236 } 237 238 if (writeBuffer.hasRemaining()) { 239 try { 240 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this); 241 } catch (ClosedChannelException e) { 242 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 243 } 244 return; 245 } 246 247 try { 248 registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future)); 249 } catch (ClosedChannelException e) { 250 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e); 251 return; 252 } 253 } 254 255 } 256 257 class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler { 258 259 UdpReadableChannelSelectedHandler(Future<?> future) { 260 super(future); 261 } 262 263 final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize); 264 265 @Override 266 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 267 DatagramChannel datagramChannel = (DatagramChannel) channel; 268 269 try { 270 datagramChannel.read(byteBuffer); 271 } catch (IOException e) { 272 abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e); 273 return; 274 } 275 276 selectionKey.cancel(); 277 try { 278 datagramChannel.close(); 279 } catch (IOException e) { 280 LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e); 281 addException(e); 282 } 283 284 DnsMessage response; 285 try { 286 response = new DnsMessage(byteBuffer.array()); 287 } catch (IOException e) { 288 abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e); 289 return; 290 } 291 292 if (response.id != request.id) { 293 addException(new MiniDnsException.IdMismatch(request, response)); 294 startTcpRequest(); 295 return; 296 } 297 298 if (response.truncated) { 299 startTcpRequest(); 300 return; 301 } 302 303 gotResult(response); 304 } 305 } 306 307 private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) { 308 if (exception == null) { 309 exception = new IOException(errorMessage); 310 } 311 LOGGER.log(Level.SEVERE, errorMessage, exception); 312 addException(exception); 313 314 if (selectionKey != null) { 315 selectionKey.cancel(); 316 } 317 318 if (socketChannel != null && socketChannel.isOpen()) { 319 try { 320 socketChannel.close(); 321 } catch (IOException e) { 322 LOGGER.log(Level.SEVERE, "Exception closing socket channel", e); 323 addException(e); 324 } 325 } 326 327 future.setException(MultipleIoException.toIOException(exceptions)); 328 } 329 330 private void startTcpRequest() { 331 SocketChannel socketChannel = null; 332 try { 333 socketChannel = SocketChannel.open(); 334 } catch (IOException e) { 335 abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e); 336 return; 337 } 338 339 try { 340 socketChannel.configureBlocking(false); 341 } catch (IOException e) { 342 abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e); 343 return; 344 } 345 346 try { 347 registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future)); 348 } catch (ClosedChannelException e) { 349 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e); 350 return; 351 } 352 353 try { 354 socketChannel.connect(socketAddress); 355 } catch (IOException e) { 356 abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel", e); 357 return; 358 } 359 } 360 361 class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler { 362 363 TcpConnectedChannelSelectedHandler(Future<?> future) { 364 super(future); 365 } 366 367 @Override 368 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 369 SocketChannel socketChannel = (SocketChannel) channel; 370 371 boolean connected; 372 try { 373 connected = socketChannel.finishConnect(); 374 } catch (IOException e) { 375 abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e); 376 return; 377 } 378 379 assert connected; 380 381 try { 382 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future)); 383 } catch (ClosedChannelException e) { 384 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 385 return; 386 } 387 } 388 389 } 390 391 class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler { 392 393 TcpWritableChannelSelectedHandler(Future<?> future) { 394 super(future); 395 } 396 397 /** 398 * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message. 399 */ 400 private ByteBuffer[] writeBuffers; 401 402 @Override 403 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 404 SocketChannel socketChannel = (SocketChannel) channel; 405 406 if (writeBuffers == null) { 407 ensureWriteBufferIsInitialized(); 408 409 ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 410 int messageLength = writeBuffer.capacity(); 411 assert messageLength <= Short.MAX_VALUE; 412 messageLengthByteBuffer.putShort((short) (messageLength & 0xffff)); 413 messageLengthByteBuffer.rewind(); 414 415 writeBuffers = new ByteBuffer[2]; 416 writeBuffers[0] = messageLengthByteBuffer; 417 writeBuffers[1] = writeBuffer; 418 } 419 420 try { 421 socketChannel.write(writeBuffers); 422 } catch (IOException e) { 423 abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e); 424 return; 425 } 426 427 if (moreToWrite()) { 428 try { 429 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this); 430 } catch (ClosedChannelException e) { 431 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 432 } 433 return; 434 } 435 436 try { 437 registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future)); 438 } catch (ClosedChannelException e) { 439 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 440 return; 441 } 442 } 443 444 private boolean moreToWrite() { 445 for (int i = 0; i < writeBuffers.length; i++) { 446 if (writeBuffers[i].hasRemaining()) { 447 return true; 448 } 449 } 450 return false; 451 } 452 } 453 454 class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler { 455 456 TcpReadableChannelSelectedHandler(Future<?> future) { 457 super(future); 458 } 459 460 final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 461 462 ByteBuffer byteBuffer; 463 464 @Override 465 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 466 SocketChannel socketChannel = (SocketChannel) channel; 467 468 int bytesRead; 469 if (byteBuffer == null) { 470 try { 471 bytesRead = socketChannel.read(messageLengthByteBuffer); 472 } catch (IOException e) { 473 abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e); 474 return; 475 } 476 477 if (bytesRead < 0) { 478 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 479 return; 480 } 481 482 if (messageLengthByteBuffer.hasRemaining()) { 483 try { 484 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 485 } catch (ClosedChannelException e) { 486 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 487 } 488 return; 489 } 490 491 messageLengthByteBuffer.rewind(); 492 short messageLengthSignedShort = messageLengthByteBuffer.getShort(); 493 int messageLength = messageLengthSignedShort & 0xffff; 494 byteBuffer = ByteBuffer.allocate(messageLength); 495 } 496 497 try { 498 bytesRead = socketChannel.read(byteBuffer); 499 } catch (IOException e) { 500 throw new Error(e); 501 } 502 503 if (bytesRead < 0) { 504 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 505 return; 506 } 507 508 if (byteBuffer.hasRemaining()) { 509 try { 510 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 511 } catch (ClosedChannelException e) { 512 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 513 } 514 return; 515 } 516 517 selectionKey.cancel(); 518 try { 519 socketChannel.close(); 520 } catch (IOException e) { 521 addException(e); 522 } 523 524 DnsMessage response; 525 try { 526 response = new DnsMessage(byteBuffer.array()); 527 } catch (IOException e) { 528 abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e); 529 return; 530 } 531 532 if (request.id != response.id) { 533 MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response); 534 addException(idMismatchException); 535 AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions)); 536 return; 537 } 538 539 gotResult(response); 540 } 541 542 } 543 544}