From 0e2fa9645e3ce18e5aeefef77d69dce5cbadd32f Mon Sep 17 00:00:00 2001 From: DeedleFake Date: Sun, 13 Oct 2019 15:48:16 -0400 Subject: [PATCH] multiple: use pointers instead of direct structs (#63) * proto: Use pointers to the message types. * p9: Fix API breakage. * p9: Missed one. --- client.go | 12 ++-- fs.go | 164 ++++++++++++++++++++++++------------------------- msg.go | 10 +-- proto/proto.go | 14 +++-- remote.go | 28 ++++----- 5 files changed, 117 insertions(+), 111 deletions(-) diff --git a/client.go b/client.go index 8d674af..7da4c37 100644 --- a/client.go +++ b/client.go @@ -50,7 +50,7 @@ func (c *Client) nextFID() uint32 { // allowed message size. A handshake must be performed before any // other request types may be sent. func (c *Client) Handshake(msize uint32) (uint32, error) { - rsp, err := c.Send(Tversion{ + rsp, err := c.Send(&Tversion{ Msize: msize, Version: Version, }) @@ -58,7 +58,7 @@ func (c *Client) Handshake(msize uint32) (uint32, error) { return 0, err } - version := rsp.(Rversion) + version := rsp.(*Rversion) if version.Version != Version { return 0, ErrUnsupportedVersion } @@ -73,7 +73,7 @@ func (c *Client) Handshake(msize uint32) (uint32, error) { func (c *Client) Auth(user, aname string) (*Remote, error) { fid := c.nextFID() - rsp, err := c.Send(Tauth{ + rsp, err := c.Send(&Tauth{ AFID: fid, Uname: user, Aname: aname, @@ -81,7 +81,7 @@ func (c *Client) Auth(user, aname string) (*Remote, error) { if err != nil { return nil, err } - rauth := rsp.(Rauth) + rauth := rsp.(*Rauth) return &Remote{ client: c, @@ -101,7 +101,7 @@ func (c *Client) Attach(afile *Remote, user, aname string) (*Remote, error) { afid = afile.fid } - rsp, err := c.Send(Tattach{ + rsp, err := c.Send(&Tattach{ FID: fid, AFID: afid, Uname: user, @@ -110,7 +110,7 @@ func (c *Client) Attach(afile *Remote, user, aname string) (*Remote, error) { if err != nil { return nil, err } - attach := rsp.(Rattach) + attach := rsp.(*Rattach) return &Remote{ client: c, diff --git a/fs.go b/fs.go index da51f1c..8af1ab0 100644 --- a/fs.go +++ b/fs.go @@ -221,9 +221,9 @@ func (h *fsHandler) largeCount(count uint32) bool { return IOHeaderSize+count > h.msize } -func (h *fsHandler) version(msg Tversion) interface{} { +func (h *fsHandler) version(msg *Tversion) interface{} { if msg.Version != Version { - return Rerror{ + return &Rerror{ Ename: ErrUnsupportedVersion.Error(), } } @@ -232,16 +232,16 @@ func (h *fsHandler) version(msg Tversion) interface{} { h.msize = msg.Msize } - return Rversion{ + return &Rversion{ Msize: h.msize, Version: Version, } } -func (h *fsHandler) auth(msg Tauth) interface{} { +func (h *fsHandler) auth(msg *Tauth) interface{} { file, err := h.fs.Auth(msg.Uname, msg.Aname) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -252,7 +252,7 @@ func (h *fsHandler) auth(msg Tauth) interface{} { f.file = file - return Rauth{ + return &Rauth{ AQID: QID{ Type: QTAuth, Path: h.getNextPath(), @@ -260,19 +260,19 @@ func (h *fsHandler) auth(msg Tauth) interface{} { } } -func (h *fsHandler) flush(msg Tflush) interface{} { +func (h *fsHandler) flush(msg *Tflush) interface{} { // TODO: Implement this. - return Rerror{ + return &Rerror{ Ename: "flush is not supported", } } -func (h *fsHandler) attach(msg Tattach) interface{} { +func (h *fsHandler) attach(msg *Tattach) interface{} { var afile File if msg.AFID != NoFID { tmp, ok := h.getFile(msg.AFID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "no such AFID", } } @@ -284,21 +284,21 @@ func (h *fsHandler) attach(msg Tattach) interface{} { attach, err := h.fs.Attach(afile, msg.Uname, msg.Aname) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } qid, err := h.getQID(msg.Aname, attach) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } file, ok := h.getFile(msg.FID, true) if ok { - return Rerror{ + return &Rerror{ Ename: "FID in use", } } @@ -308,15 +308,15 @@ func (h *fsHandler) attach(msg Tattach) interface{} { file.path = msg.Aname file.a = attach - return Rattach{ + return &Rattach{ QID: qid, } } -func (h *fsHandler) walk(msg Twalk) interface{} { +func (h *fsHandler) walk(msg *Twalk) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -333,12 +333,12 @@ func (h *fsHandler) walk(msg Twalk) interface{} { qid, err := h.getQID(next, a) if err != nil { if i == 0 { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } - return Rwalk{ + return &Rwalk{ WQID: qids, } } @@ -349,7 +349,7 @@ func (h *fsHandler) walk(msg Twalk) interface{} { file, ok = h.getFile(msg.NewFID, true) if ok { - return Rerror{ + return &Rerror{ Ename: "FID in use", } } @@ -359,15 +359,15 @@ func (h *fsHandler) walk(msg Twalk) interface{} { file.path = base file.a = a - return Rwalk{ + return &Rwalk{ WQID: qids, } } -func (h *fsHandler) open(msg Topen) interface{} { +func (h *fsHandler) open(msg *Topen) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -375,14 +375,14 @@ func (h *fsHandler) open(msg Topen) interface{} { defer file.Unlock() if file.file != nil { - return Rerror{ + return &Rerror{ Ename: "file already open", } } f, err := file.a.Open(file.path, msg.Mode) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -390,7 +390,7 @@ func (h *fsHandler) open(msg Topen) interface{} { qid, err := h.getQID(file.path, file.a) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -400,16 +400,16 @@ func (h *fsHandler) open(msg Topen) interface{} { iounit = unit.IOUnit() } - return Ropen{ + return &Ropen{ QID: qid, IOUnit: iounit, } } -func (h *fsHandler) create(msg Tcreate) interface{} { +func (h *fsHandler) create(msg *Tcreate) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -417,7 +417,7 @@ func (h *fsHandler) create(msg Tcreate) interface{} { defer file.Unlock() if file.file != nil { - return Rerror{ + return &Rerror{ Ename: "file already open", } } @@ -426,7 +426,7 @@ func (h *fsHandler) create(msg Tcreate) interface{} { f, err := file.a.Create(p, msg.Perm, msg.Mode) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -436,7 +436,7 @@ func (h *fsHandler) create(msg Tcreate) interface{} { qid, err := h.getQID(p, file.a) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -446,16 +446,16 @@ func (h *fsHandler) create(msg Tcreate) interface{} { iounit = unit.IOUnit() } - return Rcreate{ + return &Rcreate{ QID: qid, IOUnit: iounit, } } -func (h *fsHandler) read(msg Tread) interface{} { +func (h *fsHandler) read(msg *Tread) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -463,20 +463,20 @@ func (h *fsHandler) read(msg Tread) interface{} { defer file.Unlock() if file.file == nil { - return Rerror{ + return &Rerror{ Ename: "file not open", } } qid, err := h.getQID(file.path, file.a) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } if h.largeCount(msg.Count) { - return Rerror{ + return &Rerror{ Ename: "read too large", } } @@ -489,7 +489,7 @@ func (h *fsHandler) read(msg Tread) interface{} { if msg.Offset == 0 { dir, err := file.file.Readdir() if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -497,7 +497,7 @@ func (h *fsHandler) read(msg Tread) interface{} { for i := range dir { qid, err := h.getQID(path.Join(file.path, dir[i].EntryName), file.a) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -509,7 +509,7 @@ func (h *fsHandler) read(msg Tread) interface{} { file.dir.Reset() err = WriteDir(&file.dir, dir) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -524,7 +524,7 @@ func (h *fsHandler) read(msg Tread) interface{} { // issue. tmp, err := file.dir.Read(buf) if (err != nil) && (err != io.EOF) { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -533,22 +533,22 @@ func (h *fsHandler) read(msg Tread) interface{} { default: tmp, err := file.file.ReadAt(buf, int64(msg.Offset)) if (err != nil) && (err != io.EOF) { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } n = tmp } - return Rread{ + return &Rread{ Data: buf[:n], } } -func (h *fsHandler) write(msg Twrite) interface{} { +func (h *fsHandler) write(msg *Twrite) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -556,29 +556,29 @@ func (h *fsHandler) write(msg Twrite) interface{} { defer file.RUnlock() // full lock like read() does. if file.file == nil { - return Rerror{ + return &Rerror{ Ename: "file not open", } } n, err := file.file.WriteAt(msg.Data, int64(msg.Offset)) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } - return Rwrite{ + return &Rwrite{ Count: uint32(n), } } -func (h *fsHandler) clunk(msg Tclunk) interface{} { +func (h *fsHandler) clunk(msg *Tclunk) interface{} { defer h.fids.Delete(msg.FID) file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -586,50 +586,50 @@ func (h *fsHandler) clunk(msg Tclunk) interface{} { defer file.RUnlock() if file.file == nil { - return new(Rclunk) + return &Rclunk{} } err := file.file.Close() if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } - return new(Rclunk) + return &Rclunk{} } -func (h *fsHandler) remove(msg Tremove) interface{} { +func (h *fsHandler) remove(msg *Tremove) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } file.RLock() defer file.RUnlock() - rsp := h.clunk(Tclunk{ + rsp := h.clunk(&Tclunk{ FID: msg.FID, }) - if _, ok := rsp.(Rerror); ok { + if _, ok := rsp.(error); ok { return rsp } err := file.a.Remove(file.path) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } - return new(Rremove) + return &Rremove{} } -func (h *fsHandler) stat(msg Tstat) interface{} { +func (h *fsHandler) stat(msg *Tstat) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -638,29 +638,29 @@ func (h *fsHandler) stat(msg Tstat) interface{} { stat, err := file.a.Stat(file.path) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } qid, err := h.getQID(file.path, file.a) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } stat.Version = qid.Version stat.Path = qid.Path - return Rstat{ + return &Rstat{ Stat: stat.Stat(), } } -func (h *fsHandler) wstat(msg Twstat) interface{} { +func (h *fsHandler) wstat(msg *Twstat) interface{} { file, ok := h.getFile(msg.FID, false) if !ok { - return Rerror{ + return &Rerror{ Ename: "unknown FID", } } @@ -673,7 +673,7 @@ func (h *fsHandler) wstat(msg Twstat) interface{} { err := file.a.WriteStat(file.path, changes) if err != nil { - return Rerror{ + return &Rerror{ Ename: err.Error(), } } @@ -692,7 +692,7 @@ func (h *fsHandler) wstat(msg Twstat) interface{} { // file.path = next //} - return new(Rwstat) + return &Rwstat{} } func (h *fsHandler) HandleMessage(msg interface{}) (r interface{}) { @@ -703,47 +703,47 @@ func (h *fsHandler) HandleMessage(msg interface{}) (r interface{}) { debug.Log("%#v\n", msg) switch msg := msg.(type) { - case Tversion: + case *Tversion: return h.version(msg) - case Tauth: + case *Tauth: return h.auth(msg) - case Tflush: + case *Tflush: return h.flush(msg) - case Tattach: + case *Tattach: return h.attach(msg) - case Twalk: + case *Twalk: return h.walk(msg) - case Topen: + case *Topen: return h.open(msg) - case Tcreate: + case *Tcreate: return h.create(msg) - case Tread: + case *Tread: return h.read(msg) - case Twrite: + case *Twrite: return h.write(msg) - case Tclunk: + case *Tclunk: return h.clunk(msg) - case Tremove: + case *Tremove: return h.remove(msg) - case Tstat: + case *Tstat: return h.stat(msg) - case Twstat: + case *Twstat: return h.wstat(msg) default: - return Rerror{ + return &Rerror{ Ename: fmt.Sprintf("unexpected message type: %T", msg), } } diff --git a/msg.go b/msg.go index c74471c..4c1e334 100644 --- a/msg.go +++ b/msg.go @@ -81,14 +81,14 @@ type Tversion struct { Version string } -func (Tversion) P9NoTag() {} +func (*Tversion) P9NoTag() {} type Rversion struct { Msize uint32 Version string } -func (r Rversion) P9Msize() uint32 { +func (r *Rversion) P9Msize() uint32 { return r.Msize } @@ -119,7 +119,7 @@ type Rerror struct { Ename string } -func (msg Rerror) Error() string { +func (msg *Rerror) Error() string { return msg.Ename } @@ -204,7 +204,7 @@ type Rstat struct { Stat Stat } -func (stat Rstat) P9Encode() ([]byte, error) { +func (stat *Rstat) P9Encode() ([]byte, error) { var buf bytes.Buffer err := proto.Write(&buf, stat.Stat.size()+2) @@ -237,7 +237,7 @@ type Twstat struct { Stat Stat } -func (stat Twstat) P9Encode() ([]byte, error) { +func (stat *Twstat) P9Encode() ([]byte, error) { var buf bytes.Buffer err := proto.Write(&buf, stat.FID) diff --git a/proto/proto.go b/proto/proto.go index 17da10a..83e6c4b 100644 --- a/proto/proto.go +++ b/proto/proto.go @@ -42,13 +42,19 @@ type Proto struct { // NewProto builds a Proto from the given one-way mapping. func NewProto(mapping map[uint8]reflect.Type) Proto { + rmap := make(map[uint8]reflect.Type, len(mapping)) smap := make(map[reflect.Type]uint8, len(mapping)) - for k, v := range mapping { - smap[v] = k + for id, t := range mapping { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + rmap[id] = t + smap[t] = id } return Proto{ - rmap: mapping, + rmap: rmap, smap: smap, } } @@ -113,7 +119,7 @@ func (p Proto) Receive(r io.Reader, msize uint32) (msg interface{}, tag uint16, return nil, tag, util.Errorf("receive %v: %w", m.Type().Elem(), err) } - return m.Elem().Interface(), tag, err + return m.Interface(), tag, err } // Send writes a message to w with the given tag. It returns any diff --git a/remote.go b/remote.go index 1fcc39c..e7c91e4 100644 --- a/remote.go +++ b/remote.go @@ -38,7 +38,7 @@ func (file *Remote) walk(p string) (*Remote, error) { if w[0] != "/" { w = strings.Split(w[0], "/") } - rsp, err := file.client.Send(Twalk{ + rsp, err := file.client.Send(&Twalk{ FID: file.fid, NewFID: fid, Wname: w, @@ -46,7 +46,7 @@ func (file *Remote) walk(p string) (*Remote, error) { if err != nil { return nil, err } - walk := rsp.(Rwalk) + walk := rsp.(*Rwalk) qid := walk.WQID[len(walk.WQID)-1] if len(walk.WQID) != len(w) { @@ -76,14 +76,14 @@ func (file *Remote) Open(p string, mode uint8) (*Remote, error) { return nil, err } - rsp, err := file.client.Send(Topen{ + rsp, err := file.client.Send(&Topen{ FID: next.fid, Mode: mode, }) if err != nil { return nil, err } - open := rsp.(Ropen) + open := rsp.(*Ropen) next.qid = open.QID @@ -99,7 +99,7 @@ func (file *Remote) Create(p string, perm FileMode, mode uint8) (*Remote, error) return nil, err } - rsp, err := file.client.Send(Tcreate{ + rsp, err := file.client.Send(&Tcreate{ FID: next.fid, Name: name, Perm: perm, @@ -108,7 +108,7 @@ func (file *Remote) Create(p string, perm FileMode, mode uint8) (*Remote, error) if err != nil { return nil, err } - create := rsp.(Rcreate) + create := rsp.(*Rcreate) next.qid = create.QID @@ -119,7 +119,7 @@ func (file *Remote) Create(p string, perm FileMode, mode uint8) (*Remote, error) // "", it closes the current file, if open, and deletes it. func (file *Remote) Remove(p string) error { if p == "" { - _, err := file.client.Send(Tremove{ + _, err := file.client.Send(&Tremove{ FID: file.fid, }) return err @@ -192,7 +192,7 @@ func (file *Remote) maxBufSize() int { } func (file *Remote) readPart(buf []byte, off int64) (int, error) { - rsp, err := file.client.Send(Tread{ + rsp, err := file.client.Send(&Tread{ FID: file.fid, Offset: uint64(off), Count: uint32(len(buf)), @@ -200,7 +200,7 @@ func (file *Remote) readPart(buf []byte, off int64) (int, error) { if err != nil { return 0, err } - read := rsp.(Rread) + read := rsp.(*Rread) if len(read.Data) == 0 { return 0, io.EOF } @@ -252,7 +252,7 @@ func (file *Remote) Write(data []byte) (int, error) { } func (file *Remote) writePart(data []byte, off int64) (int, error) { - rsp, err := file.client.Send(Twrite{ + rsp, err := file.client.Send(&Twrite{ FID: file.fid, Offset: uint64(off), Data: data, @@ -260,7 +260,7 @@ func (file *Remote) writePart(data []byte, off int64) (int, error) { if err != nil { return 0, err } - write := rsp.(Rwrite) + write := rsp.(*Rwrite) if write.Count < uint32(len(data)) { return int(write.Count), io.EOF @@ -302,7 +302,7 @@ func (file *Remote) WriteAt(data []byte, off int64) (int, error) { // Close closes the file on the server. Further usage of the file will // produce errors. func (file *Remote) Close() error { - _, err := file.client.Send(Tclunk{ + _, err := file.client.Send(&Tclunk{ FID: file.fid, }) return err @@ -313,13 +313,13 @@ func (file *Remote) Close() error { // the current file. func (file *Remote) Stat(p string) (DirEntry, error) { if p == "" { - rsp, err := file.client.Send(Tstat{ + rsp, err := file.client.Send(&Tstat{ FID: file.fid, }) if err != nil { return DirEntry{}, err } - stat := rsp.(Rstat) + stat := rsp.(*Rstat) return stat.Stat.DirEntry(), nil }