diff --git a/docs/source/conf.py b/docs/source/conf.py index 6fec50e0de2..28244ad405f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,6 +16,11 @@ import sys import oneflow +sys.path.insert(0, os.path.abspath(".")) +CN_DOCS = os.getenv("CN_DOCS") +if CN_DOCS: + import zh + # -- Project information ----------------------------------------------------- project = u"OneFlow" diff --git a/docs/source/zh/__init__.py b/docs/source/zh/__init__.py new file mode 100644 index 00000000000..e457bc66b58 --- /dev/null +++ b/docs/source/zh/__init__.py @@ -0,0 +1,2 @@ +from .math_ops import * +from .activation import * diff --git a/docs/source/zh/activation.py b/docs/source/zh/activation.py new file mode 100644 index 00000000000..440309f0d5d --- /dev/null +++ b/docs/source/zh/activation.py @@ -0,0 +1,32 @@ +import oneflow +from oneflow.framework.docstr.utils import reset_docstr + +reset_docstr( + oneflow.nn.ReLU, + r"""ReLU(inplace=False) + + ReLU 激活函数,对张量中的每一个元素做 element-wise 运算,公式如下: + + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + + 参数: + inplace: 是否做 in-place 操作。 默认为 ``False`` + + 形状: + - Input: :math:`(N, *)` 其中 `*` 的意思是,可以指定任意维度 + - Output: :math:`(N, *)` 输入形状与输出形状一致 + + 示例: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + >>> relu = flow.nn.ReLU() + >>> ndarr = np.asarray([1, -2, 3]) + >>> x = flow.Tensor(ndarr) + >>> relu(x) + tensor([1., 0., 3.], dtype=oneflow.float32) + + """, +) diff --git a/docs/source/zh/math_ops.py b/docs/source/zh/math_ops.py new file mode 100644 index 00000000000..98491bd7b5c --- /dev/null +++ b/docs/source/zh/math_ops.py @@ -0,0 +1,43 @@ +import oneflow +from oneflow.framework.docstr.utils import reset_docstr + +reset_docstr( + oneflow.add, + r"""add(input, other) + + 计算 `input` 和 `other` 的和。支持 element-wise、标量和广播形式的加法。 + 公式为: + + .. math:: + out = input + other + + 示例: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow as flow + + # element-wise 加法 + >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32) + >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) + >>> out = flow.add(x, y).numpy() + >>> out.shape + (2, 3) + + # 标量加法 + >>> x = 5 + >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) + >>> out = flow.add(x, y).numpy() + >>> out.shape + (2, 3) + + # 广播加法 + >>> x = flow.tensor(np.random.randn(1,1), dtype=flow.float32) + >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) + >>> out = flow.add(x, y).numpy() + >>> out.shape + (2, 3) + + """, +) diff --git a/oneflow/api/python/framework/doc.cpp b/oneflow/api/python/framework/doc.cpp index c351fac6dd3..71f0edc8242 100644 --- a/oneflow/api/python/framework/doc.cpp +++ b/oneflow/api/python/framework/doc.cpp @@ -49,6 +49,37 @@ py::object AddFunctionDoc(py::object f, const std::string& doc_string) { return f; } +py::object ReplaceDoc(py::object f, const std::string& doc_string) { + static std::vector all_doc_strings; + all_doc_strings.emplace_back(doc_string); + const char* doc_str = all_doc_strings.back().c_str(); + PyObject* obj = f.ptr(); + if (PyCFunction_Check(obj)) { + auto* f = (PyCFunctionObject*)obj; + if (!f->m_ml->ml_doc) { + THROW(RuntimeError) << "function " << f->m_ml->ml_name << " has not a docstring yet."; + } + f->m_ml->ml_doc = doc_str; + } else if (PyFunction_Check(obj)) { + auto* f = (PyFunctionObject*)obj; + if (f->func_doc == Py_None) { + THROW(RuntimeError) << "function " + << PyBytes_AsString( + PyUnicode_AsEncodedString(f->func_name, "utf-8", "~E~")) + << " has not a docstring yet."; + } + Py_DECREF(f->func_doc); + f->func_doc = PyUnicode_FromString(doc_str); + } else { + THROW(RuntimeError) << "function is " << Py_TYPE(obj)->tp_name << ", not a valid function."; + } + f.inc_ref(); + return f; +} + } // namespace oneflow -ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("add_doc", &oneflow::AddFunctionDoc); } +ONEFLOW_API_PYBIND11_MODULE("", m) { + m.def("add_doc", &oneflow::AddFunctionDoc); + m.def("reset_doc", &oneflow::ReplaceDoc); +} diff --git a/python/oneflow/framework/docstr/utils.py b/python/oneflow/framework/docstr/utils.py index 301610c8d1e..125058516ca 100644 --- a/python/oneflow/framework/docstr/utils.py +++ b/python/oneflow/framework/docstr/utils.py @@ -15,7 +15,28 @@ """ import oneflow._oneflow_internal +from doctest import DocTestParser, DebugRunner, DocTestRunner + + +def _test_docstr(docstr, verbose=True, optionflags=0, raise_on_error=True): + parser = DocTestParser() + if raise_on_error: + runner = DebugRunner(verbose=verbose, optionflags=optionflags) + else: + runner = DocTestRunner(verbose=verbose, optionflags=optionflags) + test = parser.get_doctest(docstr, {}, __name__, __file__, 0) + runner.run(test) def add_docstr(fun, docstr: str): return oneflow._oneflow_internal.add_doc(fun, docstr) + + +def reset_docstr(o, docstr): + _test_docstr(docstr) + if type(o) == type: + assert hasattr(o, "__doc__"), str(o) + " does not have a docstring!" + setattr(o, "__doc__", docstr) + return o + else: + return oneflow._oneflow_internal.reset_doc(o, docstr)