Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usage snippets for Google Health AI models #1084

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
58 changes: 58 additions & 0 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ export const bm25s = (model: ModelData): string[] => [
retriever = BM25HF.load_from_hub("${model.id}")`,
];

export const cxr_foundation = (model: ModelData): string[] => [
ndebuhr marked this conversation as resolved.
Show resolved Hide resolved
`# Install library
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the install instructions completely from here and keep them in the model card (As they are currently and just start from from PIL import Image

!git clone https://github.com/Google-Health/cxr-foundation.git
import tensorflow as tf, sys
sys.path.append('cxr-foundation/python/')

# Install dependencies
major_version = tf.__version__.rsplit(".", 1)[0]
!pip install tensorflow-text=={major_version} pypng && pip install --no-deps pydicom hcls_imaging_ml_toolkit retrying

# Run inference
from PIL import Image
from clientside.clients import make_hugging_face_client

cxr_client = make_hugging_face_client('cxr_model')
!wget -nc -q https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png

print(cxr_client.get_image_embeddings_from_images([Image.open("Chest_Xray_PA_3-8-2010.png")]))`,
];

export const depth_anything_v2 = (model: ModelData): string[] => {
let encoder: string;
let features: string;
Expand Down Expand Up @@ -168,6 +188,44 @@ focallength_px = prediction["focallength_px"]`;
return [installSnippet, inferenceSnippet];
};

export const derm_foundation = (model: ModelData): string[] => [
ndebuhr marked this conversation as resolved.
Show resolved Hide resolved
`from PIL import Image
from io import BytesIO
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
import requests

# Load test image from SCIN Dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this comment to make it more concise, it's quite evident that we're passing an image via the image variable name and the URL path imo.

# https://github.com/google-research-datasets/scin
IMAGE_URL = "https://storage.googleapis.com/dx-scin-public-data/dataset/images/3445096909671059178.png"
response = requests.get(IMAGE_URL, stream=True)
# Raise an exception if the request fails
response.raise_for_status()
# Load the image into a PIL Image object
image = Image.open(response.raw)

buf = BytesIO()
image.convert("RGB").save(buf, "PNG")
image_bytes = buf.getvalue()
# Format input
input_tensor = tf.train.Example(
features=tf.train.Features(
feature={
"image/encoded": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image_bytes])
)
}
)
).SerializeToString()

# Load the model directly from Hugging Face Hub
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this comment

loaded_model = from_pretrained_keras("google/derm-foundation")

# Call inference
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this

infer = loaded_model.signatures["serving_default"]
output = infer(inputs=tf.constant([input_tensor]))`,
]

const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k";

const diffusers_default = (model: ModelData) => [
Expand Down
2 changes: 2 additions & 0 deletions packages/tasks/src/model-libraries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
prettyLabel: "CXR Foundation",
repoName: "cxr-foundation",
repoUrl: "https://github.com/google-health/cxr-foundation",
snippets: snippets.cxr_foundation,
filter: false,
countDownloads: `path:"precomputed_embeddings/embeddings.npz" OR path:"pax-elixr-b-text/saved_model.pb"`,
},
Expand Down Expand Up @@ -200,6 +201,7 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
prettyLabel: "Derm Foundation",
repoName: "derm-foundation",
repoUrl: "https://github.com/google-health/derm-foundation",
snippets: snippets.derm_foundation,
filter: false,
countDownloads: `path:"scin_dataset_precomputed_embeddings.npz" OR path:"saved_model.pb"`,
},
Expand Down