001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions Copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2016 ForgeRock AS.
015 */
016package org.forgerock.opendj.ldap;
017
018import static java.nio.charset.StandardCharsets.UTF_8;
019
020import java.security.MessageDigest;
021import java.security.NoSuchAlgorithmException;
022import java.util.Collection;
023import java.util.IdentityHashMap;
024import java.util.LinkedHashMap;
025import java.util.Map;
026import java.util.NavigableMap;
027import java.util.TreeMap;
028import java.util.concurrent.locks.ReentrantLock;
029
030import org.forgerock.util.Function;
031import org.forgerock.util.Reject;
032import org.forgerock.util.annotations.VisibleForTesting;
033import org.forgerock.util.promise.NeverThrowsException;
034
035/**
036 * An implementation of "consistent hashing" supporting per-partition weighting. This implementation is thread safe
037 * and allows partitions to be added and removed during use.
038 * <p>
039 * This implementation maps partitions to one or more points on a circle ranging from {@link Integer#MIN_VALUE} to
040 * {@link Integer#MAX_VALUE}. The number of points per partition is dictated by the partition's weight. A partition
041 * with a weight which is higher than another partition will receive a proportionally higher load.
042 *
043 * @param <P> The type of partition object.
044 *
045 * @see <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.23.3738">Consistent Hashing and Random
046 * Trees</a>
047 * @see <a href="http://www8.org/w8-papers/2a-webserver/caching/paper2.html">Web Caching with Consistent Hashing</a>
048 */
049public final class ConsistentHashMap<P> {
050    // TODO: add methods for determining which partitions will need to be rebalanced when a partition is added or
051    // removed.
052    /** The default weight. The value is relatively high in order to minimize the risk of imbalances. */
053    private static final int DEFAULT_WEIGHT = 200;
054    /** Default hash function based on MD5. */
055    @VisibleForTesting
056    static final Function<Object, Integer, NeverThrowsException> MD5 =
057            new Function<Object, Integer, NeverThrowsException>() {
058                @Override
059                public Integer apply(final Object key) {
060                    final byte[] bytes = key.toString().getBytes(UTF_8);
061                    final byte[] digest = getMD5Digest().digest(bytes);
062                    return ByteString.wrap(digest).toInt();
063                }
064
065                private MessageDigest getMD5Digest() {
066                    // TODO: we may want to cache these.
067                    try {
068                        return MessageDigest.getInstance("MD5");
069                    } catch (NoSuchAlgorithmException e) {
070                        throw new RuntimeException(e);
071                    }
072                }
073            };
074    /** Synchronizes updates. Reads are protected by copy on write. */
075    private final ReentrantLock writeLock = new ReentrantLock();
076    /** Consistent hash map circle. */
077    private volatile NavigableMap<Integer, Node<P>> circle = new TreeMap<>();
078    /** Maps partition IDs to their partition. */
079    private volatile Map<String, P> partitions = new LinkedHashMap<>();
080    /** Function used for hashing keys. */
081    private final Function<Object, Integer, NeverThrowsException> hashFunction;
082
083    /** Creates a new consistent hash map which will hash keys using MD5. */
084    public ConsistentHashMap() {
085        this(MD5);
086    }
087
088    /**
089     * Creates a new consistent hash map which will hash keys using the provided hash function.
090     *
091     * @param hashFunction
092     *         The function which should be used for hashing keys.
093     */
094    public ConsistentHashMap(final Function<Object, Integer, NeverThrowsException> hashFunction) {
095        this.hashFunction = hashFunction;
096    }
097
098    /**
099     * Puts a partition into this consistent hash map using the default weight which is sufficiently high to ensure a
100     * reasonably uniform distribution among all partitions having the same weight.
101     *
102     * @param partitionId
103     *         The partition ID.
104     * @param partition
105     *         The partition.
106     * @return This consistent hash map.
107     */
108    public ConsistentHashMap<P> put(final String partitionId, final P partition) {
109        return put(partitionId, partition, DEFAULT_WEIGHT);
110    }
111
112    /**
113     * Puts a partition into this consistent hash map using the specified weight. If all partitions have the same
114     * weight then they will each receive a similar amount of load. A partition having a weight which is twice that
115     * of another will receive twice the load. Weight values should generally be great than 200 in order to minimize
116     * the risk of unexpected imbalances due to the way in which logical partitions are mapped to real partitions.
117     *
118     * @param partitionId
119     *         The partition ID.
120     * @param partition
121     *         The partition.
122     * @param weight
123     *         The partition's weight, which should typically be over 200 and never negative.
124     * @return This consistent hash map.
125     */
126    public ConsistentHashMap<P> put(final String partitionId, final P partition, final int weight) {
127        Reject.ifNull(partitionId, "partitionId must be non-null");
128        Reject.ifNull(partition, "partition must be non-null");
129        Reject.ifTrue(weight < 0, "Weight must be a positive integer");
130
131        final Node<P> node = new Node<>(partitionId, partition, weight);
132        writeLock.lock();
133        try {
134            final TreeMap<Integer, Node<P>> newCircle = new TreeMap<>(circle);
135            for (int i = 0; i < weight; i++) {
136                newCircle.put(hashFunction.apply(partitionId + i), node);
137            }
138
139            final Map<String, P> newPartitions = new LinkedHashMap<>(partitions);
140            newPartitions.put(partitionId, partition);
141
142            // It doesn't matter that these assignments are not atomic.
143            circle = newCircle;
144            partitions = newPartitions;
145        } finally {
146            writeLock.unlock();
147        }
148        return this;
149    }
150
151    /**
152     * Removes the partition that was previously added using the provided partition ID.
153     *
154     * @param partitionId
155     *         The partition ID.
156     * @return This consistent hash map.
157     */
158    public ConsistentHashMap<P> remove(final String partitionId) {
159        Reject.ifNull(partitionId, "partitionId must be non-null");
160
161        writeLock.lock();
162        try {
163            if (partitions.containsKey(partitionId)) {
164                final TreeMap<Integer, Node<P>> newCircle = new TreeMap<>(circle);
165                final Node<P> node = newCircle.remove(hashFunction.apply(partitionId + 0));
166                for (int i = 1; i < node.weight; i++) {
167                    newCircle.remove(hashFunction.apply(partitionId + i));
168                }
169
170                final Map<String, P> newPartitions = new LinkedHashMap<>(partitions);
171                newPartitions.remove(partitionId);
172
173                // It doesn't matter that these assignments are not atomic.
174                circle = newCircle;
175                partitions = newPartitions;
176            }
177        } finally {
178            writeLock.unlock();
179        }
180        return this;
181    }
182
183    /**
184     * Returns the partition from this map corresponding to the provided key's hash, or {@code null} if this map is
185     * empty.
186     *
187     * @param key
188     *         The key for which a corresponding partition is to be returned.
189     * @return The partition from this map corresponding to the provided key's hash, or {@code null} if this map is
190     * empty.
191     */
192    P get(final Object key) {
193        final NavigableMap<Integer, Node<P>> circleSnapshot = circle;
194        final Map.Entry<Integer, Node<P>> ceilingEntry = circleSnapshot.ceilingEntry(hashFunction.apply(key));
195        if (ceilingEntry != null) {
196            return ceilingEntry.getValue().partition;
197        }
198        final Map.Entry<Integer, Node<P>> firstEntry = circleSnapshot.firstEntry();
199        return firstEntry != null ? firstEntry.getValue().partition : null;
200    }
201
202    /**
203     * Returns a collection containing all of the partitions contained in this consistent hash map.
204     *
205     * @return A collection containing all of the partitions contained in this consistent hash map.
206     */
207    Collection<P> getAll() {
208        return partitions.values();
209    }
210
211    /**
212     * Returns the number of partitions in this consistent hash map.
213     *
214     * @return The number of partitions in this consistent hash map.
215     */
216    int size() {
217        return partitions.size();
218    }
219
220    /**
221     * Returns {@code true} if there are no partitions in this consistent hash map.
222     *
223     * @return {@code true} if there are no partitions in this consistent hash map.
224     */
225    boolean isEmpty() {
226        return partitions.isEmpty();
227    }
228
229    /**
230     * Returns a map whose keys are the partitions stored in this map and whose values are the actual weights associated
231     * with each partition. The sum of the weights will be equal to 2^32.
232     * <p/>
233     * This method is intended for testing, but may one day be used in order to query the current status of the
234     * load-balancer and, in particular, the weighting associated with each partition as a percentage.
235     *
236     * @return A map whose keys are the partitions stored in this map and whose values are the actual weights associated
237     * with each partition.
238     */
239    @VisibleForTesting
240    Map<P, Long> getWeights() {
241        final NavigableMap<Integer, Node<P>> circleSnapshot = circle;
242        final IdentityHashMap<P, Long> weights = new IdentityHashMap<>();
243        Map.Entry<Integer, Node<P>> previousEntry = null;
244        for (final Map.Entry<Integer, Node<P>> entry : circleSnapshot.entrySet()) {
245            final long index = entry.getKey();
246            final P partition = entry.getValue().partition;
247            if (previousEntry == null) {
248                // Special case for first value since the range begins with the last entry.
249                final long range1 = (long) Integer.MAX_VALUE - circleSnapshot.lastEntry().getKey();
250                final long range2 = index - Integer.MIN_VALUE;
251                weights.put(partition, range1 + range2 + 1);
252            } else {
253                final long start = previousEntry.getKey();
254                final long end = entry.getKey();
255                if (weights.containsKey(partition)) {
256                    weights.put(partition, weights.get(partition) + (end - start));
257                } else {
258                    weights.put(partition, end - start);
259                }
260            }
261            previousEntry = entry;
262        }
263        return weights;
264    }
265
266    @Override
267    public String toString() {
268        return getWeights().toString();
269    }
270
271    /** A partition stored in the consistent hash map circle. */
272    private static final class Node<P> {
273        private final String partitionId;
274        private final P partition;
275        private final int weight;
276
277        private Node(final String partitionId, final P partition, final int weight) {
278            this.partitionId = partitionId;
279            this.partition = partition;
280            this.weight = weight;
281        }
282
283        @Override
284        public String toString() {
285            return partitionId;
286        }
287    }
288}