Skip to content

Support WebAssembly #280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
363 changes: 0 additions & 363 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,353 +537,6 @@ func (c *Client) Kill() {
c.l.Unlock()
}

// Start the underlying subprocess, communicating with it to negotiate
// a port for RPC connections, and returning the address to connect via RPC.
//
// This method is safe to call multiple times. Subsequent calls have no effect.
// Once a client has been started once, it cannot be started again, even if
// it was killed.
func (c *Client) Start() (addr net.Addr, err error) {
c.l.Lock()
defer c.l.Unlock()

if c.address != nil {
return c.address, nil
}

// If one of cmd or reattach isn't set, then it is an error. We wrap
// this in a {} for scoping reasons, and hopeful that the escape
// analysis will pop the stack here.
{
var mutuallyExclusiveOptions int
if c.config.Cmd != nil {
mutuallyExclusiveOptions += 1
}
if c.config.Reattach != nil {
mutuallyExclusiveOptions += 1
}
if c.config.RunnerFunc != nil {
mutuallyExclusiveOptions += 1
}
if mutuallyExclusiveOptions != 1 {
return nil, fmt.Errorf("exactly one of Cmd, or Reattach, or RunnerFunc must be set")
}

if c.config.SecureConfig != nil && c.config.Reattach != nil {
return nil, ErrSecureConfigAndReattach
}
}

if c.config.Reattach != nil {
return c.reattach()
}

if c.config.VersionedPlugins == nil {
c.config.VersionedPlugins = make(map[int]PluginSet)
}

// handle all plugins as versioned, using the handshake config as the default.
version := int(c.config.ProtocolVersion)

// Make sure we're not overwriting a real version 0. If ProtocolVersion was
// non-zero, then we have to just assume the user made sure that
// VersionedPlugins doesn't conflict.
if _, ok := c.config.VersionedPlugins[version]; !ok && c.config.Plugins != nil {
c.config.VersionedPlugins[version] = c.config.Plugins
}

var versionStrings []string
for v := range c.config.VersionedPlugins {
versionStrings = append(versionStrings, strconv.Itoa(v))
}

env := []string{
fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue),
fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort),
fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort),
fmt.Sprintf("PLUGIN_PROTOCOL_VERSIONS=%s", strings.Join(versionStrings, ",")),
}

cmd := c.config.Cmd
if cmd == nil {
// It's only possible to get here if RunnerFunc is non-nil, but we'll
// still use cmd as a spec to populate metadata for the external
// implementation to consume.
cmd = exec.Command("")
}
if !c.config.SkipHostEnv {
cmd.Env = append(cmd.Env, os.Environ()...)
}
cmd.Env = append(cmd.Env, env...)
cmd.Stdin = os.Stdin

if c.config.SecureConfig != nil {
if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil {
return nil, fmt.Errorf("error verifying checksum: %s", err)
} else if !ok {
return nil, ErrChecksumsDoNotMatch
}
}

// Setup a temporary certificate for client/server mtls, and send the public
// certificate to the plugin.
if c.config.AutoMTLS {
c.logger.Info("configuring client automatic mTLS")
certPEM, keyPEM, err := generateCert()
if err != nil {
c.logger.Error("failed to generate client certificate", "error", err)
return nil, err
}
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
c.logger.Error("failed to parse client certificate", "error", err)
return nil, err
}

cmd.Env = append(cmd.Env, fmt.Sprintf("PLUGIN_CLIENT_CERT=%s", certPEM))

c.config.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
ServerName: "localhost",
}
}

if c.config.UnixSocketConfig != nil {
c.unixSocketCfg = *c.config.UnixSocketConfig
}

if c.unixSocketCfg.Group != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketGroup, c.unixSocketCfg.Group))
}

var runner runner.Runner
switch {
case c.config.RunnerFunc != nil:
c.unixSocketCfg.socketDir, err = os.MkdirTemp(c.unixSocketCfg.TempDir, "plugin-dir")
if err != nil {
return nil, err
}
// os.MkdirTemp creates folders with 0o700, so if we have a group
// configured we need to make it group-writable.
if c.unixSocketCfg.Group != "" {
err = setGroupWritable(c.unixSocketCfg.socketDir, c.unixSocketCfg.Group, 0o770)
if err != nil {
return nil, err
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.socketDir))
c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.socketDir)

runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.socketDir)
if err != nil {
return nil, err
}
default:
runner, err = cmdrunner.NewCmdRunner(c.logger, cmd)
if err != nil {
return nil, err
}

}

c.runner = runner
startCtx, startCtxCancel := context.WithTimeout(context.Background(), c.config.StartTimeout)
defer startCtxCancel()
err = runner.Start(startCtx)
if err != nil {
return nil, err
}

// Make sure the command is properly cleaned up if there is an error
defer func() {
rErr := recover()

if err != nil || rErr != nil {
runner.Kill(context.Background())
}

if rErr != nil {
panic(rErr)
}
}()

// Create a context for when we kill
c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())

// Start goroutine that logs the stderr
c.clientWaitGroup.Add(1)
c.stderrWaitGroup.Add(1)
// logStderr calls Done()
go c.logStderr(runner.Name(), runner.Stderr())

