001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.util;
012
013import java.util.ArrayList;
014import java.util.Collection;
015import java.util.Collections;
016import java.util.LinkedList;
017import java.util.List;
018import java.util.SortedMap;
019import java.util.TreeMap;
020
021import org.minidns.dnsname.DnsName;
022import org.minidns.record.SRV;
023
024public class SrvUtil {
025
026    /**
027     * Sort the given collection of {@link SRV} resource records by their priority and weight.
028     * <p>
029     * Sorting by priority is easy. Sorting the buckets of SRV records with the same priority by weight requires to choose those records
030     * randomly but taking the weight into account.
031     * </p>
032     *
033     * @param srvRecords
034     *            a collection of SRV records.
035     * @return a sorted list of the given records.
036     */
037    @SuppressWarnings({"MixedMutabilityReturnType", "JdkObsolete"})
038    public static List<SRV> sortSrvRecords(Collection<SRV> srvRecords) {
039        // RFC 2782, Usage rules: "If there is precisely one SRV RR, and its Target is "."
040        // (the root domain), abort."
041        if (srvRecords.size() == 1 && srvRecords.iterator().next().target.equals(DnsName.ROOT)) {
042            return Collections.emptyList();
043        }
044
045        // Create the priority buckets.
046        SortedMap<Integer, List<SRV>> buckets = new TreeMap<>();
047        for (SRV srvRecord : srvRecords) {
048            Integer priority = srvRecord.priority;
049            List<SRV> bucket = buckets.get(priority);
050            if (bucket == null) {
051                bucket = new LinkedList<>();
052                buckets.put(priority, bucket);
053            }
054            bucket.add(srvRecord);
055        }
056
057        List<SRV> sortedSrvRecords = new ArrayList<>(srvRecords.size());
058
059        for (List<SRV> bucket : buckets.values()) {
060            // The list of buckets will be sorted by priority, thanks to SortedMap. We now have determine the order of
061            // the SRV records with the same priority, i.e., within the same bucket, by their weight. This is done by
062            // creating an array 'totals' which reflects the percentage of the SRV RRs weight by the total weight of all
063            // SRV RRs in the bucket. For every entry in the bucket, we choose one using a random number and the sum of
064            // all weights left in the bucket. We then select RRs position based on the according index of the selected
065            // value in the 'total' array. This ensures that its weight is taken into account.
066            int bucketSize;
067            while ((bucketSize = bucket.size()) > 0) {
068                int[] totals = new int[bucketSize];
069
070                int zeroWeight = 1;
071                for (SRV srv : bucket) {
072                    if (srv.weight > 0) {
073                        zeroWeight = 0;
074                        break;
075                    }
076                }
077
078                int bucketWeightSum = 0, count = 0;
079                for (SRV srv : bucket) {
080                    bucketWeightSum += srv.weight + zeroWeight;
081                    totals[count++] = bucketWeightSum;
082                }
083
084                int selectedPosition;
085                if (bucketWeightSum == 0) {
086                    // If total priority is 0, then the sum of all weights in this priority bucket is 0. So we simply
087                    // select one of the weights randomly as the other algorithm performed in the else block is unable
088                    // to handle this case.
089                    selectedPosition = (int) (Math.random() * bucketSize);
090                } else {
091                    double rnd = Math.random() * bucketWeightSum;
092                    selectedPosition = bisect(totals, rnd);
093                }
094
095                SRV choosenSrvRecord = bucket.remove(selectedPosition);
096                sortedSrvRecords.add(choosenSrvRecord);
097            }
098        }
099
100        return sortedSrvRecords;
101    }
102
103    // TODO This is not yet really bisection just a stupid linear search.
104    private static int bisect(int[] array, double value) {
105        int pos = 0;
106        for (int element : array) {
107            if (value < element)
108                break;
109            pos++;
110        }
111        return pos;
112    }
113
114}