001/*
002 * Copyright 2015-2018 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.cache;
012
013import java.util.HashMap;
014import java.util.LinkedList;
015import java.util.List;
016import java.util.Map;
017import java.util.Map.Entry;
018
019import org.minidns.dnsmessage.DnsMessage;
020import org.minidns.dnsmessage.Question;
021import org.minidns.dnsname.DnsName;
022import org.minidns.record.Data;
023import org.minidns.record.Record;
024
025/**
026 * A variant of {@link LruCache} also using the data found in the sections for caching.
027 */
028public class ExtendedLruCache extends LruCache {
029
030    public ExtendedLruCache() {
031        this(DEFAULT_CACHE_SIZE);
032    }
033
034    public ExtendedLruCache(int capacity) {
035        super(capacity);
036    }
037
038    public ExtendedLruCache(int capacity, long maxTTL) {
039        super(capacity, maxTTL);
040    }
041
042    @Override
043    protected void putNormalized(DnsMessage q, DnsMessage message) {
044        super.putNormalized(q, message);
045        Map<DnsMessage, List<Record<? extends Data>>> extraCaches = new HashMap<>(message.additionalSection.size());
046
047        gather(extraCaches, q, message.answerSection, null);
048        gather(extraCaches, q, message.authoritySection, null);
049        gather(extraCaches, q, message.additionalSection, null);
050
051        putExtraCaches(message, extraCaches);
052    }
053
054    @Override
055    public void offer(DnsMessage query, DnsMessage reply, DnsName authoritativeZone) {
056        // The reply shouldn't be an authoritative answers when offer() is used. That would be a case for put().
057        assert(!reply.authoritativeAnswer);
058
059        Map<DnsMessage, List<Record<? extends Data>>> extraCaches = new HashMap<>(reply.additionalSection.size());
060
061        // N.B. not gathering from reply.answerSection here. Since it is a non authoritativeAnswer it shouldn't contain anything.
062        gather(extraCaches, query, reply.authoritySection, authoritativeZone);
063        gather(extraCaches, query, reply.additionalSection, authoritativeZone);
064
065        putExtraCaches(reply, extraCaches);
066    }
067
068    private final void gather(Map<DnsMessage, List<Record<?extends Data>>> extraCaches, DnsMessage q, List<Record<? extends Data>> records, DnsName authoritativeZone) {
069        for (Record<? extends Data> extraRecord : records) {
070            if (!shouldGather(extraRecord, q.getQuestion(), authoritativeZone))
071                continue;
072
073            DnsMessage.Builder additionalRecordQuestionBuilder = extraRecord.getQuestionMessage();
074            if (additionalRecordQuestionBuilder == null)
075                continue;
076
077            additionalRecordQuestionBuilder.copyFlagsFrom(q);
078
079            additionalRecordQuestionBuilder.setAdditionalResourceRecords(q.additionalSection);
080
081            DnsMessage additionalRecordQuestion = additionalRecordQuestionBuilder.build();
082            if (additionalRecordQuestion.equals(q)) {
083                // No need to cache the additional question if it is the same as the original question.
084                continue;
085            }
086
087            List<Record<? extends Data>> additionalRecords = extraCaches.get(additionalRecordQuestion);
088            if (additionalRecords == null) {
089                 additionalRecords = new LinkedList<>();
090                 extraCaches.put(additionalRecordQuestion, additionalRecords);
091            }
092            additionalRecords.add(extraRecord);
093        }
094    }
095
096    private final void putExtraCaches(DnsMessage reply, Map<DnsMessage, List<Record<? extends Data>>> extraCaches) {
097        for (Entry<DnsMessage, List<Record<? extends Data>>> entry : extraCaches.entrySet()) {
098            DnsMessage question = entry.getKey();
099            DnsMessage answer = reply.asBuilder()
100                    .setQuestion(question.getQuestion())
101                    .setAuthoritativeAnswer(true)
102                    .addAnswers(entry.getValue())
103                    .build();
104            super.putNormalized(question, answer);
105        }
106    }
107
108    protected boolean shouldGather(Record<? extends Data> extraRecord, Question question, DnsName authoritativeZone) {
109        boolean extraRecordIsChildOfQuestion = extraRecord.name.isChildOf(question.name);
110
111        boolean extraRecordIsChildOfAuthoritativeZone = false;
112        if (authoritativeZone != null) {
113            extraRecordIsChildOfAuthoritativeZone = extraRecord.name.isChildOf(authoritativeZone);
114        }
115
116        return extraRecordIsChildOfQuestion || extraRecordIsChildOfAuthoritativeZone;
117    }
118
119}