It'll be reused by the IRC side of things too. Change-Id: I3d84f3fd5fca6a6d948f331143b14f096d10675d Reviewed-on: https://cl.tvl.fyi/c/depot/+/342 Reviewed-by: tazjin <mail@tazj.in>
		
			
				
	
	
		
			252 lines
		
	
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package gerrit implements a watcher for Gerrit events.
 | |
| package gerrit
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"code.tvl.fyi/fun/clbot/backoffutil"
 | |
| 	"code.tvl.fyi/fun/clbot/gerrit/gerritevents"
 | |
| 	log "github.com/golang/glog"
 | |
| 	"golang.org/x/crypto/ssh"
 | |
| )
 | |
| 
 | |
| // closer provides an embeddable implementation of Close which awaits a main loop acknowledging it has stopped.
 | |
| type closer struct {
 | |
| 	stop    chan struct{}
 | |
| 	stopped chan struct{}
 | |
| }
 | |
| 
 | |
| // newCloser returns a closer with the channels initialised.
 | |
| func newCloser() closer {
 | |
| 	return closer{
 | |
| 		stop:    make(chan struct{}),
 | |
| 		stopped: make(chan struct{}),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Close stops the main loop, waiting for the main loop to stop until it stops or the context is cancelled, whichever happens first.
 | |
| func (c *closer) Close(ctx context.Context) error {
 | |
| 	select {
 | |
| 	case <-c.stopped:
 | |
| 		return nil
 | |
| 	case <-c.stop:
 | |
| 		return nil
 | |
| 	case <-ctx.Done():
 | |
| 		return ctx.Err()
 | |
| 	default:
 | |
| 	}
 | |
| 	close(c.stop)
 | |
| 	select {
 | |
| 	case <-c.stopped:
 | |
| 		return nil
 | |
| 	case <-ctx.Done():
 | |
| 		return ctx.Err()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // lineWriter is an io.Writer which splits on \n and outputs each line (with no trailing newline) to its output channel.
 | |
| type lineWriter struct {
 | |
| 	buf string
 | |
| 	out chan string
 | |
| }
 | |
| 
 | |
| // Write accepts a slice of bytes containing zero or more new lines.
 | |
| // If the contained channel is non-buffering or is full, this will block.
 | |
| func (w *lineWriter) Write(p []byte) (n int, err error) {
 | |
| 	w.buf += string(p)
 | |
| 	pieces := strings.Split(w.buf, "\n")
 | |
| 	w.buf = pieces[len(pieces)-1]
 | |
| 	for n := 0; n < len(pieces)-1; n++ {
 | |
| 		w.out <- pieces[n]
 | |
| 	}
 | |
| 	return len(p), nil
 | |
| }
 | |
| 
 | |
| // restartingClient is a simple SSH client that repeatedly connects to an SSH server, runs a command, and outputs the lines output by it on stdout onto a channel.
 | |
| type restartingClient struct {
 | |
| 	closer
 | |
| 
 | |
| 	network string
 | |
| 	addr    string
 | |
| 	cfg     *ssh.ClientConfig
 | |
| 
 | |
| 	exec     string
 | |
| 	output   chan string
 | |
| 	shutdown func()
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	errStopConnect = errors.New("gerrit: told to stop reconnecting by remote server")
 | |
| )
 | |
| 
 | |
| func (c *restartingClient) runOnce() error {
 | |
| 	netConn, err := net.Dial(c.network, c.addr)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("connecting to %v/%v: %w", c.network, c.addr, err)
 | |
| 	}
 | |
| 	defer netConn.Close()
 | |
| 
 | |
| 	sshConn, newCh, newReq, err := ssh.NewClientConn(netConn, c.addr, c.cfg)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("creating SSH connection to %v/%v: %w", c.network, c.addr, err)
 | |
| 	}
 | |
| 	defer sshConn.Close()
 | |
| 
 | |
| 	goAway := false
 | |
| 	passedThroughReqs := make(chan *ssh.Request)
 | |
| 	go func() {
 | |
| 		defer close(passedThroughReqs)
 | |
| 		for req := range newReq {
 | |
| 			if req.Type == "goaway" {
 | |
| 				goAway = true
 | |
| 				log.Warningf("remote end %v/%v told me to go away!", c.network, c.addr)
 | |
| 				sshConn.Close()
 | |
| 				netConn.Close()
 | |
| 			}
 | |
| 			passedThroughReqs <- req
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	cl := ssh.NewClient(sshConn, newCh, passedThroughReqs)
 | |
| 
 | |
| 	sess, err := cl.NewSession()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("NewSession on %v/%v: %w", c.network, c.addr, err)
 | |
| 	}
 | |
| 	defer sess.Close()
 | |
| 
 | |
| 	sess.Stdout = &lineWriter{out: c.output}
 | |
| 
 | |
| 	if err := sess.Start(c.exec); err != nil {
 | |
| 		return fmt.Errorf("Start(%q) on %v/%v: %w", c.exec, c.network, c.addr, err)
 | |
| 	}
 | |
| 
 | |
| 	log.Infof("connected to %v/%v", c.network, c.addr)
 | |
| 
 | |
| 	done := make(chan struct{})
 | |
| 	go func() {
 | |
| 		sess.Wait()
 | |
| 		close(done)
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		select {
 | |
| 		case <-c.stop:
 | |
| 			sess.Close()
 | |
| 		case <-done:
 | |
| 		}
 | |
| 		return
 | |
| 	}()
 | |
| 	<-done
 | |
| 
 | |
| 	if goAway {
 | |
| 		return errStopConnect
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *restartingClient) run() {
 | |
| 	defer close(c.stopped)
 | |
| 	bo := backoffutil.NewDefaultBackOff()
 | |
| 	for {
 | |
| 		timer := time.NewTimer(bo.NextBackOff())
 | |
| 		select {
 | |
| 		case <-c.stop:
 | |
| 			timer.Stop()
 | |
| 			return
 | |
| 		case <-timer.C:
 | |
| 			break
 | |
| 		}
 | |
| 		if err := c.runOnce(); err == errStopConnect {
 | |
| 			if c.shutdown != nil {
 | |
| 				c.shutdown()
 | |
| 				return
 | |
| 			}
 | |
| 		} else if err != nil {
 | |
| 			log.Errorf("SSH: %v", err)
 | |
| 		} else {
 | |
| 			bo.Reset()
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Output returns the channel on which each newline-delimited string output by the executed command's stdout can be received.
 | |
| func (c *restartingClient) Output() <-chan string {
 | |
| 	return c.output
 | |
| }
 | |
| 
 | |
| // dialRestartingClient creates a new restartingClient.
 | |
| func dialRestartingClient(network, addr string, config *ssh.ClientConfig, exec string, shutdown func()) (*restartingClient, error) {
 | |
| 	c := &restartingClient{
 | |
| 		closer:   newCloser(),
 | |
| 		network:  network,
 | |
| 		addr:     addr,
 | |
| 		cfg:      config,
 | |
| 		exec:     exec,
 | |
| 		output:   make(chan string),
 | |
| 		shutdown: shutdown,
 | |
| 	}
 | |
| 	go c.run()
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| // Watcher watches
 | |
| type Watcher struct {
 | |
| 	closer
 | |
| 	c *restartingClient
 | |
| 
 | |
| 	output chan gerritevents.Event
 | |
| }
 | |
| 
 | |
| // Close shuts down the SSH client connection, if any, and closes the output channel.
 | |
| // It blocks until shutdown is complete or until the context is cancelled, whichever comes first.
 | |
| func (w *Watcher) Close(ctx context.Context) {
 | |
| 	w.c.Close(ctx)
 | |
| 	w.closer.Close(ctx)
 | |
| }
 | |
| 
 | |
| func (w *Watcher) run() {
 | |
| 	defer close(w.stopped)
 | |
| 	defer close(w.output)
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-w.stop:
 | |
| 			return
 | |
| 		case o := <-w.c.Output():
 | |
| 			ev, err := gerritevents.Parse([]byte(o))
 | |
| 			if err != nil {
 | |
| 				log.Errorf("failed to parse event %v: %v", o, err)
 | |
| 				continue
 | |
| 			}
 | |
| 			w.output <- ev
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Events returns the channel upon which parsed Gerrit events can be received.
 | |
| func (w *Watcher) Events() <-chan gerritevents.Event {
 | |
| 	return w.output
 | |
| }
 | |
| 
 | |
| // New returns a running Watcher from which events can be read.
 | |
| // It will begin connecting to the provided address immediately.
 | |
| func New(ctx context.Context, network, addr string, cfg *ssh.ClientConfig) (*Watcher, error) {
 | |
| 	wc := newCloser()
 | |
| 	rc, err := dialRestartingClient(network, addr, cfg, "gerrit stream-events", func() {
 | |
| 		wc.Close(context.Background())
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("dialRestartingClient: %w", err)
 | |
| 	}
 | |
| 	w := &Watcher{
 | |
| 		closer: wc,
 | |
| 		c:      rc,
 | |
| 		output: make(chan gerritevents.Event),
 | |
| 	}
 | |
| 	go w.run()
 | |
| 	return w, nil
 | |
| }
 |