Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve from_fn and add com_ptr_from_fn macro #23

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ winapi = { version = "0.3", features = [
"mmdeviceapi",
"processthreadsapi",
"setupapi",
"shobjidl_core",
"std",
"synchapi",
"unknwnbase",
Expand Down
75 changes: 41 additions & 34 deletions src/bstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,55 @@
use crate::wide::{FromWide, ToWide};
use std::{
alloc::{handle_alloc_error, Layout},
ptr::{self, NonNull},
convert::TryInto,
ffi::{OsStr, OsString},
path::PathBuf,
slice::from_raw_parts,
};
use winapi::{
shared::wtypes::BSTR,
um::oleauto::{
SysAllocStringByteLen, SysAllocStringLen, SysFreeString, SysStringByteLen, SysStringLen,
shared::{
wtypes::BSTR,
winerror::HRESULT,
},
um::{
oleauto::{SysAllocStringByteLen, SysAllocStringLen, SysFreeString, SysStringByteLen, SysStringLen},
winnt::WCHAR,
}
};
#[derive(Debug)]
pub struct BStr(BSTR);
pub struct BStr(NonNull<WCHAR>);
impl BStr {
pub unsafe fn new(s: BSTR) -> Option<BStr> {
NonNull::new(s).map(BStr)
}
pub unsafe fn from_raw(s: BSTR) -> BStr {
BStr(s)
BStr(NonNull::new(s).expect("ptr should not be null"))
}
pub unsafe fn from_fn<F>(fun: F) -> Result<BStr, HRESULT>
where
F: FnOnce(&mut BSTR) -> HRESULT
{
let mut ptr: BSTR = ptr::null_mut();
let res = fun(&mut ptr);
let bstr = BStr::new(ptr);
match res {
0 => Ok(bstr.expect("fun must set bstr to a value")),
res => {
if bstr.is_some() {
log_if_feature!("BStr::from_fn had an initialized BSTR pointer despite the function returning an error");
}
Err(res)
}
}
}
pub fn from_wide(s: &[u16]) -> BStr {
unsafe {
let ptr = SysAllocStringLen(s.as_ptr(), s.len().try_into().unwrap());
if ptr.is_null() {
handle_alloc_error(Layout::array::<u16>(s.len()).unwrap())
}
BStr(ptr)
BStr(NonNull::new_unchecked(ptr))
}
}
pub fn from_bytes(s: &[u8]) -> BStr {
Expand All @@ -38,48 +63,30 @@ impl BStr {
if ptr.is_null() {
handle_alloc_error(Layout::array::<u8>(s.len()).unwrap())
}
BStr(ptr)
BStr(NonNull::new_unchecked(ptr))
}
}
pub fn len(&self) -> usize {
unsafe { SysStringLen(self.0) as usize }
unsafe { SysStringLen(self.0.as_ptr()) as usize }
}
pub fn byte_len(&self) -> usize {
unsafe { SysStringByteLen(self.0) as usize }
}
pub fn is_null(&self) -> bool {
self.0.is_null()
unsafe { SysStringByteLen(self.0.as_ptr()) as usize }
}
pub fn as_ptr(&self) -> BSTR {
self.0
self.0.as_ptr()
}
pub fn as_wide(&self) -> &[u16] {
if self.0.is_null() {
&[]
} else {
unsafe { from_raw_parts(self.0, self.len()) }
}
unsafe { from_raw_parts(self.0.as_ptr(), self.len()) }
}
pub fn as_wide_null(&self) -> &[u16] {
if self.0.is_null() {
&[0]
} else {
unsafe { from_raw_parts(self.0, self.len() + 1) }
}
unsafe { from_raw_parts(self.0.as_ptr(), self.len() + 1) }
}
pub fn as_bytes(&self) -> &[u8] {
if self.0.is_null() {
&[]
} else {
unsafe { from_raw_parts(self.0.cast(), self.byte_len()) }
}
unsafe { from_raw_parts(self.0.as_ptr().cast(), self.byte_len()) }
}
pub fn as_bytes_null(&self) -> &[u8] {
if self.0.is_null() {
&[0]
} else {
unsafe { from_raw_parts(self.0.cast(), self.byte_len() + 1) }
}
// TODO: BECAUSE CHARS ARE WCHARS, SHOULD THIS BE +2 INSTEAD OF +1?
unsafe { from_raw_parts(self.0.as_ptr().cast(), self.byte_len() + 1) }
}
pub fn to_string(&self) -> Option<String> {
let os: OsString = self.into();
Expand All @@ -98,7 +105,7 @@ impl Clone for BStr {
}
impl Drop for BStr {
fn drop(&mut self) {
unsafe { SysFreeString(self.0) };
unsafe { SysFreeString(self.0.as_ptr()) };
}
}
impl<T> From<T> for BStr
Expand Down
91 changes: 73 additions & 18 deletions src/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,64 @@ use std::mem::forget;
use std::ops::Deref;
use std::ptr::{null_mut, NonNull};
use winapi::um::unknwnbase::IUnknown;
use winapi::shared::guiddef::GUID;
use winapi::shared::winerror::HRESULT;
use winapi::Interface;

/// Simplifies the common pattern of calling a function to initialize multiple `ComPtr`s.
///
/// This macro is a generalization of [`ComPtr::from_fn`][from_fn] to functions
/// that output multiple COM objects. It returns `Result<($(ComPtr<_>,)+), HRESULT>`, where the
/// `Ok` tuple contains the same number of `ComPtr` objects as the number of GUID/pointer pairs
/// passed to this macro.
///
/// See [`ComPtr::from_fn`][from_fn] for details on the exact semantics.
///
/// [from_fn]: crate::com::ComPtr::from_fn
#[macro_export]
macro_rules! com_ptr_from_fn {
(
|$(($guid:pat, $ptr:ident)),+ $(,)?| $init:expr
) => {{
use winapi::Interface;
use winapi::shared::guiddef::GUID;

// hack to get the GUID through type inference
struct GuidHack<T: Interface>(*mut T);
impl<T: Interface> GuidHack<T> {
fn guid(&self) -> GUID {
T::uuidof()
}
}

$(
let mut $ptr = GuidHack(std::ptr::null_mut());
)+

let result: winapi::shared::winerror::HRESULT = {
$(
let $guid = &$ptr.guid();
let $ptr = &mut *(&mut ($ptr.0) as *mut *mut _ as *mut *mut _);
)+
(|| $init)()
};
$(
let $ptr = $crate::com::ComPtr::new($ptr.0);
)+
match result {
0 => Ok(($(
$ptr.expect(concat!("`", stringify!($expr), "` must set `", stringify!($ptr), "` to a value")),
)+)),
res => {
$(if $ptr.is_some() {
$crate::log_if_feature!("ComPtr::from_fn had an initialized COM pointer despite the function returning an error");
})+
Err(res)
}
}
}};
}

// ComPtr to wrap COM interfaces sanely
#[repr(transparent)]
pub struct ComPtr<T>(NonNull<T>);
Expand All @@ -33,28 +89,27 @@ impl<T> ComPtr<T> {
{
ComPtr(NonNull::new(ptr).expect("ptr should not be null"))
}
/// Simplifies the common pattern of calling a function to initialize a ComPtr.
/// Simplifies the common pattern of calling a function to initialize a `ComPtr`.
///
/// `fun` gets passed `T`'s `GUID` and a mutable reference to a null pointer. If `fun` returns
/// `S_OK`, it _must_ initialize the pointer to a non-null value.
///
/// If `fun` *doesn't* return `S_OK` but still initializes the pointer, this function will
/// assume that the pointer was initialized to a valid COM object and will call `Release` on
/// it. If the `log` feature is enabled, it will emit a warning when that happens.
///
/// May leak the COM pointer if the function panics after initializing the pointer.
/// The pointer provided to the function starts as a null pointer.
/// If the pointer is initialized to a non-null value, it will be interpreted as a valid COM
/// pointer, even if the function returns an error in which case it will be released by
/// `from_fn` and a warning logged if logging is enabled.
pub unsafe fn from_fn<F, E>(fun: F) -> Result<Option<ComPtr<T>>, E>
///
/// If you're calling a COM function that generates multiple COM objects, use the
/// [`com_ptr_from_fn!`](../macro.com_ptr_from_fn.html) macro.
pub unsafe fn from_fn<F, P>(fun: F) -> Result<ComPtr<T>, HRESULT>
where
T: Interface,
F: FnOnce(&mut *mut T) -> Result<(), E>,
F: FnOnce(&GUID, &mut *mut P) -> HRESULT
{
let mut ptr = null_mut();
let res = fun(&mut ptr);
let com = ComPtr::new(ptr);
match res {
Ok(()) => Ok(com),
Err(err) => {
if com.is_some() {
#[cfg(feature = "log")] log::warn!("ComPtr::from_fn had an initialized COM pointer despite the function returning an error")
}
Err(err)
}
match com_ptr_from_fn!(|(guid, ptr)| fun(guid, ptr)) {
Ok((p,)) => Ok(p),
Err(e) => Err(e),
}
}
/// Casts up the inheritance chain
Expand Down
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@
#![allow(clippy::missing_safety_doc, clippy::len_without_is_empty)]
extern crate winapi;

#[doc(hidden)]
#[macro_export]
#[cfg(feature = "log")]
macro_rules! log_if_feature {
($($args:tt)*) => {log::warn!($($args)*)};
}

#[doc(hidden)]
#[macro_export]
#[cfg(not(feature = "log"))]
macro_rules! log_if_feature {
($($args:tt)*) => {};
}

// pub mod apc;
pub mod bstr;
pub mod com;
Expand Down
45 changes: 45 additions & 0 deletions tests/com_ptr_from_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
// All files in the project carrying such notice may not be copied, modified, or distributed
// except according to those terms.

extern crate wio;
extern crate winapi;

use wio::{com_ptr_from_fn, com::ComPtr};
use winapi::{
Class,
um::{
combaseapi,
shobjidl_core::{ShellItem, TaskbarList, IShellItem, ITaskbarList},
},
};
use std::ptr::null_mut;

#[test]
fn test_multi_com_ptr() {
unsafe {
let _: Result<(ComPtr<IShellItem>, ComPtr<ITaskbarList>), _> = com_ptr_from_fn!(
|(shell_guid, shell_ptr), (taskbar_guid, taskbar_ptr)| {
let hr = combaseapi::CoCreateInstance(
&ShellItem::uuidof(),
null_mut(),
combaseapi::CLSCTX_ALL,
shell_guid,
shell_ptr,
);
if hr != 0 {
return hr;
}
combaseapi::CoCreateInstance(
&TaskbarList::uuidof(),
null_mut(),
combaseapi::CLSCTX_ALL,
taskbar_guid,
taskbar_ptr,
)
}
);
}
}