1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.forgerock.json.jose.jwk.store;
17
18
19 import java.net.URL;
20 import java.util.concurrent.TimeUnit;
21
22 import org.forgerock.json.jose.exceptions.FailedToLoadJWKException;
23 import org.forgerock.json.jose.jwk.JWK;
24 import org.forgerock.json.jose.jwk.JWKSet;
25 import org.forgerock.json.jose.jwk.JWKSetParser;
26 import org.forgerock.json.jose.jwk.KeyUse;
27 import org.forgerock.json.jose.jwt.Algorithm;
28 import org.forgerock.util.Reject;
29 import org.forgerock.util.SimpleHTTPClient;
30 import org.forgerock.util.time.Duration;
31 import org.slf4j.Logger;
32 import org.slf4j.LoggerFactory;
33
34
35 public class JwksStore {
36 private static final Logger logger = LoggerFactory.getLogger(JwksStore.class);
37
38 private final String uid;
39 private final JWKSetParser jwkParser;
40
41
42 private long cacheMissCacheTimeInMs;
43 private long cacheTimeoutInMs;
44 private URL jwkUrl;
45
46 private JWKSet jwksSet;
47 private long lastReloadJwksSet;
48
49
50
51
52
53
54
55
56
57
58
59
60 JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
61 final URL jwkUrl, final SimpleHTTPClient httpClient) throws FailedToLoadJWKException {
62 this(uid, cacheTimeout, cacheMissCacheTime, jwkUrl, new JWKSetParser(httpClient));
63 }
64
65
66
67
68
69
70
71
72
73
74
75
76 JwksStore(final String uid, final Duration cacheTimeout, final Duration cacheMissCacheTime,
77 final URL jwkUrl, JWKSetParser jwkSetParser) throws FailedToLoadJWKException {
78 this.uid = uid;
79 this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
80 this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
81 this.jwkUrl = jwkUrl;
82 this.jwkParser = jwkSetParser;
83
84 try {
85 reloadJwks();
86 } catch (FailedToLoadJWKException e) {
87 logger.debug("Unable to load keys from the JWK over HTTP");
88 throw new FailedToLoadJWKException("Unable to load keys from the JWK over HTTP", e);
89 }
90 }
91
92
93
94
95
96
97 private synchronized void reloadJwks() throws FailedToLoadJWKException {
98 jwksSet = jwkParser.jwkSet(jwkUrl);
99 lastReloadJwksSet = System.currentTimeMillis();
100 }
101
102
103
104
105
106
107
108
109
110
111 public JWK findJwk(Algorithm algorithm, KeyUse keyUse) throws FailedToLoadJWKException {
112 if (keyUse == KeyUse.ENC && hasJwksCacheTimedOut()) {
113 reloadJwks();
114 }
115
116 JWK jwk = jwksSet.findJwk(algorithm, keyUse);
117 if (jwk == null && isCacheMissCacheTimeExpired()) {
118 reloadJwks();
119 return jwksSet.findJwk(algorithm, keyUse);
120 }
121 return jwk;
122 }
123
124
125
126
127
128
129
130
131 public JWK findJwk(String kid) throws FailedToLoadJWKException {
132 JWK jwk = jwksSet.findJwk(kid);
133 if (jwk == null && isCacheMissCacheTimeExpired()) {
134 reloadJwks();
135 return jwksSet.findJwk(kid);
136 }
137 return jwk;
138 }
139
140
141
142
143
144 public String getUid() {
145 return uid;
146 }
147
148
149
150
151
152 public Duration getCacheTimeout() {
153 return Duration.duration(cacheTimeoutInMs, TimeUnit.MILLISECONDS);
154 }
155
156
157
158
159
160 public Duration getCacheMissCacheTime() {
161 return Duration.duration(cacheMissCacheTimeInMs, TimeUnit.MILLISECONDS);
162 }
163
164
165
166
167
168 public URL getJwkUrl() {
169 return jwkUrl;
170 }
171
172
173
174
175
176 public void setCacheTimeout(Duration cacheTimeout) {
177 this.cacheTimeoutInMs = cacheTimeout.to(TimeUnit.MILLISECONDS);
178 }
179
180
181
182
183
184 public void setCacheMissCacheTime(Duration cacheMissCacheTime) {
185 this.cacheMissCacheTimeInMs = cacheMissCacheTime.to(TimeUnit.MILLISECONDS);
186 }
187
188
189
190
191
192
193 public void setJwkUrl(URL jwkUrl) throws FailedToLoadJWKException {
194 Reject.ifNull(jwkUrl);
195 URL originalJwkUrl = this.jwkUrl;
196 this.jwkUrl = jwkUrl;
197 if (!jwkUrl.equals(originalJwkUrl)) {
198 reloadJwks();
199 }
200 }
201
202 private boolean hasJwksCacheTimedOut() {
203 return (System.currentTimeMillis() - lastReloadJwksSet) > cacheTimeoutInMs;
204 }
205
206
207
208
209
210
211 private boolean isCacheMissCacheTimeExpired() {
212 return (System.currentTimeMillis() - lastReloadJwksSet) >= cacheMissCacheTimeInMs;
213 }
214 }