Codebase list python-castellan / d8fb4f1 castellan / tests / unit / key_manager / test_mock_key_manager.py
d8fb4f1

Tree @d8fb4f1 (Download .tar.gz)

test_mock_key_manager.py @d8fb4f1raw · history · blame

# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

"""
Test cases for the mock key manager.
"""

from cryptography.hazmat import backends
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from oslo_context import context

from castellan.common import exception
from castellan.common.objects import symmetric_key as sym_key
from castellan.tests.unit.key_manager import mock_key_manager as mock_key_mgr
from castellan.tests.unit.key_manager import test_key_manager as test_key_mgr


def get_cryptography_private_key(private_key):
    crypto_private_key = serialization.load_der_private_key(
        bytes(private_key.get_encoded()),
        password=None,
        backend=backends.default_backend())
    return crypto_private_key


def get_cryptography_public_key(public_key):
    crypto_public_key = serialization.load_der_public_key(
        bytes(public_key.get_encoded()),
        backend=backends.default_backend())
    return crypto_public_key


class MockKeyManagerTestCase(test_key_mgr.KeyManagerTestCase):

    def _create_key_manager(self):
        return mock_key_mgr.MockKeyManager()

    def setUp(self):
        super(MockKeyManagerTestCase, self).setUp()

        self.context = context.RequestContext('fake', 'fake')

    def cleanUp(self):
        super(MockKeyManagerTestCase, self).cleanUp()

        self.key_mgr.keys = {}

    def test_create_key(self):
        key_id_1 = self.key_mgr.create_key(self.context)
        key_id_2 = self.key_mgr.create_key(self.context)
        # ensure that the UUIDs are unique
        self.assertNotEqual(key_id_1, key_id_2)

    def test_create_key_with_length(self):
        for length in [64, 128, 256]:
            key_id = self.key_mgr.create_key(self.context, length=length)
            key = self.key_mgr.get(self.context, key_id)
            self.assertEqual(length / 8, len(key.get_encoded()))
            self.assertIsNotNone(key.id)

    def test_create_key_with_name(self):
        name = 'my key'
        key_id = self.key_mgr.create_key(self.context, name=name)
        key = self.key_mgr.get(self.context, key_id)
        self.assertEqual(name, key.name)
        self.assertIsNotNone(key.id)

    def test_create_key_with_algorithm(self):
        algorithm = 'DES'
        key_id = self.key_mgr.create_key(self.context, algorithm=algorithm)
        key = self.key_mgr.get(self.context, key_id)
        self.assertEqual(algorithm, key.algorithm)
        self.assertIsNotNone(key.id)

    def test_create_key_null_context(self):
        self.assertRaises(exception.Forbidden,
                          self.key_mgr.create_key, None)

    def test_create_key_pair(self):
        for length in [2048, 3072, 4096]:
            name = str(length) + ' key'
            private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair(
                self.context, 'RSA', length, name=name)

            private_key = self.key_mgr.get(self.context, private_key_uuid)
            self.assertIsNotNone(private_key.id)
            public_key = self.key_mgr.get(self.context, public_key_uuid)
            self.assertIsNotNone(public_key.id)

            crypto_private_key = get_cryptography_private_key(private_key)
            crypto_public_key = get_cryptography_public_key(public_key)

            self.assertEqual(name, private_key.name)
            self.assertEqual(name, public_key.name)

            self.assertEqual(length, crypto_private_key.key_size)
            self.assertEqual(length, crypto_public_key.key_size)

    def test_create_key_pair_encryption(self):
        private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair(
            self.context, 'RSA', 2048)

        private_key = self.key_mgr.get(self.context, private_key_uuid)
        public_key = self.key_mgr.get(self.context, public_key_uuid)

        crypto_private_key = get_cryptography_private_key(private_key)
        crypto_public_key = get_cryptography_public_key(public_key)

        message = b'secret plaintext'
        ciphertext = crypto_public_key.encrypt(
            message,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None))
        plaintext = crypto_private_key.decrypt(
            ciphertext,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None))

        self.assertEqual(message, plaintext)

    def test_create_key_pair_null_context(self):
        self.assertRaises(exception.Forbidden,
                          self.key_mgr.create_key_pair, None, 'RSA', 2048)

    def test_create_key_pair_invalid_algorithm(self):
        self.assertRaises(ValueError,
                          self.key_mgr.create_key_pair,
                          self.context, 'DSA', 2048)

    def test_create_key_pair_invalid_length(self):
        self.assertRaises(ValueError,
                          self.key_mgr.create_key_pair,
                          self.context, 'RSA', 10)

    def test_store_and_get_key(self):
        secret_key = bytes(b'0' * 64)
        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
        key_id = self.key_mgr.store(self.context, _key)

        actual_key = self.key_mgr.get(self.context, key_id)
        self.assertEqual(_key, actual_key)

        self.assertIsNotNone(actual_key.id)

    def test_store_key_and_get_metadata(self):
        secret_key = bytes(b'0' * 64)
        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
        key_id = self.key_mgr.store(self.context, _key)

        actual_key = self.key_mgr.get(self.context,
                                      key_id,
                                      metadata_only=True)
        self.assertIsNone(actual_key.get_encoded())
        self.assertTrue(actual_key.is_metadata_only())

        self.assertIsNotNone(actual_key.id)

    def test_store_key_and_get_metadata_and_get_key(self):
        secret_key = bytes(b'0' * 64)
        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
        key_id = self.key_mgr.store(self.context, _key)

        actual_key = self.key_mgr.get(self.context,
                                      key_id,
                                      metadata_only=True)
        self.assertIsNone(actual_key.get_encoded())
        self.assertTrue(actual_key.is_metadata_only())

        actual_key = self.key_mgr.get(self.context,
                                      key_id,
                                      metadata_only=False)
        self.assertIsNotNone(actual_key.get_encoded())
        self.assertFalse(actual_key.is_metadata_only())

        self.assertIsNotNone(actual_key.id)

    def test_store_null_context(self):
        self.assertRaises(exception.Forbidden,
                          self.key_mgr.store, None, None)

    def test_get_null_context(self):
        self.assertRaises(exception.Forbidden,
                          self.key_mgr.get, None, None)

    def test_get_unknown_key(self):
        self.assertRaises(KeyError, self.key_mgr.get, self.context, None)

    def test_delete_key(self):
        key_id = self.key_mgr.create_key(self.context)
        self.key_mgr.delete(self.context, key_id)

        self.assertRaises(KeyError, self.key_mgr.get, self.context,
                          key_id)

    def test_delete_null_context(self):
        self.assertRaises(exception.Forbidden,
                          self.key_mgr.delete, None, None)

    def test_delete_unknown_key(self):
        self.assertRaises(KeyError, self.key_mgr.delete, self.context,
                          None)

    def test_list_null_context(self):
        self.assertRaises(exception.Forbidden, self.key_mgr.list, None)

    def test_list_keys(self):
        key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64))
        self.key_mgr.store(self.context, key1)
        key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32))
        self.key_mgr.store(self.context, key2)

        keys = self.key_mgr.list(self.context)
        self.assertEqual(2, len(keys))
        self.assertTrue(key1 in keys)
        self.assertTrue(key2 in keys)

        for key in keys:
            self.assertIsNotNone(key.id)

    def test_list_keys_metadata_only(self):
        key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64))
        self.key_mgr.store(self.context, key1)
        key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32))
        self.key_mgr.store(self.context, key2)

        keys = self.key_mgr.list(self.context, metadata_only=True)
        self.assertEqual(2, len(keys))
        bit_length_list = [key1.bit_length, key2.bit_length]
        for key in keys:
            self.assertTrue(key.is_metadata_only())
            self.assertTrue(key.bit_length in bit_length_list)

        for key in keys:
            self.assertIsNotNone(key.id)