Skip to content

Commit

Permalink
refactor/tests: fix tests after breaking change autumnai#37 [SKIP_CHA…
Browse files Browse the repository at this point in the history
…NGELOG]
  • Loading branch information
alexandermorozov committed Apr 18, 2016
1 parent 1fbc433 commit b87ac1c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 43 deletions.
6 changes: 3 additions & 3 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
//! // allocate memory
//! let native = Native::new();
//! let device = native.new_device(native.hardwares()).unwrap();
//! let shared_data = &mut SharedTensor::<i32>::new(&device, &5).unwrap();
//! let shared_data = &mut SharedTensor::<i32>::new(&5).unwrap();
//! // fill memory with some numbers
//! let local_data = [0, 1, 2, 3, 4];
//! let data = shared_data.get_mut(&device).unwrap().as_mut_native().unwrap();
//! let mut mem = shared_data.write_only(&device).unwrap().as_mut_native().unwrap();
//! mem.as_mut_slice::<i32>().clone_from_slice(&[0, 1, 2, 3, 4]);
//! # }
//! ```

Expand Down
3 changes: 2 additions & 1 deletion tests/framework_cuda_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ mod framework_cuda_spec {
let cuda = Cuda::new();
let device = cuda.new_device(&cuda.hardwares()[0..1]).unwrap();
for _ in 0..256 {
let _ = &mut SharedTensor::<f32>::new(&device, &vec![256, 1024, 128]).unwrap();
let x = &mut SharedTensor::<f32>::new(&vec![256, 1024, 128]).unwrap();
x.write_only(&device).unwrap();
}
}

