diff --git a/notebooks/deepdive.ipynb b/notebooks/deepdive.ipynb index 3ac51d73..69f06f04 100644 --- a/notebooks/deepdive.ipynb +++ b/notebooks/deepdive.ipynb @@ -553,8 +553,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel ran was not connected to run, andthus could not disconnect from it.\n", - " warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel run was not connected to ran, andthus could not disconnect from it.\n", " warn(\n" ] @@ -922,7 +920,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -951,6 +949,86 @@ ")()" ] }, + { + "cell_type": "markdown", + "id": "bd95ba27-e439-45cc-87d8-db587dc3b78c", + "metadata": {}, + "source": [ + "## Edge cases\n", + "\n", + "If output labels aren't provided, we try to scrape them from the source code for the function -- but this has limitations, like that the source code needs to be available (no lambda functions!) and that there's a single return value. \n", + "\n", + "If explicit output labels _are_ provided, we _still_ try to scrape them from the function source code just to make sure that everything lines up nicely. However, there are a couple of edge cases where you may want to tell the workflow code that you really know what you're serious about your labels and just use them without any validation.\n", + "\n", + "(Failing to find the source code to compare with only triggers a warning, so in-memory functions are still OK as long as you provide output labels.)\n", + "\n", + "Turning off this validation comes with some responsibility that your labels make sense and will work. Let's look at a couple examples:\n", + "\n", + "(1) You might want to return a single tuple, and break it appart into channels" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "1a43985b-98d7-4c56-b8fe-e6598298b44b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x0': 7, 'x1': 10.14}" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@as_function_node(\"x0\", \"x1\", validate_output_labels=False)\n", + "def ReturnsTuple(x: int) -> tuple[int, float]:\n", + " x = (x, x + 3.14)\n", + " return x\n", + "\n", + "from_tuple = ReturnsTuple(x=7, run_after_init=True)\n", + "from_tuple.outputs.to_value_dict()" + ] + }, + { + "cell_type": "markdown", + "id": "cca66b86-763c-4082-aca1-b19fd7edcc3a", + "metadata": {}, + "source": [ + "(2) To handle multiple return branches -- just be careful that the branches return the same number and type of values, or you may wind up with strange results." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "ab3ad9e6-2a5e-4b0f-82e3-9e7208970d22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True False\n" + ] + } + ], + "source": [ + "\n", + "@as_function_node(\"bool\", validate_output_labels=False)\n", + "def MultipleBranches(x):\n", + " if x < 10:\n", + " return True\n", + " else:\n", + " return False\n", + "\n", + "switch = MultipleBranches()\n", + "print(switch(3), switch(13))" + ] + }, { "cell_type": "markdown", "id": "5dc12164-b663-405b-872f-756996f628bd", @@ -978,7 +1056,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "id": "1cd000bd-9b24-4c39-9cac-70a3291d0660", "metadata": {}, "outputs": [], @@ -1005,7 +1083,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 34, "id": "7964df3c-55af-4c25-afc5-9e07accb606a", "metadata": {}, "outputs": [ @@ -1052,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "id": "809178a5-2e6b-471d-89ef-0797db47c5ad", "metadata": {}, "outputs": [ @@ -1106,7 +1184,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "id": "52c48d19-10a2-4c48-ae81-eceea4129a60", "metadata": {}, "outputs": [ @@ -1116,7 +1194,7 @@ "{'ay': 3, 'a + b + 2': 7}" ] }, - "execution_count": 34, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1144,7 +1222,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 37, "id": "bb35ba3e-602d-4c9c-b046-32da9401dd1c", "metadata": {}, "outputs": [ @@ -1154,7 +1232,7 @@ "(7, 3)" ] }, - "execution_count": 35, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1173,7 +1251,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 38, "id": "2b0d2c85-9049-417b-8739-8a8432a1efbe", "metadata": {}, "outputs": [ @@ -1515,10 +1593,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 36, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1545,14 +1623,14 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 39, "id": "ae500d5e-e55b-432c-8b5f-d5892193cdf5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "41321dda7d2945c281e8bd60ca8721eb", + "model_id": "a8d34d23d0594e9a892c2d00c3733bf1", "version_major": 2, "version_minor": 0 }, @@ -1571,7 +1649,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ec78b63fef4545afa899deea3f836d71", + "model_id": "96ca190eaee147f3b6268748a635acc1", "version_major": 2, "version_minor": 0 }, @@ -1585,10 +1663,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 37, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" }, @@ -1631,7 +1709,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 40, "id": "2114d0c3-cdad-43c7-9ffa-50c36d56d18f", "metadata": {}, "outputs": [ @@ -1845,10 +1923,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 38, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -1885,7 +1963,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 41, "id": "c71a8308-f8a1-4041-bea0-1c841e072a6d", "metadata": {}, "outputs": [], @@ -1895,7 +1973,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 42, "id": "2b9bb21a-73cd-444e-84a9-100e202aa422", "metadata": {}, "outputs": [ @@ -1905,6 +1983,8 @@ "text": [ "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io.py:404: UserWarning: The keyword 'type_hint' was not found among input labels. If you are trying to update a class instance keyword, please use attribute assignment directly instead of calling this method\n", " warnings.warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to x, andthus could not disconnect from it.\n", + " warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel run was not connected to ran, andthus could not disconnect from it.\n", " warn(\n" ] @@ -1915,7 +1995,7 @@ "13" ] }, - "execution_count": 40, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1967,23 +2047,23 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 43, "id": "3668f9a9-adca-48a4-84ea-13add965897c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'plus_three': 103}" + "{'add_two': 102, 'add_three': 103}" ] }, - "execution_count": 41, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "@Workflow.wrap.as_macro_node(\"plus_three\")\n", + "@Workflow.wrap.as_macro_node()\n", "def AddThree(macro, x: int = 0):\n", " \"\"\"\n", " The function decorator `as_macro_node` expects the decorated function \n", @@ -1997,10 +2077,11 @@ " macro.add_two = AddOne(macro.add_one)\n", " macro.add_three = AddOne(macro.add_two)\n", " macro.outputs_map = {\"add_two__result\": \"intermediate\"}\n", - " # return macro.add_three.outputs.result\n", " # We need to return something like output channels, but since AddOne has \n", " # only a single output channel, we can return it directly.\n", - " return macro.add_three\n", + " # We also return an intermediate value that would not normally be \n", + " # exposed if this were a workflow since it's connected to other child channels\n", + " return macro.add_two, macro.add_three\n", " \n", "macro = AddThree()\n", "macro(x=100)\n", @@ -2021,12 +2102,12 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 44, "id": "9aaeeec0-5f88-4c94-a6cc-45b56d2f0111", "metadata": {}, "outputs": [], "source": [ - "@Workflow.wrap.as_macro_node(\"structure\", \"energy\")\n", + "@Workflow.wrap.as_macro_node()\n", "def LammpsMinimize(macro, element: str, crystalstructure: str, lattice_guess: float | int):\n", " macro.structure = macro.create.pyiron_atomistics.Bulk(\n", " name=element,\n", @@ -2035,15 +2116,58 @@ " )\n", " macro.engine = macro.create.pyiron_atomistics.Lammps(structure=macro.structure)\n", " macro.calc = macro.create.pyiron_atomistics.CalcMin(job=macro.engine, pressure=0)\n", - " return macro.structure, macro.calc.outputs.energy_pot" + " energy = macro.calc.outputs.energy_pot\n", + " return macro.structure, energy" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 45, + "id": "26a080dc-acaf-45bb-9935-7a42ff8d9552", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'structure': None, 'energy': None}" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LammpsMinimize.preview_outputs()" + ] + }, + { + "cell_type": "markdown", + "id": "4dfe9c0c-e9e7-4d5f-ad34-e19fd0382670", + "metadata": {}, + "source": [ + "Note that we didn't include any output labels, but they still come out looking OK. Here, we're exploiting a shortcut that the `macro.` (or whatever your `self`-like variable is called) gets left-stripped off the output label, since it will be very common to return children of the macro. However, other \".\" are not permissible, so for the energy we create and return a well-named local variable." + ] + }, + { + "cell_type": "code", + "execution_count": 46, "id": "a832e552-b3cc-411a-a258-ef21574fc439", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to name, andthus could not disconnect from it.\n", + " warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to crystalstructure, andthus could not disconnect from it.\n", + " warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to a, andthus could not disconnect from it.\n", + " warn(\n" + ] + } + ], "source": [ "wf = Workflow(\"phase_preference\")\n", "wf.element = wf.create.standard.UserInput()\n", @@ -2083,7 +2207,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 47, "id": "b764a447-236f-4cb7-952a-7cba4855087d", "metadata": {}, "outputs": [ @@ -3041,10 +3165,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 44, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -3055,7 +3179,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 48, "id": "b51bef25-86c5-4d57-80c1-ab733e703caf", "metadata": {}, "outputs": [ @@ -3069,7 +3193,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "19cd5b32478b4915858a38554da273c6", + "model_id": "f9b77a718b1b4077ac2e758d30c38ba6", "version_major": 2, "version_minor": 0 }, @@ -3090,7 +3214,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cbae5f07d9b640ce9cd3096078d24f72", + "model_id": "6c257b32d747497c892a58aa415bca7b", "version_major": 2, "version_minor": 0 }, @@ -3116,7 +3240,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 49, "id": "091e2386-0081-436c-a736-23d019bd9b91", "metadata": {}, "outputs": [ @@ -3138,7 +3262,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e0209e5b1ae843a680f3c144618325df", + "model_id": "15e6378664df4929829581c69baac4d7", "version_major": 2, "version_minor": 0 }, @@ -3159,7 +3283,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "386603d888d445f6a218b26052c43fb9", + "model_id": "84c434ce4fba48bd86bb4d255f5261f3", "version_major": 2, "version_minor": 0 }, @@ -3197,7 +3321,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 50, "id": "4cdffdca-48d3-4486-9045-48102c7e5f31", "metadata": {}, "outputs": [ @@ -3229,7 +3353,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 51, "id": "ed4a3a22-fc3a-44c9-9d4f-c65bc1288889", "metadata": {}, "outputs": [ @@ -3251,7 +3375,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2b4d04a785b54ad4958c8c30e4bc8d2f", + "model_id": "b7b5833fddcd4320bf509f699a2b9d39", "version_major": 2, "version_minor": 0 }, @@ -3272,7 +3396,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "301e1fee2e854c4ab889749f5d555368", + "model_id": "ae6530375bc446cbaf029850dd9f3e44", "version_major": 2, "version_minor": 0 }, @@ -3299,7 +3423,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 52, "id": "5a985cbf-c308-4369-9223-b8a37edb8ab1", "metadata": {}, "outputs": [ @@ -3321,7 +3445,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5f06ed3367e8425a91888f5daa074f53", + "model_id": "ef0fd622423c4e2590cdd813e25134fd", "version_major": 2, "version_minor": 0 }, @@ -3342,7 +3466,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f0220640d2f34e618764b165b6d2ba8c", + "model_id": "4bd16ff817754d67820801fa49810a5e", "version_major": 2, "version_minor": 0 }, @@ -3397,7 +3521,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 53, "id": "aa575249-b209-4e0c-9ea6-a82bc69dc833", "metadata": {}, "outputs": [ @@ -3406,7 +3530,7 @@ "output_type": "stream", "text": [ "None 1\n", - " NOT_DATA\n" + " NOT_DATA\n" ] } ], @@ -3433,7 +3557,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 54, "id": "c1b7b4e9-1c76-470c-ba6e-a58ea3f611f6", "metadata": {}, "outputs": [ @@ -3465,7 +3589,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 55, "id": "7e98058b-a791-4cb1-ae2c-864ad7e56cee", "metadata": {}, "outputs": [], @@ -3483,7 +3607,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 56, "id": "0d1b4005-488e-492f-adcb-8ad7235e4fe3", "metadata": {}, "outputs": [ @@ -3492,7 +3616,7 @@ "output_type": "stream", "text": [ "None 1\n", - " NOT_DATA\n", + " NOT_DATA\n", "Finally 5\n", "b (Add):\n", "Inputs ['obj', 'other']\n", @@ -3531,7 +3655,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 57, "id": "d03ca074-35a0-4e0d-9377-d4eaa5521f85", "metadata": {}, "outputs": [], @@ -3550,7 +3674,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 58, "id": "a7c07aa0-84fc-4f43-aa4f-6498c0837d76", "metadata": {}, "outputs": [ @@ -3558,7 +3682,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "6.026229030976538\n" + "6.01592511900526\n" ] } ], @@ -3582,7 +3706,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 59, "id": "b062ab5f-9b98-4843-8925-b93bf4c173f8", "metadata": {}, "outputs": [ @@ -3590,7 +3714,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "2.334164628002327\n" + "2.5806497349985875\n" ] } ], @@ -3681,7 +3805,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 60, "id": "c8196054-aff3-4d39-a872-b428d329dac9", "metadata": {}, "outputs": [], @@ -3691,7 +3815,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 61, "id": "ffd741a3-b086-4ed0-9a62-76143a3705b2", "metadata": {}, "outputs": [], @@ -3708,7 +3832,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 62, "id": "3a22c622-f8c1-449b-a910-c52beb6a09c3", "metadata": {}, "outputs": [ @@ -3739,7 +3863,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 63, "id": "0999d3e8-3a5a-451d-8667-a01dae7c1193", "metadata": {}, "outputs": [], @@ -3773,7 +3897,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 64, "id": "0b373764-b389-4c24-8086-f3d33a4f7fd7", "metadata": {}, "outputs": [ @@ -3781,10 +3905,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/macro.py:682: UserWarning: Could not find the source code to validate BulkForA5 output labels\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io_preview.py:237: OutputLabelsNotValidated: Could not find the source code to validate BulkForA5 output labels against the number of returned values -- proceeding without validation\n", " warnings.warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io.py:404: UserWarning: The keyword 'type_hint' was not found among input labels. If you are trying to update a class instance keyword, please use attribute assignment directly instead of calling this method\n", " warnings.warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io_preview.py:237: OutputLabelsNotValidated: Could not find the source code to validate __many_to_list output labels against the number of returned values -- proceeding without validation\n", + " warnings.warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io.py:404: UserWarning: The keyword 'a' was not found among input labels. If you are trying to update a class instance keyword, please use attribute assignment directly instead of calling this method\n", " warnings.warn(\n" ] @@ -3799,7 +3925,7 @@ " 17.230249999999995]" ] }, - "execution_count": 61, + "execution_count": 64, "metadata": {}, "output_type": "execute_result" } @@ -3846,7 +3972,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 65, "id": "0dd04b4c-e3e7-4072-ad34-58f2c1e4f596", "metadata": {}, "outputs": [ @@ -3854,8 +3980,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/macro.py:682: UserWarning: Could not find the source code to validate AddWhileLessThan_m758651476798023903 output labels\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io_preview.py:237: OutputLabelsNotValidated: Could not find the source code to validate AddWhileLessThan_122886440957850675 output labels against the number of returned values -- proceeding without validation\n", " warnings.warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to a, andthus could not disconnect from it.\n", + " warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to b, andthus could not disconnect from it.\n", + " warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to other, andthus could not disconnect from it.\n", + " warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel run was not connected to true, andthus could not disconnect from it.\n", " warn(\n", "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel run was not connected to ran, andthus could not disconnect from it.\n", @@ -3903,7 +4035,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 66, "id": "2dfb967b-41ac-4463-b606-3e315e617f2a", "metadata": {}, "outputs": [ @@ -3927,26 +4059,27 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 67, "id": "2e87f858-b327-4f6b-9237-c8a557f29aeb", "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/macro.py:682: UserWarning: Could not find the source code to validate RandomWhileGreaterThan_m697380346905229202 output labels\n", - " warnings.warn(\n" + "0.588 > 0.2\n", + "0.144 <= 0.2\n", + "Finally 0.144\n" ] }, { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "0.247 > 0.2\n", - "0.361 > 0.2\n", - "0.097 <= 0.2\n", - "Finally 0.097\n" + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io_preview.py:237: OutputLabelsNotValidated: Could not find the source code to validate RandomWhileGreaterThan_m2124235887652166674 output labels against the number of returned values -- proceeding without validation\n", + " warnings.warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to threshold, andthus could not disconnect from it.\n", + " warn(\n" ] } ], diff --git a/notebooks/quickstart.ipynb b/notebooks/quickstart.ipynb index 73e42c5c..d90639ce 100644 --- a/notebooks/quickstart.ipynb +++ b/notebooks/quickstart.ipynb @@ -123,7 +123,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:171: UserWarning: The channel run was not connected to ran, andthus could not disconnect from it.\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel run was not connected to ran, andthus could not disconnect from it.\n", " warn(\n" ] }, @@ -293,214 +293,214 @@ "\n", "Inputs\n", "\n", - "\n", - "clustermy_workflowplot\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "plot: Scatter\n", - "\n", - "\n", - "clustermy_workflowplotInputs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Inputs\n", - "\n", - "\n", - "clustermy_workflowplotOutputs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Outputs\n", - "\n", "\n", "clustermy_workflowOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", "clustermy_workflowarange\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "arange: Arange\n", "\n", "\n", "clustermy_workflowarangeInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Inputs\n", "\n", "\n", "clustermy_workflowarangeOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Outputs\n", "\n", "\n", "clustermy_workflowarange__length_Subtract_1\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "arange__length_Subtract_1: Subtract\n", "\n", - "\n", - "clustermy_workflowarange__length_Subtract_1Outputs\n", + "\n", + "clustermy_workflowarange__length_Subtract_1Inputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Inputs\n", "\n", - "\n", - "clustermy_workflowarange__length_Subtract_1Inputs\n", + "\n", + "clustermy_workflowarange__length_Subtract_1Outputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Outputs\n", "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_None\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "arange__arange_Slice_None_arange__length_Subtract_1__sub_None: Slice\n", "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "Inputs\n", "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice: GetItem\n", + "\n", + "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice: GetItem\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2: Power\n", + "\n", + "\n", + "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputs\n", + "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2: Power\n", + "\n", + "Inputs\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", - "\n", - "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputs\n", + "\n", + "clustermy_workflowplot\n", "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "plot: Scatter\n", + "\n", + "\n", + "clustermy_workflowplotInputs\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Inputs\n", + "\n", + "\n", + "clustermy_workflowplotOutputs\n", + "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Outputs\n", "\n", "\n", "\n", @@ -511,8 +511,8 @@ "\n", "\n", "clustermy_workflowOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", @@ -606,21 +606,21 @@ "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsother\n", - "\n", - "other\n", + "\n", + "other\n", "\n", "\n", "\n", "clustermy_workflowInputsarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2__other->clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsother\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustermy_workflowOutputsplot__fig\n", - "\n", - "plot__fig\n", + "\n", + "plot__fig\n", "\n", "\n", "\n", @@ -650,15 +650,15 @@ "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsobj\n", - "\n", - "obj\n", + "\n", + "obj\n", "\n", "\n", "\n", "clustermy_workflowarangeOutputsarange->clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsobj\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", @@ -726,8 +726,8 @@ "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", @@ -739,148 +739,148 @@ "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsslice\n", - "\n", - "slice\n", + "\n", + "slice\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsitem\n", - "\n", - "item\n", + "\n", + "item\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsslice->clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsitem\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem\n", - "\n", - "getitem\n", + "\n", + "getitem\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsobj\n", - "\n", - "obj\n", + "\n", + "obj\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem->clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsobj\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustermy_workflowplotInputsx\n", - "\n", - "x: Union\n", + "\n", + "x: Union\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem->clustermy_workflowplotInputsx\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputspow\n", - "\n", - "pow\n", + "\n", + "pow\n", "\n", "\n", "\n", "clustermy_workflowplotInputsy\n", - "\n", - "y: Union\n", + "\n", + "y: Union\n", "\n", "\n", "\n", "clustermy_workflowarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputspow->clustermy_workflowplotInputsy\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustermy_workflowplotInputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", "\n", "clustermy_workflowplotOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", "\n", "clustermy_workflowplotInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", "\n", "clustermy_workflowplotOutputsfig\n", - "\n", - "fig\n", + "\n", + "fig\n", "\n", "\n", "\n", "clustermy_workflowplotOutputsfig->clustermy_workflowOutputsplot__fig\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 9, @@ -919,7 +919,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -983,45 +983,12 @@ "\n", "We don't yet have an automated tool for converting workflows into macros, but we can create them by decorating a function that takes a macro instance and builds its graph, so we can just copy-and-paste our workflow above into a decorated function! \n", "\n", - "We can also give our macro prettier IO names. This can be done with \"maps\" (which are also available on the workflows):" + "Just like a function node, the IO of a macro is defined by the signature and return values of the function we're decorating. Just remember to include a `self`-like argument for the macro instance itself as the first argument, and (usually) to only return single-output nodes or output channels in the `return` statement:" ] }, { "cell_type": "code", "execution_count": 12, - "id": "f67312c0-7028-4569-8b3a-d9e2fe88df48", - "metadata": {}, - "outputs": [], - "source": [ - "@Workflow.wrap.as_macro_node()\n", - "def MySquarePlot(macro):\n", - " macro.arange = Arange()\n", - " macro.plot = macro.create.plotting.Scatter(\n", - " x=macro.arange.outputs.arange[:macro.arange.outputs.length -1],\n", - " y=macro.arange.outputs.arange[:macro.arange.outputs.length -1]**2\n", - " )\n", - " macro.inputs_map = {\"arange__n\": \"n\"}\n", - " macro.outputs_map = {\n", - " \"arange__arange\": \"x\",\n", - " \"arange__length\": \"n\",\n", - " \"plot__fig\": \"fig\"\n", - " }\n", - " # Note that we also forced regularly hidden IO to be exposed!\n", - " # We can also hide IO that's usually exposed by mapping to `None`\n", - " # but that would be a lot of typing in this case" - ] - }, - { - "cell_type": "markdown", - "id": "e260929f-2d13-486c-b547-f5d8e2f0a330", - "metadata": {}, - "source": [ - "Or we can use a more function-node-like defintion of our macro with args and/or kwargs, and return values and output labels. The \"maps\" above _always take precedence_ so you still have full control over your macro-level IO, but using this format switches us over to an \"whitelist\" paradigm that automatically turns off all the other IO, which can make it easier to keep things tidy:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, "id": "996c9e9a-ba0e-458a-9e54-331974073cca", "metadata": {}, "outputs": [], @@ -1038,19 +1005,29 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "b43f7a86-4579-4476-89a9-9d7c5942c3fb", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/io.py:404: UserWarning: The keyword 'type_hint' was not found among input labels. If you are trying to update a class instance keyword, please use attribute assignment directly instead of calling this method\n", + " warnings.warn(\n", + "/Users/huber/work/pyiron/pyiron_workflow/pyiron_workflow/channels.py:176: UserWarning: The channel user_input was not connected to n, andthus could not disconnect from it.\n", + " warn(\n" + ] + }, { "data": { "text/plain": [ "{'square_plot__n': 10,\n", - " 'square_plot__fig': ,\n", - " 'plus_one_square_plot__fig': }" + " 'square_plot__fig': ,\n", + " 'plus_one_square_plot__fig': }" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, @@ -1087,7 +1064,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "370b4c4b-8a95-4a2a-8255-1574763606bb", "metadata": {}, "outputs": [ @@ -1100,12 +1077,12 @@ "\n", "\n", - "\n", - "\n", + "\n", + "\n", "clustersquare_plot\n", - "\n", - "square_plot: MySquarePlot\n", + "\n", + "square_plot: MySquarePlot\n", "\n", "clustersquare_plotInputs\n", "\n", @@ -1120,27 +1097,27 @@ "\n", "clustersquare_plotOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", - "clustersquare_plotn\n", + "clustersquare_plotarange\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "n: UserInput\n", + "\n", + "arange: Arange\n", "\n", "\n", - "clustersquare_plotnInputs\n", + "clustersquare_plotarangeInputs\n", "\n", "\n", "\n", @@ -1151,213 +1128,180 @@ "Inputs\n", "\n", "\n", - "clustersquare_plotnOutputs\n", + "clustersquare_plotarangeOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", - "clustersquare_plotarange\n", + "clustersquare_plotarange__length_Subtract_1\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "arange: Arange\n", + "\n", + "arange__length_Subtract_1: Subtract\n", "\n", "\n", - "clustersquare_plotarangeInputs\n", + "clustersquare_plotarange__length_Subtract_1Inputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", "\n", - "clustersquare_plotarangeOutputs\n", + "clustersquare_plotarange__length_Subtract_1Outputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", - "clustersquare_plotarange__length_Subtract_1\n", + "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_None\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "arange__length_Subtract_1: Subtract\n", + "\n", + "arange__arange_Slice_None_arange__length_Subtract_1__sub_None: Slice\n", "\n", "\n", - "clustersquare_plotarange__length_Subtract_1Inputs\n", + "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", "\n", - "clustersquare_plotarange__length_Subtract_1Outputs\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Outputs\n", - "\n", - "\n", - "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_None\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "arange__arange_Slice_None_arange__length_Subtract_1__sub_None: Slice\n", - "\n", - "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputs\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "Outputs\n", - "\n", - "\n", - "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputs\n", - "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Outputs\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice: GetItem\n", + "\n", + "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice: GetItem\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2: Power\n", + "\n", + "arange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2: Power\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", - "\n", + "\n", "clustersquare_plotplot\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", - "plot: Scatter\n", + "\n", + "plot: Scatter\n", "\n", - "\n", + "\n", "clustersquare_plotplotInputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Inputs\n", + "\n", + "Inputs\n", "\n", - "\n", + "\n", "clustersquare_plotplotOutputs\n", "\n", - "\n", + "\n", "\n", "\n", "\n", "\n", - "\n", - "Outputs\n", + "\n", + "Outputs\n", "\n", "\n", "\n", @@ -1368,8 +1312,8 @@ "\n", "\n", "clustersquare_plotOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", @@ -1381,434 +1325,389 @@ "\n", "\n", "clustersquare_plotInputsn\n", - "\n", - "n: int\n", + "\n", + "n\n", "\n", - "\n", + "\n", "\n", - "clustersquare_plotnInputsuser_input\n", - "\n", - "user_input: int\n", + "clustersquare_plotarangeInputsn\n", + "\n", + "n: int\n", "\n", - "\n", - "\n", - "clustersquare_plotInputsn->clustersquare_plotnInputsuser_input\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "clustersquare_plotInputsn->clustersquare_plotarangeInputsn\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", "clustersquare_plotOutputsx\n", - "\n", - "x: ndarray\n", + "\n", + "x: ndarray\n", "\n", "\n", "\n", "clustersquare_plotOutputsn\n", - "\n", - "n: int\n", + "\n", + "n: int\n", "\n", "\n", "\n", "clustersquare_plotOutputsfig\n", - "\n", - "fig\n", + "\n", + "fig\n", "\n", - "\n", + "\n", "\n", - "clustersquare_plotnInputsrun\n", + "clustersquare_plotarangeInputsrun\n", "\n", "run\n", "\n", - "\n", + "\n", "\n", - "clustersquare_plotnOutputsran\n", - "\n", - "ran\n", + "clustersquare_plotarangeOutputsran\n", + "\n", + "ran\n", "\n", - "\n", - "\n", + "\n", + "\n", "\n", - "clustersquare_plotnInputsaccumulate_and_run\n", + "clustersquare_plotarangeInputsaccumulate_and_run\n", "\n", "accumulate_and_run\n", "\n", - "\n", - "\n", - "clustersquare_plotarangeInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", - "\n", - "\n", - "\n", - "clustersquare_plotnOutputsran->clustersquare_plotarangeInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clustersquare_plotnOutputsuser_input\n", - "\n", - "user_input\n", - "\n", - "\n", - "\n", - "clustersquare_plotarangeInputsn\n", - "\n", - "n: int\n", - "\n", - "\n", - "\n", - "clustersquare_plotnOutputsuser_input->clustersquare_plotarangeInputsn\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clustersquare_plotarangeInputsrun\n", - "\n", - "run\n", - "\n", - "\n", - "\n", - "clustersquare_plotarangeOutputsran\n", - "\n", - "ran\n", - "\n", - "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Inputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputsran->clustersquare_plotarange__length_Subtract_1Inputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputsran->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputsarange\n", - "\n", - "arange: ndarray\n", + "\n", + "arange: ndarray\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputsarange->clustersquare_plotOutputsx\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsobj\n", - "\n", - "obj\n", + "\n", + "obj\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputsarange->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsobj\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputslength\n", - "\n", - "length: int\n", + "\n", + "length: int\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputslength->clustersquare_plotOutputsn\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Inputsobj\n", - "\n", - "obj\n", + "\n", + "obj\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarangeOutputslength->clustersquare_plotarange__length_Subtract_1Inputsobj\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Inputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Outputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Inputsother\n", - "\n", - "other\n", + "\n", + "other\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Outputsran->clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Outputssub\n", - "\n", - "sub\n", + "\n", + "sub\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsstop\n", - "\n", - "stop\n", + "\n", + "stop\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__length_Subtract_1Outputssub->clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsstop\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsstart\n", - "\n", - "start\n", + "\n", + "start\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneInputsstep\n", - "\n", - "step\n", + "\n", + "step\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsran->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsslice\n", - "\n", - "slice\n", + "\n", + "slice\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsitem\n", - "\n", - "item\n", + "\n", + "item\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_Slice_None_arange__length_Subtract_1__sub_NoneOutputsslice->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsitem\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceInputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsran->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotInputsaccumulate_and_run\n", - "\n", - "accumulate_and_run\n", + "\n", + "accumulate_and_run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsran->clustersquare_plotplotInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem\n", - "\n", - "getitem\n", + "\n", + "getitem\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsobj\n", - "\n", - "obj\n", + "\n", + "obj\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem->clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsobj\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotInputsx\n", - "\n", - "x: Union\n", + "\n", + "x: Union\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__sliceOutputsgetitem->clustersquare_plotplotInputsx\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Inputsother\n", - "\n", - "other\n", + "\n", + "other\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputsran->clustersquare_plotplotInputsaccumulate_and_run\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputspow\n", - "\n", - "pow\n", + "\n", + "pow\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotInputsy\n", - "\n", - "y: Union\n", + "\n", + "y: Union\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotarange__arange_GetItem_arange__arange_Slice_None_arange__length_Subtract_1__sub_None__slice__getitem_Power_2Outputspow->clustersquare_plotplotInputsy\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotInputsrun\n", - "\n", - "run\n", + "\n", + "run\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotOutputsran\n", - "\n", - "ran\n", + "\n", + "ran\n", "\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotOutputsfig\n", - "\n", - "fig\n", + "\n", + "fig\n", "\n", "\n", - "\n", + "\n", "clustersquare_plotplotOutputsfig->clustersquare_plotOutputsfig\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 15, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } diff --git a/pyiron_workflow/function.py b/pyiron_workflow/function.py index fc1e1f9f..7a5684fe 100644 --- a/pyiron_workflow/function.py +++ b/pyiron_workflow/function.py @@ -1,22 +1,19 @@ from __future__ import annotations from abc import ABC, abstractmethod -import inspect -import warnings -from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING -from pyiron_workflow.channels import InputData, NOT_DATA +from pyiron_workflow.channels import InputData from pyiron_workflow.injection import OutputDataWithInjection from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.node import Node -from pyiron_workflow.output_parser import ParseOutput +from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory from pyiron_workflow.snippets.colors import SeabornColors if TYPE_CHECKING: from pyiron_workflow.composite import Composite -class Function(Node, ABC): +class Function(DecoratedNode, ABC): """ Function nodes wrap an arbitrary python function. @@ -301,8 +298,6 @@ class Function(Node, ABC): guaranteed. """ - _provided_output_labels: tuple[str] | None = None - def __init__( self, *args, @@ -333,73 +328,14 @@ def node_function(*args, **kwargs) -> callable: """What the node _does_.""" @classmethod - def _type_hints(cls) -> dict: - """The result of :func:`typing.get_type_hints` on the :meth:`node_function`.""" - return get_type_hints(cls.node_function) - - @classmethod - def preview_output_channels(cls) -> dict[str, Any]: - """ - Gives a class-level peek at the expected output channels. - - Returns: - dict[str, tuple[Any, Any]]: The channel name and its corresponding type - hint. - """ - labels = cls._get_output_labels() - try: - type_hints = cls._type_hints()["return"] - if len(labels) > 1: - type_hints = get_args(type_hints) - if not isinstance(type_hints, tuple): - raise TypeError( - f"With multiple return labels expected to get a tuple of type " - f"hints, but got type {type(type_hints)}" - ) - if len(type_hints) != len(labels): - raise ValueError( - f"Expected type hints and return labels to have matching " - f"lengths, but got {len(type_hints)} hints and " - f"{len(labels)} labels: {type_hints}, {labels}" - ) - else: - # If there's only one hint, wrap it in a tuple, so we can zip it with - # *return_labels and iterate over both at once - type_hints = (type_hints,) - except KeyError: # If there are no return hints - type_hints = [None] * len(labels) - # Note that this nicely differs from `NoneType`, which is the hint when - # `None` is actually the hint! - return {label: hint for label, hint in zip(labels, type_hints)} + def _io_defining_function(cls) -> callable: + return cls.node_function @classmethod - def _get_output_labels(cls): - """ - Return output labels provided on the class if not None, else scrape them from - :meth:`node_function`. - - Note: When the user explicitly provides output channels, they are taking - responsibility that these are correct, e.g. in terms of quantity, order, etc. - """ - if cls._provided_output_labels is None: - return cls._scrape_output_labels() - else: - return cls._provided_output_labels - - @classmethod - def _scrape_output_labels(cls): - """ - Inspect :meth:`node_function` to scrape out strings representing the - returned values. - - _Only_ works for functions with a single `return` expression in their body. - - It will return expressions and function calls just fine, thus good practice is - to create well-named variables and return those so that the output labels stay - dot-accessible. - """ - parsed_outputs = ParseOutput(cls.node_function).output - return [None] if parsed_outputs is None else parsed_outputs + def preview_outputs(cls) -> dict[str, Any]: + preview = super(Function, cls).preview_outputs() + return preview if len(preview) > 0 else {"None": type(None)} + # If clause facilitates functions with no return value @property def outputs(self) -> Outputs: @@ -414,49 +350,9 @@ def _build_output_channels(self): owner=self, type_hint=hint, ) - for label, hint in self.preview_output_channels().items() + for label, hint in self.preview_outputs().items() ] - @classmethod - def preview_input_channels(cls) -> dict[str, tuple[Any, Any]]: - """ - Gives a class-level peek at the expected input channels. - - Returns: - dict[str, tuple[Any, Any]]: The channel name and a tuple of its - corresponding type hint and default value. - """ - type_hints = cls._type_hints() - scraped: dict[str, tuple[Any, Any]] = {} - for label, param in cls._input_args().items(): - if label in cls._init_keywords(): - # We allow users to parse arbitrary kwargs as channel initialization - # So don't let them choose bad channel names - raise ValueError( - f"The Input channel name {label} is not valid. Please choose a " - f"name _not_ among {cls._init_keywords()}" - ) - - try: - type_hint = type_hints[label] - except KeyError: - type_hint = None - - default = ( - NOT_DATA if param.default is inspect.Parameter.empty else param.default - ) - - scraped[label] = (type_hint, default) - return scraped - - @classmethod - def _input_args(cls): - return inspect.signature(cls.node_function).parameters - - @classmethod - def _init_keywords(cls): - return list(inspect.signature(cls.__init__).parameters.keys()) - @property def inputs(self) -> Inputs: if self._inputs is None: @@ -471,7 +367,7 @@ def _build_input_channels(self): default=default, type_hint=type_hint, ) - for label, (type_hint, default) in self.preview_input_channels().items() + for label, (type_hint, default) in self.preview_inputs().items() ] @property @@ -495,7 +391,7 @@ def process_run_result(self, function_output: Any | tuple) -> Any | tuple: return function_output def _convert_input_args_and_kwargs_to_input_kwargs(self, *args, **kwargs): - reverse_keys = list(self._input_args().keys())[::-1] + reverse_keys = list(self._get_input_args().keys())[::-1] if len(args) > len(reverse_keys): raise ValueError( f"Received {len(args)} positional arguments, but the node {self.label}" @@ -559,6 +455,9 @@ def color(self) -> str: return SeabornColors.green +as_function_node = decorated_node_decorator_factory(Function, Function.node_function) + + def function_node( node_function: callable, *args, @@ -569,6 +468,7 @@ def function_node( storage_backend: Optional[Literal["h5io", "tinybase"]] = None, save_after_run: bool = False, output_labels: Optional[str | tuple[str]] = None, + validate_output_labels: bool = True, **kwargs, ): """ @@ -604,7 +504,9 @@ def function_node( elif isinstance(output_labels, str): output_labels = (output_labels,) - return as_function_node(*output_labels)(node_function)( + return as_function_node( + *output_labels, validate_output_labels=validate_output_labels + )(node_function)( *args, label=label, parent=parent, @@ -614,67 +516,3 @@ def function_node( save_after_run=save_after_run, **kwargs, ) - - -def as_function_node(*output_labels: str): - """ - A decorator for dynamically creating node classes from functions. - - Decorates a function. - Returns a `Function` subclass whose name is the camel-case version of the function - node, and whose signature is modified to exclude the node function and output labels - (which are explicitly defined in the process of using the decorator). - - Args: - *output_labels (str): A name for each return value of the node function OR an - empty tuple. When empty, scrapes output labels automatically from the - source code of the wrapped function. This can be useful when returned - values are not well named, e.g. to make the output channel dot-accessible - if it would otherwise have a label that requires item-string-based access. - Additionally, specifying a _single_ label for a wrapped function that - returns a tuple of values ensures that a _single_ output channel (holding - the tuple) is created, instead of one channel for each return value. The - default approach of extracting labels from the function source code also - requires that the function body contain _at most_ one `return` expression, - so providing explicit labels can be used to circumvent this - (at your own risk), or to circumvent un-inspectable source code (e.g. a - function that exists only in memory). - """ - output_labels = None if len(output_labels) == 0 else output_labels - - # One really subtle thing is that we manually parse the function type hints right - # here and include these as a class-level attribute. - # This is because on (de)(cloud)pickling a function node, somehow the node function - # method attached to it gets its `__globals__` attribute changed; it retains stuff - # _inside_ the function, but loses imports it used from the _outside_ -- i.e. type - # hints! I (@liamhuber) don't deeply understand _why_ (de)pickling is modifying the - # __globals__ in this way, but the result is that type hints cannot be parsed after - # the change. - # The final piece of the puzzle here is that because the node function is a _class_ - # level attribute, if you (de)pickle a node, _new_ instances of that node wind up - # having their node function's `__globals__` trimmed down in this way! - # So to keep the type hint parsing working, we snag and interpret all the type hints - # at wrapping time, when we are guaranteed to have all the globals available, and - # also slap them on as a class-level attribute. These get safely packed and returned - # when (de)pickling so we can keep processing type hints without trouble. - def as_node(node_function: callable): - node_class = type( - node_function.__name__, - (Function,), # Define parentage - { - "node_function": staticmethod(node_function), - "_provided_output_labels": output_labels, - "__module__": node_function.__module__, - }, - ) - try: - node_class.preview_output_channels() - except ValueError as e: - raise ValueError( - f"Failed to create a new {Function.__name__} child class " - f"dynamically from {node_function.__name__} -- probably due to a " - f"mismatch among output labels, returned values, and return type hints." - ) from e - return node_class - - return as_node diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index cc962b2f..bc37e528 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -7,9 +7,9 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod from typing import Any +import warnings from pyiron_workflow.channels import ( Channel, diff --git a/pyiron_workflow/io_preview.py b/pyiron_workflow/io_preview.py new file mode 100644 index 00000000..89dbf9f4 --- /dev/null +++ b/pyiron_workflow/io_preview.py @@ -0,0 +1,402 @@ +""" +Mixin classes for classes which offer previews of input and output at the _class_ level. + +The intent is for mixing with :class:`pyiron_workflow.node.Node`, and for the inputs +and outputs to be IO channels there, but in principle this should function just fine +independently. + +These previews need to be available at the class level so that suggestion menus and +ontologies can know how mixin classes relate to the rest of the world via input and +output without first having to instantiate them. +""" + +from __future__ import annotations + +import inspect +import warnings +from abc import ABC, abstractmethod +from textwrap import dedent +from types import FunctionType +from typing import Any, get_args, get_type_hints + +from pyiron_workflow.channels import NOT_DATA +from pyiron_workflow.node import Node +from pyiron_workflow.output_parser import ParseOutput +from pyiron_workflow.snippets.dotdict import DotDict + + +class HasIOPreview(ABC): + """ + An interface mixin guaranteeing the class-level availability of input and output + previews. + + E.g. for :class:`pyiron_workflow.node.Node` that have input and output channels. + """ + + @classmethod + @abstractmethod + def preview_inputs(cls) -> dict[str, tuple[Any, Any]]: + """ + Gives a class-level peek at the expected inputs. + + Returns: + dict[str, tuple[Any, Any]]: The input name and a tuple of its + corresponding type hint and default value. + """ + + @classmethod + @abstractmethod + def preview_outputs(cls) -> dict[str, Any]: + """ + Gives a class-level peek at the expected outputs. + + Returns: + dict[str, tuple[Any, Any]]: The output name and its corresponding type hint. + """ + + @classmethod + def preview_io(cls) -> DotDict[str:dict]: + return DotDict( + {"inputs": cls.preview_inputs(), "outputs": cls.preview_outputs()} + ) + + +class ScrapesIO(HasIOPreview, ABC): + """ + A mixin class for scraping IO channel information from a specific class method's + signature and returns. + + Requires that the (static and class) method :meth:`_io_defining_function` be + specified in child classes, as well as :meth:`_io_defining_function_uses_self`. + Optionally, :attr:`_output_labels` can be overridden at the class level to avoid + scraping the return signature for channel labels altogether. + + Since scraping returns is only possible when the function source code is available, + this can be bypassed by manually specifying the class attribute + :attr:`_output_labels`. + + Attributes: + _output_labels (): + _validate_output_labels (bool): Whether to + _io_defining_function_uses_self (bool): Whether the signature of the IO + defining function starts with self. When true, the first argument in the + :meth:`_io_defining_function` is ignored. (Default is False, use the entire + signature for specifying input.) + + Warning: + There are a number of class features which, for computational efficiency, get + calculated at first call and any subsequent calls return that initial value + (including on other instances, since these are class properties); these + depend on the :meth:`_io_defining_function` and its signature, which should + thus be left static from the time of class definition onwards. + """ + + @classmethod + @abstractmethod + def _io_defining_function(cls) -> callable: + """Must return a static class method.""" + + _output_labels: tuple[str] | None = None # None: scrape them + _validate_output_labels: bool = True # True: validate against source code + _io_defining_function_uses_self: bool = False # False: use entire signature + + __type_hints = None + __input_args = None + __init_keywords = None + __input_preview = None + __output_preview = None + + @classmethod + def preview_inputs(cls) -> dict[str, tuple[Any, Any]]: + if cls.__input_preview is None: + cls.__input_preview = cls._build_input_preview() + return cls.__input_preview + + @classmethod + def preview_outputs(cls) -> dict[str, Any]: + """ + Gives a class-level peek at the expected output channels. + + Returns: + dict[str, tuple[Any, Any]]: The channel name and its corresponding type + hint. + """ + if cls.__output_preview is None: + if cls._validate_output_labels: + cls._validate() # Validate output on first call + cls.__output_preview = cls._build_output_preview() + return cls.__output_preview + + @classmethod + def _build_input_preview(cls): + type_hints = cls._get_type_hints() + scraped: dict[str, tuple[Any, Any]] = {} + for i, (label, value) in enumerate(cls._get_input_args().items()): + if cls._io_defining_function_uses_self and i == 0: + continue # Skip the macro argument itself, it's like `self` here + elif label in cls._get_init_keywords(): + # We allow users to parse arbitrary kwargs as channel initialization + # So don't let them choose bad channel names + raise ValueError( + f"Trying to build input preview for {cls.__name__}, encountered an " + f"argument name that conflicts with __init__: {label}. Please " + f"choose a name _not_ among {cls._get_init_keywords()}" + ) + + try: + type_hint = type_hints[label] + except KeyError: + type_hint = None + + default = ( + NOT_DATA if value.default is inspect.Parameter.empty else value.default + ) + + scraped[label] = (type_hint, default) + return scraped + + @classmethod + def _build_output_preview(cls): + labels = cls._get_output_labels() + if labels is None: + labels = [] + try: + type_hints = cls._get_type_hints()["return"] + if len(labels) > 1: + type_hints = get_args(type_hints) + if not isinstance(type_hints, tuple): + raise TypeError( + f"With multiple return labels expected to get a tuple of type " + f"hints, but {cls.__name__} got type {type(type_hints)}" + ) + if len(type_hints) != len(labels): + raise ValueError( + f"Expected type hints and return labels to have matching " + f"lengths, but {cls.__name__} got {len(type_hints)} hints and " + f"{len(labels)} labels: {type_hints}, {labels}" + ) + else: + # If there's only one hint, wrap it in a tuple, so we can zip it with + # *return_labels and iterate over both at once + type_hints = (type_hints,) + except KeyError: # If there are no return hints + type_hints = [None] * len(labels) + # Note that this nicely differs from `NoneType`, which is the hint when + # `None` is actually the hint! + return {label: hint for label, hint in zip(labels, type_hints)} + + @classmethod + def _get_output_labels(cls): + """ + Return output labels provided for the class, scraping them from the io-defining + function if they are not already available. + """ + if cls._output_labels is None: + cls._output_labels = cls._scrape_output_labels() + return cls._output_labels + + @classmethod + def _get_type_hints(cls) -> dict: + """ + The result of :func:`typing.get_type_hints` on the io-defining function + """ + if cls.__type_hints is None: + cls.__type_hints = get_type_hints(cls._io_defining_function()) + return cls.__type_hints + + @classmethod + def _get_input_args(cls): + if cls.__input_args is None: + cls.__input_args = inspect.signature(cls._io_defining_function()).parameters + return cls.__input_args + + @classmethod + def _get_init_keywords(cls): + if cls.__init_keywords is None: + cls.__init_keywords = list( + inspect.signature(cls.__init__).parameters.keys() + ) + return cls.__init_keywords + + @classmethod + def _scrape_output_labels(cls): + """ + Inspect :meth:`node_function` to scrape out strings representing the + returned values. + + _Only_ works for functions with a single `return` expression in their body. + + It will return expressions and function calls just fine, thus good practice is + to create well-named variables and return those so that the output labels stay + dot-accessible. + """ + return ParseOutput(cls._io_defining_function()).output + + @classmethod + def _validate(cls): + """ + Ensure that output_labels, if provided, are commensurate with graph creator + return values, if provided, and return them as a tuple. + """ + try: + cls._validate_degeneracy() + cls._validate_return_count() + except OSError: + warnings.warn( + f"Could not find the source code to validate {cls.__name__} output " + f"labels against the number of returned values -- proceeding without " + f"validation", + OutputLabelsNotValidated, + ) + + @classmethod + def _validate_degeneracy(cls): + output_labels = cls._get_output_labels() + if output_labels is not None and len(set(output_labels)) != len(output_labels): + raise ValueError( + f"{cls.__name__} must not have degenerate output labels: " + f"{output_labels}" + ) + + @classmethod + def _validate_return_count(cls): + output_labels = cls._get_output_labels() + graph_creator_returns = ParseOutput(cls._io_defining_function()).output + if graph_creator_returns is not None or output_labels is not None: + error_suffix = ( + f"but {cls.__name__} got return values: {graph_creator_returns} and " + f"labels: {output_labels}. If this intentional, you can bypass output " + f"validation making sure the class attribute `_validate_output_labels` " + f"is False." + ) + try: + if len(output_labels) != len(graph_creator_returns): + raise ValueError( + "The number of return values must exactly match the number of " + "output labels provided, " + error_suffix + ) + except TypeError: + raise TypeError( + f"Output labels and return values must either both or neither be " + f"present, " + error_suffix + ) + + +class OutputLabelsNotValidated(Warning): + pass + + +class StaticNode(Node, HasIOPreview, ABC): + """A node whose IO specification is available at the class level.""" + + +class DecoratedNode(StaticNode, ScrapesIO, ABC): + """ + A static node whose IO is defined by a function's information (and maybe output + labels). + """ + + +def decorated_node_decorator_factory( + parent_class: type[DecoratedNode], + io_static_method: callable, + decorator_docstring_additions: str = "", + **parent_class_attr_overrides, +): + """ + A decorator factory for building decorators to dynamically create new subclasses + of some subclass of :class:`DecoratedNode` using the function they decorate. + + New classes get their class name and module set using the decorated function's + name and module. + + Args: + parent_class (type[DecoratedNode]): The base class for the new node class. + io_static_method: The static method on the :param:`parent_class` which will + store the io-defining function the resulting decorator will decorate. + :param:`parent_class` must override :meth:`_io_defining_function` inherited + from :class:`DecoratedNode` to return this method. This allows + :param:`parent_class` classes to have unique names for their io-defining + functions. + decorator_docstring_additions (str): Any extra text to add between the main + body of the docstring and the arguments. + **parent_class_attr_overrides: Any additional attributes to pass to the new, + dynamically created class created by the resulting decorator. + + Returns: + (callable): A decorator that takes creates a new subclass of + :param:`parent_class` that uses the wrapped function as the return value of + :meth:`_io_defining_function` for the :class:`DecoratedNode` mixin. + """ + if getattr(parent_class, io_static_method.__name__) is not io_static_method: + raise ValueError( + f"{io_static_method.__name__} is not a method on {parent_class}" + ) + if not isinstance(io_static_method, FunctionType): + raise TypeError(f"{io_static_method.__name__} should be a static method") + + def as_decorated_node_decorator( + *output_labels: str, + validate_output_labels: bool = True, + ): + output_labels = None if len(output_labels) == 0 else output_labels + + def as_decorated_node(io_defining_function: callable): + if not callable(io_defining_function): + raise AttributeError( + f"Tried to create a new child class of {parent_class.__name__}, " + f"but got {io_defining_function} instead of a callable." + ) + + decorated_node_class = type( + io_defining_function.__name__, + (parent_class,), # Define parentage + { + io_static_method.__name__: staticmethod(io_defining_function), + "__module__": io_defining_function.__module__, + "_output_labels": output_labels, + "_validate_output_labels": validate_output_labels, + **parent_class_attr_overrides, + }, + ) + decorated_node_class.preview_io() # Construct everything + return decorated_node_class + + return as_decorated_node + + as_decorated_node_decorator.__doc__ = dedent( + f""" + A decorator for dynamically creating `{parent_class.__name__}` sub-classes by + wrapping a function as the `{io_static_method.__name__}`. + + The returned subclass uses the wrapped function (and optionally any provided + :param:`output_labels`) to specify its IO. + + {decorator_docstring_additions} + + Args: + *output_labels (str): A name for each return value of the graph creating + function. When empty, scrapes output labels automatically from the + source code of the wrapped function. This can be useful when returned + values are not well named, e.g. to make the output channel + dot-accessible if it would otherwise have a label that requires + item-string-based access. Additionally, specifying a _single_ label for + a wrapped function that returns a tuple of values ensures that a + _single_ output channel (holding the tuple) is created, instead of one + channel for each return value. The default approach of extracting + labels from the function source code also requires that the function + body contain _at most_ one `return` expression, so providing explicit + labels can be used to circumvent this (at your own risk). (Default is + empty, try to scrape labels from the source code of the wrapped + function.) + validate_output_labels (bool): Whether to compare the provided output labels + (if any) against the source code (if available). (Default is True.) + + Returns: + (callable[[callable], type[{parent_class.__name__}]]): A decorator that + transforms a function into a child class of `{parent_class.__name__}` + using the decorated function as + `{parent_class.__name__}.{io_static_method.__name__}`. + """ + ) + return as_decorated_node_decorator diff --git a/pyiron_workflow/macro.py b/pyiron_workflow/macro.py index 3066c81e..563c7000 100644 --- a/pyiron_workflow/macro.py +++ b/pyiron_workflow/macro.py @@ -6,22 +6,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -import inspect import re -from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING +from typing import Literal, Optional, TYPE_CHECKING import warnings -from pyiron_workflow.channels import InputData, OutputData, NOT_DATA +from pyiron_workflow.channels import InputData, OutputData from pyiron_workflow.composite import Composite from pyiron_workflow.has_interface_mixins import HasChannel from pyiron_workflow.io import Outputs, Inputs -from pyiron_workflow.output_parser import ParseOutput +from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory if TYPE_CHECKING: from pyiron_workflow.channels import Channel -class Macro(Composite, ABC): +class Macro(Composite, DecoratedNode, ABC): """ A macro is a composite node that holds a graph with a fixed interface, like a pre-populated workflow that is the same every time you instantiate it. @@ -191,7 +190,7 @@ class Macro(Composite, ABC): the same graph is always created. >>> class AddThreeMacro(Macro): - ... _provided_output_labels = ["three"] + ... _output_labels = ["three"] ... ... @staticmethod ... def graph_creator(macro, x): @@ -245,8 +244,6 @@ class Macro(Composite, ABC): """ - _provided_output_labels: tuple[str] | None = None - def __init__( self, label: Optional[str] = None, @@ -287,11 +284,7 @@ def __init__( if returned_has_channel_objects is None else returned_has_channel_objects ), - ( - () - if self._provided_output_labels is None - else self._provided_output_labels - ), + (() if self._output_labels is None else self._output_labels), ) ) ) @@ -307,105 +300,21 @@ def graph_creator(self, *args, **kwargs) -> callable: """Build the graph the node will run.""" @classmethod - def _validate_output_labels(cls) -> tuple[str]: - """ - Ensure that output_labels, if provided, are commensurate with graph creator - return values, if provided, and return them as a tuple. - """ - graph_creator_returns = ParseOutput(cls.graph_creator).output - output_labels = cls._get_output_labels() - if output_labels is not None and len(set(output_labels)) != len(output_labels): - raise ValueError( - f"{cls.__name__} must not have degenerate output labels: " - f"{output_labels}" - ) - if graph_creator_returns is not None or output_labels is not None: - error_suffix = ( - f"but {cls.__name__} macro class got return values: " - f"{graph_creator_returns} and labels: {output_labels}." - ) - try: - if len(output_labels) != len(graph_creator_returns): - raise ValueError( - "The number of return values in the graph creator must exactly " - "match the number of output labels provided, " + error_suffix - ) - except TypeError: - raise TypeError( - f"Output labels and graph creator return values must either both " - f"or neither be present, " + error_suffix - ) - - @classmethod - def _type_hints(cls): - """The result of :func:`typing.get_type_hints` on the :meth:`graph_creator`.""" - return get_type_hints(cls.graph_creator) - - @classmethod - def preview_output_channels(cls) -> dict[str, Any]: - """ - Gives a class-level peek at the expected output channels. + def _io_defining_function(cls) -> callable: + return cls.graph_creator - Returns: - dict[str, tuple[Any, Any]]: The channel name and its corresponding type - hint. - """ - labels = cls._get_output_labels() - try: - type_hints = cls._type_hints()["return"] - if len(labels) > 1: - type_hints = get_args(type_hints) - if not isinstance(type_hints, tuple): - raise TypeError( - f"With multiple return labels expected to get a tuple of type " - f"hints, but got type {type(type_hints)}" - ) - if len(type_hints) != len(labels): - raise ValueError( - f"Expected type hints and return labels to have matching " - f"lengths, but got {len(type_hints)} hints and " - f"{len(labels)} labels: {type_hints}, {labels}" - ) - else: - # If there's only one hint, wrap it in a tuple, so we can zip it with - # *return_labels and iterate over both at once - type_hints = (type_hints,) - except KeyError: # If there are no return hints - type_hints = [None] * len(labels) - # Note that this nicely differs from `NoneType`, which is the hint when - # `None` is actually the hint! - return {label: hint for label, hint in zip(labels, type_hints)} - - @classmethod - def _get_output_labels(cls): - """ - Return output labels provided on the class if not None. - """ - if cls._provided_output_labels is None: - cls._scrape_output_labels() - return cls._provided_output_labels + _io_defining_function_uses_self = True @classmethod def _scrape_output_labels(cls): - """ - Inspect :meth:`node_function` to scrape out strings representing the - returned values. - - _Only_ works for functions with a single `return` expression in their body. + scraped_labels = super(Macro, cls)._scrape_output_labels() - It will return expressions and function calls just fine, thus good practice is - to create well-named variables and return those so that the output labels stay - dot-accessible. - """ - parsed_outputs = ParseOutput(cls.graph_creator).output - if parsed_outputs is None: - cls._provided_output_labels = None - else: - self_argument = list(cls._input_args().keys())[0] + if scraped_labels is not None: + # Strip off the first argument, e.g. self.foo just becomes foo + self_argument = list(cls._get_input_args().keys())[0] cleaned_labels = [ - # Strip off the first argument, e.g. self.foo just becomes foo - re.sub(r"^" + re.escape(f"{self_argument}."), "", p) - for p in parsed_outputs + re.sub(r"^" + re.escape(f"{self_argument}."), "", label) + for label in scraped_labels ] if any("." in label for label in cleaned_labels): raise ValueError( @@ -413,51 +322,9 @@ def _scrape_output_labels(cls): f"one of {cleaned_labels} still contains a '.' -- please provide " f"explicit labels" ) - cls._provided_output_labels = cleaned_labels - - @classmethod - def preview_input_channels(cls) -> dict[str, tuple[Any, Any]]: - """ - Gives a class-level peek at the expected input channels. - - Returns: - dict[str, tuple[Any, Any]]: The channel name and a tuple of its - corresponding type hint and default value. - """ - type_hints = cls._type_hints() - scraped: dict[str, tuple[Any, Any]] = {} - for i, (label, value) in enumerate(cls._input_args().items()): - if i == 0: - continue # Skip the macro argument itself, it's like `self` here - elif label in cls._init_keywords(): - # We allow users to parse arbitrary kwargs as channel initialization - # So don't let them choose bad channel names - raise ValueError( - f"The Input channel name {label} is not valid. Please choose a " - f"name _not_ among {cls._init_keywords()}" - ) - - try: - type_hint = type_hints[label] - except KeyError: - type_hint = None - - default = ( - value.default - if value.default is not inspect.Parameter.empty - else NOT_DATA - ) - - scraped[label] = (type_hint, default) - return scraped - - @classmethod - def _input_args(cls): - return inspect.signature(cls.graph_creator).parameters - - @classmethod - def _init_keywords(cls): - return list(inspect.signature(cls.__init__).parameters.keys()) + return cleaned_labels + else: + return scraped_labels def _prepopulate_ui_nodes_from_graph_creator_signature( self, storage_backend: Literal["h5io", "tinybase"] @@ -470,7 +337,7 @@ def _prepopulate_ui_nodes_from_graph_creator_signature( type_hint=type_hint, storage_backend=storage_backend, ) - for label, (type_hint, default) in self.preview_input_channels().items() + for label, (type_hint, default) in self.preview_inputs().items() ) def _get_linking_channel( @@ -661,6 +528,15 @@ def __setstate__(self, state): self.children[child].outputs[child_out].value_receiver = self.outputs[out] +as_macro_node = decorated_node_decorator_factory( + Macro, + Macro.graph_creator, + decorator_docstring_additions="The first argument in the wrapped function is " + "`self`-like and will receive the macro instance " + "itself, and thus is ignored in the IO.", +) + + def macro_node( graph_creator, label: Optional[str] = None, @@ -671,6 +547,7 @@ def macro_node( save_after_run: bool = False, strict_naming: bool = True, output_labels: Optional[str | list[str] | tuple[str]] = None, + validate_output_labels: bool = True, **kwargs, ): """ @@ -678,6 +555,24 @@ def macro_node( :func:`graph_creator` and returns an instance of that. Quacks like a :class:`Composite` for the sake of creating and registering nodes. + + Beyond the standard :class:`Macro`, initialization allows the args... + + Args: + graph_creator (callable): The function defining macro's graph. + output_labels (Optional[str | list[str] | tuple[str]]): A name for each return + value of the node function OR a single label. (Default is None, which + scrapes output labels automatically from the source code of the wrapped + function.) This can be useful when returned values are not well named, e.g. + to make the output channel dot-accessible if it would otherwise have a label + that requires item-string-based access. Additionally, specifying a _single_ + label for a wrapped function that returns a tuple of values ensures that a + _single_ output channel (holding the tuple) is created, instead of one + channel for each return value. The default approach of extracting labels + from the function source code also requires that the function body contain + _at most_ one `return` expression, so providing explicit labels can be used + to circumvent this (at your own risk), or to circumvent un-inspectable + source code (e.g. a function that exists only in memory). """ if not callable(graph_creator): # `function_node` quacks like a class, even though it's a function and @@ -692,7 +587,9 @@ def macro_node( elif isinstance(output_labels, str): output_labels = (output_labels,) - return as_macro_node(*output_labels)(graph_creator)( + return as_macro_node(*output_labels, validate_output_labels=validate_output_labels)( + graph_creator + )( label=label, parent=parent, overwrite_save=overwrite_save, @@ -702,42 +599,3 @@ def macro_node( strict_naming=strict_naming, **kwargs, ) - - -def as_macro_node(*output_labels): - """ - A decorator for dynamically creating macro classes from graph-creating functions. - - Decorates a function. - Returns a :class:`Macro` subclass whose name is the camel-case version of the - graph-creating function, and whose signature is modified to exclude this function - and provided kwargs. - - Optionally takes output labels as args in case the node function uses the - like-a-function interface to define its IO. (The number of output labels must match - number of channel-like objects returned by the graph creating function _exactly_.) - - Optionally takes any keyword arguments of :class:`Macro`. - """ - output_labels = None if len(output_labels) == 0 else output_labels - - def as_node(graph_creator: callable[[Macro, ...], Optional[tuple[HasChannel]]]): - node_class = type( - graph_creator.__name__, - (Macro,), # Define parentage - { - "graph_creator": staticmethod(graph_creator), - "_provided_output_labels": output_labels, - "__module__": graph_creator.__module__, - }, - ) - try: - node_class._validate_output_labels() - except OSError: - warnings.warn( - f"Could not find the source code to validate {node_class.__name__} " - f"output labels" - ) - return node_class - - return as_node diff --git a/pyiron_workflow/meta.py b/pyiron_workflow/meta.py index 13d8c679..28824e07 100644 --- a/pyiron_workflow/meta.py +++ b/pyiron_workflow/meta.py @@ -107,8 +107,8 @@ def for_loop( :param:`length` - Provide enter and exit magic methods so we can `for` or `with` this fancy-like """ - input_preview = loop_body_class.preview_input_channels() - output_preview = loop_body_class.preview_output_channels() + input_preview = loop_body_class.preview_inputs() + output_preview = loop_body_class.preview_outputs() # Ensure `iterate_on` is in the input iterate_on = [iterate_on] if isinstance(iterate_on, str) else iterate_on diff --git a/pyiron_workflow/node_library/atomistics/calculator.py b/pyiron_workflow/node_library/atomistics/calculator.py index 47d20073..c5c6c313 100644 --- a/pyiron_workflow/node_library/atomistics/calculator.py +++ b/pyiron_workflow/node_library/atomistics/calculator.py @@ -12,7 +12,7 @@ def Emt(): @as_function_node("calculator") def Abinit( - label="abinit_evcurve", + ase_label="abinit_evcurve", nbands=32, ecut=10 * Ry, kpts=(3, 3, 3), @@ -22,7 +22,7 @@ def Abinit( from ase.calculators.abinit import Abinit return Abinit( - label=label, + label=ase_label, nbands=nbands, ecut=ecut, kpts=kpts, @@ -57,7 +57,7 @@ def QuantumEspresso( @as_function_node("calculator") def Siesta( - label="siesta", + ase_label="siesta", xc="PBE", mesh_cutoff=200 * Ry, energy_shift=0.01 * Ry, @@ -70,7 +70,7 @@ def Siesta( from ase.calculators.siesta import Siesta return Siesta( - label=label, + label=ase_label, xc=xc, mesh_cutoff=mesh_cutoff, energy_shift=energy_shift, diff --git a/pyiron_workflow/node_library/pyiron_atomistics.py b/pyiron_workflow/node_library/pyiron_atomistics.py index a3403eb4..6e2b2c72 100644 --- a/pyiron_workflow/node_library/pyiron_atomistics.py +++ b/pyiron_workflow/node_library/pyiron_atomistics.py @@ -100,6 +100,7 @@ def _run_and_remove_job(job, modifier: Optional[callable] = None, **modifier_kwa "total_displacements", "unwrapped_positions", "volume", + validate_output_labels=False, ) def CalcStatic( job: AtomisticGenericJob, @@ -122,6 +123,7 @@ def CalcStatic( "total_displacements", "unwrapped_positions", "volume", + validate_output_labels=False, ) def CalcMd( job: AtomisticGenericJob, @@ -168,6 +170,7 @@ def calc_md(job, n_ionic_steps, n_print, temperature, pressure): "total_displacements", "unwrapped_positions", "volume", + validate_output_labels=False, ) def CalcMin( job: AtomisticGenericJob, diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index ebaa8383..3c484fb2 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -125,7 +125,11 @@ def test_label_choices(self): self.assertListEqual(n.outputs.labels, ["sum_plus_one"]) with self.subTest("Allow forcing _one_ output channel"): - n = function_node(returns_multiple, output_labels="its_a_tuple") + n = function_node( + returns_multiple, + output_labels="its_a_tuple", + validate_output_labels=False, + ) self.assertListEqual(n.outputs.labels, ["its_a_tuple"]) with self.subTest("Fail on multiple return values"): @@ -135,7 +139,18 @@ def test_label_choices(self): function_node(multiple_branches) with self.subTest("Override output label scraping"): - switch = function_node(multiple_branches, output_labels="bool") + with self.assertRaises( + ValueError, + msg="Multiple return branches can't be parsed" + ): + switch = function_node(multiple_branches, output_labels="bool") + self.assertListEqual(switch.outputs.labels, ["bool"]) + + switch = function_node( + multiple_branches, + output_labels="bool", + validate_output_labels=False + ) self.assertListEqual(switch.outputs.labels, ["bool"]) def test_default_label(self): @@ -159,72 +174,6 @@ def bilinear(x, y): "use at the class level" ) - def test_preview_output_channels(self): - @as_function_node() - def Foo(x): - return x - - self.assertDictEqual( - {"x": None}, - Foo.preview_output_channels(), - msg="Should parse without label or hint." - ) - - @as_function_node("y") - def Foo(x) -> None: - return x - - self.assertDictEqual( - {"y": type(None)}, - Foo.preview_output_channels(), - msg="Should parse with label and hint." - ) - - with self.assertRaises( - ValueError, - msg="Should fail when scraping incommensurate hints and returns" - ): - @as_function_node() - def Foo(x) -> int: - y, z = 5.0, 5 - return x, y, z - - with self.assertRaises( - ValueError, - msg="Should fail when provided labels are incommensurate with hints" - ): - @as_function_node("xo", "yo", "zo") - def Foo(x) -> int: - y, z = 5.0, 5 - return x, y, z - - @as_function_node("xo", "yo") - def Foo(x) -> tuple[int, float]: - y, z = 5.0, 5 - return x - - self.assertDictEqual( - {"xo": int, "yo": float}, - Foo.preview_output_channels(), - msg="The user carries extra responsibility if they specify return values " - "-- we don't even try scraping the returned stuff and it's up to them " - "to make sure everything is commensurate! This is necessary so that " - "source code scraping can get bypassed sometimes (e.g. for dynamically " - "generated code that is only in memory and thus not inspectable)" - ) - - def test_preview_input_channels(self): - @as_function_node() - def Foo(x, y: int = 42): - return x + y - - self.assertDictEqual( - {"x": (None, NOT_DATA), "y": (int, 42)}, - Foo.preview_input_channels(), - msg="Input specifications should be available at the class level, with or " - "without type hints and/or defaults provided." - ) - def test_statuses(self): n = function_node(plus_one) self.assertTrue(n.ready) @@ -242,19 +191,6 @@ def test_statuses(self): self.assertFalse(n.running) self.assertTrue(n.failed) - def test_protected_name(self): - @as_function_node() - def Selfish(self, x): - return x - - n = Selfish() - with self.assertRaises( - ValueError, - msg="When we try to build inputs, we should run into the fact that inputs " - "can't overlap with __init__ signature terms" - ): - n.inputs - def test_call(self): node = function_node(no_default, output_labels="output") @@ -574,6 +510,22 @@ def returns_foo() -> Foo: ): single_output._some_nonexistant_private_var + def test_void_return(self): + """Test extensions to the `ScrapesIO` mixin.""" + + @as_function_node() + def NoReturn(x): + y = x + 1 + + self.assertDictEqual( + {"None": type(None)}, + NoReturn.preview_outputs(), + msg="Functions without a return value should be permissible, although it " + "is not interesting" + ) + # Honestly, functions with no return should probably be made illegal to + # encourage functional setups... + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_io_preview.py b/tests/unit/test_io_preview.py new file mode 100644 index 00000000..fa366028 --- /dev/null +++ b/tests/unit/test_io_preview.py @@ -0,0 +1,169 @@ +from abc import ABC, abstractmethod +from textwrap import dedent +import unittest + +from pyiron_workflow.channels import NOT_DATA +from pyiron_workflow.io_preview import ( + ScrapesIO, decorated_node_decorator_factory, OutputLabelsNotValidated +) + + +class ScraperParent(ScrapesIO, ABC): + + @staticmethod + @abstractmethod + def io_function(*args, **kwargs): + pass + + @classmethod + def _io_defining_function(cls): + return cls.io_function + + +as_scraper = decorated_node_decorator_factory( + ScraperParent, ScraperParent.io_function +) + + +class TestIOPreview(unittest.TestCase): + # FROM FUNCTION + def test_void(self): + @as_scraper() + def AbsenceOfIOIsPermissible(): + nothing = None + + def test_preview_inputs(self): + @as_scraper() + def Mixed(x, y: int = 42): + """Has (un)hinted and with(out)-default input""" + return x + y + + self.assertDictEqual( + {"x": (None, NOT_DATA), "y": (int, 42)}, + Mixed.preview_inputs(), + msg="Input specifications should be available at the class level, with or " + "without type hints and/or defaults provided." + ) + + with self.subTest("Protected"): + with self.assertRaises( + ValueError, + msg="Inputs must not overlap with __init__ signature terms" + ): + @as_scraper() + def Selfish(self, x): + return x + + def test_preview_outputs(self): + + with self.subTest("Plain"): + @as_scraper() + def Return(x): + return x + + self.assertDictEqual( + {"x": None}, + Return.preview_outputs(), + msg="Should parse without label or hint." + ) + + with self.subTest("Labeled"): + @as_scraper("y") + def LabeledReturn(x) -> None: + return x + + self.assertDictEqual( + {"y": type(None)}, + LabeledReturn.preview_outputs(), + msg="Should parse with label and hint." + ) + + with self.subTest("Hint-return count mismatch"): + with self.assertRaises( + ValueError, + msg="Should fail when scraping incommensurate hints and returns" + ): + @as_scraper() + def HintMismatchesScraped(x) -> int: + y, z = 5.0, 5 + return x, y, z + + with self.assertRaises( + ValueError, + msg="Should fail when provided labels are incommensurate with hints" + ): + @as_scraper("xo", "yo", "zo") + def HintMismatchesProvided(x) -> int: + y, z = 5.0, 5 + return x, y, z + + with self.subTest("Provided-scraped mismatch"): + with self.assertRaises( + ValueError, + msg="The nuber of labels -- if explicitly provided -- must be commensurate " + "with the number of returned items" + ): + @as_scraper("xo", "yo") + def LabelsMismatchScraped(x) -> tuple[int, float]: + y, z = 5.0, 5 + return x + + @as_scraper("x0", "x1", validate_output_labels=False) + def IgnoreScraping(x) -> tuple[int, float]: + x = (5, 5.5) + return x + + self.assertDictEqual( + {"x0": int, "x1": float}, + IgnoreScraping.preview_outputs(), + msg="Returned tuples can be received by force" + ) + + with self.subTest("Multiple returns"): + with self.assertRaises( + ValueError, + msg="Branched returns cannot be scraped and will fail on validation" + ): + @as_scraper("truth") + def Branched(x) -> bool: + if x <= 0: + return False + else: + return True + + @as_scraper("truth", validate_output_labels=False) + def Branched(x) -> bool: + if x <= 0: + return False + else: + return True + self.assertDictEqual( + {"truth": bool}, + Branched.preview_outputs(), + msg="We can force-override this at our own risk." + ) + + with self.subTest("Uninspectable function"): + def _uninspectable(): + template = dedent(f""" + def __source_code_not_available(x): + return x + """) + exec(template) + return locals()["__source_code_not_available"] + + f = _uninspectable() + + with self.assertRaises( + OSError, + msg="If the source code cannot be inspected for output labels, they " + "_must_ be provided." + ): + as_scraper()(f) + + with self.assertWarns( + OutputLabelsNotValidated, + msg="If provided labels cannot be validated against the source code, " + "a warning should be issued" + ): + as_scraper("y")(f) diff --git a/tests/unit/test_macro.py b/tests/unit/test_macro.py index 6631d084..27ef1785 100644 --- a/tests/unit/test_macro.py +++ b/tests/unit/test_macro.py @@ -164,7 +164,7 @@ def test_creation_from_decorator(self): def test_creation_from_subclass(self): class MyMacro(Macro): - _provided_output_labels = ("three__result",) + _output_labels = ("three__result",) @staticmethod def graph_creator(self, one__x): @@ -418,99 +418,6 @@ def fail_at_zero(x): msg="Original connections should get restored on upstream failure" ) - def test_output_labels_vs_return_values(self): - def no_return(macro): - macro.foo = macro.create.standard.UserInput() - - macro_node(no_return) # Neither is fine - - @as_macro_node("some_return") - def LabelsAndReturnsMatch(macro): - macro.foo = macro.create.standard.UserInput() - return macro.foo - - LabelsAndReturnsMatch() # Both is fine - - @as_macro_node() - def OutputScrapedFromCleanReturn(macro): - macro.foo = macro.create.standard.UserInput() - my_out = macro.foo - return my_out - - self.assertListEqual( - ["my_out"], - list(OutputScrapedFromCleanReturn.preview_output_channels().keys()), - msg="Output labels should get scraped from code, just like for functions" - ) - - @as_macro_node() - def OutputScrapedFromFilteredReturn(macro): - macro.foo = macro.create.standard.UserInput() - return macro.foo - - self.assertListEqual( - ["foo"], - list(OutputScrapedFromFilteredReturn.preview_output_channels().keys()), - msg="The first, self-like argument, should get stripped from output labels" - ) - - with self.assertRaises( - ValueError, - msg="Return values shouldn't have extra dots" - ): - @as_macro_node() - def ReturnHasDot(macro): - macro.foo = macro.create.standard.UserInput() - return macro.foo.outputs.user_input - - with self.assertRaises( - ValueError, - msg="The number of output labels and return values must match" - ): - @as_macro_node("some_return", "nonexistent") - def MissingReturn(macro): - macro.foo = macro.create.standard.UserInput() - return macro.foo - - with self.assertRaises( - TypeError, - msg="Return values must be there if output labels are" - ): - @as_macro_node("some_label") - def MissingReturn(macro): - macro.foo = macro.create.standard.UserInput() - - with self.assertRaises( - ValueError, - msg="Degenerate output labels should not be allowed" - ): - @as_macro_node() - def DegenerateOutput(macro): - macro.foo = macro.create.standard.UserInput() - macro.bar = macro.create.standard.UserInput(macro.foo) - bar = macro.foo - return bar, macro.bar - - def test_functionlike_io_parsing(self): - """ - Check that various aspects of the IO are parsing from the function signature - and returns, and labels - """ - - @as_macro_node("lout", "n_plus_2") - def LikeAFunction(macro, lin: list, n: int = 2): - macro.plus_two = n + 2 - macro.sliced_list = lin[n:macro.plus_two] - macro.double_fork = 2 * n - # ^ This is vestigial, just to show we don't need to blacklist it - # Test returning both a single value node and an output channel, - # even though here we could just use the node both times - return macro.sliced_list, macro.plus_two.channel - - macro = LikeAFunction(n=1, lin=[1, 2, 3, 4, 5, 6]) - self.assertListEqual(["lin", "n"], macro.inputs.labels) - self.assertDictEqual({"n_plus_2": 3, "lout": [2, 3]}, macro()) - def test_efficient_signature_interface(self): with self.subTest("Forked input"): @as_macro_node("output") @@ -632,12 +539,28 @@ def test_storage_for_modified_macros(self): finally: macro.storage.delete() - def test_wrong_return(self): + def test_output_label_stripping(self): + """Test extensions to the `ScrapesIO` mixin.""" + + @as_macro_node() + def OutputScrapedFromFilteredReturn(macro): + macro.foo = macro.create.standard.UserInput() + return macro.foo + + self.assertListEqual( + ["foo"], + list(OutputScrapedFromFilteredReturn.preview_outputs().keys()), + msg="The first, self-like argument, should get stripped from output labels" + ) + with self.assertRaises( - TypeError, - msg="Macro returning object without channel did not raise an error" + ValueError, + msg="Return values with extra dots are not permissible as scraped labels" ): - macro_node(wrong_return_macro) + @as_macro_node() + def ReturnHasDot(macro): + macro.foo = macro.create.standard.UserInput() + return macro.foo.outputs.user_input if __name__ == '__main__':