Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[CRC][No.34] add docstring for TensorFunction/Method/Layer (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
RedContritio authored Aug 8, 2023
1 parent 78cf992 commit 707ac5e
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def main_info(self) -> dict[str, Any]:


class TensorFunctionVariable(FunctionVariable):
"""
TensorFunctionVariable is a subclass of FunctionVariable used to wrap a method of a tensor.
Args:
method_name (str): The name of the tensor method to be wrapped.
graph(FunctionGraph): The FunctionGraph object that this variable is associated with.
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

def __init__(
self, method_name: str, graph: FunctionGraph, tracker: Tracker
):
Expand All @@ -230,6 +239,17 @@ def main_info(self) -> dict[str, Any]:


class MethodVariable(CallableVariable):
"""
MethodVariable is a subclass of CallableVariable used to wrap a method variable.
Args:
bound_instance (VariableBase): The instance of the method.
fn (VariableBase): The method to be wrapped.
graph(FunctionGraph): The FunctionGraph object that this variable is associated with.
tracker(Tracker): The Tracker object that tracks the information of this variable.
method_name (str): The name of the method to be wrapped.
"""

def __init__(
self,
bound_instance: VariableBase,
Expand Down Expand Up @@ -312,6 +332,15 @@ def main_info(self) -> dict[str, Any]:


class LayerVariable(CallableVariable):
"""
LayerVariable is a subclass of CallableVariable used to wrap a layer.
Args:
layer (paddle.nn.Layer): The layer to be wrapped.
graph(FunctionGraph): The FunctionGraph object that this variable is associated with.
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

def __init__(
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker
):
Expand Down

0 comments on commit 707ac5e

Please sign in to comment.