JwksStore.java
/*
* The contents of this file are subject to the terms of the Common Development and
* Distribution License (the License). You may not use this file except in compliance with the
* License.
*
* You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
* specific language governing permission and limitations under the License.
*
* When distributing Covered Software, include this CDDL Header Notice in each file and include
* the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
* Header, with the fields enclosed by brackets [] replaced by your own identifying
* information: "Portions copyright [year] [name of copyright owner]".
*
* Copyright 2014-2017 ForgeRock AS.
*/
package org.forgerock.json.jose.jwk.store;
import java.net.URL;
import java.util.concurrent.TimeUnit;
import org.forgerock.json.jose.exceptions.FailedToLoadJWKException;
import org.forgerock.json.jose.jwk.JWK;
import org.forgerock.json.jose.jwk.JWKSet;
import org.forgerock.json.jose.jwk.JWKSetParser;
import org.forgerock.json.jose.jwk.KeyUse;
import org.forgerock.json.jose.jwt.Algorithm;
import org.forgerock.util.Reject;
import org.forgerock.util.SimpleHTTPClient;
import org.forgerock.util.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Store JWKs into a jwkSet from a JWKs_URI and refresh the jwkSet when necessary. */
public class JwksStore {
private static final Logger logger = LoggerFactory.getLogger(JwksStore.class);
private final String uid;
private final JWKSetParser jwkParser;
/** To prevent attackers reloading the cache too often. */
private long cacheMissCacheTimeInMs;
private long cacheTimeoutInMs;
private URL jwkUrl;
private JWKSet jwksSet;
private long lastReloadJwksSet;
/**
* Create a new JWKs store.
*
* @param uid the unique identifier for this store
* @param cacheTimeout a cache timeout to avoid reloading the cache all the time when doing encryption
* @param cacheMissCacheTime the cache time before reload the cache in case of a cache miss.
* This avoid polling the client application too often.
* @param jwkUrl the jwk url of the JWKs hosted by the client application
* @param httpClient The http client through which we will attempt to read the jwkUrl
* @throws FailedToLoadJWKException if the jwks can't be reloaded.
*/
JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
final URL jwkUrl, final SimpleHTTPClient httpClient) throws FailedToLoadJWKException {
this(uid, cacheTimeout, cacheMissCacheTime, jwkUrl, new JWKSetParser(httpClient));
}
/**
* Create a new JWKs store.
*
* @param uid the unique identifier for this store
* @param cacheTimeout a cache timeout to avoid reloading the cache all the time when doing encryption
* @param cacheMissCacheTime the cache time before reload the cache in case of a cache miss.
* This avoid polling the client application too often.
* @param jwkUrl the jwk url of the JWKs hosted by the client application
* @param jwkSetParser the jwks set parser
* @throws FailedToLoadJWKException if the jwks can't be reloaded.
*/
JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
final URL jwkUrl, JWKSetParser jwkSetParser) throws FailedToLoadJWKException {
this.uid = uid;
this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
this.jwkUrl = jwkUrl;
this.jwkParser = jwkSetParser;
try {
reloadJwks();
} catch (FailedToLoadJWKException e) {
logger.debug("Unable to load keys from the JWK over HTTP");
throw new FailedToLoadJWKException("Unable to load keys from the JWK over HTTP", e);
}
}
/**
* Communicates with the configured server, attempting to download the latest JWKs for use.
*
* @throws FailedToLoadJWKException if there were issues parsing the supplied URL
*/
private synchronized void reloadJwks() throws FailedToLoadJWKException {
jwksSet = jwkParser.jwkSet(jwkUrl);
lastReloadJwksSet = System.currentTimeMillis();
}
/**
* Search for a JWK that matches the algorithm and the key usage.
*
* @param algorithm the algorithm needed
* @param keyUse the key usage. If null, only the algorithm will be used as a search criteria.
* @return A jwk that matches the search criteria. If no JWK found for the key usage, then it searches for a JWK
* without key usage defined. If still no JWK found, then returns null.
* @throws FailedToLoadJWKException if the jwks can't be reloaded.
*/
public JWK findJwk(Algorithm algorithm, KeyUse keyUse) throws FailedToLoadJWKException {
if (keyUse == KeyUse.ENC && hasJwksCacheTimedOut()) {
reloadJwks();
}
JWK jwk = jwksSet.findJwk(algorithm, keyUse);
if (jwk == null && isCacheMissCacheTimeExpired()) {
reloadJwks();
return jwksSet.findJwk(algorithm, keyUse);
}
return jwk;
}
/**
* Search for a JWK that matches the kid.
*
* @param kid Key ID
* @return A jwk that matches the kid. If no JWK found, returns null
* @throws FailedToLoadJWKException if the jwks can't be reloaded.
*/
public JWK findJwk(String kid) throws FailedToLoadJWKException {
JWK jwk = jwksSet.findJwk(kid);
if (jwk == null && isCacheMissCacheTimeExpired()) {
reloadJwks();
return jwksSet.findJwk(kid);
}
return jwk;
}
/**
* Get the UID.
* @return the uid.
*/
public String getUid() {
return uid;
}
/**
* Get the cache timeout.
* @return the cache timeout.
*/
public Duration getCacheTimeout() {
return Duration.duration(cacheTimeoutInMs, TimeUnit.MILLISECONDS);
}
/**
* Get the cache time before reload the cache in case of cache miss.
* @return the cache miss cache time.
*/
public Duration getCacheMissCacheTime() {
return Duration.duration(cacheMissCacheTimeInMs, TimeUnit.MILLISECONDS);
}
/**
* The JWKs URI.
* @return the jwk uri.
*/
public URL getJwkUrl() {
return jwkUrl;
}
/**
* Update the cache timeout.
* @param cacheTimeout the cache timeout.
*/
public void setCacheTimeout(Duration cacheTimeout) {
this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
}
/**
* Update the cache time before reload the cache in case of cache miss.
* @param cacheMissCacheTime the cache miss cache time.
*/
public void setCacheMissCacheTime(Duration cacheMissCacheTime) {
this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
}
/**
* Update the JWKs URI.
* @param jwkUrl the jwks uri.
* @throws FailedToLoadJWKException If the URI has changed and the JWK set cannot be loaded.
*/
public void setJwkUrl(URL jwkUrl) throws FailedToLoadJWKException {
Reject.ifNull(jwkUrl);
URL originalJwkUrl = this.jwkUrl;
this.jwkUrl = jwkUrl;
if (!jwkUrl.equals(originalJwkUrl)) {
reloadJwks();
}
}
private boolean hasJwksCacheTimedOut() {
return (System.currentTimeMillis() - lastReloadJwksSet) > cacheTimeoutInMs;
}
/**
* When we have a cache miss, we don't refresh the cache straight away. We check first if the cache miss cache
* time is expired out or not
* @return true is we can reload the cache
*/
private boolean isCacheMissCacheTimeExpired() {
return (System.currentTimeMillis() - lastReloadJwksSet) >= cacheMissCacheTimeInMs;
}
}