001/*
002 * Copyright 2015-2024 the original author or authors
003 *
004 * This software is licensed under the Apache License, Version 2.0,
005 * the GNU Lesser General Public License version 2 or later ("LGPL")
006 * and the WTFPL.
007 * You may choose either license to govern your use of this software only
008 * upon the condition that you accept all of the terms of either
009 * the Apache License 2.0, the LGPL 2.1+ or the WTFPL.
010 */
011package org.minidns.cache;
012
013import java.util.ArrayList;
014import java.util.HashMap;
015import java.util.List;
016import java.util.Map;
017import java.util.Map.Entry;
018
019import org.minidns.dnsmessage.DnsMessage;
020import org.minidns.dnsmessage.Question;
021import org.minidns.dnsname.DnsName;
022import org.minidns.dnsqueryresult.CachedDnsQueryResult;
023import org.minidns.dnsqueryresult.DnsQueryResult;
024import org.minidns.dnsqueryresult.SynthesizedCachedDnsQueryResult;
025import org.minidns.record.Data;
026import org.minidns.record.Record;
027
028/**
029 * A variant of {@link LruCache} also using the data found in the sections for caching.
030 */
031public class ExtendedLruCache extends LruCache {
032
033    public ExtendedLruCache() {
034        this(DEFAULT_CACHE_SIZE);
035    }
036
037    public ExtendedLruCache(int capacity) {
038        super(capacity);
039    }
040
041    public ExtendedLruCache(int capacity, long maxTTL) {
042        super(capacity, maxTTL);
043    }
044
045    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
046    @Override
047    protected void putNormalized(DnsMessage q, DnsQueryResult result) {
048        super.putNormalized(q, result);
049        DnsMessage message = result.response;
050        Map<DnsMessage, List<Record<? extends Data>>> extraCaches = new HashMap<>(message.additionalSection.size());
051
052        gather(extraCaches, q, message.answerSection, null);
053        gather(extraCaches, q, message.authoritySection, null);
054        gather(extraCaches, q, message.additionalSection, null);
055
056        putExtraCaches(result, extraCaches);
057    }
058
059    @Override
060    public void offer(DnsMessage query, DnsQueryResult result, DnsName authoritativeZone) {
061        DnsMessage reply = result.response;
062        // The reply shouldn't be an authoritative answers when offer() is used. That would be a case for put().
063        assert !reply.authoritativeAnswer;
064
065        Map<DnsMessage, List<Record<? extends Data>>> extraCaches = new HashMap<>(reply.additionalSection.size());
066
067        // N.B. not gathering from reply.answerSection here. Since it is a non authoritativeAnswer it shouldn't contain anything.
068        gather(extraCaches, query, reply.authoritySection, authoritativeZone);
069        gather(extraCaches, query, reply.additionalSection, authoritativeZone);
070
071        putExtraCaches(result, extraCaches);
072    }
073
074    private void gather(Map<DnsMessage, List<Record<?extends Data>>> extraCaches, DnsMessage q, List<Record<? extends Data>> records, DnsName authoritativeZone) {
075        for (Record<? extends Data> extraRecord : records) {
076            if (!shouldGather(extraRecord, q.getQuestion(), authoritativeZone))
077                continue;
078
079            DnsMessage.Builder additionalRecordQuestionBuilder = extraRecord.getQuestionMessage();
080            if (additionalRecordQuestionBuilder == null)
081                continue;
082
083            additionalRecordQuestionBuilder.copyFlagsFrom(q);
084
085            additionalRecordQuestionBuilder.setAdditionalResourceRecords(q.additionalSection);
086
087            DnsMessage additionalRecordQuestion = additionalRecordQuestionBuilder.build();
088            if (additionalRecordQuestion.equals(q)) {
089                // No need to cache the additional question if it is the same as the original question.
090                continue;
091            }
092
093            List<Record<? extends Data>> additionalRecords = extraCaches.get(additionalRecordQuestion);
094            if (additionalRecords == null) {
095                 additionalRecords = new ArrayList<>();
096                 extraCaches.put(additionalRecordQuestion, additionalRecords);
097            }
098            additionalRecords.add(extraRecord);
099        }
100    }
101
102    private void putExtraCaches(DnsQueryResult synthesynthesizationSource, Map<DnsMessage, List<Record<? extends Data>>> extraCaches) {
103        DnsMessage reply = synthesynthesizationSource.response;
104        for (Entry<DnsMessage, List<Record<? extends Data>>> entry : extraCaches.entrySet()) {
105            DnsMessage question = entry.getKey();
106            DnsMessage answer = reply.asBuilder()
107                    .setQuestion(question.getQuestion())
108                    .setAuthoritativeAnswer(true)
109                    .addAnswers(entry.getValue())
110                    .build();
111            CachedDnsQueryResult cachedDnsQueryResult = new SynthesizedCachedDnsQueryResult(question, answer, synthesynthesizationSource);
112            synchronized (this) {
113                backend.put(question, cachedDnsQueryResult);
114            }
115        }
116    }
117
118    protected boolean shouldGather(Record<? extends Data> extraRecord, Question question, DnsName authoritativeZone) {
119        boolean extraRecordIsChildOfQuestion = extraRecord.name.isChildOf(question.name);
120
121        boolean extraRecordIsChildOfAuthoritativeZone = false;
122        if (authoritativeZone != null) {
123            extraRecordIsChildOfAuthoritativeZone = extraRecord.name.isChildOf(authoritativeZone);
124        }
125
126        return extraRecordIsChildOfQuestion || extraRecordIsChildOfAuthoritativeZone;
127    }
128
129}