Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify test code by using Results instead of unwraps #39

Merged
merged 2 commits into from
Sep 3, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 71 additions & 63 deletions cryptoki/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@ use serial_test::serial;
use std::sync::Arc;
use std::thread;

#[derive(Debug)]
struct ErrorWithStacktrace;

impl<T: std::error::Error> From<T> for ErrorWithStacktrace {
fn from(p: T) -> Self {
panic!("Error: {:#?}", p);
}
}

type Result<T> = std::result::Result<T, ErrorWithStacktrace>;

#[test]
#[serial]
fn sign_verify() {
fn sign_verify() -> Result<()> {
let (pkcs11, slot) = init_pins();

// set flags
let mut flags = Flags::new();
let _ = flags.set_rw_session(true).set_serial_session(true);

// open a session
let session = pkcs11.open_session_no_callback(slot, flags).unwrap();
let session = pkcs11.open_session_no_callback(slot, flags)?;

// log in the session
session.login(UserType::User).unwrap();
session.login(UserType::User)?;

// get mechanism
let mechanism = Mechanism::RsaPkcsKeyPairGen;
Expand All @@ -44,40 +55,39 @@ fn sign_verify() {
let priv_key_template = vec![Attribute::Token(true.into())];

// generate a key pair
let (public, private) = session
.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)
.unwrap();
let (public, private) =
session.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)?;

// data to sign
let data = [0xFF, 0x55, 0xDD];

// sign something with it
let signature = session.sign(&Mechanism::RsaPkcs, private, &data).unwrap();
let signature = session.sign(&Mechanism::RsaPkcs, private, &data)?;

// verify the signature
session
.verify(&Mechanism::RsaPkcs, public, &data, &signature)
.unwrap();
session.verify(&Mechanism::RsaPkcs, public, &data, &signature)?;

// delete keys
session.destroy_object(public).unwrap();
session.destroy_object(private).unwrap();
session.destroy_object(public)?;
session.destroy_object(private)?;

Ok(())
}

#[test]
#[serial]
fn encrypt_decrypt() {
fn encrypt_decrypt() -> Result<()> {
let (pkcs11, slot) = init_pins();

// set flags
let mut flags = Flags::new();
let _ = flags.set_rw_session(true).set_serial_session(true);

// open a session
let session = pkcs11.open_session_no_callback(slot, flags).unwrap();
let session = pkcs11.open_session_no_callback(slot, flags)?;

// log in the session
session.login(UserType::User).unwrap();
session.login(UserType::User)?;

// get mechanism
let mechanism = Mechanism::RsaPkcsKeyPairGen;
Expand All @@ -101,43 +111,42 @@ fn encrypt_decrypt() {
];

// generate a key pair
let (public, private) = session
.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)
.unwrap();
let (public, private) =
session.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)?;

// data to encrypt
let data = vec![0xFF, 0x55, 0xDD];

// encrypt something with it
let encrypted_data = session.encrypt(&Mechanism::RsaPkcs, public, &data).unwrap();
let encrypted_data = session.encrypt(&Mechanism::RsaPkcs, public, &data)?;

// decrypt
let decrypted_data = session
.decrypt(&Mechanism::RsaPkcs, private, &encrypted_data)
.unwrap();
let decrypted_data = session.decrypt(&Mechanism::RsaPkcs, private, &encrypted_data)?;

// The decrypted buffer is bigger than the original one.
assert_eq!(data, decrypted_data);

// delete keys
session.destroy_object(public).unwrap();
session.destroy_object(private).unwrap();
session.destroy_object(public)?;
session.destroy_object(private)?;

Ok(())
}

