forked from toolshed/abra
		
	
		
			
				
	
	
		
			223 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			223 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package packet
 | |
| 
 | |
| // This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub)
 | |
| // to verify pgp packet sequences. See Paul's blogpost for more details:
 | |
| // https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/
 | |
| import (
 | |
| 	"fmt"
 | |
| 
 | |
| 	"github.com/ProtonMail/go-crypto/openpgp/errors"
 | |
| )
 | |
| 
 | |
| func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage {
 | |
| 	return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol))
 | |
| }
 | |
| 
 | |
| // InputSymbol defines the input alphabet of the PDA
 | |
| type InputSymbol uint8
 | |
| 
 | |
| const (
 | |
| 	LDSymbol InputSymbol = iota
 | |
| 	SigSymbol
 | |
| 	OPSSymbol
 | |
| 	CompSymbol
 | |
| 	ESKSymbol
 | |
| 	EncSymbol
 | |
| 	EOSSymbol
 | |
| 	UnknownSymbol
 | |
| )
 | |
| 
 | |
| // StackSymbol defines the stack alphabet of the PDA
 | |
| type StackSymbol int8
 | |
| 
 | |
| const (
 | |
| 	MsgStackSymbol StackSymbol = iota
 | |
| 	OpsStackSymbol
 | |
| 	KeyStackSymbol
 | |
| 	EndStackSymbol
 | |
| 	EmptyStackSymbol
 | |
| )
 | |
| 
 | |
| // State defines the states of the PDA
 | |
| type State int8
 | |
| 
 | |
| const (
 | |
| 	OpenPGPMessage State = iota
 | |
| 	ESKMessage
 | |
| 	LiteralMessage
 | |
| 	CompressedMessage
 | |
| 	EncryptedMessage
 | |
| 	ValidMessage
 | |
| )
 | |
| 
 | |
| // transition represents a state transition in the PDA
 | |
| type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error)
 | |
| 
 | |
| // SequenceVerifier is a pushdown automata to verify
 | |
| // PGP messages packet sequences according to rfc4880.
 | |
| type SequenceVerifier struct {
 | |
| 	stack []StackSymbol
 | |
| 	state State
 | |
| }
 | |
| 
 | |
| // Next performs a state transition with the given input symbol.
 | |
| // If the transition fails a ErrMalformedMessage is returned.
 | |
