From 0ef80c26b1e4bef261f6cb25ba330394321f4d43 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:41:43 -0600 Subject: [PATCH 1/2] Add Enum module support in PyTorchFileRecorder Fixes #1431 --- burn-book/src/import/pytorch-model.md | 52 +++++ crates/burn-core/src/record/serde/de.rs | 146 ++++++++++++- .../tests/enum_module/enum_depthwise_false.pt | Bin 0 -> 1766 bytes .../tests/enum_module/enum_depthwise_true.pt | Bin 0 -> 2288 bytes .../tests/enum_module/export_weights.py | 60 ++++++ .../pytorch-tests/tests/enum_module/mod.rs | 194 ++++++++++++++++++ crates/burn-import/pytorch-tests/tests/mod.rs | 1 + 7 files changed, 445 insertions(+), 8 deletions(-) create mode 100644 crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_false.pt create mode 100644 crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_true.pt create mode 100755 crates/burn-import/pytorch-tests/tests/enum_module/export_weights.py create mode 100644 crates/burn-import/pytorch-tests/tests/enum_module/mod.rs diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 024a147c08..9bcc1f35a2 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -347,6 +347,58 @@ let record = PyTorchFileRecorder::::default() .expect("Should decode state successfully") ``` +### Models containing enum modules + +Burn supports models containing enum modules with new-type variants (tuple with one item). Importing +weights for such models is supported by using the `PyTorchFileRecorder` automatically. However, it +should be noted that since the source weights file does not contain the enum variant information, +the enum variant is picked based on the enum variant type. Let's consider the following example: + +```rust +#[derive(Module, Debug)] +pub enum Conv { + DwsConv(DwsConv), + Conv(Conv2d), +} + +#[derive(Module, Debug)] +pub struct DwsConv { + dconv: Conv2d, + pconv: Conv2d, +} + +#[derive(Module, Debug)] +pub struct Net { + conv: Conv, +} +``` + +If the source weights file contains weights for `DwsConv`, such as the following keys: + +```text +--- +Key: conv.dconv.bias +Shape: [2] +Dtype: F32 +--- +Key: conv.dconv.weight +Shape: [2, 1, 3, 3] +Dtype: F32 +--- +Key: conv.pconv.bias +Shape: [2] +Dtype: F32 +--- +Key: conv.pconv.weight +Shape: [2, 2, 1, 1] +Dtype: F32 +``` + +The weights will be imported into the `DwsConv` variant of the `Conv` enum module. + +If the variant types are identica, then the first variant is picked. Generally, it won't be a +problem since the variant types are usually different. + ## Current known issues 1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179). diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 5407dcdedf..b429e1cd2a 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -1,8 +1,10 @@ +use core::ptr; use std::collections::HashMap; use super::data::NestedValue; use super::{adapter::BurnModuleAdapter, error::Error}; +use serde::de::{EnumAccess, VariantAccess}; use serde::{ de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}, forward_to_deserialize_any, @@ -313,16 +315,65 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { unimplemented!("deserialize_tuple_struct is not implemented") } + /// Deserializes an enum by attempting to match its variants against the provided data. + /// + /// This function attempts to deserialize an enum by iterating over its possible variants + /// and trying to deserialize the data into each until one succeeds. We need to do this + /// because we don't have a way to know which variant to deserialize from the data. + /// + /// This is similar to Serde's + /// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged), + /// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force + /// using `deserialize_any`, which is not what we want because we want to use methods, such + /// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just + /// for enums because it will affect other serialization and deserialization, such + /// as JSON and Bincode. + /// + /// # Safety + /// The function uses an unsafe block to clone the `visitor`. This is necessary because + /// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it + /// as we are going to use it multiple times. The Visitor is a code generated unit struct + /// with no states or mutations, so it is safe to clone it in this case. We mainly care + /// about the `visit_enum` method, which is the only method that will be called on the + /// cloned visitor. fn deserialize_enum( self, _name: &'static str, - _variants: &'static [&'static str], - _visitor: V, + variants: &'static [&'static str], + visitor: V, ) -> Result where V: Visitor<'de>, { - unimplemented!("deserialize_enum is not implemented") + fn clone_unsafely(thing: &T) -> T { + unsafe { + // Allocate memory for the clone. + let clone = ptr::null_mut(); + // Correcting pointer usage based on feedback + let clone = ptr::addr_of_mut!(*clone); + // Copy the memory + ptr::copy_nonoverlapping(thing as *const T, clone, 1); + // Transmute the cloned data pointer into an owned instance of T. + ptr::read(clone) + } + } + + // Try each variant in order + for &variant in variants { + // clone visitor to avoid moving it + let cloned_visitor = clone_unsafely(&visitor); + let result = cloned_visitor.visit_enum(ProbeEnumAccess::::new( + self.value.clone().unwrap(), + variant.to_owned(), + self.default_for_missing_fields, + )); + + if result.is_ok() { + return result; + } + } + + Err(de::Error::custom("No variant match")) } fn deserialize_identifier(self, _visitor: V) -> Result @@ -431,6 +482,82 @@ where } } +struct ProbeEnumAccess { + value: NestedValue, + current_variant: String, + default_for_missing_fields: bool, + phantom: std::marker::PhantomData, +} + +impl ProbeEnumAccess { + fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self { + ProbeEnumAccess { + value, + current_variant, + default_for_missing_fields, + phantom: std::marker::PhantomData, + } + } +} + +impl<'de, A> EnumAccess<'de> for ProbeEnumAccess +where + A: BurnModuleAdapter, +{ + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + seed.deserialize(self.current_variant.clone().into_deserializer()) + .map(|v| (v, self)) + } +} + +impl<'de, A> VariantAccess<'de> for ProbeEnumAccess +where + A: BurnModuleAdapter, +{ + type Error = Error; + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + let value = seed.deserialize( + NestedValueWrapper::::new(self.value, self.default_for_missing_fields) + .into_deserializer(), + )?; + Ok(value) + } + + fn unit_variant(self) -> Result<(), Self::Error> { + unimplemented!("unit variant is not implemented because it is not used in the burn module") + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + unimplemented!("tuple variant is not implemented because it is not used in the burn module") + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + unimplemented!( + "struct variant is not implemented because it is not used in the burn module" + ) + } +} + /// A wrapper for the nested value data structure with a burn module adapter. struct NestedValueWrapper { value: NestedValue, @@ -601,11 +728,14 @@ impl<'de> serde::Deserializer<'de> for DefaultDeserializer { where V: Visitor<'de>, { - panic!( - "Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct", - self.originator_field_name.unwrap_or("UNKNOWN".to_string()), - name, - ); + // Return an error if the originator field name is not set + Err(Error::Other( + format!( + "Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct", + self.originator_field_name.unwrap_or("UNKNOWN".to_string()), + name, + ) + )) } fn deserialize_tuple_struct( diff --git a/crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_false.pt b/crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_false.pt new file mode 100644 index 0000000000000000000000000000000000000000..27e893591fb9065c19f43b79e1d22d420642face GIT binary patch literal 1766 zcma)7T}V?=9KYLi>dep5GCxu(o0{{^O`YVu(1;OU#Hj?spmBG1@?N}dp4~I0NH9<# z=piIBv@fFGk_rpTyz`~#B`C`t>Ltt9(w4UZOXPb^FRLFyL-?1{eI{Fb)B3p zjiNF$sei1TN~eG?IN-hXmSUUBr67Y0C<9u>~$mHQ~Uw< zP@@6rj7C=m-lpJ{d>?>mku3&??B`KWqkyEJ$W4fQS^p%@Ip4DJwunOz)ZdY)X!L!R+R^e(ee!TxEnc&=jjrUeT3W|`TJI0d+%978ex6`X{`9ghIuEd2-%f`{ zE`ZRD?^D8(C#J7vYk!>PjcGwT@Boa^8d)TG7m0{j5_j4OaBN9)|TSl3Z z^G0jTVXxf@aWwuG13B^1t-e!Q%t|~pa~*X!IkPF8d<6s~UY5@DWE8k1uftyTe<9`; znzq?!y0yvLY;A08X|h^dX_{_sv|7v-v)M{F+bmYvW+5tT?@X;2u|ncbPD8|QO|IC^ zDJL)DY}+}O2o>816>%`rxGj@HZ8csZR%|Ji#$gp>`AM-Njg$x!TO08}m3a821~MlC z#iqI}PSN#LK+(cVM2byET^v#&o^VMeB`*{wr`I2i0#xYgx1h)*BDQ&bL&iO}XJmk4 a!~4ibf*Ol1Ag`TVdib9n?;`6XYyScY48SP> literal 0 HcmV?d00001 diff --git a/crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_true.pt b/crates/burn-import/pytorch-tests/tests/enum_module/enum_depthwise_true.pt new file mode 100644 index 0000000000000000000000000000000000000000..e9488a6ef1681d1c4d03853b687ceb245431f795 GIT binary patch literal 2288 zcmbW3UuaWT9LG#qoc-!Qc zd(yJ8FhQy~Mu$F#4>}4(eBXQXR@j3$^`TG(qWB^VL{MSE1kbrQ#(QIKS9iim!uj+0 zeSg36`<;7J*wu!R*Nfh31IUe}`N~|jAeD%`sAy7_lq=FiAxCn7(pfc`aH7+v&pGq? zqN+-Hq7>&f?qs<5xIlB<#z~%F0Jan^0SLS9VXfw8ONwt_Gr;I@k+A|Kih6i`X#SE1y z;K92Ic)!NThn?7gUD!pS!-qTQ$AujPxJxpVwu8XVH0;WBTJ_B+ISoFNsd2eYHsAtu z${y?_&_$uxxdh$Vu>`vf%pL+Cr{R-#ZkEh$_EEkjd7KROt%!eLn-;_5@jX+oe16HODkP~fW|=o1xfCXm7& z@0Zv9{9?q35bG)1tQ$m~EwgUsdmB1jXnLjP;!q^VfJ6W-sIUu!218RAN)}J%cK9q#a%t|@FO|C9`;N#9lva5M!Tp6 z*0<8(vyX$AOdQ-+9Q-%^)6Zki^!V2zzxVyO|NUs9rQLerBoqE^8Lp?4$T5%a{Wkjz zEepiDy~kr{q^0G6p1hh=R&uKHLyq0VS*4In3~xgS9}&Z$XqXQPf)Ej=B4IJi3!!L8 z6sIO5q2OdB%15UnlToH}?R>|ek1Aw6g&j<6DwzNP literal 0 HcmV?d00001 diff --git a/crates/burn-import/pytorch-tests/tests/enum_module/export_weights.py b/crates/burn-import/pytorch-tests/tests/enum_module/export_weights.py new file mode 100755 index 0000000000..7a22623326 --- /dev/null +++ b/crates/burn-import/pytorch-tests/tests/enum_module/export_weights.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +import torch +from torch import nn, Tensor + +class DwsConv(nn.Module): + """Depthwise separable convolution.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None: + super().__init__() + # Depthwise conv + self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels) + # Pointwise conv + self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1) + + def forward(self, x: Tensor) -> Tensor: + x = self.dconv(x) + return self.pconv(x) + + +class Model(nn.Module): + def __init__(self, depthwise: bool = False) -> None: + super().__init__() + self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(x) + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + torch.save(model.state_dict(), "enum_depthwise_false.pt") + + input = torch.rand(1, 2, 5, 5) + + print("Depthwise is False") + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + + print("Depthwise is True") + model = Model(depthwise=True).to(torch.device("cpu")) + torch.save(model.state_dict(), "enum_depthwise_true.pt") + + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/pytorch-tests/tests/enum_module/mod.rs b/crates/burn-import/pytorch-tests/tests/enum_module/mod.rs new file mode 100644 index 0000000000..ad0eeb06f7 --- /dev/null +++ b/crates/burn-import/pytorch-tests/tests/enum_module/mod.rs @@ -0,0 +1,194 @@ +use burn::{ + module::Module, + nn::conv::{Conv2d, Conv2dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub enum Conv { + DwsConv(DwsConv), + Conv(Conv2d), +} + +#[derive(Module, Debug)] +pub struct DwsConv { + dconv: Conv2d, + pconv: Conv2d, +} + +#[derive(Module, Debug)] +pub struct Net { + conv: Conv, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + let conv = match record.conv { + ConvRecord::DwsConv(dws_conv) => { + let dconv = Conv2dConfig::new([2, 2], [3, 3]) + .with_groups(2) + .init_with(dws_conv.dconv); + let pconv = Conv2dConfig::new([2, 2], [1, 1]) + .with_groups(1) + .init_with(dws_conv.pconv); + Conv::DwsConv(DwsConv { dconv, pconv }) + } + ConvRecord::Conv(conv) => { + let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]); + Conv::Conv(conv2d_config.init_with(conv)) + } + }; + Net { conv } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + match &self.conv { + Conv::DwsConv(dws_conv) => { + let x = dws_conv.dconv.forward(x); + dws_conv.pconv.forward(x) + } + Conv::Conv(conv) => conv.forward(x), + } + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; + + use super::*; + + #[test] + fn depthwise_false() { + let device = Default::default(); + let load_args = + LoadArgs::new("tests/enum_module/enum_depthwise_false.pt".into()).with_debug_print(); + + let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + let input = Tensor::::from_data( + [[ + [ + [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], + [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], + [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], + [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], + [ + 0.804_481_1, + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + ], + ], + [ + [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], + [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], + [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], + [ + 0.03694397, + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + ], + [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.35449377, -0.02832414, 0.490_976_1], + [0.29709217, 0.332_586_3, 0.30594018], + [0.18101373, 0.30932188, 0.30558896], + ], + [ + [-0.17683622, -0.13244139, -0.05608707], + [0.23467252, -0.07038684, 0.255_044_1], + [-0.241_931_3, -0.20476191, -0.14468731], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } + + #[test] + fn depthwise_true() { + let device = Default::default(); + let load_args = + LoadArgs::new("tests/enum_module/enum_depthwise_true.pt".into()).with_debug_print(); + + let record = PyTorchFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + + let input = Tensor::::from_data( + [[ + [ + [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], + [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], + [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], + [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], + [ + 0.804_481_1, + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + ], + ], + [ + [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], + [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], + [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], + [ + 0.03694397, + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + ], + [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.77874625, 0.859_017_6, 0.834_283_5], + [0.773_056_4, 0.73817325, 0.78292674], + [0.710_775_2, 0.747_187_2, 0.733_264_4], + ], + [ + [-0.44891885, -0.49027523, -0.394_170_7], + [-0.43836114, -0.33961445, -0.387_311_5], + [-0.581_134_3, -0.34197026, -0.535_035_7], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/crates/burn-import/pytorch-tests/tests/mod.rs b/crates/burn-import/pytorch-tests/tests/mod.rs index bcf31a8278..b414a554d1 100644 --- a/crates/burn-import/pytorch-tests/tests/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/mod.rs @@ -8,6 +8,7 @@ mod conv2d; mod conv_transpose1d; mod conv_transpose2d; mod embedding; +mod enum_module; mod group_norm; mod integer; mod key_remap; From 6662ed906f4a03ecfd84283ee4cbc0a9d472c0d9 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:51:47 -0500 Subject: [PATCH 2/2] Fix wording/typos per PR feedback --- burn-book/src/import/pytorch-model.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 9bcc1f35a2..5051bbfc56 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -350,9 +350,9 @@ let record = PyTorchFileRecorder::::default() ### Models containing enum modules Burn supports models containing enum modules with new-type variants (tuple with one item). Importing -weights for such models is supported by using the `PyTorchFileRecorder` automatically. However, it -should be noted that since the source weights file does not contain the enum variant information, -the enum variant is picked based on the enum variant type. Let's consider the following example: +weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be +noted that since the source weights file does not contain the enum variant information, the enum +variant is picked based on the enum variant type. Let's consider the following example: ```rust #[derive(Module, Debug)] @@ -396,7 +396,7 @@ Dtype: F32 The weights will be imported into the `DwsConv` variant of the `Conv` enum module. -If the variant types are identica, then the first variant is picked. Generally, it won't be a +If the variant types are identical, then the first variant is picked. Generally, it won't be a problem since the variant types are usually different. ## Current known issues