c.clientWaitGroup.Add(1)
go func() {
// ensure the context is cancelled when we're done
defer c.ctxCancel()

defer c.clientWaitGroup.Done()

// wait to finish reading from stderr since the stderr pipe reader
// will be closed by the subsequent call to cmd.Wait().
c.stderrWaitGroup.Wait()

// Wait for the command to end.
err := runner.Wait(context.Background())
if err != nil {
c.logger.Error("plugin process exited", "plugin", runner.Name(), "id", runner.ID(), "error", err.Error())
} else {
// Log and make sure to flush the logs right away
c.logger.Info("plugin process exited", "plugin", runner.Name(), "id", runner.ID())
}

os.Stderr.Sync()

// Set that we exited, which takes a lock
c.l.Lock()
defer c.l.Unlock()
c.exited = true
}()

// Start a goroutine that is going to be reading the lines
// out of stdout
linesCh := make(chan string)
c.clientWaitGroup.Add(1)
go func() {
defer c.clientWaitGroup.Done()
defer close(linesCh)

scanner := bufio.NewScanner(runner.Stdout())
for scanner.Scan() {
linesCh <- scanner.Text()
}
if scanner.Err() != nil {
c.logger.Error("error encountered while scanning stdout", "error", scanner.Err())
}
}()

// Make sure after we exit we read the lines from stdout forever
// so they don't block since it is a pipe.
// The scanner goroutine above will close this, but track it with a wait
// group for completeness.
c.clientWaitGroup.Add(1)
defer func() {
go func() {
defer c.clientWaitGroup.Done()
for range linesCh {
}
}()
}()

// Some channels for the next step
timeout := time.After(c.config.StartTimeout)

// Start looking for the address
c.logger.Debug("waiting for RPC address", "plugin", runner.Name())
select {
case <-timeout:
err = errors.New("timeout while waiting for plugin to start")
case <-c.doneCtx.Done():
err = errors.New("plugin exited before we could connect")
case line, ok := <-linesCh:
// Trim the line and split by "|" in order to get the parts of
// the output.
line = strings.TrimSpace(line)
parts := strings.SplitN(line, "|", 6)
if len(parts) < 4 {
errText := fmt.Sprintf("Unrecognized remote plugin message: %s", line)
if !ok {
errText += "\n" + "Failed to read any lines from plugin's stdout"
}
additionalNotes := runner.Diagnose(context.Background())
if additionalNotes != "" {
errText += "\n" + additionalNotes
}
err = errors.New(errText)
return
}

// Check the core protocol. Wrapped in a {} for scoping.
{
var coreProtocol int
coreProtocol, err = strconv.Atoi(parts[0])
if err != nil {
err = fmt.Errorf("Error parsing core protocol version: %s", err)
return
}

if coreProtocol != CoreProtocolVersion {
err = fmt.Errorf("Incompatible core API version with plugin. "+
"Plugin version: %s, Core version: %d\n\n"+
"To fix this, the plugin usually only needs to be recompiled.\n"+
"Please report this to the plugin author.", parts[0], CoreProtocolVersion)
return
}
}

// Test the API version
version, pluginSet, err := c.checkProtoVersion(parts[1])
if err != nil {
return addr, err
}

// set the Plugins value to the compatible set, so the version
// doesn't need to be passed through to the ClientProtocol
// implementation.
c.config.Plugins = pluginSet
c.negotiatedVersion = version
c.logger.Debug("using plugin", "version", version)

network, address, err := runner.PluginToHost(parts[2], parts[3])
if err != nil {
return addr, err
}

switch network {
case "tcp":
addr, err = net.ResolveTCPAddr("tcp", address)
case "unix":
addr, err = net.ResolveUnixAddr("unix", address)
default:
err = fmt.Errorf("Unknown address type: %s", address)
}

// If we have a server type, then record that. We default to net/rpc
// for backwards compatibility.
c.protocol = ProtocolNetRPC
if len(parts) >= 5 {
c.protocol = Protocol(parts[4])
}

found := false
for _, p := range c.config.AllowedProtocols {
if p == c.protocol {
found = true
break
}
}
if !found {
err = fmt.Errorf("Unsupported plugin protocol %q. Supported: %v",
c.protocol, c.config.AllowedProtocols)
return addr, err
}

// See if we have a TLS certificate from the server.
// Checking if the length is > 50 rules out catching the unused "extra"
// data returned from some older implementations.
if len(parts) >= 6 && len(parts[5]) > 50 {
err := c.loadServerCert(parts[5])
if err != nil {
return nil, fmt.Errorf("error parsing server cert: %s", err)
}
}
}

c.address = addr
return
}

// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
// server, and load it as the RootCA and ClientCA for the client TLSConfig.
func (c *Client) loadServerCert(cert string) error {
Expand Down Expand Up @@ -1042,22 +695,6 @@ func (c *Client) Protocol() Protocol {
return c.protocol
}

func netAddrDialer(addr net.Addr) func(string, time.Duration) (net.Conn, error) {
return func(_ string, _ time.Duration) (net.Conn, error) {
// Connect to the client
conn, err := net.Dial(addr.Network(), addr.String())
if err != nil {
return nil, err
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
}

return conn, nil
}
}

// dialer is compatible with grpc.WithDialer and creates the connection
// to the plugin.
func (c *Client) dialer(_ string, timeout time.Duration) (net.Conn, error) {
Expand Down
Loading