Skip to content

Commit

Permalink
Replace shark_turbine with iree.turbine (nod-ai#870)
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored Oct 22, 2024
1 parent db5f1b6 commit 541572a
Show file tree
Hide file tree
Showing 22 changed files with 42 additions and 42 deletions.
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/resnet_18.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from shark_turbine.aot import *
from iree.turbine.aot import *
from iree.compiler.ir import Context
import iree.runtime as rt
from turbine_models.custom_models.sd_inference import utils
import shark_turbine.ops.iree as ops
import iree.turbine.ops.iree as ops
import argparse

parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import numpy as np
from tqdm.auto import tqdm
from shark_turbine.ops.iree import trace_tensor
from iree.turbine.ops.iree import trace_tensor

torch.random.manual_seed(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import torch
from typing import Any, Callable, Dict, List, Optional, Union
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
import iree.turbine.ops.iree as ops
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from iree.compiler.ir import Context
import iree.runtime as ireert
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
from turbine_models.custom_models.sd3_inference.text_encoder_impls import (
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sd3_inference/sd3_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from turbine_models.custom_models.sd_inference import utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch, math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
from shark_turbine import ops
from iree.turbine import ops

#################################################################################################
### Core/Utility
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import re

from iree.compiler.ir import Context
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import List

import torch
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
import iree.turbine.ops.iree as ops
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from iree.compiler.ir import Context
import iree.runtime as ireert
import numpy as np
Expand Down
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sdxl_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass

from turbine_models.custom_models.sd_inference import utils
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context

from shark_turbine.aot import *
import shark_turbine.ops as ops
from iree.turbine.aot import *
import iree.turbine.ops as ops

from turbine_models.custom_models.sd_inference import utils
from turbine_models.custom_models.sd_inference.schedulers import get_scheduler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import numpy as np
from tqdm.auto import tqdm
from shark_turbine.ops.iree import trace_tensor
from iree.turbine.ops.iree import trace_tensor

torch.random.manual_seed(0)

Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass


from turbine_models.custom_models.sd_inference import utils
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from turbine_models.custom_models.sd_inference import utils
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils import _pytree as pytree
from shark_turbine.aot import *
from iree.turbine.aot import *
from iree.compiler.ir import Context
from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import (
enable_llama_pos_shift_attention,
Expand Down Expand Up @@ -458,7 +458,7 @@ def evict_kvcache_space(self):
# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant
from iree.turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
import shark_turbine.aot as aot
import iree.turbine.aot as aot
from turbine_models.turbine_tank import turbine_tank
import os
import re
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import os
import numpy as np
from iree.compiler.ir import Context
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
from turbine_models.custom_models.pipeline_base import (
PipelineComponent,
TurbinePipelineBase,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass

model_metadata_forward = {
"model_name": "TestModel2xLinear",
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/tests/stateless_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tempfile

os.environ["TORCH_LOGS"] = "dynamic"
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models import llm_runner

from turbine_models.gen_external_params.gen_external_params import (
Expand Down

0 comments on commit 541572a

Please sign in to comment.