forked from toolshed/abra
		
	
		
			
				
	
	
		
			365 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			365 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2024 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| // Helper code for parsing a protocol buffer
 | |
| 
 | |
| package protolazy
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 
 | |
| 	"google.golang.org/protobuf/encoding/protowire"
 | |
| )
 | |
| 
 | |
| // BufferReader is a structure encapsulating a protobuf and a current position
 | |
| type BufferReader struct {
 | |
| 	Buf []byte
 | |
| 	Pos int
 | |
| }
 | |
| 
 | |
| // NewBufferReader creates a new BufferRead from a protobuf
 | |
| func NewBufferReader(buf []byte) BufferReader {
 | |
| 	return BufferReader{Buf: buf, Pos: 0}
 | |
| }
 | |
| 
 | |
| var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
 | |
| var errOverflow = errors.New("proto: integer overflow")
 | |
| 
 | |
| func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
 | |
| 	i := b.Pos
 | |
| 	l := len(b.Buf)
 | |
| 
 | |
| 	for shift := uint(0); shift < 64; shift += 7 {
 | |
| 		if i >= l {
 | |
| 			err = io.ErrUnexpectedEOF
 | |
| 			return
 | |
| 		}
 | |
| 		v := b.Buf[i]
 | |
| 		i++
 | |
| 		x |= (uint64(v) & 0x7F) << shift
 | |
| 		if v < 0x80 {
 | |
| 			b.Pos = i
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// The number is too large to represent in a 64-bit value.
 | |
| 	err = errOverflow
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // decodeVarint decodes a varint at the current position
 | |
| func (b *BufferReader) DecodeVarint() (x uint64, err error) {
 | |
| 	i := b.Pos
 | |
| 	buf := b.Buf
 | |
| 
 | |
| 	if i >= len(buf) {
 | |
| 		return 0, io.ErrUnexpectedEOF
 | |
| 	} else if buf[i] < 0x80 {
 | |
| 		b.Pos++
 | |
| 		return uint64(buf[i]), nil
 | |
| 	} else if len(buf)-i < 10 {
 | |
| 		return b.DecodeVarintSlow()
 | |
| 	}
 | |
| 
 | |
| 	var v uint64
 | |
| 	// we already checked the first byte
 | |
| 	x = uint64(buf[i]) & 127
 | |
| 	i++
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 7
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 14
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 21
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 28
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 35
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 42
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 49
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 56
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint64(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 63
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	return 0, errOverflow
 | |
| 
 | |
| done:
 | |
| 	b.Pos = i
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // decodeVarint32 decodes a varint32 at the current position
 | |
| func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
 | |
| 	i := b.Pos
 | |
| 	buf := b.Buf
 | |
| 
 | |
| 	if i >= len(buf) {
 | |
| 		return 0, io.ErrUnexpectedEOF
 | |
| 	} else if buf[i] < 0x80 {
 | |
| 		b.Pos++
 | |
| 		return uint32(buf[i]), nil
 | |
| 	} else if len(buf)-i < 5 {
 | |
| 		v, err := b.DecodeVarintSlow()
 | |
| 		return uint32(v), err
 | |
| 	}
 | |
| 
 | |
| 	var v uint32
 | |
| 	// we already checked the first byte
 | |
| 	x = uint32(buf[i]) & 127
 | |
| 	i++
 | |
| 
 | |
| 	v = uint32(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 7
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint32(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 14
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint32(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 21
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	v = uint32(buf[i])
 | |
| 	i++
 | |
| 	x |= (v & 127) << 28
 | |
| 	if v < 128 {
 | |
| 		goto done
 | |
| 	}
 | |
| 
 | |
| 	return 0, errOverflow
 | |
| 
 | |
| done:
 | |
| 	b.Pos = i
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // skipValue skips a value in the protobuf, based on the specified tag
 | |
| func (b *BufferReader) SkipValue(tag uint32) (err error) {
 | |
| 	wireType := tag & 0x7
 | |
| 	switch protowire.Type(wireType) {
 | |
| 	case protowire.VarintType:
 | |
| 		err = b.SkipVarint()
 | |
| 	case protowire.Fixed64Type:
 | |
| 		err = b.SkipFixed64()
 | |
| 	case protowire.BytesType:
 | |
| 		var n uint32
 | |
| 		n, err = b.DecodeVarint32()
 | |
| 		if err == nil {
 | |
| 			err = b.Skip(int(n))
 | |
| 		}
 | |
| 	case protowire.StartGroupType:
 | |
| 		err = b.SkipGroup(tag)
 | |
| 	case protowire.Fixed32Type:
 | |
| 		err = b.SkipFixed32()
 | |
| 	default:
 | |
| 		err = fmt.Errorf("Unexpected wire type (%d)", wireType)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // skipGroup skips a group with the specified tag.  It executes efficiently using a tag stack
 | |
| func (b *BufferReader) SkipGroup(tag uint32) (err error) {
 | |
| 	tagStack := make([]uint32, 0, 16)
 | |
| 	tagStack = append(tagStack, tag)
 | |
| 	var n uint32
 | |
| 	for len(tagStack) > 0 {
 | |
| 		tag, err = b.DecodeVarint32()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		switch protowire.Type(tag & 0x7) {
 | |
| 		case protowire.VarintType:
 | |
| 			err = b.SkipVarint()
 | |
| 		case protowire.Fixed64Type:
 | |
| 			err = b.Skip(8)
 | |
| 		case protowire.BytesType:
 | |
| 			n, err = b.DecodeVarint32()
 | |
| 			if err == nil {
 | |
| 				err = b.Skip(int(n))
 | |
| 			}
 | |
| 		case protowire.StartGroupType:
 | |
| 			tagStack = append(tagStack, tag)
 | |
| 		case protowire.Fixed32Type:
 | |
| 			err = b.SkipFixed32()
 | |
| 		case protowire.EndGroupType:
 | |
| 			if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
 | |
| 				tagStack = tagStack[:len(tagStack)-1]
 | |
| 			} else {
 | |
| 				err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
 | |
| 					protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
 | |
| 			}
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // skipVarint effiently skips a varint
 | |
| func (b *BufferReader) SkipVarint() (err error) {
 | |
| 	i := b.Pos
 | |
| 
 | |
| 	if len(b.Buf)-i < 10 {
 | |
| 		// Use DecodeVarintSlow() to check for buffer overflow, but ignore result
 | |
| 		if _, err := b.DecodeVarintSlow(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	i++
 | |
| 
 | |
| 	if b.Buf[i] < 0x80 {
 | |
| 		goto out
 | |
| 	}
 | |
| 	return errOverflow
 | |
| 
 | |
| out:
 | |
| 	b.Pos = i + 1
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // skip skips the specified number of bytes
 | |
| func (b *BufferReader) Skip(n int) (err error) {
 | |
| 	if len(b.Buf) < b.Pos+n {
 | |
| 		return io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	b.Pos += n
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // skipFixed64 skips a fixed64
 | |
| func (b *BufferReader) SkipFixed64() (err error) {
 | |
| 	return b.Skip(8)
 | |
| }
 | |
| 
 | |
| // skipFixed32 skips a fixed32
 | |
| func (b *BufferReader) SkipFixed32() (err error) {
 | |
| 	return b.Skip(4)
 | |
| }
 | |
| 
 | |
| // skipBytes skips a set of bytes
 | |
| func (b *BufferReader) SkipBytes() (err error) {
 | |
| 	n, err := b.DecodeVarint32()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return b.Skip(int(n))
 | |
| }
 | |
| 
 | |
| // Done returns whether we are at the end of the protobuf
 | |
| func (b *BufferReader) Done() bool {
 | |
| 	return b.Pos == len(b.Buf)
 | |
| }
 | |
| 
 | |
| // Remaining returns how many bytes remain
 | |
| func (b *BufferReader) Remaining() int {
 | |
| 	return len(b.Buf) - b.Pos
 | |
| }
 |