Skip to content

Commit

Permalink
Expose a few options for model.Solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
irfansharif committed Sep 18, 2021
1 parent 6b95fa8 commit 263d25e
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 185 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ go_library(
"intvar.go",
"linearexpr.go",
"model.go",
"options.go",
"result.go",
],
importpath = "github.com/irfansharif/solver",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ z = false
$ make rewrite

# to run specific tests
$ bazel test :all internal/... --test_output=all \
$ bazel test ... --test_output=all \
--cache_test_results=no \
--test_arg='-test.v' \
--test_filter='Test.*'
Expand Down
2 changes: 2 additions & 0 deletions datadriven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func TestDatadriven(t *testing.T) {

model := solver.NewModel("") // instantiate a model

// Identifier scope, mapping identifiers to the types they were
// instantiated with.
itvM := make(map[string]solver.Interval)
varM := make(map[string]solver.IntVar)
litM := make(map[string]solver.Literal)
Expand Down
183 changes: 93 additions & 90 deletions internal/internal.go

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions internal/sat.i
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse,
%unignore operations_research::sat::CpSatHelper::ValidateModel;
%unignore operations_research::sat::CpSatHelper::VariableDomain; // unused

%feature("director") operations_research::sat::LogCallback; // unused
%unignore operations_research::sat::LogCallback;
%unignore operations_research::sat::LogCallback::~LogCallback;
%unignore operations_research::sat::LogCallback::NewMessage;

%feature("director") operations_research::sat::SolutionCallback;
%unignore operations_research::sat::SolutionCallback;
%unignore operations_research::sat::SolutionCallback::~SolutionCallback;
%unignore operations_research::sat::SolutionCallback::BestObjectiveBound;
%feature("director") operations_research::sat::LogCallback;
%unignore operations_research::sat::LogCallback;
%unignore operations_research::sat::LogCallback::~LogCallback;
%unignore operations_research::sat::LogCallback::NewMessage;
%feature("nodirector") operations_research::sat::SolutionCallback::BestObjectiveBound;
%unignore operations_research::sat::SolutionCallback::HasResponse;
%feature("nodirector") operations_research::sat::SolutionCallback::HasResponse;
Expand Down
92 changes: 46 additions & 46 deletions internal/wrapper.cc

Large diffs are not rendered by default.

65 changes: 23 additions & 42 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,41 +185,32 @@ func (m *Model) String() string {
// for all the variables/literals that were instantiated into it. It returns the
// optimal result if an objective function is declared. If not, it returns
// the first found result that satisfies the model.
func (m *Model) Solve() Result {
wrapper := internal.NewSolveWrapper()
defer func() {
internal.DeleteSolveWrapper(wrapper)
}()
//
// The solve process itself can be configured with various options.
func (m *Model) Solve(os ...Option) Result {
solver := internal.NewSolveWrapper()
defer func() { internal.DeleteSolveWrapper(solver) }()

var opts options
for _, o := range os {
o(&opts, solver)
}
if ok, err := opts.validate(); !ok {
panic(err)
}
if opts.solution != nil {
defer func() { internal.DeleteDirectorSolutionCallback(opts.solution.hook) }()
}

resp := wrapper.Solve(*m.pb)
return Result{pb: &resp}
}
solver.SetParameters(opts.params)
resp := solver.Solve(*m.pb)

// SolveAll returns all valid results that satisfy the model.
func (m *Model) SolveAll() []Result {
var results []Result
cb := &solutionCallback{
cb: func(r Result) {
results = append(results, r)
},
if opts.logger != nil {
for _, line := range strings.Split(resp.SolveLog, "\n") {
opts.logger.Print(line)
}
}
cb.director = internal.NewDirectorSolutionCallback(cb)
defer func() {
internal.DeleteDirectorSolutionCallback(cb.director)
}()

enumerate := true
params := pb.SatParameters{EnumerateAllSolutions: &enumerate}

wrapper := internal.NewSolveWrapper()
defer func() {
internal.DeleteSolveWrapper(wrapper)
}()

wrapper.AddSolutionCallback(cb.director)
wrapper.SetParameters(params)
wrapper.Solve(*m.pb)
return results
return Result{pb: &resp}
}

func (m *Model) name() string {
Expand Down Expand Up @@ -250,13 +241,3 @@ func (m *Model) toObjectiveProto(e LinearExpr) *pb.CpObjectiveProto {
Offset: float64(e.offset()),
}
}

type solutionCallback struct {
cb func(Result)
director internal.SolutionCallback
}

func (p *solutionCallback) OnSolutionCallback() {
proto := p.director.Response()
p.cb(Result{pb: &proto})
}
103 changes: 103 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2021 Irfan Sharif.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.

package solver

import (
"fmt"
"io"
"log"
"time"

"github.com/irfansharif/solver/internal"
"github.com/irfansharif/solver/internal/pb"
)

type Option func(o *options, s internal.SolveWrapper)

type options struct {
params pb.SatParameters
logger *log.Logger
solution *solutionCallback
}

func (o *options) validate() (bool, error) {
if o.params.GetEnumerateAllSolutions() && o.params.GetNumSearchWorkers() > 1 {
return false, fmt.Errorf("cannot enumerate with parallelism > 1")
}
return true, nil
}

// WithTimeout configures the solver with a hard time limit.
func WithTimeout(d time.Duration) Option {
return func(o *options, _ internal.SolveWrapper) {
seconds := d.Seconds()
o.params.MaxTimeInSeconds = &seconds
}
}

// WithLogger configures the solver to route its internal logging to the given
// io.Writer, using the given prefix.
func WithLogger(w io.Writer, prefix string) Option {
return func(o *options, s internal.SolveWrapper) {
logSearchProgress, logToResponse, logToStdout := true, true, false
o.params.LogSearchProgress = &logSearchProgress
o.params.LogToStdout = &logToStdout
o.params.LogToResponse = &logToResponse

// TODO(irfansharif): Right now we're simply logging to the response
// proto, which isn't being streamed during the search process and not
// super. OR-Tools v9.0 does support an experimental logger callback
// (looks identical to the solution callback), but that didn't work.
//
// Worth checking back on at some point.
// https://github.com/google/or-tools/issues/1903
o.logger = log.New(w, prefix, 0)
}
}

// WithParallelism configures the solver to use the given number of parallel
// workers during search. If the number provided is <= 1, there will be no
// parallelism.
func WithParallelism(parallelism int) Option {
return func(options *options, w internal.SolveWrapper) {
threads := int32(parallelism)
options.params.NumSearchWorkers = &threads
}
}

// WithEnumeration configures the solver to enumerate over all solutions without
// objective. This option is incompatible with a parallelism greater than 1.
func WithEnumeration(f func(Result)) Option {
return func(o *options, s internal.SolveWrapper) {
enumerate := true
o.params.EnumerateAllSolutions = &enumerate

o.solution = &solutionCallback{f: f}
o.solution.hook = internal.NewDirectorSolutionCallback(o.solution)
s.AddSolutionCallback(o.solution.hook)
}
}

// solutionCallback is used to hook into the underlying solver during its search
// process. It's invoked whenever a solution is found.
type solutionCallback struct {
f func(Result)
hook internal.SolutionCallback
}

func (p *solutionCallback) OnSolutionCallback() {
proto := p.hook.Response()
p.f(Result{pb: &proto})
}
32 changes: 30 additions & 2 deletions solver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package solver
import (
"fmt"
"math"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -272,13 +274,16 @@ func TestElement(t *testing.T) {
require.True(t, result.Value(target) == 10*result.Value(index))
}

func TestIterateThroughSolutions(t *testing.T) {
func TestEnumerateSolutions(t *testing.T) {
model := NewModel("")

var numVals int64 = 3
_ = model.NewIntVar(1, numVals, "x")

results := model.SolveAll()
var results []Result
_ = model.Solve(
WithEnumeration(func(r Result) { results = append(results, r) }),
)
require.Len(t, results, int(numVals))
}

Expand Down Expand Up @@ -429,3 +434,26 @@ func TestNonOverlappingIntervalsWithEnforcement(t *testing.T) {
}
}
}

func TestSolverOptions(t *testing.T) {
model := NewModel("")

A := model.NewLiteral("A")
B := model.NewLiteral("B")
C := model.NewLiteral("C")

model.AddConstraints(NewAllSameConstraint(A, B, C))

t.Log(model.String())
result := model.Solve(
WithLogger(os.Stdout, "[solver] "),
WithParallelism(4),
WithTimeout(time.Second),
)
require.True(t, result.Optimal(), "expected solver to find solution")

{
A, B, C := result.BooleanValue(A), result.BooleanValue(B), result.BooleanValue(C)
require.True(t, A == B && B == C)
}
}

0 comments on commit 263d25e

Please sign in to comment.