#[test]
#[serial]
fn derive_key() {
fn derive_key() -> Result<()> {
let (pkcs11, slot) = init_pins();

// set flags
let mut flags = Flags::new();
let _ = flags.set_rw_session(true).set_serial_session(true);

// open a session
let session = pkcs11.open_session_no_callback(slot, flags).unwrap();
let session = pkcs11.open_session_no_callback(slot, flags)?;

// log in the session
session.login(UserType::User).unwrap();
session.login(UserType::User)?;

// get mechanism
let mechanism = Mechanism::EccKeyPairGen;
Expand Down Expand Up @@ -165,13 +174,11 @@ fn derive_key() {
];

// generate a key pair
let (public, private) = session
.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)
.unwrap();
let (public, private) =
session.generate_key_pair(&mechanism, &pub_key_template, &priv_key_template)?;

let ec_point_attribute = session
.get_attributes(public, &[AttributeType::EcPoint])
.unwrap()
.get_attributes(public, &[AttributeType::EcPoint])?
.remove(0);

let ec_point = if let Attribute::EcPoint(point) = ec_point_attribute {
Expand All @@ -185,29 +192,26 @@ fn derive_key() {

let params = Ecdh1DeriveParams {
kdf: EcKdfType::NULL,
shared_data_len: 0_usize.try_into().unwrap(),
shared_data_len: 0_usize.try_into()?,
shared_data: std::ptr::null(),
public_data_len: (*ec_point).len().try_into().unwrap(),
public_data_len: (*ec_point).len().try_into()?,
public_data: ec_point.as_ptr() as *const std::ffi::c_void,
};

let shared_secret = session
.derive_key(
&Mechanism::Ecdh1Derive(params),
private,
&[
Attribute::Class(ObjectClass::SECRET_KEY),
Attribute::KeyType(KeyType::GENERIC_SECRET),
Attribute::Sensitive(false.into()),
Attribute::Extractable(true.into()),
Attribute::Token(false.into()),
],
)
.unwrap();
let shared_secret = session.derive_key(
&Mechanism::Ecdh1Derive(params),
private,
&[
Attribute::Class(ObjectClass::SECRET_KEY),
Attribute::KeyType(KeyType::GENERIC_SECRET),
Attribute::Sensitive(false.into()),
Attribute::Extractable(true.into()),
Attribute::Token(false.into()),
],
)?;

let value_attribute = session
.get_attributes(shared_secret, &[AttributeType::Value])
.unwrap()
.get_attributes(shared_secret, &[AttributeType::Value])?
.remove(0);
let value = if let Attribute::Value(value) = value_attribute {
value
Expand All @@ -218,24 +222,26 @@ fn derive_key() {
assert_eq!(value.len(), 32);

// delete keys
session.destroy_object(public).unwrap();
session.destroy_object(private).unwrap();
session.destroy_object(public)?;
session.destroy_object(private)?;

Ok(())
}

#[test]
#[serial]
fn import_export() {
fn import_export() -> Result<()> {
let (pkcs11, slot) = init_pins();

// set flags
let mut flags = Flags::new();
let _ = flags.set_rw_session(true).set_serial_session(true);

// open a session
let session = pkcs11.open_session_no_callback(slot, flags).unwrap();
let session = pkcs11.open_session_no_callback(slot, flags)?;

// log in the session
session.login(UserType::User).unwrap();
session.login(UserType::User)?;

let public_exponent: Vec<u8> = vec![0x01, 0x00, 0x01];
let modulus = vec![0xFF; 1024];
Expand All @@ -252,14 +258,13 @@ fn import_export() {

{
// Intentionally forget the object handle to find it later
let _public_key = session.create_object(&template).unwrap();
let _public_key = session.create_object(&template)?;
}

let is_it_the_public_key = session.find_objects(&template).unwrap().remove(0);
let is_it_the_public_key = session.find_objects(&template)?.remove(0);

let attribute_info = session
.get_attribute_info(is_it_the_public_key, &[AttributeType::Modulus])
.unwrap()
.get_attribute_info(is_it_the_public_key, &[AttributeType::Modulus])?
.remove(0);

if let AttributeInfo::Available(size) = attribute_info {
Expand All @@ -269,8 +274,7 @@ fn import_export() {
};

let attr = session
.get_attributes(is_it_the_public_key, &[AttributeType::Modulus])
.unwrap()
.get_attributes(is_it_the_public_key, &[AttributeType::Modulus])?
.remove(0);

if let Attribute::Modulus(modulus_cmp) = attr {
Expand All @@ -280,15 +284,19 @@ fn import_export() {
}

// delete key
session.destroy_object(is_it_the_public_key).unwrap();
session.destroy_object(is_it_the_public_key)?;

Ok(())
}

#[test]
#[serial]
fn get_token_info() {
fn get_token_info() -> Result<()> {
let (pkcs11, slot) = init_pins();
let info = pkcs11.get_token_info(slot).unwrap();
let info = pkcs11.get_token_info(slot)?;
assert_eq!("SoftHSM project", info.get_manufacturer_id());

Ok(())
}

#[test]
Expand Down