Codebase list golang-github-denisenkom-go-mssqldb / HEAD buf.go
HEAD

Tree @HEAD (Download .tar.gz)

buf.go @HEADraw · history · blame

package mssql

import (
	"database/sql/driver"
	"encoding/binary"
	"errors"
	"io"
	"net"
)

type packetType uint8

type header struct {
	PacketType packetType
	Status     uint8
	Size       uint16
	Spid       uint16
	PacketNo   uint8
	Pad        uint8
}

// tdsBuffer reads and writes TDS packets of data to the transport.
// The write and read buffers are separate to make sending attn signals
// possible without locks. Currently attn signals are only sent during
// reads, not writes.
type tdsBuffer struct {
	transport io.ReadWriteCloser

	// Write fields.
	wbuf []byte
	wpos uint16

	// Read fields.
	rbuf        []byte
	rpos        uint16
	rsize       uint16
	final       bool
	packet_type packetType

	// afterFirst is assigned to right after tdsBuffer is created and
	// before the first use. It is executed after the first packet is
	// written and then removed.
	afterFirst func()
}

func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
	w := new(tdsBuffer)
	w.wbuf = make([]byte, bufsize)
	w.rbuf = make([]byte, bufsize)
	w.wpos = 0
	w.rpos = 8
	w.transport = transport
	return w
}

func checkBadConn(err error) error {
	if err == io.EOF {
		return driver.ErrBadConn
	}

	switch err.(type) {
	case net.Error:
		return driver.ErrBadConn
	default:
		return err
	}
}

func (rw *tdsBuffer) ResizeBuffer(packetsizei int) {
	if len(rw.rbuf) != packetsizei {
		newbuf := make([]byte, packetsizei)
		copy(newbuf, rw.rbuf)
		rw.rbuf = newbuf
	}
	if len(rw.wbuf) != packetsizei {
		newbuf := make([]byte, packetsizei)
		copy(newbuf, rw.wbuf)
		rw.wbuf = newbuf
	}
}

func (w *tdsBuffer) PackageSize() uint32 {
	return uint32(len(w.wbuf))
}

func (w *tdsBuffer) flush() (err error) {
	// writing packet size
	binary.BigEndian.PutUint16(w.wbuf[2:], w.wpos)

	// writing packet into underlying transport
	if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
		return err
	}

	// execute afterFirst hook if it is set
	if w.afterFirst != nil {
		w.afterFirst()
		w.afterFirst = nil
	}

	w.wpos = 8
	// packet number
	w.wbuf[6] += 1
	return nil
}

func (w *tdsBuffer) Write(p []byte) (total int, err error) {
	total = 0
	for {
		copied := copy(w.wbuf[w.wpos:], p)
		w.wpos += uint16(copied)
		total += copied
		if copied == len(p) {
			break
		}
		if err = w.flush(); err != nil {
			return
		}
		p = p[copied:]
	}
	return
}

func (w *tdsBuffer) WriteByte(b byte) error {
	if int(w.wpos) == len(w.wbuf) {
		if err := w.flush(); err != nil {
			return err
		}
	}
	w.wbuf[w.wpos] = b
	w.wpos += 1
	return nil
}

func (w *tdsBuffer) BeginPacket(packet_type packetType) {
	w.wbuf[0] = byte(packet_type)
	w.wbuf[1] = 0 // packet is incomplete
	w.wbuf[4] = 0 // spid
	w.wbuf[5] = 0
	w.wbuf[6] = 1 // packet id
	w.wbuf[7] = 0 // window
	w.wpos = 8
}

func (w *tdsBuffer) FinishPacket() error {
	w.wbuf[1] = 1 // this is last packet
	return w.flush()
}

func (r *tdsBuffer) readNextPacket() error {
	header := header{}
	var err error
	err = binary.Read(r.transport, binary.BigEndian, &header)
	if err != nil {
		return checkBadConn(err)
	}
	offset := uint16(binary.Size(header))
	if int(header.Size) > len(r.rbuf) {
		return errors.New("Invalid packet size, it is longer than buffer size")
	}
	if int(offset) > int(header.Size) {
		return errors.New("Invalid packet size, it is shorter than header size")
	}
	_, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size])
	if err != nil {
		return checkBadConn(err)
	}
	r.rpos = offset
	r.rsize = header.Size
	r.final = header.Status != 0
	r.packet_type = header.PacketType
	return nil
}

func (r *tdsBuffer) BeginRead() (packetType, error) {
	err := r.readNextPacket()
	if err != nil {
		return 0, err
	}
	return r.packet_type, nil
}

func (r *tdsBuffer) ReadByte() (res byte, err error) {
	if r.rpos == r.rsize {
		if r.final {
			return 0, io.EOF
		}
		err = r.readNextPacket()
		if err != nil {
			return 0, err
		}
	}
	res = r.rbuf[r.rpos]
	r.rpos++
	return res, nil
}

func (r *tdsBuffer) byte() byte {
	b, err := r.ReadByte()
	if err != nil {
		badStreamPanic(err)
	}
	return b
}

func (r *tdsBuffer) ReadFull(buf []byte) {
	_, err := io.ReadFull(r, buf[:])
	if err != nil {
		badStreamPanic(checkBadConn(err))
	}
}

func (r *tdsBuffer) uint64() uint64 {
	var buf [8]byte
	r.ReadFull(buf[:])
	return binary.LittleEndian.Uint64(buf[:])
}

func (r *tdsBuffer) int32() int32 {
	return int32(r.uint32())
}

func (r *tdsBuffer) uint32() uint32 {
	var buf [4]byte
	r.ReadFull(buf[:])
	return binary.LittleEndian.Uint32(buf[:])
}

func (r *tdsBuffer) uint16() uint16 {
	var buf [2]byte
	r.ReadFull(buf[:])
	return binary.LittleEndian.Uint16(buf[:])
}

func (r *tdsBuffer) BVarChar() string {
	l := int(r.byte())
	return r.readUcs2(l)
}

func (r *tdsBuffer) UsVarChar() string {
	l := int(r.uint16())
	return r.readUcs2(l)
}

func (r *tdsBuffer) readUcs2(numchars int) string {
	b := make([]byte, numchars*2)
	r.ReadFull(b)
	res, err := ucs22str(b)
	if err != nil {
		badStreamPanic(err)
	}
	return res
}

func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
	copied = 0
	err = nil
	if r.rpos == r.rsize {
		if r.final {
			return 0, io.EOF
		}
		err = r.readNextPacket()
		if err != nil {
			return
		}
	}
	copied = copy(buf, r.rbuf[r.rpos:r.rsize])
	r.rpos += uint16(copied)
	return
}