Expand Down
67 changes: 30 additions & 37 deletions tests/shared_memory_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ mod shared_memory_spec {
fn it_creates_new_shared_memory_for_native() {
let ntv = Native::new();
let cpu = ntv.new_device(ntv.hardwares()).unwrap();
let shared_data = &mut SharedTensor::<f32>::new(&cpu, &10).unwrap();
match shared_data.get(&cpu).unwrap() {
&MemoryType::Native(ref dat) => {
let shared_data = &mut SharedTensor::<f32>::new(&10).unwrap();
match shared_data.write_only(&cpu).unwrap() {
&mut MemoryType::Native(ref dat) => {
let data = dat.as_slice::<f32>();
assert_eq!(10, data.len());
},
Expand All @@ -39,10 +39,11 @@ mod shared_memory_spec {
fn it_creates_new_shared_memory_for_cuda() {
let ntv = Cuda::new();
let device = ntv.new_device(&ntv.hardwares()[0..1]).unwrap();
let shared_data = &mut SharedTensor::<f32>::new(&device, &10).unwrap();
match shared_data.get(&device) {
Some(&MemoryType::Cuda(_)) => assert!(true),
_ => assert!(false),
let shared_data = &mut SharedTensor::<f32>::new(&10).unwrap();
match shared_data.write_only(&device) {
Ok(&mut MemoryType::Cuda(_)) => {},
#[cfg(any(feature = "cuda", feature = "opencl"))]
_ => assert!(false)
}
}

Expand All @@ -51,9 +52,9 @@ mod shared_memory_spec {
fn it_creates_new_shared_memory_for_opencl() {
let ntv = OpenCL::new();
let device = ntv.new_device(&ntv.hardwares()[0..1]).unwrap();
let shared_data = &mut SharedTensor::<f32>::new(&device, &10).unwrap();
match shared_data.get(&device) {
Some(&MemoryType::OpenCL(_)) => assert!(true),
let shared_data = &mut SharedTensor::<f32>::new(&10).unwrap();
match shared_data.write_only(&device) {
Ok(&mut MemoryType::OpenCL(_)) => {},
_ => assert!(false),
}
}
Expand All @@ -65,20 +66,22 @@ mod shared_memory_spec {
let nt = Native::new();
let cu_device = cu.new_device(&cu.hardwares()[0..1]).unwrap();
let nt_device = nt.new_device(nt.hardwares()).unwrap();
let mem = &mut SharedTensor::<f64>::new(&nt_device, &3).unwrap();
write_to_memory(mem.get_mut(&nt_device).unwrap(), &[1, 2, 3]);
mem.add_device(&cu_device).unwrap();
match mem.sync(&cu_device) {
let mem = &mut SharedTensor::<f64>::new(&3).unwrap();
write_to_memory(mem.write_only(&nt_device).unwrap(),
&[1.0f64, 2.0, 123.456]);
match mem.read(&cu_device) {
Ok(_) => assert!(true),
Err(err) => {
println!("{:?}", err);
assert!(false);
}
}
// It has not successfully synced to the device.
// It has successfully synced to the device.
// Not the other way around.
match mem.sync(&nt_device) {
Ok(_) => assert!(true),
mem.drop_device(&nt_device).unwrap();
match mem.read(&nt_device) {
Ok(m) => assert_eq!(m.as_native().unwrap().as_slice::<f64>(),
[1.0, 2.0, 123.456]),
Err(err) => {
println!("{:?}", err);
assert!(false);
Expand All @@ -93,10 +96,10 @@ mod shared_memory_spec {
let nt = Native::new();
let cl_device = cl.new_device(&cl.hardwares()[0..1]).unwrap();
let nt_device = nt.new_device(nt.hardwares()).unwrap();
let mem = &mut SharedTensor::<f64>::new(&nt_device, &3).unwrap();
write_to_memory(mem.get_mut(&nt_device).unwrap(), &[1, 2, 3]);
mem.add_device(&cl_device).unwrap();
match mem.sync(&cl_device) {
let mem = &mut SharedTensor::<f64>::new(&3).unwrap();
write_to_memory(mem.write_only(&nt_device).unwrap(),
&[1.0f64, 2.0, 123.456]);
match mem.read(&cl_device) {
Ok(_) => assert!(true),
Err(err) => {
println!("{:?}", err);
Expand All @@ -105,36 +108,26 @@ mod shared_memory_spec {
}
// It has not successfully synced to the device.
// Not the other way around.
match mem.sync(&nt_device) {
Ok(_) => assert!(true),
mem.drop_device(&nt_device);
match mem.read(&nt_device) {
Ok(m) => assert_eq!(m.as_native().unwrap().as_slice::<f64>(),
[1.0, 2.0, 123.456]),
Err(err) => {
println!("{:?}", err);
assert!(false);
}
}
}

#[test]
fn it_has_correct_latest_device() {
let ntv = Native::new();
let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap();
let shared_data = &mut SharedTensor::<f32>::new(&cpu_dev, &10).unwrap();
assert_eq!(&cpu_dev, shared_data.latest_device());
}

#[test]
fn it_reshapes_correctly() {
let ntv = Native::new();
let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap();
let mut shared_data = &mut SharedTensor::<f32>::new(&cpu_dev, &10).unwrap();
let mut shared_data = &mut SharedTensor::<f32>::new(&10).unwrap();
assert!(shared_data.reshape(&vec![5, 2]).is_ok());
}

#[test]
fn it_returns_err_for_invalid_size_reshape() {
let ntv = Native::new();
let cpu_dev = ntv.new_device(ntv.hardwares()).unwrap();
let mut shared_data = &mut SharedTensor::<f32>::new(&cpu_dev, &10).unwrap();
let mut shared_data = &mut SharedTensor::<f32>::new(&10).unwrap();
assert!(shared_data.reshape(&vec![10, 2]).is_err());
}
}
3 changes: 1 addition & 2 deletions tests/tensor_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ mod tensor_spec {

#[test]
fn it_resizes_tensor() {
let native = Backend::<Native>::default().unwrap();
let mut tensor = SharedTensor::<f32>::new(native.device(), &(10, 20, 30)).unwrap();
let mut tensor = SharedTensor::<f32>::new(&(10, 20, 30)).unwrap();
assert_eq!(tensor.desc(), &[10, 20, 30]);
tensor.resize(&(2, 3, 4, 5)).unwrap();
assert_eq!(tensor.desc(), &[2, 3, 4, 5]);
Expand Down

0 comments on commit b87ac1c

Please sign in to comment.