| func (sv *SequenceVerifier) Next(input InputSymbol) error {
 | |
| 	for {
 | |
| 		stackSymbol := sv.popStack()
 | |
| 		transitionFunc := getTransition(sv.state)
 | |
| 		nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		if redo {
 | |
| 			sv.pushStack(stackSymbol)
 | |
| 		}
 | |
| 		for _, newStackSymbol := range newStackSymbols {
 | |
| 			sv.pushStack(newStackSymbol)
 | |
| 		}
 | |
| 		sv.state = nextState
 | |
| 		if !redo {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Valid returns true if RDA is in a valid state.
 | |
| func (sv *SequenceVerifier) Valid() bool {
 | |
| 	return sv.state == ValidMessage && len(sv.stack) == 0
 | |
| }
 | |
| 
 | |
| func (sv *SequenceVerifier) AssertValid() error {
 | |
| 	if !sv.Valid() {
 | |
| 		return errors.ErrMalformedMessage("invalid message")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NewSequenceVerifier() *SequenceVerifier {
 | |
| 	return &SequenceVerifier{
 | |
| 		stack: []StackSymbol{EndStackSymbol, MsgStackSymbol},
 | |
| 		state: OpenPGPMessage,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (sv *SequenceVerifier) popStack() StackSymbol {
 | |
| 	if len(sv.stack) == 0 {
 | |
| 		return EmptyStackSymbol
 | |
| 	}
 | |
| 	elemIndex := len(sv.stack) - 1
 | |
| 	stackSymbol := sv.stack[elemIndex]
 | |
| 	sv.stack = sv.stack[:elemIndex]
 | |
| 	return stackSymbol
 | |
| }
 | |
| 
 | |
| func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) {
 | |
| 	sv.stack = append(sv.stack, stackSymbol)
 | |
| }
 | |
| 
 | |
| func getTransition(from State) transition {
 | |
| 	switch from {
 | |
| 	case OpenPGPMessage:
 | |
| 		return fromOpenPGPMessage
 | |
| 	case LiteralMessage:
 | |
| 		return fromLiteralMessage
 | |
| 	case CompressedMessage:
 | |
| 		return fromCompressedMessage
 | |
| 	case EncryptedMessage:
 | |
| 		return fromEncryptedMessage
 | |
| 	case ESKMessage:
 | |
| 		return fromESKMessage
 | |
| 	case ValidMessage:
 | |
| 		return fromValidMessage
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // fromOpenPGPMessage is the transition for the state OpenPGPMessage.
 | |
| func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	if stackSymbol != MsgStackSymbol {
 | |
| 		return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
 | |
| 	}
 | |
| 	switch input {
 | |
| 	case LDSymbol:
 | |
| 		return LiteralMessage, nil, false, nil
 | |
| 	case SigSymbol:
 | |
| 		return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil
 | |
| 	case OPSSymbol:
 | |
| 		return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil
 | |
| 	case CompSymbol:
 | |
| 		return CompressedMessage, nil, false, nil
 | |
| 	case ESKSymbol:
 | |
| 		return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
 | |
| 	case EncSymbol:
 | |
| 		return EncryptedMessage, nil, false, nil
 | |
| 	}
 | |
| 	return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
 | |
| }
 | |
| 
 | |
| // fromESKMessage is the transition for the state ESKMessage.
 | |
| func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	if stackSymbol != KeyStackSymbol {
 | |
| 		return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
 | |
| 	}
 | |
| 	switch input {
 | |
| 	case ESKSymbol:
 | |
| 		return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
 | |
| 	case EncSymbol:
 | |
| 		return EncryptedMessage, nil, false, nil
 | |
| 	}
 | |
| 	return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
 | |
| }
 | |
| 
 | |
| // fromLiteralMessage is the transition for the state LiteralMessage.
 | |
| func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	switch input {
 | |
| 	case SigSymbol:
 | |
| 		if stackSymbol == OpsStackSymbol {
 | |
| 			return LiteralMessage, nil, false, nil
 | |
| 		}
 | |
| 	case EOSSymbol:
 | |
| 		if stackSymbol == EndStackSymbol {
 | |
| 			return ValidMessage, nil, false, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol)
 | |
| }
 | |
| 
 | |
| // fromLiteralMessage is the transition for the state CompressedMessage.
 | |
| func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	switch input {
 | |
| 	case SigSymbol:
 | |
| 		if stackSymbol == OpsStackSymbol {
 | |
| 			return CompressedMessage, nil, false, nil
 | |
| 		}
 | |
| 	case EOSSymbol:
 | |
| 		if stackSymbol == EndStackSymbol {
 | |
| 			return ValidMessage, nil, false, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
 | |
| }
 | |
| 
 | |
| // fromEncryptedMessage is the transition for the state EncryptedMessage.
 | |
| func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	switch input {
 | |
| 	case SigSymbol:
 | |
| 		if stackSymbol == OpsStackSymbol {
 | |
| 			return EncryptedMessage, nil, false, nil
 | |
| 		}
 | |
| 	case EOSSymbol:
 | |
| 		if stackSymbol == EndStackSymbol {
 | |
| 			return ValidMessage, nil, false, nil
 | |
| 		}
 | |
| 	}
 | |
| 	return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
 | |
| }
 | |
| 
 | |
| // fromValidMessage is the transition for the state ValidMessage.
 | |
| func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
 | |
| 	return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol)
 | |
| }
 |