001/* 002 * Copyright 2015-2020 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.net.InetSocketAddress; 016import java.nio.ByteBuffer; 017import java.nio.channels.Channel; 018import java.nio.channels.ClosedChannelException; 019import java.nio.channels.DatagramChannel; 020import java.nio.channels.SelectableChannel; 021import java.nio.channels.SelectionKey; 022import java.nio.channels.SocketChannel; 023import java.util.ArrayList; 024import java.util.List; 025import java.util.concurrent.Future; 026import java.util.logging.Level; 027import java.util.logging.Logger; 028 029import org.minidns.MiniDnsException; 030import org.minidns.MiniDnsFuture; 031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture; 032import org.minidns.dnsmessage.DnsMessage; 033import org.minidns.dnsqueryresult.DnsQueryResult; 034import org.minidns.dnsqueryresult.DnsQueryResult.QueryMethod; 035import org.minidns.dnsqueryresult.StandardDnsQueryResult; 036import org.minidns.source.DnsDataSource.OnResponseCallback; 037import org.minidns.source.AbstractDnsDataSource.QueryMode; 038import org.minidns.util.MultipleIoException; 039 040public class AsyncDnsRequest { 041 042 private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName()); 043 044 private final InternalMiniDnsFuture<DnsQueryResult, IOException> future = new InternalMiniDnsFuture<DnsQueryResult, IOException>() { 045 @SuppressWarnings("UnsynchronizedOverridesSynchronized") 046 @Override 047 public boolean cancel(boolean mayInterruptIfRunning) { 048 boolean res = super.cancel(mayInterruptIfRunning); 049 cancelAsyncDnsRequest(); 050 return res; 051 } 052 }; 053 054 private final DnsMessage request; 055 056 private final int udpPayloadSize; 057 058 private final InetSocketAddress socketAddress; 059 060 private final AsyncNetworkDataSource asyncNds; 061 062 private final OnResponseCallback onResponseCallback; 063 064 private final boolean skipUdp; 065 066 private ByteBuffer writeBuffer; 067 068 private List<IOException> exceptions; 069 070 private SelectionKey selectionKey; 071 072 final long deadline; 073 074 /** 075 * Creates a new AsyncDnsRequest instance. 076 * 077 * @param request the DNS message of the request. 078 * @param inetAddress The IP address of the DNS server to ask. 079 * @param port The port of the DNS server to ask. 080 * @param udpPayloadSize The configured UDP payload size. 081 * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests. 082 * @param onResponseCallback the optional callback when a response was received. 083 */ 084 AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) { 085 this.request = request; 086 this.udpPayloadSize = udpPayloadSize; 087 this.asyncNds = asyncNds; 088 this.onResponseCallback = onResponseCallback; 089 090 final QueryMode queryMode = asyncNds.getQueryMode(); 091 switch (queryMode) { 092 case dontCare: 093 case udpTcp: 094 skipUdp = false; 095 break; 096 case tcp: 097 skipUdp = true; 098 break; 099 default: 100 throw new IllegalStateException("Unsupported query mode: " + queryMode); 101 102 } 103 deadline = System.currentTimeMillis() + asyncNds.getTimeout(); 104 socketAddress = new InetSocketAddress(inetAddress, port); 105 } 106 107 private void ensureWriteBufferIsInitialized() { 108 if (writeBuffer != null) { 109 if (!writeBuffer.hasRemaining()) { 110 writeBuffer.rewind(); 111 } 112 return; 113 } 114 writeBuffer = request.getInByteBuffer(); 115 } 116 117 private synchronized void cancelAsyncDnsRequest() { 118 if (selectionKey != null) { 119 selectionKey.cancel(); 120 } 121 asyncNds.cancelled(this); 122 } 123 124 private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler) 125 throws ClosedChannelException { 126 if (future.isCancelled()) { 127 return; 128 } 129 selectionKey = asyncNds.registerWithSelector(channel, ops, handler); 130 } 131 132 private void addException(IOException e) { 133 if (exceptions == null) { 134 exceptions = new ArrayList<>(4); 135 } 136 exceptions.add(e); 137 } 138 139 private void gotResult(DnsQueryResult result) { 140 if (onResponseCallback != null) { 141 onResponseCallback.onResponse(request, result); 142 } 143 asyncNds.finished(this); 144 future.setResult(result); 145 } 146 147 MiniDnsFuture<DnsQueryResult, IOException> getFuture() { 148 return future; 149 } 150 151 boolean wasDeadlineMissedAndFutureNotified() { 152 if (System.currentTimeMillis() < deadline) { 153 return false; 154 } 155 156 future.setException(new IOException("Timeout")); 157 return true; 158 } 159 160 void startHandling() { 161 if (!skipUdp) { 162 startUdpRequest(); 163 } else { 164 startTcpRequest(); 165 } 166 } 167 168 private void abortRequestAndCleanup(Channel channel, String errorMessage, IOException exception) { 169 if (exception == null) { 170 // TODO: Can this case be removed? Is 'exception' ever null? 171 LOGGER.info("Exception was null in abortRequestAndCleanup()"); 172 exception = new IOException(errorMessage); 173 } 174 LOGGER.log(Level.SEVERE, "Error connecting " + channel + ": " + errorMessage, exception); 175 addException(exception); 176 177 if (selectionKey != null) { 178 selectionKey.cancel(); 179 } 180 181 if (channel != null && channel.isOpen()) { 182 try { 183 channel.close(); 184 } catch (IOException e) { 185 LOGGER.log(Level.SEVERE, "Exception closing socket channel", e); 186 addException(e); 187 } 188 } 189 } 190 191 private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) { 192 abortRequestAndCleanup(datagramChannel, errorMessage, exception); 193 startTcpRequest(); 194 } 195 196 private void startUdpRequest() { 197 if (future.isCancelled()) { 198 return; 199 } 200 201 DatagramChannel datagramChannel; 202 try { 203 datagramChannel = DatagramChannel.open(); 204 } catch (IOException e) { 205 LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e); 206 addException(e); 207 startTcpRequest(); 208 return; 209 } 210 211 try { 212 datagramChannel.configureBlocking(false); 213 } catch (IOException e) { 214 abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e); 215 return; 216 } 217 218 try { 219 datagramChannel.connect(socketAddress); 220 } catch (IOException e) { 221 abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel to " + socketAddress, e); 222 return; 223 } 224 225 try { 226 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future)); 227 } catch (ClosedChannelException e) { 228 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 229 return; 230 } 231 } 232 233 class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler { 234 235 UdpWritableChannelSelectedHandler(Future<?> future) { 236 super(future); 237 } 238 239 @Override 240 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 241 DatagramChannel datagramChannel = (DatagramChannel) channel; 242 243 ensureWriteBufferIsInitialized(); 244 245 try { 246 datagramChannel.write(writeBuffer); 247 } catch (IOException e) { 248 abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e); 249 return; 250 } 251 252 if (writeBuffer.hasRemaining()) { 253 try { 254 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this); 255 } catch (ClosedChannelException e) { 256 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 257 } 258 return; 259 } 260 261 try { 262 registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future)); 263 } catch (ClosedChannelException e) { 264 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e); 265 return; 266 } 267 } 268 269 } 270 271 class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler { 272 273 UdpReadableChannelSelectedHandler(Future<?> future) { 274 super(future); 275 } 276 277 final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize); 278 279 @Override 280 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 281 DatagramChannel datagramChannel = (DatagramChannel) channel; 282 283 try { 284 datagramChannel.read(byteBuffer); 285 } catch (IOException e) { 286 abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e); 287 return; 288 } 289 290 selectionKey.cancel(); 291 try { 292 datagramChannel.close(); 293 } catch (IOException e) { 294 LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e); 295 addException(e); 296 } 297 298 DnsMessage response; 299 try { 300 response = new DnsMessage(byteBuffer.array()); 301 } catch (IOException e) { 302 abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e); 303 return; 304 } 305 306 if (response.id != request.id) { 307 addException(new MiniDnsException.IdMismatch(request, response)); 308 startTcpRequest(); 309 return; 310 } 311 312 if (response.truncated) { 313 startTcpRequest(); 314 return; 315 } 316 317 DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(), 318 QueryMethod.asyncUdp, request, response); 319 gotResult(result); 320 } 321 } 322 323 private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) { 324 abortRequestAndCleanup(socketChannel, errorMessage, exception); 325 future.setException(MultipleIoException.toIOException(exceptions)); 326 } 327 328 private void startTcpRequest() { 329 SocketChannel socketChannel = null; 330 try { 331 socketChannel = SocketChannel.open(); 332 } catch (IOException e) { 333 abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e); 334 return; 335 } 336 337 try { 338 socketChannel.configureBlocking(false); 339 } catch (IOException e) { 340 abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e); 341 return; 342 } 343 344 try { 345 registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future)); 346 } catch (ClosedChannelException e) { 347 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e); 348 return; 349 } 350 351 try { 352 socketChannel.connect(socketAddress); 353 } catch (IOException e) { 354 abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel to " + socketAddress, e); 355 return; 356 } 357 } 358 359 class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler { 360 361 TcpConnectedChannelSelectedHandler(Future<?> future) { 362 super(future); 363 } 364 365 @Override 366 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 367 SocketChannel socketChannel = (SocketChannel) channel; 368 369 boolean connected; 370 try { 371 connected = socketChannel.finishConnect(); 372 } catch (IOException e) { 373 abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e); 374 return; 375 } 376 377 assert connected; 378 379 try { 380 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future)); 381 } catch (ClosedChannelException e) { 382 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 383 return; 384 } 385 } 386 387 } 388 389 class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler { 390 391 TcpWritableChannelSelectedHandler(Future<?> future) { 392 super(future); 393 } 394 395 /** 396 * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message. 397 */ 398 private ByteBuffer[] writeBuffers; 399 400 @Override 401 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 402 SocketChannel socketChannel = (SocketChannel) channel; 403 404 if (writeBuffers == null) { 405 ensureWriteBufferIsInitialized(); 406 407 ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 408 int messageLength = writeBuffer.capacity(); 409 assert messageLength <= Short.MAX_VALUE; 410 messageLengthByteBuffer.putShort((short) (messageLength & 0xffff)); 411 messageLengthByteBuffer.rewind(); 412 413 writeBuffers = new ByteBuffer[2]; 414 writeBuffers[0] = messageLengthByteBuffer; 415 writeBuffers[1] = writeBuffer; 416 } 417 418 try { 419 socketChannel.write(writeBuffers); 420 } catch (IOException e) { 421 abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e); 422 return; 423 } 424 425 if (moreToWrite()) { 426 try { 427 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this); 428 } catch (ClosedChannelException e) { 429 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 430 } 431 return; 432 } 433 434 try { 435 registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future)); 436 } catch (ClosedChannelException e) { 437 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 438 return; 439 } 440 } 441 442 private boolean moreToWrite() { 443 for (int i = 0; i < writeBuffers.length; i++) { 444 if (writeBuffers[i].hasRemaining()) { 445 return true; 446 } 447 } 448 return false; 449 } 450 } 451 452 class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler { 453 454 TcpReadableChannelSelectedHandler(Future<?> future) { 455 super(future); 456 } 457 458 final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 459 460 ByteBuffer byteBuffer; 461 462 @Override 463 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 464 SocketChannel socketChannel = (SocketChannel) channel; 465 466 int bytesRead; 467 if (byteBuffer == null) { 468 try { 469 bytesRead = socketChannel.read(messageLengthByteBuffer); 470 } catch (IOException e) { 471 abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e); 472 return; 473 } 474 475 if (bytesRead < 0) { 476 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 477 return; 478 } 479 480 if (messageLengthByteBuffer.hasRemaining()) { 481 try { 482 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 483 } catch (ClosedChannelException e) { 484 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 485 } 486 return; 487 } 488 489 messageLengthByteBuffer.rewind(); 490 short messageLengthSignedShort = messageLengthByteBuffer.getShort(); 491 int messageLength = messageLengthSignedShort & 0xffff; 492 byteBuffer = ByteBuffer.allocate(messageLength); 493 } 494 495 try { 496 bytesRead = socketChannel.read(byteBuffer); 497 } catch (IOException e) { 498 throw new Error(e); 499 } 500 501 if (bytesRead < 0) { 502 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 503 return; 504 } 505 506 if (byteBuffer.hasRemaining()) { 507 try { 508 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 509 } catch (ClosedChannelException e) { 510 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 511 } 512 return; 513 } 514 515 selectionKey.cancel(); 516 try { 517 socketChannel.close(); 518 } catch (IOException e) { 519 addException(e); 520 } 521 522 DnsMessage response; 523 try { 524 response = new DnsMessage(byteBuffer.array()); 525 } catch (IOException e) { 526 abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e); 527 return; 528 } 529 530 if (request.id != response.id) { 531 MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response); 532 addException(idMismatchException); 533 AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions)); 534 return; 535 } 536 537 DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(), 538 QueryMethod.asyncTcp, request, response); 539 gotResult(result); 540 } 541 542 } 543 544}