diff --git a/solver/llbsolver/vertex.go b/solver/llbsolver/vertex.go index c2f659c5788e..c5e3870a9d43 100644 --- a/solver/llbsolver/vertex.go +++ b/solver/llbsolver/vertex.go @@ -155,9 +155,8 @@ func (dpc *detectPrunedCacheID) Load(op *pb.Op, md *pb.OpMetadata, opt *solver.V } func Load(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, opts ...LoadOpt) (solver.Edge, error) { - return loadLLB(ctx, def, polEngine, func(dgst digest.Digest, pbOp *pb.Op, load func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error) { - opMetadata := def.Metadata[string(dgst)] - vtx, err := newVertex(dgst, pbOp, opMetadata, load, opts...) + return loadLLB(ctx, def, polEngine, func(dgst digest.Digest, op *op, load func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error) { + vtx, err := newVertex(dgst, op.Op, op.Metadata, load, opts...) if err != nil { return nil, err } @@ -198,7 +197,7 @@ func newVertex(dgst digest.Digest, op *pb.Op, opMeta *pb.OpMetadata, load func(d return vtx, nil } -func recomputeDigests(ctx context.Context, all map[digest.Digest]*pb.Op, visited map[digest.Digest]digest.Digest, dgst digest.Digest) (digest.Digest, error) { +func recomputeDigests(ctx context.Context, all map[digest.Digest]*op, visited map[digest.Digest]digest.Digest, dgst digest.Digest) (digest.Digest, error) { if dgst, ok := visited[dgst]; ok { return dgst, nil } @@ -235,30 +234,38 @@ func recomputeDigests(ctx context.Context, all map[digest.Digest]*pb.Op, visited return newDgst, nil } +// op is a private wrapper around pb.Op that includes its metadata. +type op struct { + *pb.Op + Metadata *pb.OpMetadata +} + // loadLLB loads LLB. // fn is executed sequentially. -func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, fn func(digest.Digest, *pb.Op, func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error)) (solver.Edge, error) { +func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEvaluator, fn func(digest.Digest, *op, func(digest.Digest) (solver.Vertex, error)) (solver.Vertex, error)) (solver.Edge, error) { if len(def.Def) == 0 { return solver.Edge{}, errors.New("invalid empty definition") } - allOps := make(map[digest.Digest]*pb.Op) + allOps := make(map[digest.Digest]*op) var lastDgst digest.Digest for _, dt := range def.Def { - var op pb.Op - if err := op.UnmarshalVT(dt); err != nil { + var pbop pb.Op + if err := pbop.Unmarshal(dt); err != nil { return solver.Edge{}, errors.Wrap(err, "failed to parse llb proto op") } dgst := digest.FromBytes(dt) if polEngine != nil { - if _, err := polEngine.Evaluate(ctx, op.GetSource()); err != nil { + if _, err := polEngine.Evaluate(ctx, pbop.GetSource()); err != nil { return solver.Edge{}, errors.Wrap(err, "error evaluating the source policy") } } - - allOps[dgst] = &op + allOps[dgst] = &op{ + Op: &pbop, + Metadata: def.Metadata[string(dgst)], + } lastDgst = dgst } @@ -300,7 +307,7 @@ func loadLLB(ctx context.Context, def *pb.Definition, polEngine SourcePolicyEval return nil, errors.Errorf("invalid missing input digest %s", dgst) } - if err := opsutils.Validate(op); err != nil { + if err := opsutils.Validate(op.Op); err != nil { return nil, err } diff --git a/solver/llbsolver/vertex_test.go b/solver/llbsolver/vertex_test.go index 3123207221f2..7aa68c990f44 100644 --- a/solver/llbsolver/vertex_test.go +++ b/solver/llbsolver/vertex_test.go @@ -38,9 +38,9 @@ func TestRecomputeDigests(t *testing.T) { require.NoError(t, err) op2Digest := digest.FromBytes(op2Data) - all := map[digest.Digest]*pb.Op{ - newDigest: op1, - op2Digest: op2, + all := map[digest.Digest]*op{ + newDigest: {Op: op1}, + op2Digest: {Op: op2}, } visited := map[digest.Digest]digest.Digest{oldDigest: newDigest} @@ -48,10 +48,10 @@ func TestRecomputeDigests(t *testing.T) { require.NoError(t, err) require.Len(t, visited, 2) require.Len(t, all, 2) - assert.Equal(t, op1, all[newDigest]) + assert.Equal(t, op1, all[newDigest].Op) require.Equal(t, newDigest, visited[oldDigest]) - require.Equal(t, op1, all[newDigest]) - assert.Equal(t, op2, all[updated]) + require.Equal(t, op1, all[newDigest].Op) + assert.Equal(t, op2, all[updated].Op) require.Equal(t, newDigest, digest.Digest(op2.Inputs[0].Digest)) assert.NotEqual(t, op2Digest, updated) } @@ -88,14 +88,14 @@ func TestIngestDigest(t *testing.T) { // Read the definition from the test data and ensure it uses the // canonical digests after recompute. var lastDgst digest.Digest - all := map[digest.Digest]*pb.Op{} + all := map[digest.Digest]*op{} for _, in := range def.Def { - op := new(pb.Op) - err := op.Unmarshal(in) + opNew := new(pb.Op) + err := opNew.Unmarshal(in) require.NoError(t, err) lastDgst = digest.FromBytes(in) - all[lastDgst] = op + all[lastDgst] = &op{Op: opNew} } fmt.Println(all, lastDgst)