View Javadoc
1   /*
2   * The contents of this file are subject to the terms of the Common Development and
3   * Distribution License (the License). You may not use this file except in compliance with the
4   * License.
5   *
6   * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
7   * specific language governing permission and limitations under the License.
8   *
9   * When distributing Covered Software, include this CDDL Header Notice in each file and include
10  * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
11  * Header, with the fields enclosed by brackets [] replaced by your own identifying
12  * information: "Portions copyright [year] [name of copyright owner]".
13  *
14  * Copyright 2014-2017 ForgeRock AS.
15  */
16  package org.forgerock.json.jose.jwk.store;
17  
18  
19  import java.net.URL;
20  import java.util.concurrent.TimeUnit;
21  
22  import org.forgerock.json.jose.exceptions.FailedToLoadJWKException;
23  import org.forgerock.json.jose.jwk.JWK;
24  import org.forgerock.json.jose.jwk.JWKSet;
25  import org.forgerock.json.jose.jwk.JWKSetParser;
26  import org.forgerock.json.jose.jwk.KeyUse;
27  import org.forgerock.json.jose.jwt.Algorithm;
28  import org.forgerock.util.Reject;
29  import org.forgerock.util.SimpleHTTPClient;
30  import org.forgerock.util.time.Duration;
31  import org.slf4j.Logger;
32  import org.slf4j.LoggerFactory;
33  
34  /** Store JWKs into a jwkSet from a JWKs_URI and refresh the jwkSet when necessary. */
35  public class JwksStore {
36      private static final Logger logger = LoggerFactory.getLogger(JwksStore.class);
37  
38      private final String uid;
39      private final JWKSetParser jwkParser;
40  
41      /** To prevent attackers reloading the cache too often. */
42      private long cacheMissCacheTimeInMs;
43      private long cacheTimeoutInMs;
44      private URL jwkUrl;
45  
46      private JWKSet jwksSet;
47      private long lastReloadJwksSet;
48  
49      /**
50       * Create a new JWKs store.
51       *
52       * @param uid the unique identifier for this store
53       * @param cacheTimeout a cache timeout to avoid reloading the cache all the time when doing encryption
54       * @param cacheMissCacheTime the cache time before reload the cache in case of a cache miss.
55       *                           This avoid polling the client application too often.
56       * @param jwkUrl the jwk url of the JWKs hosted by the client application
57       * @param httpClient The http client through which we will attempt to read the jwkUrl
58       * @throws FailedToLoadJWKException if the jwks can't be reloaded.
59       */
60      JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
61              final URL jwkUrl, final SimpleHTTPClient httpClient) throws FailedToLoadJWKException {
62          this(uid, cacheTimeout, cacheMissCacheTime, jwkUrl, new JWKSetParser(httpClient));
63      }
64  
65      /**
66       * Create a new JWKs store.
67       *
68       * @param uid the unique identifier for this store
69       * @param cacheTimeout a cache timeout to avoid reloading the cache all the time when doing encryption
70       * @param cacheMissCacheTime the cache time before reload the cache in case of a cache miss.
71       *                           This avoid polling the client application too often.
72       * @param jwkUrl the jwk url  of the JWKs hosted by the client application
73       * @param jwkSetParser the jwks set parser
74       * @throws FailedToLoadJWKException if the jwks can't be reloaded.
75       */
76      JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
77              final URL jwkUrl, JWKSetParser jwkSetParser) throws FailedToLoadJWKException {
78          this.uid = uid;
79          this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
80          this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
81          this.jwkUrl = jwkUrl;
82          this.jwkParser = jwkSetParser;
83  
84          try {
85              reloadJwks();
86          } catch (FailedToLoadJWKException e) {
87              logger.debug("Unable to load keys from the JWK over HTTP");
88              throw new FailedToLoadJWKException("Unable to load keys from the JWK over HTTP", e);
89          }
90      }
91  
92      /**
93       * Communicates with the configured server, attempting to download the latest JWKs for use.
94       *
95       * @throws FailedToLoadJWKException if there were issues parsing the supplied URL
96       */
97      private synchronized void reloadJwks() throws FailedToLoadJWKException {
98          jwksSet = jwkParser.jwkSet(jwkUrl);
99          lastReloadJwksSet = System.currentTimeMillis();
100     }
101 
102     /**
103      * Search for a JWK that matches the algorithm and the key usage.
104      *
105      * @param algorithm the algorithm needed
106      * @param keyUse the key usage. If null, only the algorithm will be used as a search criteria.
107      * @return A jwk that matches the search criteria. If no JWK found for the key usage, then it searches for a JWK
108      * without key usage defined. If still no JWK found, then returns null.
109      * @throws FailedToLoadJWKException if the jwks can't be reloaded.
110      */
111     public JWK findJwk(Algorithm algorithm, KeyUse keyUse) throws FailedToLoadJWKException {
112         if (keyUse == KeyUse.ENC && hasJwksCacheTimedOut()) {
113             reloadJwks();
114         }
115 
116         JWK jwk = jwksSet.findJwk(algorithm, keyUse);
117         if (jwk == null && isCacheMissCacheTimeExpired()) {
118             reloadJwks();
119             return jwksSet.findJwk(algorithm, keyUse);
120         }
121         return jwk;
122     }
123 
124     /**
125      * Search for a JWK that matches the kid.
126      *
127      * @param kid Key ID
128      * @return A jwk that matches the kid. If no JWK found, returns null
129      * @throws FailedToLoadJWKException if the jwks can't be reloaded.
130      */
131     public JWK findJwk(String kid) throws FailedToLoadJWKException {
132         JWK jwk = jwksSet.findJwk(kid);
133         if (jwk == null && isCacheMissCacheTimeExpired()) {
134             reloadJwks();
135             return jwksSet.findJwk(kid);
136         }
137         return jwk;
138     }
139 
140     /**
141      * Get the UID.
142      * @return the uid.
143      */
144     public String getUid() {
145         return uid;
146     }
147 
148     /**
149      * Get the cache timeout.
150      * @return the cache timeout.
151      */
152     public Duration getCacheTimeout() {
153         return Duration.duration(cacheTimeoutInMs, TimeUnit.MILLISECONDS);
154     }
155 
156     /**
157      * Get the cache time before reload the cache in case of cache miss.
158      * @return the cache miss cache time.
159      */
160     public Duration getCacheMissCacheTime() {
161         return Duration.duration(cacheMissCacheTimeInMs, TimeUnit.MILLISECONDS);
162     }
163 
164     /**
165      * The JWKs URI.
166      * @return the jwk uri.
167      */
168     public URL getJwkUrl() {
169         return jwkUrl;
170     }
171 
172     /**
173      * Update the cache timeout.
174      * @param cacheTimeout the cache timeout.
175      */
176     public void setCacheTimeout(Duration cacheTimeout) {
177         this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
178     }
179 
180     /**
181      * Update the cache time before reload the cache in case of cache miss.
182      * @param cacheMissCacheTime the cache miss cache time.
183      */
184     public void setCacheMissCacheTime(Duration cacheMissCacheTime) {
185         this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
186     }
187 
188     /**
189      * Update the JWKs URI.
190      * @param jwkUrl the jwks uri.
191      * @throws FailedToLoadJWKException If the URI has changed and the JWK set cannot be loaded.
192      */
193     public void setJwkUrl(URL jwkUrl) throws FailedToLoadJWKException {
194         Reject.ifNull(jwkUrl);
195         URL originalJwkUrl = this.jwkUrl;
196         this.jwkUrl = jwkUrl;
197         if (!jwkUrl.equals(originalJwkUrl)) {
198             reloadJwks();
199         }
200     }
201 
202     private boolean hasJwksCacheTimedOut() {
203         return (System.currentTimeMillis() - lastReloadJwksSet) > cacheTimeoutInMs;
204     }
205 
206     /**
207      * When we have a cache miss, we don't refresh the cache straight away. We check first if the cache miss cache
208      * time is expired out or not
209      * @return true is we  can reload the cache
210      */
211     private boolean isCacheMissCacheTimeExpired() {
212         return (System.currentTimeMillis() - lastReloadJwksSet) >= cacheMissCacheTimeInMs;
213     }
214 }