001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.source;
012
013import java.io.IOException;
014import java.net.InetAddress;
015import java.util.Locale;
016import java.util.concurrent.atomic.AtomicInteger;
017
018import org.minidns.AbstractDnsClient;
019import org.minidns.dnsmessage.DnsMessage;
020import org.minidns.dnsqueryresult.StandardDnsQueryResult;
021
022public class NetworkDataSourceWithAccounting extends NetworkDataSource {
023
024    private final AtomicInteger successfulQueries = new AtomicInteger();
025    private final AtomicInteger responseSize = new AtomicInteger();
026    private final AtomicInteger failedQueries = new AtomicInteger();
027
028    private final AtomicInteger successfulUdpQueries = new AtomicInteger();
029    private final AtomicInteger udpResponseSize = new AtomicInteger();
030    private final AtomicInteger failedUdpQueries = new AtomicInteger();
031
032    private final AtomicInteger successfulTcpQueries = new AtomicInteger();
033    private final AtomicInteger tcpResponseSize = new AtomicInteger();
034    private final AtomicInteger failedTcpQueries = new AtomicInteger();
035
036    @Override
037    public StandardDnsQueryResult query(DnsMessage message, InetAddress address, int port) throws IOException {
038        StandardDnsQueryResult response;
039        try {
040            response = super.query(message, address, port);
041        } catch (IOException e) {
042            failedQueries.incrementAndGet();
043            throw e;
044        }
045
046        successfulQueries.incrementAndGet();
047        responseSize.addAndGet(response.response.toArray().length);
048
049        return response;
050    }
051
052    @Override
053    protected DnsMessage queryUdp(DnsMessage message, InetAddress address, int port) throws IOException {
054        DnsMessage response;
055        try {
056            response = super.queryUdp(message, address, port);
057        } catch (IOException e) {
058            failedUdpQueries.incrementAndGet();
059            throw e;
060        }
061
062        successfulUdpQueries.incrementAndGet();
063        udpResponseSize.addAndGet(response.toArray().length);
064
065        return response;
066    }
067
068    @Override
069    protected DnsMessage queryTcp(DnsMessage message, InetAddress address, int port) throws IOException {
070        DnsMessage response;
071        try {
072            response = super.queryTcp(message, address, port);
073        } catch (IOException e) {
074            failedTcpQueries.incrementAndGet();
075            throw e;
076        }
077
078        successfulTcpQueries.incrementAndGet();
079        tcpResponseSize.addAndGet(response.toArray().length);
080
081        return response;
082    }
083
084    public Stats getStats() {
085        return new Stats(this);
086    }
087
088    public static NetworkDataSourceWithAccounting from(AbstractDnsClient client) {
089        DnsDataSource ds = client.getDataSource();
090        if (ds instanceof NetworkDataSourceWithAccounting) {
091            return (NetworkDataSourceWithAccounting) ds;
092        }
093        return null;
094    }
095
096    public static final class Stats {
097        public final int successfulQueries;
098        public final int responseSize;
099        public final int averageResponseSize;
100        public final int failedQueries;
101
102        public final int successfulUdpQueries;
103        public final int udpResponseSize;
104        public final int averageUdpResponseSize;
105        public final int failedUdpQueries;
106
107        public final int successfulTcpQueries;
108        public final int tcpResponseSize;
109        public final int averageTcpResponseSize;
110        public final int failedTcpQueries;
111
112        private String stringCache;
113
114        private Stats(NetworkDataSourceWithAccounting ndswa) {
115            successfulQueries = ndswa.successfulQueries.get();
116            responseSize = ndswa.responseSize.get();
117            failedQueries = ndswa.failedQueries.get();
118
119            successfulUdpQueries = ndswa.successfulUdpQueries.get();
120            udpResponseSize = ndswa.udpResponseSize.get();
121            failedUdpQueries = ndswa.failedUdpQueries.get();
122
123            successfulTcpQueries = ndswa.successfulTcpQueries.get();
124            tcpResponseSize = ndswa.tcpResponseSize.get();
125            failedTcpQueries = ndswa.failedTcpQueries.get();
126
127            // Calculated stats section
128            averageResponseSize = successfulQueries > 0 ? responseSize / successfulQueries : 0;
129            averageUdpResponseSize = successfulUdpQueries > 0 ? udpResponseSize / successfulUdpQueries : 0;
130            averageTcpResponseSize = successfulTcpQueries > 0 ? tcpResponseSize / successfulTcpQueries : 0;
131        }
132
133        @Override
134        public String toString() {
135            if (stringCache != null)
136                return stringCache;
137
138            StringBuilder sb = new StringBuilder();
139
140            sb.append("Stats\t").append("# Successful").append('\t').append("# Failed").append('\t')
141                    .append("Resp. Size").append('\t').append("Avg. Resp. Size").append('\n');
142            sb.append("Total\t").append(toString(successfulQueries)).append('\t').append(toString(failedQueries))
143                    .append('\t').append(toString(responseSize)).append('\t').append(toString(averageResponseSize))
144                    .append('\n');
145            sb.append("UDP\t").append(toString(successfulUdpQueries)).append('\t').append(toString(failedUdpQueries))
146                    .append('\t').append(toString(udpResponseSize)).append('\t')
147                    .append(toString(averageUdpResponseSize)).append('\n');
148            sb.append("TCP\t").append(toString(successfulTcpQueries)).append('\t').append(toString(failedTcpQueries))
149                    .append('\t').append(toString(tcpResponseSize)).append('\t')
150                    .append(toString(averageTcpResponseSize)).append('\n');
151
152            stringCache = sb.toString();
153            return stringCache;
154        }
155
156        private static String toString(int i) {
157            return String.format(Locale.US, "%,09d", i);
158        }
159    }
160}