Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer state_dict misses some values #3597

Open
JacoCheung opened this issue Jan 22, 2025 · 0 comments
Open

optimizer state_dict misses some values #3597

JacoCheung opened this issue Jan 22, 2025 · 0 comments

Comments

@JacoCheung
Copy link

Hi fbgemm team, while using torchrec EmbeddingCollection with adam optimizer, I found out that ec.fused_optimizer.state_dict() returns nothing but momentum tensors. lr, decay etc. which are usually accessible with normal torch.optim.Adam are gone.

See thread

I think the problem is here.


import os
import sys
sys.path.append(os.path.abspath('/home/scratch.junzhang_sw/workspace/torchrec'))
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")
ebc = torchrec.EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)

from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward

apply_optimizer_in_backward(
    optimizer_class=torch.optim.Adam,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embedding import EmbeddingCollectionSharder

sharder = EmbeddingCollectionSharder(
    # qcomm_codecs_registry=get_qcomm_codecs_registry(
    #         qcomms_config=QCommsConfig(
    #             forward_precision=CommType.FP16,
    #             backward_precision=CommType.BF16,
                  use_index_dedup=True,
    #         )
    #     )
)
dp_rank = dist.get_rank()
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device("cuda"))
mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 201, 101, 404, 404, 606, 606, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 3], dtype=torch.int64).cuda(),
)
import pdb;pdb.set_trace()
ret = model(mb) # => this is awaitable
product = ret['product'] # implicitly call awaitable.wait()
# import pdb;pdb.set_trace()

Above model gives me optimizer state like:

>>> model.fused_optimizer.state_dict()['state']['embeddings.product_table.weight'].keys()
dict_keys(['product_table.momentum1', 'product_table.exp_avg_sq']) 
# only `state` key.value, there are no param_groups that contain the lr, beta1 etc.

Those metadata are mandatory when I need to dump and reload the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant