001/* 002 * Copyright 2015-2024 the original author or authors 003 * 004 * This software is licensed under the Apache License, Version 2.0, 005 * the GNU Lesser General Public License version 2 or later ("LGPL") 006 * and the WTFPL. 007 * You may choose either license to govern your use of this software only 008 * upon the condition that you accept all of the terms of either 009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL. 010 */ 011package org.minidns.source; 012 013import org.minidns.MiniDnsException; 014import org.minidns.dnsmessage.DnsMessage; 015import org.minidns.dnsqueryresult.DnsQueryResult.QueryMethod; 016import org.minidns.dnsqueryresult.StandardDnsQueryResult; 017import org.minidns.util.MultipleIoException; 018 019import java.io.DataInputStream; 020import java.io.DataOutputStream; 021import java.io.IOException; 022import java.net.DatagramPacket; 023import java.net.DatagramSocket; 024import java.net.InetAddress; 025import java.net.InetSocketAddress; 026import java.net.Socket; 027import java.net.SocketAddress; 028import java.net.SocketException; 029import java.util.ArrayList; 030import java.util.List; 031import java.util.logging.Level; 032import java.util.logging.Logger; 033 034public class NetworkDataSource extends AbstractDnsDataSource { 035 036 protected static final Logger LOGGER = Logger.getLogger(NetworkDataSource.class.getName()); 037 038 // TODO: Rename 'message' parameter to query. 039 @Override 040 public StandardDnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException { 041 final QueryMode queryMode = getQueryMode(); 042 boolean doUdpFirst; 043 switch (queryMode) { 044 case dontCare: 045 case udpTcp: 046 doUdpFirst = true; 047 break; 048 case tcp: 049 doUdpFirst = false; 050 break; 051 default: 052 throw new IllegalStateException("Unsupported query mode: " + queryMode); 053 } 054 055 List<IOException> ioExceptions = new ArrayList<>(2); 056 DnsMessage dnsMessage = null; 057 058 if (doUdpFirst) { 059 try { 060 dnsMessage = queryUdp(message, address, port); 061 } catch (IOException e) { 062 ioExceptions.add(e); 063 } 064 065 // TODO: This null check could probably be removed by now. 066 if (dnsMessage != null && !dnsMessage.truncated) { 067 return new StandardDnsQueryResult(address, port, QueryMethod.udp, message, dnsMessage); 068 } 069 070 assert dnsMessage == null || dnsMessage.truncated || ioExceptions.size() == 1; 071 LOGGER.log(Level.FINE, "Fallback to TCP because {0}", 072 new Object[] { dnsMessage != null ? "response is truncated" : ioExceptions.get(0) }); 073 } 074 075 try { 076 dnsMessage = queryTcp(message, address, port); 077 } catch (IOException e) { 078 ioExceptions.add(e); 079 MultipleIoException.throwIfRequired(ioExceptions); 080 } 081 082 return new StandardDnsQueryResult(address, port, QueryMethod.tcp, message, dnsMessage); 083 } 084 085 protected DnsMessage queryUdp(DnsMessage message, InetAddress address, int port) throws IOException { 086 // TODO Use a try-with-resource statement here once miniDNS minimum 087 // required Android API level is >= 19 088 DatagramSocket socket = null; 089 DatagramPacket packet = message.asDatagram(address, port); 090 byte[] buffer = new byte[udpPayloadSize]; 091 try { 092 socket = createDatagramSocket(); 093 socket.setSoTimeout(timeout); 094 socket.send(packet); 095 packet = new DatagramPacket(buffer, buffer.length); 096 socket.receive(packet); 097 DnsMessage dnsMessage = new DnsMessage(packet.getData()); 098 if (dnsMessage.id != message.id) { 099 throw new MiniDnsException.IdMismatch(message, dnsMessage); 100 } 101 return dnsMessage; 102 } finally { 103 if (socket != null) { 104 socket.close(); 105 } 106 } 107 } 108 109 protected DnsMessage queryTcp(DnsMessage message, InetAddress address, int port) throws IOException { 110 // TODO Use a try-with-resource statement here once miniDNS minimum 111 // required Android API level is >= 19 112 Socket socket = null; 113 try { 114 socket = createSocket(); 115 SocketAddress socketAddress = new InetSocketAddress(address, port); 116 socket.connect(socketAddress, timeout); 117 socket.setSoTimeout(timeout); 118 DataOutputStream dos = new DataOutputStream(socket.getOutputStream()); 119 message.writeTo(dos); 120 dos.flush(); 121 DataInputStream dis = new DataInputStream(socket.getInputStream()); 122 int length = dis.readUnsignedShort(); 123 byte[] data = new byte[length]; 124 int read = 0; 125 while (read < length) { 126 read += dis.read(data, read, length - read); 127 } 128 DnsMessage dnsMessage = new DnsMessage(data); 129 if (dnsMessage.id != message.id) { 130 throw new MiniDnsException.IdMismatch(message, dnsMessage); 131 } 132 return dnsMessage; 133 } finally { 134 if (socket != null) { 135 socket.close(); 136 } 137 } 138 } 139 140 /** 141 * Create a {@link Socket} using the system default {@link javax.net.SocketFactory}. 142 * 143 * @return The new {@link Socket} instance 144 */ 145 protected Socket createSocket() { 146 return new Socket(); 147 } 148 149 /** 150 * Create a {@link DatagramSocket} using the system defaults. 151 * 152 * @return The new {@link DatagramSocket} instance 153 * @throws SocketException If creation of the {@link DatagramSocket} fails 154 */ 155 protected DatagramSocket createDatagramSocket() throws SocketException { 156 return new DatagramSocket(); 157 } 158}