Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(hlapi): ciphertext list decompress after safe_deser #2019

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tmontaigu
Copy link
Contributor

@tmontaigu tmontaigu commented Jan 30, 2025

After a safe_serialize/safe_deserialize, the CompressedCiphertextList was on Cpu. As the get method looked at the device of the data and not the device of the server_key to know where computation needs to happen, it meant that in this case decompressing using Gpu was impossible, only Cpu was usable (as data was always onlu on Cpu)

The fix is twofold:

  • First, when deserializing, the data will use the current serverkey (if any) as a hint on where data should be placed
  • the get method now uses the server_key to know where computations needs to be done, which may incur a temporary copy/transfer on every call to get if the device is not correct.

The API to move data has also been added

Note that this was not the case when using regular serialize/deserialize as this would store the device, so that deserialize was able to restore into the same device (hence why the test which use serialie/deserialize did not fail). In hindsight, the ser/de impl should not save which device the data originated from

Fixes zama-ai/tfhe-rs-internal/issues/905


This change is Reviewable

After a safe_serialize/safe_deserialize, the CompressedCiphertextList
was on Cpu. As the `get` method looked at the device of the data
and not the device of the server_key to know where computation
needs to happen, it meant that in this case decompressing using Gpu
was impossible, only Cpu was usable (as data was always onlu on Cpu)

The fix is twofold:
* First, when deserializing, the data will use the current serverkey
  (if any) as a hint on where data should be placed
* the `get` method now uses the server_key to know where computations
  needs to be done, which may incur a temporary copy/transfer on every
  call to `get` if the device is not correct.

The API to move data has also been added

Note that this was not the case when using regular serialize/deserialize
as this would store the device, so that deserialize was able to restore
into the same device (hence why the test which use serialie/deserialize
did not fail). In hindsight, the ser/de impl should not save which
device the data originated from
@cla-bot cla-bot bot added the cla-signed label Jan 30, 2025
Copy link
Member

@IceTDrinker IceTDrinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments/questions otherwise looks very good IMO

Reviewed 5 of 5 files at r1, all commit messages.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @tmontaigu)


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 224 at r1 (raw file):

}

impl<'de> serde::Deserialize<'de> for InnerCompressedCiphertextList {

could this be simplified somewhat by having the Fake enum and a Deserialize Into bound ?

Not against manual serialization code, just that it may need a bit less manual wokr here

something like https://serde.rs/container-attrs.html#from perhaps and have a serializable intermediate struct and the into would do the conversion


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 735 at r1 (raw file):

                let mut serialized = vec![];
                safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16)
                    .expect("safe serialize succeeds");

here the message would be "compressed list safe serialize failed" as it would appear in the stdout right ?

others are just unwrap though so maybe just an unwrap here could be enough


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 796 at r1 (raw file):

                let mut serialized = vec![];
                safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16)
                    .expect("safe serialize succeeds");

same remark


tfhe/src/core_crypto/gpu/mod.rs line 739 at r1 (raw file):

            d_vec
        };
        streams.synchronize();

is this enough to ensure copy went through ? I'm too rusty on cuda primitives currently to remember

Copy link
Contributor Author

@tmontaigu tmontaigu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @IceTDrinker)


tfhe/src/core_crypto/gpu/mod.rs line 739 at r1 (raw file):

Previously, IceTDrinker (Arthur Meyre) wrote…

is this enough to ensure copy went through ? I'm too rusty on cuda primitives currently to remember

It should, that is what we do for all functions that call someting async


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 224 at r1 (raw file):

Previously, IceTDrinker (Arthur Meyre) wrote…

could this be simplified somewhat by having the Fake enum and a Deserialize Into bound ?

Not against manual serialization code, just that it may need a bit less manual wokr here

something like https://serde.rs/container-attrs.html#from perhaps and have a serializable intermediate struct and the into would do the conversion

Using #[serde(from = "FromType")] would probably work, it would just make us move the definition of fake, and the new+move_to_device into a impl of From,

I can try it if you think that's better


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 735 at r1 (raw file):

Previously, IceTDrinker (Arthur Meyre) wrote…

here the message would be "compressed list safe serialize failed" as it would appear in the stdout right ?

others are just unwrap though so maybe just an unwrap here could be enough

Yes, will change this one and the other, it was the original message from the bug report

Copy link
Member

@IceTDrinker IceTDrinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks looks good !

Reviewable status: all files reviewed, 2 unresolved discussions (waiting on @tmontaigu)


tfhe/src/high_level_api/compressed_ciphertext_list.rs line 224 at r1 (raw file):

Previously, tmontaigu (tmontaigu) wrote…

Using #[serde(from = "FromType")] would probably work, it would just make us move the definition of fake, and the new+move_to_device into a impl of From,

I can try it if you think that's better

Not saying is better just wondering, if it's not worth it no need to do it

@tmontaigu
Copy link
Contributor Author

tfhe/src/high_level_api/compressed_ciphertext_list.rs line 224 at r1 (raw file):

Previously, IceTDrinker (Arthur Meyre) wrote…

Not saying is better just wondering, if it's not worth it no need to do it

I would say that to me, since this fake enum is only used in one place, it's better to hide it in the custom impl of Deserialize

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants