diff --git a/src/tensor.rs b/src/tensor.rs index d0f9e68f..65dcbc74 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -40,10 +40,10 @@ //! // allocate memory //! let native = Native::new(); //! let device = native.new_device(native.hardwares()).unwrap(); -//! let shared_data = &mut SharedTensor::::new(&device, &5).unwrap(); +//! let shared_data = &mut SharedTensor::::new(&5).unwrap(); //! // fill memory with some numbers -//! let local_data = [0, 1, 2, 3, 4]; -//! let data = shared_data.get_mut(&device).unwrap().as_mut_native().unwrap(); +//! let mut mem = shared_data.write_only(&device).unwrap().as_mut_native().unwrap(); +//! mem.as_mut_slice::().clone_from_slice(&[0, 1, 2, 3, 4]); //! # } //! ``` diff --git a/tests/framework_cuda_specs.rs b/tests/framework_cuda_specs.rs index 0c853dcb..b52ea360 100644 --- a/tests/framework_cuda_specs.rs +++ b/tests/framework_cuda_specs.rs @@ -48,7 +48,8 @@ mod framework_cuda_spec { let cuda = Cuda::new(); let device = cuda.new_device(&cuda.hardwares()[0..1]).unwrap(); for _ in 0..256 { - let _ = &mut SharedTensor::::new(&device, &vec![256, 1024, 128]).unwrap(); + let x = &mut SharedTensor::::new(&vec![256, 1024, 128]).unwrap(); + x.write_only(&device).unwrap(); } } diff --git a/tests/shared_memory_specs.rs b/tests/shared_memory_specs.rs index 56ada6a9..23175bf0 100644 --- a/tests/shared_memory_specs.rs +++ b/tests/shared_memory_specs.rs @@ -23,9 +23,9 @@ mod shared_memory_spec { fn it_creates_new_shared_memory_for_native() { let ntv = Native::new(); let cpu = ntv.new_device(ntv.hardwares()).unwrap(); - let shared_data = &mut SharedTensor::::new(&cpu, &10).unwrap(); - match shared_data.get(&cpu).unwrap() { - &MemoryType::Native(ref dat) => { + let shared_data = &mut SharedTensor::::new(&10).unwrap(); + match shared_data.write_only(&cpu).unwrap() { + &mut MemoryType::Native(ref dat) => { let data = dat.as_slice::(); assert_eq!(10, data.len()); }, @@ -39,10 +39,11 @@ mod shared_memory_spec { fn it_creates_new_shared_memory_for_cuda() { let ntv = Cuda::new(); let device = ntv.new_device(&ntv.hardwares()[0..1]).unwrap(); - let shared_data = &mut SharedTensor::::new(&device, &10).unwrap(); - match shared_data.get(&device) { - Some(&MemoryType::Cuda(_)) => assert!(true), - _ => assert!(false), + let shared_data = &mut SharedTensor::::new(&10).unwrap(); + match shared_data.write_only(&device) { + Ok(&mut MemoryType::Cuda(_)) => {}, + #[cfg(any(feature = "cuda", feature = "opencl"))] + _ => assert!(false) } } @@ -51,9 +52,9 @@ mod shared_memory_spec { fn it_creates_new_shared_memory_for_opencl() { let ntv = OpenCL::new(); let device = ntv.new_device(&ntv.hardwares()[0..1]).unwrap(); - let shared_data = &mut SharedTensor::::new(&device, &10).unwrap(); - match shared_data.get(&device) { - Some(&MemoryType::OpenCL(_)) => assert!(true), + let shared_data = &mut SharedTensor::::new(&10).unwrap(); + match shared_data.write_only(&device) { + Ok(&mut MemoryType::OpenCL(_)) => {}, _ => assert!(false), } } @@ -65,20 +66,22 @@ mod shared_memory_spec { let nt = Native::new(); let cu_device = cu.new_device(&cu.hardwares()[0..1]).unwrap(); let nt_device = nt.new_device(nt.hardwares()).unwrap(); - let mem = &mut SharedTensor::::new(&nt_device, &3).unwrap(); - write_to_memory(mem.get_mut(&nt_device).unwrap(), &[1, 2, 3]); - mem.add_device(&cu_device).unwrap(); - match mem.sync(&cu_device) { + let mem = &mut SharedTensor::::new(&3).unwrap(); + write_to_memory(mem.write_only(&nt_device).unwrap(), + &[1.0f64, 2.0, 123.456]); + match mem.read(&cu_device) { Ok(_) => assert!(true), Err(err) => { println!("{:?}", err); assert!(false); } } - // It has not successfully synced to the device. + // It has successfully synced to the device. // Not the other way around. - match mem.sync(&nt_device) { - Ok(_) => assert!(true), + mem.drop_device(&nt_device).unwrap(); + match mem.read(&nt_device) { + Ok(m) => assert_eq!(m.as_native().unwrap().as_slice::(), + [1.0, 2.0, 123.456]), Err(err) => { println!("{:?}", err); assert!(false); @@ -93,10 +96,10 @@ mod shared_memory_spec { let nt = Native::new(); let cl_device = cl.new_device(&cl.hardwares()[0..1]).unwrap(); let nt_device = nt.new_device(nt.hardwares()).unwrap(); - let mem = &mut SharedTensor::::new(&nt_device, &3).unwrap(); - write_to_memory(mem.get_mut(&nt_device).unwrap(), &[1, 2, 3]); - mem.add_device(&cl_device).unwrap(); - match mem.sync(&cl_device) { + let mem = &mut SharedTensor::::new(&3).unwrap(); + write_to_memory(mem.write_only(&nt_device).unwrap(), + &[1.0f64, 2.0, 123.456]); + match mem.read(&cl_device) { Ok(_) => assert!(true), Err(err) => { println!("{:?}", err); @@ -105,8 +108,10 @@ mod shared_memory_spec { } // It has not successfully synced to the device. // Not the other way around. - match mem.sync(&nt_device) { - Ok(_) => assert!(true), + mem.drop_device(&nt_device); + match mem.read(&nt_device) { + Ok(m) => assert_eq!(m.as_native().unwrap().as_slice::(), + [1.0, 2.0, 123.456]), Err(err) => { println!("{:?}", err); assert!(false); @@ -114,27 +119,15 @@ mod shared_memory_spec { } } - #[test] - fn it_has_correct_latest_device() { - let ntv = Native::new(); - let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap(); - let shared_data = &mut SharedTensor::::new(&cpu_dev, &10).unwrap(); - assert_eq!(&cpu_dev, shared_data.latest_device()); - } - #[test] fn it_reshapes_correctly() { - let ntv = Native::new(); - let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap(); - let mut shared_data = &mut SharedTensor::::new(&cpu_dev, &10).unwrap(); + let mut shared_data = &mut SharedTensor::::new(&10).unwrap(); assert!(shared_data.reshape(&vec![5, 2]).is_ok()); } #[test] fn it_returns_err_for_invalid_size_reshape() { - let ntv = Native::new(); - let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap(); - let mut shared_data = &mut SharedTensor::::new(&cpu_dev, &10).unwrap(); + let mut shared_data = &mut SharedTensor::::new(&10).unwrap(); assert!(shared_data.reshape(&vec![10, 2]).is_err()); } } diff --git a/tests/tensor_specs.rs b/tests/tensor_specs.rs index 8590fc3d..61f7c666 100644 --- a/tests/tensor_specs.rs +++ b/tests/tensor_specs.rs @@ -31,8 +31,7 @@ mod tensor_spec { #[test] fn it_resizes_tensor() { - let native = Backend::::default().unwrap(); - let mut tensor = SharedTensor::::new(native.device(), &(10, 20, 30)).unwrap(); + let mut tensor = SharedTensor::::new(&(10, 20, 30)).unwrap(); assert_eq!(tensor.desc(), &[10, 20, 30]); tensor.resize(&(2, 3, 4, 5)).unwrap(); assert_eq!(tensor.desc(), &[2, 3, 4, 5]);