Skip to content

Commit

Permalink
style: format
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Feb 10, 2023
1 parent 8fd8b63 commit b89d6b7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
6 changes: 3 additions & 3 deletions tools/inference/inference_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@
"import wandb\n",
"\n",
"# Initialize a W&B run.\n",
"project = 'dalle-mini-tables-colab'\n",
"project = \"dalle-mini-tables-colab\"\n",
"run = wandb.init(project=project)\n",
"\n",
"# Initialize an empty W&B Tables.\n",
Expand All @@ -500,10 +500,10 @@
" # If CLIP scores exist, sort the Images\n",
" if logits is not None:\n",
" idxs = logits[i].argsort()[::-1]\n",
" tmp_imgs = images[i::len(prompts)]\n",
" tmp_imgs = images[i :: len(prompts)]\n",
" tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
" else:\n",
" tmp_imgs = images[i::len(prompts)]\n",
" tmp_imgs = images[i :: len(prompts)]\n",
"\n",
" # Add the data to the table.\n",
" gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
Expand Down
47 changes: 23 additions & 24 deletions tools/train/embeddings_retrain_preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@
],
"source": [
"import json\n",
"\n",
"tree = jax.tree_map(lambda x: x.shape, params)\n",
"print(json.dumps(tree, indent=2))"
]
Expand All @@ -455,9 +456,9 @@
"metadata": {},
"outputs": [],
"source": [
"del params['lm_head']\n",
"for layer in ['embed_positions', 'embed_tokens', 'final_ln', 'layernorm_embedding']:\n",
" del params['model']['decoder'][layer]"
"del params[\"lm_head\"]\n",
"for layer in [\"embed_positions\", \"embed_tokens\", \"final_ln\", \"layernorm_embedding\"]:\n",
" del params[\"model\"][\"decoder\"][layer]"
]
},
{
Expand Down Expand Up @@ -839,7 +840,7 @@
}
],
"source": [
"params_reinit['model']['decoder']['embed_positions']"
"params_reinit[\"model\"][\"decoder\"][\"embed_positions\"]"
]
},
{
Expand All @@ -860,7 +861,7 @@
}
],
"source": [
"embedding_new = params_reinit['model']['decoder']['embed_positions']['embedding']\n",
"embedding_new = params_reinit[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
"embedding_new.min(), embedding_new.max()"
]
},
Expand Down Expand Up @@ -893,7 +894,7 @@
}
],
"source": [
"params_original['model']['decoder']['embed_positions']"
"params_original[\"model\"][\"decoder\"][\"embed_positions\"]"
]
},
{
Expand All @@ -914,7 +915,7 @@
}
],
"source": [
"embedding_original = params_original['model']['decoder']['embed_positions']['embedding']\n",
"embedding_original = params_original[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
"embedding_original.min(), embedding_new.max()"
]
},
Expand All @@ -924,9 +925,7 @@
"metadata": {},
"outputs": [],
"source": [
"assert(\n",
" jnp.allclose(embedding_new, embedding_original).item() == False\n",
")"
"assert jnp.allclose(embedding_new, embedding_original).item() == False"
]
},
{
Expand All @@ -935,11 +934,9 @@
"metadata": {},
"outputs": [],
"source": [
"lm_head_original = params_original['lm_head']['kernel']\n",
"lm_head_reinit = params_reinit['lm_head']['kernel']\n",
"assert(\n",
" jnp.allclose(lm_head_reinit, lm_head_original).item() == False\n",
")"
"lm_head_original = params_original[\"lm_head\"][\"kernel\"]\n",
"lm_head_reinit = params_reinit[\"lm_head\"][\"kernel\"]\n",
"assert jnp.allclose(lm_head_reinit, lm_head_original).item() == False"
]
},
{
Expand All @@ -948,12 +945,14 @@
"metadata": {},
"outputs": [],
"source": [
"assert(\n",
" jnp.allclose(\n",
" params_reinit['model']['encoder']['layers']['FlaxBartEncoderLayers']['FlaxBartAttention_0']['k_proj']['kernel'],\n",
" params_original['model']['encoder']['layers']['FlaxBartEncoderLayers']['FlaxBartAttention_0']['k_proj']['kernel']\n",
" ).item()\n",
")"
"assert jnp.allclose(\n",
" params_reinit[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
" \"FlaxBartAttention_0\"\n",
" ][\"k_proj\"][\"kernel\"],\n",
" params_original[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
" \"FlaxBartAttention_0\"\n",
" ][\"k_proj\"][\"kernel\"],\n",
").item()"
]
},
{
Expand Down Expand Up @@ -1077,9 +1076,9 @@
],
"source": [
"wandb.init(\n",
" entity = 'dalle-mini',\n",
" project = 'dalle-mini',\n",
" job_type = 'Seq2Seq',\n",
" entity=\"dalle-mini\",\n",
" project=\"dalle-mini\",\n",
" job_type=\"Seq2Seq\",\n",
")"
]
},
Expand Down

0 comments on commit b89d6b7

Please sign in to comment.