001/* 002 * Copyright 2015-2024 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source.async; 012 013import java.io.IOException; 014import java.net.InetAddress; 015import java.net.InetSocketAddress; 016import java.nio.ByteBuffer; 017import java.nio.channels.Channel; 018import java.nio.channels.ClosedChannelException; 019import java.nio.channels.DatagramChannel; 020import java.nio.channels.SelectableChannel; 021import java.nio.channels.SelectionKey; 022import java.nio.channels.SocketChannel; 023import java.util.ArrayList; 024import java.util.List; 025import java.util.concurrent.Future; 026import java.util.logging.Level; 027import java.util.logging.Logger; 028 029import org.minidns.MiniDnsException; 030import org.minidns.MiniDnsFuture; 031import org.minidns.MiniDnsFuture.InternalMiniDnsFuture; 032import org.minidns.dnsmessage.DnsMessage; 033import org.minidns.dnsqueryresult.DnsQueryResult; 034import org.minidns.dnsqueryresult.DnsQueryResult.QueryMethod; 035import org.minidns.dnsqueryresult.StandardDnsQueryResult; 036import org.minidns.source.DnsDataSource.OnResponseCallback; 037import org.minidns.source.AbstractDnsDataSource.QueryMode; 038import org.minidns.util.MultipleIoException; 039 040/** 041 * A DNS request that is performed asynchronously. 042 */ 043public class AsyncDnsRequest { 044 045 private static final Logger LOGGER = Logger.getLogger(AsyncDnsRequest.class.getName()); 046 047 private final InternalMiniDnsFuture<DnsQueryResult, IOException> future = new InternalMiniDnsFuture<DnsQueryResult, IOException>() { 048 @SuppressWarnings("UnsynchronizedOverridesSynchronized") 049 @Override 050 public boolean cancel(boolean mayInterruptIfRunning) { 051 boolean res = super.cancel(mayInterruptIfRunning); 052 cancelAsyncDnsRequest(); 053 return res; 054 } 055 }; 056 057 private final DnsMessage request; 058 059 private final int udpPayloadSize; 060 061 private final InetSocketAddress socketAddress; 062 063 private final AsyncNetworkDataSource asyncNds; 064 065 private final OnResponseCallback onResponseCallback; 066 067 private final boolean skipUdp; 068 069 private ByteBuffer writeBuffer; 070 071 private List<IOException> exceptions; 072 073 private SelectionKey selectionKey; 074 075 final long deadline; 076 077 /** 078 * Creates a new AsyncDnsRequest instance. 079 * 080 * @param request the DNS message of the request. 081 * @param inetAddress The IP address of the DNS server to ask. 082 * @param port The port of the DNS server to ask. 083 * @param udpPayloadSize The configured UDP payload size. 084 * @param asyncNds A reference to the {@link AsyncNetworkDataSource} instance manageing the requests. 085 * @param onResponseCallback the optional callback when a response was received. 086 */ 087 AsyncDnsRequest(DnsMessage request, InetAddress inetAddress, int port, int udpPayloadSize, AsyncNetworkDataSource asyncNds, OnResponseCallback onResponseCallback) { 088 this.request = request; 089 this.udpPayloadSize = udpPayloadSize; 090 this.asyncNds = asyncNds; 091 this.onResponseCallback = onResponseCallback; 092 093 final QueryMode queryMode = asyncNds.getQueryMode(); 094 switch (queryMode) { 095 case dontCare: 096 case udpTcp: 097 skipUdp = false; 098 break; 099 case tcp: 100 skipUdp = true; 101 break; 102 default: 103 throw new IllegalStateException("Unsupported query mode: " + queryMode); 104 105 } 106 deadline = System.currentTimeMillis() + asyncNds.getTimeout(); 107 socketAddress = new InetSocketAddress(inetAddress, port); 108 } 109 110 private void ensureWriteBufferIsInitialized() { 111 if (writeBuffer != null) { 112 if (!writeBuffer.hasRemaining()) { 113 writeBuffer.rewind(); 114 } 115 return; 116 } 117 writeBuffer = request.getInByteBuffer(); 118 } 119 120 private synchronized void cancelAsyncDnsRequest() { 121 if (selectionKey != null) { 122 selectionKey.cancel(); 123 } 124 asyncNds.cancelled(this); 125 } 126 127 private synchronized void registerWithSelector(SelectableChannel channel, int ops, ChannelSelectedHandler handler) 128 throws ClosedChannelException { 129 if (future.isCancelled()) { 130 return; 131 } 132 selectionKey = asyncNds.registerWithSelector(channel, ops, handler); 133 } 134 135 private void addException(IOException e) { 136 if (exceptions == null) { 137 exceptions = new ArrayList<>(4); 138 } 139 exceptions.add(e); 140 } 141 142 private void gotResult(DnsQueryResult result) { 143 if (onResponseCallback != null) { 144 onResponseCallback.onResponse(request, result); 145 } 146 asyncNds.finished(this); 147 future.setResult(result); 148 } 149 150 MiniDnsFuture<DnsQueryResult, IOException> getFuture() { 151 return future; 152 } 153 154 boolean wasDeadlineMissedAndFutureNotified() { 155 if (System.currentTimeMillis() < deadline) { 156 return false; 157 } 158 159 future.setException(new IOException("Timeout")); 160 return true; 161 } 162 163 void startHandling() { 164 if (!skipUdp) { 165 startUdpRequest(); 166 } else { 167 startTcpRequest(); 168 } 169 } 170 171 private void abortRequestAndCleanup(Channel channel, String errorMessage, IOException exception) { 172 if (exception == null) { 173 // TODO: Can this case be removed? Is 'exception' ever null? 174 LOGGER.info("Exception was null in abortRequestAndCleanup()"); 175 exception = new IOException(errorMessage); 176 } 177 LOGGER.log(Level.SEVERE, "Error connecting " + channel + ": " + errorMessage, exception); 178 addException(exception); 179 180 if (selectionKey != null) { 181 selectionKey.cancel(); 182 } 183 184 if (channel != null && channel.isOpen()) { 185 try { 186 channel.close(); 187 } catch (IOException e) { 188 LOGGER.log(Level.SEVERE, "Exception closing socket channel", e); 189 addException(e); 190 } 191 } 192 } 193 194 private void abortUdpRequestAndCleanup(DatagramChannel datagramChannel, String errorMessage, IOException exception) { 195 abortRequestAndCleanup(datagramChannel, errorMessage, exception); 196 startTcpRequest(); 197 } 198 199 private void startUdpRequest() { 200 if (future.isCancelled()) { 201 return; 202 } 203 204 DatagramChannel datagramChannel; 205 try { 206 datagramChannel = DatagramChannel.open(); 207 } catch (IOException e) { 208 LOGGER.log(Level.SEVERE, "Exception opening datagram channel", e); 209 addException(e); 210 startTcpRequest(); 211 return; 212 } 213 214 try { 215 datagramChannel.configureBlocking(false); 216 } catch (IOException e) { 217 abortUdpRequestAndCleanup(datagramChannel, "Exception configuring datagram channel", e); 218 return; 219 } 220 221 try { 222 datagramChannel.connect(socketAddress); 223 } catch (IOException e) { 224 abortUdpRequestAndCleanup(datagramChannel, "Exception connecting datagram channel to " + socketAddress, e); 225 return; 226 } 227 228 try { 229 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, new UdpWritableChannelSelectedHandler(future)); 230 } catch (ClosedChannelException e) { 231 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 232 return; 233 } 234 } 235 236 class UdpWritableChannelSelectedHandler extends ChannelSelectedHandler { 237 238 UdpWritableChannelSelectedHandler(Future<?> future) { 239 super(future); 240 } 241 242 @Override 243 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 244 DatagramChannel datagramChannel = (DatagramChannel) channel; 245 246 ensureWriteBufferIsInitialized(); 247 248 try { 249 datagramChannel.write(writeBuffer); 250 } catch (IOException e) { 251 abortUdpRequestAndCleanup(datagramChannel, "Exception writing to datagram channel", e); 252 return; 253 } 254 255 if (writeBuffer.hasRemaining()) { 256 try { 257 registerWithSelector(datagramChannel, SelectionKey.OP_WRITE, this); 258 } catch (ClosedChannelException e) { 259 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_WRITE", e); 260 } 261 return; 262 } 263 264 try { 265 registerWithSelector(datagramChannel, SelectionKey.OP_READ, new UdpReadableChannelSelectedHandler(future)); 266 } catch (ClosedChannelException e) { 267 abortUdpRequestAndCleanup(datagramChannel, "Exception registering datagram channel for OP_READ", e); 268 return; 269 } 270 } 271 272 } 273 274 class UdpReadableChannelSelectedHandler extends ChannelSelectedHandler { 275 276 UdpReadableChannelSelectedHandler(Future<?> future) { 277 super(future); 278 } 279 280 final ByteBuffer byteBuffer = ByteBuffer.allocate(udpPayloadSize); 281 282 @Override 283 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 284 DatagramChannel datagramChannel = (DatagramChannel) channel; 285 286 try { 287 datagramChannel.read(byteBuffer); 288 } catch (IOException e) { 289 abortUdpRequestAndCleanup(datagramChannel, "Exception reading from datagram channel", e); 290 return; 291 } 292 293 selectionKey.cancel(); 294 try { 295 datagramChannel.close(); 296 } catch (IOException e) { 297 LOGGER.log(Level.SEVERE, "Exception closing datagram channel", e); 298 addException(e); 299 } 300 301 DnsMessage response; 302 try { 303 response = new DnsMessage(byteBuffer.array()); 304 } catch (IOException e) { 305 abortUdpRequestAndCleanup(datagramChannel, "Exception constructing dns message from datagram channel", e); 306 return; 307 } 308 309 if (response.id != request.id) { 310 addException(new MiniDnsException.IdMismatch(request, response)); 311 startTcpRequest(); 312 return; 313 } 314 315 if (response.truncated) { 316 startTcpRequest(); 317 return; 318 } 319 320 DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(), 321 QueryMethod.asyncUdp, request, response); 322 gotResult(result); 323 } 324 } 325 326 private void abortTcpRequestAndCleanup(SocketChannel socketChannel, String errorMessage, IOException exception) { 327 abortRequestAndCleanup(socketChannel, errorMessage, exception); 328 future.setException(MultipleIoException.toIOException(exceptions)); 329 } 330 331 private void startTcpRequest() { 332 SocketChannel socketChannel = null; 333 try { 334 socketChannel = SocketChannel.open(); 335 } catch (IOException e) { 336 abortTcpRequestAndCleanup(socketChannel, "Exception opening socket channel", e); 337 return; 338 } 339 340 try { 341 socketChannel.configureBlocking(false); 342 } catch (IOException e) { 343 abortTcpRequestAndCleanup(socketChannel, "Exception configuring socket channel", e); 344 return; 345 } 346 347 try { 348 registerWithSelector(socketChannel, SelectionKey.OP_CONNECT, new TcpConnectedChannelSelectedHandler(future)); 349 } catch (ClosedChannelException e) { 350 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel", e); 351 return; 352 } 353 354 try { 355 socketChannel.connect(socketAddress); 356 } catch (IOException e) { 357 abortTcpRequestAndCleanup(socketChannel, "Exception connecting socket channel to " + socketAddress, e); 358 return; 359 } 360 } 361 362 class TcpConnectedChannelSelectedHandler extends ChannelSelectedHandler { 363 364 TcpConnectedChannelSelectedHandler(Future<?> future) { 365 super(future); 366 } 367 368 @Override 369 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 370 SocketChannel socketChannel = (SocketChannel) channel; 371 372 boolean connected; 373 try { 374 connected = socketChannel.finishConnect(); 375 } catch (IOException e) { 376 abortTcpRequestAndCleanup(socketChannel, "Exception finish connecting socket channel", e); 377 return; 378 } 379 380 assert connected; 381 382 try { 383 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, new TcpWritableChannelSelectedHandler(future)); 384 } catch (ClosedChannelException e) { 385 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 386 return; 387 } 388 } 389 390 } 391 392 class TcpWritableChannelSelectedHandler extends ChannelSelectedHandler { 393 394 TcpWritableChannelSelectedHandler(Future<?> future) { 395 super(future); 396 } 397 398 /** 399 * ByteBuffer array of length 2. First buffer is for the length of the DNS message, second one is the actual DNS message. 400 */ 401 private ByteBuffer[] writeBuffers; 402 403 @Override 404 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 405 SocketChannel socketChannel = (SocketChannel) channel; 406 407 if (writeBuffers == null) { 408 ensureWriteBufferIsInitialized(); 409 410 ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 411 int messageLength = writeBuffer.capacity(); 412 assert messageLength <= Short.MAX_VALUE; 413 messageLengthByteBuffer.putShort((short) (messageLength & 0xffff)); 414 messageLengthByteBuffer.rewind(); 415 416 writeBuffers = new ByteBuffer[2]; 417 writeBuffers[0] = messageLengthByteBuffer; 418 writeBuffers[1] = writeBuffer; 419 } 420 421 try { 422 socketChannel.write(writeBuffers); 423 } catch (IOException e) { 424 abortTcpRequestAndCleanup(socketChannel, "Exception writing to socket channel", e); 425 return; 426 } 427 428 if (moreToWrite()) { 429 try { 430 registerWithSelector(socketChannel, SelectionKey.OP_WRITE, this); 431 } catch (ClosedChannelException e) { 432 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_WRITE", e); 433 } 434 return; 435 } 436 437 try { 438 registerWithSelector(socketChannel, SelectionKey.OP_READ, new TcpReadableChannelSelectedHandler(future)); 439 } catch (ClosedChannelException e) { 440 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 441 return; 442 } 443 } 444 445 private boolean moreToWrite() { 446 for (int i = 0; i < writeBuffers.length; i++) { 447 if (writeBuffers[i].hasRemaining()) { 448 return true; 449 } 450 } 451 return false; 452 } 453 } 454 455 class TcpReadableChannelSelectedHandler extends ChannelSelectedHandler { 456 457 TcpReadableChannelSelectedHandler(Future<?> future) { 458 super(future); 459 } 460 461 final ByteBuffer messageLengthByteBuffer = ByteBuffer.allocate(2); 462 463 ByteBuffer byteBuffer; 464 465 @Override 466 public void handleChannelSelectedAndNotCancelled(SelectableChannel channel, SelectionKey selectionKey) { 467 SocketChannel socketChannel = (SocketChannel) channel; 468 469 int bytesRead; 470 if (byteBuffer == null) { 471 try { 472 bytesRead = socketChannel.read(messageLengthByteBuffer); 473 } catch (IOException e) { 474 abortTcpRequestAndCleanup(socketChannel, "Exception reading from socket channel", e); 475 return; 476 } 477 478 if (bytesRead < 0) { 479 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 480 return; 481 } 482 483 if (messageLengthByteBuffer.hasRemaining()) { 484 try { 485 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 486 } catch (ClosedChannelException e) { 487 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 488 } 489 return; 490 } 491 492 messageLengthByteBuffer.rewind(); 493 short messageLengthSignedShort = messageLengthByteBuffer.getShort(); 494 int messageLength = messageLengthSignedShort & 0xffff; 495 byteBuffer = ByteBuffer.allocate(messageLength); 496 } 497 498 try { 499 bytesRead = socketChannel.read(byteBuffer); 500 } catch (IOException e) { 501 throw new Error(e); 502 } 503 504 if (bytesRead < 0) { 505 abortTcpRequestAndCleanup(socketChannel, "Socket closed by remote host " + socketAddress, null); 506 return; 507 } 508 509 if (byteBuffer.hasRemaining()) { 510 try { 511 registerWithSelector(socketChannel, SelectionKey.OP_READ, this); 512 } catch (ClosedChannelException e) { 513 abortTcpRequestAndCleanup(socketChannel, "Exception registering socket channel for OP_READ", e); 514 } 515 return; 516 } 517 518 selectionKey.cancel(); 519 try { 520 socketChannel.close(); 521 } catch (IOException e) { 522 addException(e); 523 } 524 525 DnsMessage response; 526 try { 527 response = new DnsMessage(byteBuffer.array()); 528 } catch (IOException e) { 529 abortTcpRequestAndCleanup(socketChannel, "Exception creating DNS message form socket channel bytes", e); 530 return; 531 } 532 533 if (request.id != response.id) { 534 MiniDnsException idMismatchException = new MiniDnsException.IdMismatch(request, response); 535 addException(idMismatchException); 536 AsyncDnsRequest.this.future.setException(MultipleIoException.toIOException(exceptions)); 537 return; 538 } 539 540 DnsQueryResult result = new StandardDnsQueryResult(socketAddress.getAddress(), socketAddress.getPort(), 541 QueryMethod.asyncTcp, request, response); 542 gotResult(result); 543 } 544 545 } 546 547}