001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2015-2017 ForgeRock AS.
015 */
016
017package org.forgerock.json.jose.tokenhandler;
018
019import static org.forgerock.json.JsonValue.json;
020
021import java.io.IOException;
022import java.security.KeyPair;
023import java.util.Date;
024import java.util.Map;
025
026import org.wrensecurity.guava.common.base.Optional;
027import org.forgerock.json.JsonValue;
028import org.forgerock.json.jose.builders.JwtBuilderFactory;
029import org.forgerock.json.jose.builders.JwtClaimsSetBuilder;
030import org.forgerock.json.jose.jwe.EncryptionMethod;
031import org.forgerock.json.jose.jwe.JweAlgorithm;
032import org.forgerock.json.jose.jws.JwsAlgorithm;
033import org.forgerock.json.jose.jws.SignedEncryptedJwt;
034import org.forgerock.json.jose.jws.handlers.SigningHandler;
035import org.forgerock.json.jose.jwt.JwtClaimsSet;
036import org.forgerock.tokenhandler.ExpiredTokenException;
037import org.forgerock.tokenhandler.InvalidTokenException;
038import org.forgerock.tokenhandler.TokenHandler;
039import org.forgerock.tokenhandler.TokenHandlerException;
040import org.forgerock.util.Reject;
041
042import com.fasterxml.jackson.databind.ObjectMapper;
043
044/**
045 * Token handler for creating tokens using a JWT as the store.
046 */
047public final class JwtTokenHandler implements TokenHandler {
048
049    private static final ObjectMapper MAPPER = new ObjectMapper();
050
051    private final JwtBuilderFactory jwtBuilderFactory;
052    private final JweAlgorithm jweAlgorithm;
053    private final EncryptionMethod jweMethod;
054    private final KeyPair jweKeyPair;
055    private final JwsAlgorithm jwsAlgorithm;
056    private final SigningHandler jwsHandler;
057    private final Optional<Long> tokenLifeTimeInSeconds;
058
059    /**
060     * Constructs a new JWT token handler that never expires.
061     *
062     * @param jweAlgorithm
063     *         the JWE algorithm use to construct the key pair
064     * @param jweMethod
065     *         the encryption method to use
066     * @param jweKeyPair
067     *         key pair for the purpose of encryption
068     * @param jwsAlgorithm
069     *         the JWS algorithm to use
070     * @param jwsHandler
071     *         the signing handler
072     */
073    public JwtTokenHandler(JweAlgorithm jweAlgorithm, EncryptionMethod jweMethod, KeyPair jweKeyPair,
074            JwsAlgorithm jwsAlgorithm, SigningHandler jwsHandler) {
075        this(jweAlgorithm, jweMethod, jweKeyPair, jwsAlgorithm, jwsHandler, Optional.<Long>absent());
076    }
077
078    /**
079     * Constructs a new JWT token handler.
080     *
081     * @param jweAlgorithm
082     *         the JWE algorithm use to construct the key pair
083     * @param jweMethod
084     *         the encryption method to use
085     * @param jweKeyPair
086     *         key pair for the purpose of encryption
087     * @param jwsAlgorithm
088     *         the JWS algorithm to use
089     * @param jwsHandler
090     *         the signing handler
091     * @param tokenLifeTimeInSeconds
092     *         token life time in seconds
093     */
094    public JwtTokenHandler(JweAlgorithm jweAlgorithm, EncryptionMethod jweMethod, KeyPair jweKeyPair,
095            JwsAlgorithm jwsAlgorithm, SigningHandler jwsHandler, Optional<Long> tokenLifeTimeInSeconds) {
096        Reject.ifNull(jweAlgorithm, jweMethod, jweKeyPair, jwsAlgorithm, jwsHandler);
097        Reject.ifTrue(tokenLifeTimeInSeconds.isPresent() && tokenLifeTimeInSeconds.get() <= 0);
098        jwtBuilderFactory = new JwtBuilderFactory();
099        this.jweAlgorithm = jweAlgorithm;
100        this.jweMethod = jweMethod;
101        this.jweKeyPair = jweKeyPair;
102        this.jwsAlgorithm = jwsAlgorithm;
103        this.jwsHandler = jwsHandler;
104        this.tokenLifeTimeInSeconds = tokenLifeTimeInSeconds;
105    }
106
107    @Override
108    public String generate(JsonValue state) throws TokenHandlerException {
109        Reject.ifNull(state);
110
111        try {
112            JwtClaimsSetBuilder claimsSetBuilder = jwtBuilderFactory
113                    .claims()
114                    .claim("state", MAPPER.writeValueAsString(state.getObject()));
115
116            final JwtClaimsSet claimsSet;
117            if (tokenLifeTimeInSeconds.isPresent()) {
118                claimsSet = claimsSetBuilder
119                        .exp(new Date(System.currentTimeMillis() + (tokenLifeTimeInSeconds.get() * 1000L)))
120                        .build();
121            } else {
122                claimsSet = claimsSetBuilder.build();
123            }
124
125            return jwtBuilderFactory
126                    .jwe(jweKeyPair.getPublic())
127                    .headers()
128                        .alg(jweAlgorithm)
129                        .enc(jweMethod)
130                        .done()
131                    .claims(claimsSet)
132                    .sign(jwsHandler, jwsAlgorithm)
133                    .build();
134        } catch (IOException e) {
135            throw new TokenHandlerException("Error serializing token state", e);
136        } catch (RuntimeException e) {
137            throw new TokenHandlerException("Error constructing token", e);
138        }
139    }
140
141    @Override
142    public void validate(String snapshotToken) throws TokenHandlerException {
143        validateAndExtractClaims(snapshotToken);
144    }
145
146    @Override
147    public JsonValue validateAndExtractState(String snapshotToken) throws TokenHandlerException {
148        Reject.ifNull(snapshotToken);
149
150        try {
151            JwtClaimsSet claimsSet = validateAndExtractClaims(snapshotToken);
152            return json(MAPPER.readValue(claimsSet.getClaim("state").toString(), Map.class));
153        } catch (IOException e) {
154            throw new InvalidTokenException("Failed to parse token state as JSON", e);
155        }
156    }
157
158    private JwtClaimsSet validateAndExtractClaims(String snapshotToken) throws TokenHandlerException {
159        try {
160            SignedEncryptedJwt signedEncryptedJwt = jwtBuilderFactory
161                    .reconstruct(snapshotToken, SignedEncryptedJwt.class);
162
163            if (!signedEncryptedJwt.verify(jwsHandler)) {
164                throw new InvalidTokenException("Invalid token");
165            }
166
167            signedEncryptedJwt.decrypt(jweKeyPair.getPrivate());
168
169            JwtClaimsSet claimsSet = signedEncryptedJwt.getClaimsSet();
170            Date expirationTime = claimsSet.getExpirationTime();
171
172            if (expirationTime != null && expirationTime.before(new Date())) {
173                throw new ExpiredTokenException("Token has expired");
174            }
175
176            return claimsSet;
177        } catch (RuntimeException e) {
178            throw new InvalidTokenException("Invalid token", e);
179        }
180    }
181
182}