-
Notifications
You must be signed in to change notification settings - Fork 705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
predict sql submit predict job as maxcompute udf script #605
Changes from 16 commits
9174479
6a47ac4
f64f50e
a7b1a92
a61d51f
d0aa22a
bb9e0c3
661dc68
cd001e8
46ff066
f54b04b
742c2dd
99ef0c4
90c298b
7771500
da4afeb
76554a2
2119731
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ import ( | |
"fmt" | ||
"io/ioutil" | ||
"os" | ||
"os/exec" | ||
"path/filepath" | ||
"strconv" | ||
"strings" | ||
|
@@ -28,6 +29,9 @@ import ( | |
"sqlflow.org/gomaxcompute" | ||
) | ||
|
||
var alpsTrainTemplate = template.Must(template.New("alps_train").Parse(alpsTrainTemplateText)) | ||
var alpsPredTemplate = template.Must(template.New("alps_predict").Parse(alpsPredTemplateText)) | ||
|
||
type alpsFiller struct { | ||
// Training or Predicting | ||
IsTraining bool | ||
|
@@ -52,13 +56,21 @@ type alpsFiller struct { | |
TrainClause *resolvedTrainClause | ||
ExitOnSubmit bool | ||
|
||
// Predict | ||
PredictUDF string | ||
|
||
// Feature map | ||
FeatureMapTable string | ||
FeatureMapPartition string | ||
|
||
// ODPS | ||
OdpsConf *gomaxcompute.Config | ||
EngineCode string | ||
|
||
// Credential | ||
UserID string | ||
OSSID string | ||
OSSKey string | ||
} | ||
|
||
type alpsFeatureColumn interface { | ||
|
@@ -255,13 +267,37 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra | |
ExitOnSubmit: exitOnSubmit}, nil | ||
} | ||
|
||
func newALPSPredictFiller(pr *extendedSelect) (*alpsFiller, error) { | ||
return nil, fmt.Errorf("alps predict not supported") | ||
func newALPSPredictFiller(pr *extendedSelect, session *pb.Session) (*alpsFiller, error) { | ||
var ossID, ossKey *expr | ||
var ok bool | ||
if ossID, ok = pr.predAttrs["OSS_ID"]; !ok { | ||
return nil, fmt.Errorf("the ALPS Predict job should specify OSS_ID") | ||
} | ||
if ossKey, ok = pr.predAttrs["OSS_KEY"]; !ok { | ||
return nil, fmt.Errorf("the ALPS Predict job should specify OSS_KEY") | ||
} | ||
modelDir := fmt.Sprintf("oss://arks-model/%s/%s.tar.gz", session.UserId, pr.predictClause.model) | ||
|
||
return &alpsFiller{ | ||
IsTraining: true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why predict should set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
PredictInputTable: pr.tables[0], | ||
PredictOutputTable: pr.predictClause.into, | ||
PredictUDF: strings.Join(pr.fields.Strings(), " "), | ||
ModelDir: modelDir, | ||
UserID: session.UserId, | ||
OSSID: ossID.String(), | ||
OSSKey: ossKey.String(), | ||
}, nil | ||
} | ||
|
||
func submitALPS(w *PipeWriter, cwd string, filler *alpsFiller) error { | ||
func alpsTrain(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb.Session, ds *trainAndValDataset) error { | ||
var program bytes.Buffer | ||
if err := alpsTemplate.Execute(&program, filler); err != nil { | ||
filler, err := newALPSTrainFiller(pr, db, session, ds) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if err = alpsTrainTemplate.Execute(&program, filler); err != nil { | ||
return fmt.Errorf("submitALPS: failed executing template: %v", err) | ||
} | ||
code := program.String() | ||
|
@@ -293,24 +329,53 @@ pip install http://091349.oss-cn-hangzhou-zmf.aliyuncs.com/alps/sqlflow/alps-2.0 | |
if e := cmd.Run(); e != nil { | ||
return fmt.Errorf("code %v failed %v", code, e) | ||
} | ||
// TODO(uuleon): save model to DB if train | ||
// TODO(uuleon): save model to DB | ||
return nil | ||
} | ||
|
||
func alpsTrain(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb.Session, ds *trainAndValDataset) error { | ||
f, err := newALPSTrainFiller(pr, db, session, ds) | ||
func alpsPred(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb.Session) error { | ||
var program bytes.Buffer | ||
filler, err := newALPSPredictFiller(pr, session) | ||
if err != nil { | ||
return err | ||
} | ||
return submitALPS(w, cwd, f) | ||
} | ||
if err = alpsPredTemplate.Execute(&program, filler); err != nil { | ||
return fmt.Errorf("submitALPS: failed executing template: %v", err) | ||
} | ||
|
||
func alpsPred(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb.Session) error { | ||
f, err := newALPSPredictFiller(pr) | ||
fname := "alps_pre.odps" | ||
filepath := filepath.Join(cwd, fname) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defer clean up this file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
f, err := os.Create(filepath) | ||
if err != nil { | ||
return err | ||
return fmt.Errorf("Create ODPS script failed %v", err) | ||
} | ||
f.WriteString(program.String()) | ||
f.Close() | ||
cw := &logChanWriter{wr: w} | ||
_, ok := db.Driver().(*gomaxcompute.Driver) | ||
if !ok { | ||
return fmt.Errorf("Alps Predict Job only supports Maxcompute database driver") | ||
} | ||
cfg, err := gomaxcompute.ParseDSN(db.dataSourceName) | ||
if err != nil { | ||
return fmt.Errorf("Parse Maxcompute DSN failed: %v", err) | ||
} | ||
// FIXME(Yancey1989): using https proto. | ||
fixedEndpoint := strings.Replace(cfg.Endpoint, "https://", "http://", 0) | ||
// TODO(Yancey1989): submit the Maxcompute UDF script using gomaxcompute driver. | ||
cmd := exec.Command("odpscmd", | ||
"-u", cfg.AccessID, | ||
"-p", cfg.AccessKey, | ||
fmt.Sprintf("--endpoint=%s", fixedEndpoint), | ||
fmt.Sprintf("--project=%s", cfg.Project), | ||
"-s", filepath) | ||
cmd.Dir = cwd | ||
cmd.Stdout = cw | ||
cmd.Stderr = cw | ||
if e := cmd.Run(); e != nil { | ||
return fmt.Errorf("submit ODPS script %s failed %v", program.String(), e) | ||
} | ||
return submitALPS(w, cwd, f) | ||
return nil | ||
} | ||
|
||
func (nc *numericColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { | ||
|
@@ -425,134 +490,6 @@ func generateAlpsFeatureColumnCode(fcs []featureColumn, metadata *metadata) ([]s | |
return codes, nil | ||
} | ||
|
||
const alpsTemplateText = ` | ||
# coding: utf-8 | ||
# Copyright (c) Antfin, Inc. All rights reserved. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
|
||
import tensorflow as tf | ||
|
||
from alps.conf.closure import Closure | ||
from alps.framework.train.training import build_run_config | ||
from alps.framework.exporter import ExportStrategy | ||
from alps.framework.exporter.arks_exporter import ArksExporter | ||
from alps.client.base import run_experiment, submit_experiment | ||
from alps.framework.engine import LocalEngine, YarnEngine, ResourceConf | ||
from alps.framework.column.column import DenseColumn, SparseColumn, GroupedSparseColumn | ||
from alps.framework.exporter.compare_fn import best_auc_fn | ||
from alps.io import DatasetX | ||
from alps.io.base import OdpsConf, FeatureMap | ||
from alps.framework.experiment import EstimatorBuilder, Experiment, TrainConf, EvalConf, RuntimeConf | ||
from alps.io.reader.odps_reader import OdpsReader | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # for debug usage. | ||
#tf.logging.set_verbosity(tf.logging.INFO) | ||
|
||
class SQLFlowEstimatorBuilder(EstimatorBuilder): | ||
def _build(self, experiment, run_config): | ||
{{if ne .FeatureMapTable ""}} | ||
feature_columns = [] | ||
{{.FeatureColumnCode}} | ||
{{end}} | ||
{{if ne .ImportCode ""}} | ||
{{.ImportCode}} | ||
{{end}} | ||
return {{.ModelCreatorCode}} | ||
|
||
if __name__ == "__main__": | ||
odpsConf=OdpsConf( | ||
accessid="{{.OdpsConf.AccessID}}", | ||
accesskey="{{.OdpsConf.AccessKey}}", | ||
endpoint="{{.OdpsConf.Endpoint}}" | ||
) | ||
|
||
trainDs = DatasetX( | ||
num_epochs={{.TrainClause.Epoch}}, | ||
batch_size={{.TrainClause.BatchSize}}, | ||
shuffle="{{.TrainClause.EnableShuffle}}" == "true", | ||
shuffle_buffer_size={{.TrainClause.ShuffleBufferSize}}, | ||
{{if .TrainClause.EnableCache}} | ||
cache_file={{.TrainClause.CachePath}}, | ||
{{end}} | ||
reader=OdpsReader( | ||
odps=odpsConf, | ||
project="{{.OdpsConf.Project}}", | ||
table="{{.TrainInputTable}}", | ||
# FIXME(typhoonzero): add field_names back if needed. | ||
# field_names={{.Fields}}, | ||
features={{.X}}, | ||
labels={{.Y}}, | ||
{{if ne .FeatureMapTable ""}} | ||
feature_map=FeatureMap(table="{{.FeatureMapTable}}", | ||
{{if ne .FeatureMapPartition ""}} | ||
partition="{{.FeatureMapPartition}}" | ||
{{end}} | ||
), | ||
flatten_group=True | ||
{{end}} | ||
), | ||
drop_remainder="{{.TrainClause.DropRemainder}}" == "true" | ||
) | ||
|
||
evalDs = DatasetX( | ||
num_epochs=1, | ||
batch_size={{.TrainClause.BatchSize}}, | ||
reader=OdpsReader( | ||
odps=odpsConf, | ||
project="{{.OdpsConf.Project}}", | ||
table="{{.EvalInputTable}}", | ||
# FIXME(typhoonzero): add field_names back if needed. | ||
# field_names={{.Fields}}, | ||
features={{.X}}, | ||
labels={{.Y}}, | ||
flatten_group=True | ||
) | ||
) | ||
|
||
export_path = "{{.ModelDir}}" | ||
{{if ne .ScratchDir ""}} | ||
runtime_conf = RuntimeConf(model_dir="{{.ScratchDir}}") | ||
{{else}} | ||
runtime_conf = None | ||
{{end}} | ||
experiment = Experiment( | ||
user="shangchun.sun", # TODO(joyyoj) pai will check user name be a valid user, removed later. | ||
engine={{.EngineCode}}, | ||
train=TrainConf(input=trainDs, | ||
{{if (ne .TrainClause.MaxSteps -1)}} | ||
max_steps={{.TrainClause.MaxSteps}}, | ||
{{end}} | ||
), | ||
eval=EvalConf(input=evalDs, | ||
# FIXME(typhoonzero): Support configure metrics | ||
metrics_set=['accuracy'], | ||
{{if (ne .TrainClause.EvalSteps -1)}} | ||
steps={{.TrainClause.EvalSteps}}, | ||
{{end}} | ||
start_delay_secs={{.TrainClause.EvalStartDelay}}, | ||
throttle_secs={{.TrainClause.EvalThrottle}}, | ||
), | ||
# FIXME(typhoonzero): Use ExportStrategy.BEST when possible. | ||
exporter=ArksExporter(deploy_path=export_path, strategy=ExportStrategy.LATEST, compare_fn=Closure(best_auc_fn)), | ||
runtime = runtime_conf, | ||
model_builder=SQLFlowEstimatorBuilder()) | ||
|
||
if isinstance(experiment.engine, LocalEngine): | ||
run_experiment(experiment) | ||
else: | ||
if "{{.ExitOnSubmit}}" == "false": | ||
submit_experiment(experiment) | ||
else: | ||
submit_experiment(experiment, exit_on_submit=True) | ||
` | ||
|
||
var alpsTemplate = template.Must(template.New("alps").Parse(alpsTemplateText)) | ||
|
||
type metadata struct { | ||
odpsConfig *gomaxcompute.Config | ||
table string | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure
odpscmd
a public available program.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the link from Aliyun, I think it's a public program.