001/*
002 * Copyright 2015-2018 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source;
012
013import org.minidns.MiniDnsException;
014import org.minidns.dnsmessage.DnsMessage;
015import org.minidns.util.MultipleIoException;
016
017import java.io.DataInputStream;
018import java.io.DataOutputStream;
019import java.io.IOException;
020import java.net.DatagramPacket;
021import java.net.DatagramSocket;
022import java.net.InetAddress;
023import java.net.InetSocketAddress;
024import java.net.Socket;
025import java.net.SocketAddress;
026import java.net.SocketException;
027import java.util.ArrayList;
028import java.util.List;
029import java.util.logging.Level;
030import java.util.logging.Logger;
031
032public class NetworkDataSource extends DnsDataSource {
033
034    protected static final Logger LOGGER = Logger.getLogger(NetworkDataSource.class.getName());
035
036    @Override
037    public DnsMessage query(DnsMessage message, InetAddress address, int port) throws IOException {
038        List<IOException> ioExceptions = new ArrayList<>(2);
039        DnsMessage dnsMessage = null;
040        final QueryMode queryMode = getQueryMode();
041        boolean doUdpFirst;
042        switch (queryMode) {
043        case dontCare:
044        case udpTcp:
045            doUdpFirst = true;
046            break;
047        case tcp:
048            doUdpFirst = false;
049            break;
050        default:
051            throw new IllegalStateException("Unsupported query mode: " + queryMode);
052        }
053
054        if (doUdpFirst) {
055            try {
056                dnsMessage = queryUdp(message, address, port);
057            } catch (IOException e) {
058                ioExceptions.add(e);
059            }
060
061            if (dnsMessage != null && !dnsMessage.truncated) {
062                return dnsMessage;
063            }
064
065            assert (dnsMessage == null || dnsMessage.truncated || ioExceptions.size() == 1);
066            LOGGER.log(Level.FINE, "Fallback to TCP because {0}",
067                    new Object[] { dnsMessage != null ? "response is truncated" : ioExceptions.get(0) });
068        }
069
070        try {
071            dnsMessage = queryTcp(message, address, port);
072        } catch (IOException e) {
073            ioExceptions.add(e);
074            MultipleIoException.throwIfRequired(ioExceptions);
075        }
076
077        return dnsMessage;
078    }
079
080    protected DnsMessage queryUdp(DnsMessage message, InetAddress address, int port) throws IOException {
081        // TODO Use a try-with-resource statement here once miniDNS minimum
082        // required Android API level is >= 19
083        DatagramSocket socket = null;
084        DatagramPacket packet = message.asDatagram(address, port);
085        byte[] buffer = new byte[udpPayloadSize];
086        try {
087            socket = createDatagramSocket();
088            socket.setSoTimeout(timeout);
089            socket.send(packet);
090            packet = new DatagramPacket(buffer, buffer.length);
091            socket.receive(packet);
092            DnsMessage dnsMessage = new DnsMessage(packet.getData());
093            if (dnsMessage.id != message.id) {
094                throw new MiniDnsException.IdMismatch(message, dnsMessage);
095            }
096            return dnsMessage;
097        } finally {
098            if (socket != null) {
099                socket.close();
100            }
101        }
102    }
103
104    protected DnsMessage queryTcp(DnsMessage message, InetAddress address, int port) throws IOException {
105        // TODO Use a try-with-resource statement here once miniDNS minimum
106        // required Android API level is >= 19
107        Socket socket = null;
108        try {
109            socket = createSocket();
110            SocketAddress socketAddress = new InetSocketAddress(address, port);
111            socket.connect(socketAddress, timeout);
112            socket.setSoTimeout(timeout);
113            DataOutputStream dos = new DataOutputStream(socket.getOutputStream());
114            message.writeTo(dos);
115            dos.flush();
116            DataInputStream dis = new DataInputStream(socket.getInputStream());
117            int length = dis.readUnsignedShort();
118            byte[] data = new byte[length];
119            int read = 0;
120            while (read < length) {
121                read += dis.read(data, read, length-read);
122            }
123            DnsMessage dnsMessage = new DnsMessage(data);
124            if (dnsMessage.id != message.id) {
125                throw new MiniDnsException.IdMismatch(message, dnsMessage);
126            }
127            return dnsMessage;
128        } finally {
129            if (socket != null) {
130                socket.close();
131            }
132        }
133    }
134
135    /**
136     * Create a {@link Socket} using the system default {@link javax.net.SocketFactory}.
137     *
138     * @return The new {@link Socket} instance
139     */
140    protected Socket createSocket() {
141        return new Socket();
142    }
143
144    /**
145     * Create a {@link DatagramSocket} using the system defaults.
146     *
147     * @return The new {@link DatagramSocket} instance
148     * @throws SocketException If creation of the {@link DatagramSocket} fails
149     */
150    protected DatagramSocket createDatagramSocket() throws SocketException {
151        return new DatagramSocket();
152    }
153}