Skip to content

Warnings when learning on tpu #20890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
intexcor opened this issue Jun 11, 2025 · 1 comment
Open

Warnings when learning on tpu #20890

intexcor opened this issue Jun 11, 2025 · 1 comment
Labels
bug Something isn't working run TPU ver: 2.5.x

Comments

@intexcor
Copy link

Bug description

WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os


os.environ["WANDB_API_KEY"] = "652be9a335ccff9372ec8e5b16946c34163f0ff5"
os.environ["HF_TOKEN"] = "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"


torch.set_float32_matmul_precision('high')


class ChatDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=1024):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]

        chat = ("<|im_start|>user\n" + data["input"] + "<|im_end|>\n" +
                "<|im_start|>assistant\n<think>\n \n</think>\n" + data["output"] + "<|im_end|>\n")

        encoding = self.tokenizer(
            chat,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()

        labels = input_ids.clone()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class LanguageModelLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=2e-5, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto"
        )
        self.model.train()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        return optimizer


dataset = load_dataset('intexcp/russian-llm-training-dataset')

model = LanguageModelLightning("Qwen/Qwen3-0.6B")
#model = torch.compile(model)

train_dataset = ChatDataset(dataset["train"], model.tokenizer)
val_dataset = ChatDataset(dataset["test"], model.tokenizer)


train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=16,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=16,
    pin_memory=True
)

wandb_logger = WandbLogger(
    project="IGen",
    name="IGen"
)

checkpoint_callback = ModelCheckpoint(
    dirpath="IGen/checkpoints",
    filename='{epoch}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=2,
    precision="bf16-true",
    accelerator="auto",
    strategy="auto",
    devices="auto",
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=1,
    log_every_n_steps=50,
    enable_model_summary=True,
    enable_progress_bar=True,
)

trainer.fit(model, train_loader, val_loader)


model.model.save_pretrained("IGen/final_model")
model.tokenizer.save_pretrained("IGen/final_model")

Error messages and logs

# Error messages and logs here please
```WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.


### Environment

<details>
  <summary>Current environment</summary>

#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):


</details>


### More info

_No response_
@intexcor intexcor added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 11, 2025
@Borda Borda added run TPU and removed needs triage Waiting to be triaged by maintainers labels Jun 11, 2025
@rittik9
Copy link
Contributor

rittik9 commented Jun 13, 2025

#20872

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working run TPU ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

3 participants