SSH port forwarding with Go

SSH port forwarding is a common practice to make connections to services that could not be exposed directly to the public internet. If you're running your own database server, most likely you have a strict firewall rule in place that only allows connects from a known IP addresses and a big chance that the only publicly exposed port on your machine(s) is for 22/ssh.

Example

To start forwarding ports you can use ssh command:

ssh -Ng -L 5000:localhost:5432 user@myapp.com

That will start server on localhost:5000 and forward connection to localhost:5432 on myapp.com machine. Flag description (per man page):

  • -N - Do not execute a remote command. This is useful for just forwarding ports
  • -g - Allows remote hosts to connect to local forwarded ports.
  • -L - Specifies that the given port on the local (client) host is to be forwarded to the given host and port on the remote side. Format: [bind_address:]port:host:hostport

Example above was modeled for usage with PostgreSQL. Here's how you can start a standard psql console:

psql postgres://user:password@127.0.0.0.1:5000/database

Implementation

Go standard library has plenty of packages, but unfortunately it does not provide package to deal with SSH stuff. There's a "third-party" package crypto/ssh maintained by Google (docs):

go get golang.org/x/crypto/ssh

Implementing ssh port forwarding programmatically takes a few steps:

  • Establish SSH connection with remote server user pubkey or password authentication
  • Make connection to the target ip:port from SSH connection
  • Start a local server on port
  • Accept local connections and forward data to the remote connection

Here's simplified code that does just that:

package main

import (
  "io"
  "io/ioutil"
  "log"
  "net"
  "os"

  "golang.org/x/crypto/ssh"
)

// Get default location of a private key
func privateKeyPath() string {
  return os.Getenv("HOME") + "/.ssh/id_rsa"
}

// Get private key for ssh authentication
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
  buff, _ := ioutil.ReadFile(keyPath)
  return ssh.ParsePrivateKey(buff)
}

// Get ssh client config for our connection
// SSH config will use 2 authentication strategies: by key and by password
func makeSshConfig(user, password string) (*ssh.ClientConfig, error) {
  key, err := parsePrivateKey(privateKeyPath())
  if err != nil {
    return nil, err
  }

  config := ssh.ClientConfig{
    User: user,
    Auth: []ssh.AuthMethod{
      ssh.PublicKeys(key),
      ssh.Password(password),
    },
  }

  return &config, nil
}

// Handle local client connections and tunnel data to the remote serverq
// Will use io.Copy - http://golang.org/pkg/io/#Copy
func handleClient(client net.Conn, remote net.Conn) {
  defer client.Close()
  chDone := make(chan bool)

  // Start remote -> local data transfer
  go func() {
    _, err := io.Copy(client, remote)
    if err != nil {
      log.Println("error while copy remote->local:", err)
    }
    chDone <- true
  }()

  // Start local -> remote data transfer
  go func() {
    _, err := io.Copy(remote, client)
    if err != nil {
      log.Println(err)
    }
    chDone <- true
  }()

  <-chDone
}

func main() {
  // Connection settings
  sshAddr := "remote_ip:22"
  localAddr := "127.0.0.1:5000"
  remoteAddr := "127.0.0.1:5432"

  // Build SSH client configuration
  cfg, err := makeSshConfig("user", "password")
  if err != nil {
    log.Fatalln(err)
  }

  // Establish connection with SSH server
  conn, err := ssh.Dial("tcp", sshAddr, cfg)
  if err != nil {
    log.Fatalln(err)
  }
  defer conn.Close()

  // Establish connection with remote server
  remote, err := conn.Dial("tcp", remoteAddr)
  if err != nil {
    log.Fatalln(err)
  }

  // Start local server to forward traffic to remote connection
  local, err := net.Listen("tcp", localAddr)
  if err != nil {
    log.Fatalln(err)
  }
  defer local.Close()

  // Handle incoming connections
  for {
    client, err := local.Accept()
    if err != nil {
      log.Fatalln(err)
    }

    handleClient(client, remote)
  }
}

The code above does not need much explanation except for the io.Copy(dst, src) call which does all the magic. It copies from src to dst until either EOF is reached on src or an error occurs.

Example works, however there are few issues with it: a) it does not handle concurrency well and b) is not stable as using just ssh command. I have to dig into the problem a bit more to fully understand what's happening. To clarify, concurrency issue only appears when using io.Copy via SSH connection, using it as local port forwarded works just fine.