It'll be reused by the IRC side of things too. Change-Id: I3d84f3fd5fca6a6d948f331143b14f096d10675d Reviewed-on: https://cl.tvl.fyi/c/depot/+/342 Reviewed-by: tazjin <mail@tazj.in>
		
			
				
	
	
		
			252 lines
		
	
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Package gerrit implements a watcher for Gerrit events.
 | 
						|
package gerrit
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"code.tvl.fyi/fun/clbot/backoffutil"
 | 
						|
	"code.tvl.fyi/fun/clbot/gerrit/gerritevents"
 | 
						|
	log "github.com/golang/glog"
 | 
						|
	"golang.org/x/crypto/ssh"
 | 
						|
)
 | 
						|
 | 
						|
// closer provides an embeddable implementation of Close which awaits a main loop acknowledging it has stopped.
 | 
						|
type closer struct {
 | 
						|
	stop    chan struct{}
 | 
						|
	stopped chan struct{}
 | 
						|
}
 | 
						|
 | 
						|
// newCloser returns a closer with the channels initialised.
 | 
						|
func newCloser() closer {
 | 
						|
	return closer{
 | 
						|
		stop:    make(chan struct{}),
 | 
						|
		stopped: make(chan struct{}),
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Close stops the main loop, waiting for the main loop to stop until it stops or the context is cancelled, whichever happens first.
 | 
						|
func (c *closer) Close(ctx context.Context) error {
 | 
						|
	select {
 | 
						|
	case <-c.stopped:
 | 
						|
		return nil
 | 
						|
	case <-c.stop:
 | 
						|
		return nil
 | 
						|
	case <-ctx.Done():
 | 
						|
		return ctx.Err()
 | 
						|
	default:
 | 
						|
	}
 | 
						|
	close(c.stop)
 | 
						|
	select {
 | 
						|
	case <-c.stopped:
 | 
						|
		return nil
 | 
						|
	case <-ctx.Done():
 | 
						|
		return ctx.Err()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// lineWriter is an io.Writer which splits on \n and outputs each line (with no trailing newline) to its output channel.
 | 
						|
type lineWriter struct {
 | 
						|
	buf string
 | 
						|
	out chan string
 | 
						|
}
 | 
						|
 | 
						|
// Write accepts a slice of bytes containing zero or more new lines.
 | 
						|
// If the contained channel is non-buffering or is full, this will block.
 | 
						|
func (w *lineWriter) Write(p []byte) (n int, err error) {
 | 
						|
	w.buf += string(p)
 | 
						|
	pieces := strings.Split(w.buf, "\n")
 | 
						|
	w.buf = pieces[len(pieces)-1]
 | 
						|
	for n := 0; n < len(pieces)-1; n++ {
 | 
						|
		w.out <- pieces[n]
 | 
						|
	}
 | 
						|
	return len(p), nil
 | 
						|
}
 | 
						|
 | 
						|
// restartingClient is a simple SSH client that repeatedly connects to an SSH server, runs a command, and outputs the lines output by it on stdout onto a channel.
 | 
						|
type restartingClient struct {
 | 
						|
	closer
 | 
						|
 | 
						|
	network string
 | 
						|
	addr    string
 | 
						|
	cfg     *ssh.ClientConfig
 | 
						|
 | 
						|
	exec     string
 | 
						|
	output   chan string
 | 
						|
	shutdown func()
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	errStopConnect = errors.New("gerrit: told to stop reconnecting by remote server")
 | 
						|
)
 | 
						|
 | 
						|
func (c *restartingClient) runOnce() error {
 | 
						|
	netConn, err := net.Dial(c.network, c.addr)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("connecting to %v/%v: %w", c.network, c.addr, err)
 | 
						|
	}
 | 
						|
	defer netConn.Close()
 | 
						|
 | 
						|
	sshConn, newCh, newReq, err := ssh.NewClientConn(netConn, c.addr, c.cfg)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("creating SSH connection to %v/%v: %w", c.network, c.addr, err)
 | 
						|
	}
 | 
						|
	defer sshConn.Close()
 | 
						|
 | 
						|
	goAway := false
 | 
						|
	passedThroughReqs := make(chan *ssh.Request)
 | 
						|
	go func() {
 | 
						|
		defer close(passedThroughReqs)
 | 
						|
		for req := range newReq {
 | 
						|
			if req.Type == "goaway" {
 | 
						|
				goAway = true
 | 
						|
				log.Warningf("remote end %v/%v told me to go away!", c.network, c.addr)
 | 
						|
				sshConn.Close()
 | 
						|
				netConn.Close()
 | 
						|
			}
 | 
						|
			passedThroughReqs <- req
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	cl := ssh.NewClient(sshConn, newCh, passedThroughReqs)
 | 
						|
 | 
						|
	sess, err := cl.NewSession()
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("NewSession on %v/%v: %w", c.network, c.addr, err)
 | 
						|
	}
 | 
						|
	defer sess.Close()
 | 
						|
 | 
						|
	sess.Stdout = &lineWriter{out: c.output}
 | 
						|
 | 
						|
	if err := sess.Start(c.exec); err != nil {
 | 
						|
		return fmt.Errorf("Start(%q) on %v/%v: %w", c.exec, c.network, c.addr, err)
 | 
						|
	}
 | 
						|
 | 
						|
	log.Infof("connected to %v/%v", c.network, c.addr)
 | 
						|
 | 
						|
	done := make(chan struct{})
 | 
						|
	go func() {
 | 
						|
		sess.Wait()
 | 
						|
		close(done)
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		select {
 | 
						|
		case <-c.stop:
 | 
						|
			sess.Close()
 | 
						|
		case <-done:
 | 
						|
		}
 | 
						|
		return
 | 
						|
	}()
 | 
						|
	<-done
 | 
						|
 | 
						|
	if goAway {
 | 
						|
		return errStopConnect
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *restartingClient) run() {
 | 
						|
	defer close(c.stopped)
 | 
						|
	bo := backoffutil.NewDefaultBackOff()
 | 
						|
	for {
 | 
						|
		timer := time.NewTimer(bo.NextBackOff())
 | 
						|
		select {
 | 
						|
		case <-c.stop:
 | 
						|
			timer.Stop()
 | 
						|
			return
 | 
						|
		case <-timer.C:
 | 
						|
			break
 | 
						|
		}
 | 
						|
		if err := c.runOnce(); err == errStopConnect {
 | 
						|
			if c.shutdown != nil {
 | 
						|
				c.shutdown()
 | 
						|
				return
 | 
						|
			}
 | 
						|
		} else if err != nil {
 | 
						|
			log.Errorf("SSH: %v", err)
 | 
						|
		} else {
 | 
						|
			bo.Reset()
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Output returns the channel on which each newline-delimited string output by the executed command's stdout can be received.
 | 
						|
func (c *restartingClient) Output() <-chan string {
 | 
						|
	return c.output
 | 
						|
}
 | 
						|
 | 
						|
// dialRestartingClient creates a new restartingClient.
 | 
						|
func dialRestartingClient(network, addr string, config *ssh.ClientConfig, exec string, shutdown func()) (*restartingClient, error) {
 | 
						|
	c := &restartingClient{
 | 
						|
		closer:   newCloser(),
 | 
						|
		network:  network,
 | 
						|
		addr:     addr,
 | 
						|
		cfg:      config,
 | 
						|
		exec:     exec,
 | 
						|
		output:   make(chan string),
 | 
						|
		shutdown: shutdown,
 | 
						|
	}
 | 
						|
	go c.run()
 | 
						|
	return c, nil
 | 
						|
}
 | 
						|
 | 
						|
// Watcher watches
 | 
						|
type Watcher struct {
 | 
						|
	closer
 | 
						|
	c *restartingClient
 | 
						|
 | 
						|
	output chan gerritevents.Event
 | 
						|
}
 | 
						|
 | 
						|
// Close shuts down the SSH client connection, if any, and closes the output channel.
 | 
						|
// It blocks until shutdown is complete or until the context is cancelled, whichever comes first.
 | 
						|
func (w *Watcher) Close(ctx context.Context) {
 | 
						|
	w.c.Close(ctx)
 | 
						|
	w.closer.Close(ctx)
 | 
						|
}
 | 
						|
 | 
						|
func (w *Watcher) run() {
 | 
						|
	defer close(w.stopped)
 | 
						|
	defer close(w.output)
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-w.stop:
 | 
						|
			return
 | 
						|
		case o := <-w.c.Output():
 | 
						|
			ev, err := gerritevents.Parse([]byte(o))
 | 
						|
			if err != nil {
 | 
						|
				log.Errorf("failed to parse event %v: %v", o, err)
 | 
						|
				continue
 | 
						|
			}
 | 
						|
			w.output <- ev
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Events returns the channel upon which parsed Gerrit events can be received.
 | 
						|
func (w *Watcher) Events() <-chan gerritevents.Event {
 | 
						|
	return w.output
 | 
						|
}
 | 
						|
 | 
						|
// New returns a running Watcher from which events can be read.
 | 
						|
// It will begin connecting to the provided address immediately.
 | 
						|
func New(ctx context.Context, network, addr string, cfg *ssh.ClientConfig) (*Watcher, error) {
 | 
						|
	wc := newCloser()
 | 
						|
	rc, err := dialRestartingClient(network, addr, cfg, "gerrit stream-events", func() {
 | 
						|
		wc.Close(context.Background())
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("dialRestartingClient: %w", err)
 | 
						|
	}
 | 
						|
	w := &Watcher{
 | 
						|
		closer: wc,
 | 
						|
		c:      rc,
 | 
						|
		output: make(chan gerritevents.Event),
 | 
						|
	}
 | 
						|
	go w.run()
 | 
						|
	return w, nil
 | 
						|
}
 |