Skip to content

Commit

Permalink
Timesfm docker code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688609435
  • Loading branch information
dstnluong-google authored and copybara-github committed Oct 29, 2024
1 parent b183a3d commit a5797cd
Show file tree
Hide file tree
Showing 3 changed files with 587 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
FROM nvidia/cuda:12.3.2-devel-ubuntu22.04

# Install basic libs
RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \
cmake \
curl \
wget \
sudo \
gnupg \
libsm6 \
libxext6 \
libxrender-dev \
lsb-release \
ca-certificates \
build-essential \
git \
software-properties-common \
cuda-toolkit \
libcudnn8 \
apt-transport-https

RUN apt install -y --no-install-recommends python3.10 \
python3.10-venv \
python3.10-dev \
python3-pip

Run apt-get autoremove -y

RUN pip install --upgrade pip
RUN pip install --upgrade --ignore-installed \
"jax[cuda12]==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
numpy==1.26.4 \
paxml==1.4.0 \
praxis==1.4.0 \
jaxlib==0.4.26 \
pandas==2.1.4 \
einshape==1.0.0 \
utilsforecast==0.1.10 \
huggingface_hub[cli]==0.23.0 \
google-cloud-aiplatform[prediction]==1.51.0 \
fastapi==0.109.1 \
flask==3.0.3 \
smart_open[gcs]==7.0.4 \
protobuf==3.19.6 \
scikit-learn==1.0.2 \
timesfm==1.0.1

# Download license.
RUN wget https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/LICENSE

# Move scaffold.
COPY model_oss/timesfm/main.py /app/main.py
COPY model_oss/timesfm/predictor.py /app/predictor.py

WORKDIR ..

# Spin off inference server.
CMD ["python3", "/app/main.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Predict server for TimesFM."""

import json
import os
import flask
import predictor

# Create the flask app.
app = flask.Flask(__name__)
_OK_STATUS = 200
_INTERNAL_ERROR_STATUS = 500
_HOST = '0.0.0.0'

# Define the predictor and load the checkpoints.
predictor = predictor.TimesFMPredictor()
predictor.load(os.environ['AIP_STORAGE_URI'])


@app.route(os.environ['AIP_HEALTH_ROUTE'], methods=['GET'])
def health() -> flask.Response:
return flask.Response(status=_OK_STATUS)


@app.route(os.environ['AIP_PREDICT_ROUTE'], methods=['GET', 'POST'])
def predict() -> flask.Response:
"""Calls TimesFM for prediction.
Returns:
A `flask.Response` containing the prediction result in JSON.
"""
try:
body = flask.request.get_json(silent=True, force=True)
preprocessed_inputs = predictor.preprocess(body)
outputs = predictor.predict(preprocessed_inputs)
postprocessed_outputs = predictor.postprocess(outputs)
return flask.Response(
json.dumps(postprocessed_outputs),
status=_OK_STATUS,
mimetype='application/json',
)
except Exception as e: # pylint: disable=broad-exception-caught
return flask.Response(
json.dumps({'error': str(e)}),
status=_INTERNAL_ERROR_STATUS,
mimetype='application/json',
)


if __name__ == '__main__':
app.run(host=_HOST, port=os.environ['AIP_HTTP_PORT'])
Loading

0 comments on commit a5797cd

Please sign in to comment.