Skip to content

Batch transform sparse matrix on Scikit-learn model #1093

Closed
@ivankeller

Description

@ivankeller

Reference: 0414058987

I reproduce here a question I submited on stackoverflow (https://stackoverflow.com/questions/58410583/batch-transform-sparse-matrix-with-aws-sagemaker-python-sdk):

I have successfully trained a Scikit-Learn LSVC model with AWS SageMaker.
I want to make batch prediction (aka. batch transform) on a relatively big dataset which is a scipy sparse matrix with shape 252772 x 185128. (The number of features is high because there is one-hot-encoding of bag-of-words and ngrams features).

I struggle because of:

  • the size of the data

  • the format of the data

I did several experiments to check what was going on:

1. predict locally on sample sparse matrix data

It works
Deserialize the model artifact locally on a SageMaker notebook and predict on a sample of the sparse matrix.
This was just to check that the model can predict on this kind of data.

2. Batch Transform on a sample csv data

It works
Launch a Batch Transform Job on SageMaker and request to transform a small sample in dense csv format : it works but does not scale, obviously.
The code is:

sklearn_model = SKLearnModel(
    model_data=model_artifact_location_on_s3,
    entry_point='my_script.py',
    role=role,
    sagemaker_session=sagemaker_session)

transformer = sklearn_model.transformer(
   instance_count=1, 
   instance_type='ml.m4.xlarge', 
   max_payload=100)

transformer.transform(
   data=batch_data, 
   content_type='text/csv',
   split_type=None)   

print('Waiting for transform job: ' + transformer.latest_transform_job.job_name)
transformer.wait()

where:

  • 'my_script.py' implements a simple model_fn to deserialize the model artifact:
def model_fn(model_dir):
    clf = joblib.load(os.path.join(model_dir, "model.joblib"))
    return clf
  • batch_data is the s3 path for the csv file.

3. Batch Transform of a sample dense numpy dataset.

It works
I prepared a sample of the data and saved it to s3 in Numpy .npy format. According to this documentation, SageMaker Scikit-learn model server can deserialize NPY-formatted data (along with JSON and CSV data).
The only difference with the previous experiment (2) is the argument content_type='application/x-npy' in transformer.transform(...).

This solution does not scale and we would like to pass a Scipy sparse matrix:

4. Batch Transform of a big sparse matrix.

Here is the problem
SageMaker Python SDK does not support sparse matrix format out of the box.
Following this:

I used write_spmatrix_to_sparse_tensor to write the data to protobuf format on s3. The function I used is:

def write_protobuf(X_sparse, bucket, prefix, obj):
    """Write sparse matrix to protobuf format at location bucket/prefix/obj."""
    buf = io.BytesIO()
    write_spmatrix_to_sparse_tensor(file=buf, array=X_sparse, labels=None)
    buf.seek(0)
    key = '{}/{}'.format(prefix, obj)
    boto3.resource('s3').Bucket(bucket).Object(key).upload_fileobj(buf)
    return 's3://{}/{}'.format(bucket, key)

Then the code used for launching the batch transform job is:

sklearn_model = SKLearnModel(
    model_data=model_artifact_location_on_s3,
    entry_point='my_script.py',
    role=role,
    sagemaker_session=sagemaker_session)

transformer = sklearn_model.transformer(
   instance_count=1, 
   instance_type='ml.m4.xlarge', 
   max_payload=100)

transformer.transform(
   data=batch_data, 
   content_type='application/x-recordio-protobuf',
   split_type='RecordIO')   

print('Waiting for transform job: ' + transformer.latest_transform_job.job_name)
transformer.wait()

I get the following error:

sagemaker_containers._errors.ClientError: Content type application/x-recordio-protobuf is not supported by this framework.

Questions:
(Reference doc for Transformer: https://sagemaker.readthedocs.io/en/stable/transformer.html)

  • If content_type='application/x-recordio-protobuf' is not allowed, what should I use?
  • Is split_type='RecordIO' the proper setting in this context?
  • Should I provide an input_fn function in my script to deserialize the data?
  • Is there another better approach to tackle this problem?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions