Skip to content

Commit

Permalink
Add Enum module support in PyTorchFileRecorder (#1436)
Browse files Browse the repository at this point in the history
* Add Enum module support in PyTorchFileRecorder

Fixes #1431

* Fix wording/typos per PR feedback
  • Loading branch information
antimora authored Mar 11, 2024
1 parent 9d4fbc5 commit 0138e16
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 8 deletions.
52 changes: 52 additions & 0 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,58 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.expect("Should decode state successfully")
```

### Models containing enum modules

Burn supports models containing enum modules with new-type variants (tuple with one item). Importing
weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be
noted that since the source weights file does not contain the enum variant information, the enum
variant is picked based on the enum variant type. Let's consider the following example:

```rust
#[derive(Module, Debug)]
pub enum Conv<B: Backend> {
DwsConv(DwsConv<B>),
Conv(Conv2d<B>),
}

#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
dconv: Conv2d<B>,
pconv: Conv2d<B>,
}

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv: Conv<B>,
}
```

If the source weights file contains weights for `DwsConv`, such as the following keys:

```text
---
Key: conv.dconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.dconv.weight
Shape: [2, 1, 3, 3]
Dtype: F32
---
Key: conv.pconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.pconv.weight
Shape: [2, 2, 1, 1]
Dtype: F32
```

The weights will be imported into the `DwsConv` variant of the `Conv` enum module.

If the variant types are identical, then the first variant is picked. Generally, it won't be a
problem since the variant types are usually different.

## Current known issues

1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
146 changes: 138 additions & 8 deletions crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::ptr;
use std::collections::HashMap;

use super::data::NestedValue;
use super::{adapter::BurnModuleAdapter, error::Error};

use serde::de::{EnumAccess, VariantAccess};
use serde::{
de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},
forward_to_deserialize_any,
Expand Down Expand Up @@ -313,16 +315,65 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
unimplemented!("deserialize_tuple_struct is not implemented")
}

/// Deserializes an enum by attempting to match its variants against the provided data.
///
/// This function attempts to deserialize an enum by iterating over its possible variants
/// and trying to deserialize the data into each until one succeeds. We need to do this
/// because we don't have a way to know which variant to deserialize from the data.
///
/// This is similar to Serde's
/// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged),
/// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force
/// using `deserialize_any`, which is not what we want because we want to use methods, such
/// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just
/// for enums because it will affect other serialization and deserialization, such
/// as JSON and Bincode.
///
/// # Safety
/// The function uses an unsafe block to clone the `visitor`. This is necessary because
/// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it
/// as we are going to use it multiple times. The Visitor is a code generated unit struct
/// with no states or mutations, so it is safe to clone it in this case. We mainly care
/// about the `visit_enum` method, which is the only method that will be called on the
/// cloned visitor.
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("deserialize_enum is not implemented")
fn clone_unsafely<T>(thing: &T) -> T {
unsafe {
// Allocate memory for the clone.
let clone = ptr::null_mut();
// Correcting pointer usage based on feedback
let clone = ptr::addr_of_mut!(*clone);
// Copy the memory
ptr::copy_nonoverlapping(thing as *const T, clone, 1);
// Transmute the cloned data pointer into an owned instance of T.
ptr::read(clone)
}
}

// Try each variant in order
for &variant in variants {
// clone visitor to avoid moving it
let cloned_visitor = clone_unsafely(&visitor);
let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(
self.value.clone().unwrap(),
variant.to_owned(),
self.default_for_missing_fields,
));

if result.is_ok() {
return result;
}
}

Err(de::Error::custom("No variant match"))
}

fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -431,6 +482,82 @@ where
}
}

struct ProbeEnumAccess<A: BurnModuleAdapter> {
value: NestedValue,
current_variant: String,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}

impl<A: BurnModuleAdapter> ProbeEnumAccess<A> {
fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self {
ProbeEnumAccess {
value,
current_variant,
default_for_missing_fields,
phantom: std::marker::PhantomData,
}
}
}

impl<'de, A> EnumAccess<'de> for ProbeEnumAccess<A>
where
A: BurnModuleAdapter,
{
type Error = Error;
type Variant = Self;

fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
seed.deserialize(self.current_variant.clone().into_deserializer())
.map(|v| (v, self))
}
}

impl<'de, A> VariantAccess<'de> for ProbeEnumAccess<A>
where
A: BurnModuleAdapter,
{
type Error = Error;

fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
let value = seed.deserialize(
NestedValueWrapper::<A>::new(self.value, self.default_for_missing_fields)
.into_deserializer(),
)?;
Ok(value)
}

fn unit_variant(self) -> Result<(), Self::Error> {
unimplemented!("unit variant is not implemented because it is not used in the burn module")
}

fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!("tuple variant is not implemented because it is not used in the burn module")
}

fn struct_variant<V>(
self,
_fields: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
unimplemented!(
"struct variant is not implemented because it is not used in the burn module"
)
}
}

/// A wrapper for the nested value data structure with a burn module adapter.
struct NestedValueWrapper<A: BurnModuleAdapter> {
value: NestedValue,
Expand Down Expand Up @@ -601,11 +728,14 @@ impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
where
V: Visitor<'de>,
{
panic!(
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
name,
);
// Return an error if the originator field name is not set
Err(Error::Other(
format!(
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
name,
)
))
}

fn deserialize_tuple_struct<V>(
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
import torch
from torch import nn, Tensor

class DwsConv(nn.Module):
"""Depthwise separable convolution."""

def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
super().__init__()
# Depthwise conv
self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
# Pointwise conv
self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)

def forward(self, x: Tensor) -> Tensor:
x = self.dconv(x)
return self.pconv(x)


class Model(nn.Module):
def __init__(self, depthwise: bool = False) -> None:
super().__init__()
self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)

def forward(self, x: Tensor) -> Tensor:
return self.conv(x)


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

torch.save(model.state_dict(), "enum_depthwise_false.pt")

input = torch.rand(1, 2, 5, 5)

print("Depthwise is False")
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


print("Depthwise is True")
model = Model(depthwise=True).to(torch.device("cpu"))
torch.save(model.state_dict(), "enum_depthwise_true.pt")

print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == '__main__':
main()
Loading

0 comments on commit 0138e16

Please sign in to comment.