/* $Id: socks.c,v 1.1 2007-01-05 23:56:49 gophi Exp $ */

#include <unistd.h>
#include <stdlib.h>
#include <signal.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <netdb.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/select.h>
#include <sys/types.h>

extern int errno, h_errno;

#define die(...) do { fprintf(stderr, "fatal[%d][%d]: ", getpid(), __LINE__); fprintf(stderr, __VA_ARGS__); fprintf(stderr, "\n"); exit(EXIT_FAILURE); } while(0)
#define debug(...) do { fprintf(stderr, "debug[%d][%d]: ", getpid(), __LINE__); fprintf(stderr, __VA_ARGS__); fprintf(stderr, "\n"); } while(0)

static in_port_t make_port(const char *str)
{
	struct servent *ent;

	ent = getservbyname(str, "tcp");

	return ent ? ent->s_port : htons(atoi(str));
}

static in_addr_t make_addr(const char *str)
{
	struct hostent *hent;
	struct in_addr in_address;

	hent = gethostbyname(str);
	if (!hent)
		die("cannot resolve %s: %s", str, hstrerror(h_errno));

	memcpy((char *) &in_address, hent->h_addr, sizeof(in_address));

	return in_address.s_addr;
}

static void socks5_auth(int fd)
{
	unsigned char buf[] = "\x05\x01\x00";	/* socks 5, no auth */

	if (send(fd, buf, 3, 0) != 3)
		die("send() failed");

	if (recv(fd, buf, 2, 0) != 2)
		die("recv() failed");

	if (buf[0] != 0x05 && buf[1] != 0x00)
		die("socks5: auth failed: ver=%u reply=%u (see rfc1928)", buf[0], buf[1]);
}

static void socks5_relay(int fd, struct sockaddr_in *dest)
{
	unsigned char buf[10];

	buf[0] = 0x05;	/* socks 5*/
	buf[1] = 0x01;	/* connect */
	buf[2] = 0x00;	/* reserved */
	buf[3] = 0x01;	/* ipv4 */

	memcpy(buf + 4, &dest->sin_addr.s_addr, 4);
	memcpy(buf + 8, &dest->sin_port, 2);

	if (send(fd, buf, 10, 0) != 10)
		die("send() failed");

	if (recv(fd, buf, 4, 0) != 4)
		die("recv() failed");

	if (buf[0] != 0x05 && buf[1] != 0x00 && buf[3] != 0x01)
		die("socks5: relay failed: ver=%u reply=%u res=%u atyp=%u (see rfc1928)", buf[0], buf[1], buf[2], buf[3]);

	if (recv(fd, buf, 6, 0) != 6)
		die("recv() failed");
}

static int tunnel(int sfd, int dfd)
{
	char buf[1024], *p = buf;
	ssize_t rs;

	rs = recv(sfd, buf, sizeof(buf), 0);
	if (rs == -1 && (errno == EAGAIN || errno == EINTR))
		return 1;
	else if (rs == 0)
		return 0;

	while (rs) {
		ssize_t srs = send(dfd, p, rs, 0);

		if (srs == -1 && (errno == EAGAIN || errno == EINTR))
			continue;
		else if (srs == -1)
			die("send() error: %s", strerror(errno));
		else if (srs == 0)
			die("send() socket closed");

		p += srs;
		rs -= srs;
	}

	return 1;
}

static void make_tunnel(int cfd, int pfd)
{
	fd_set fds;
	int max = ((cfd > pfd) ? cfd : pfd) + 1;

	for (;;) {
		int rs;

		FD_ZERO(&fds);
		FD_SET(cfd, &fds);
		FD_SET(pfd, &fds);

		if ((rs = select(max, &fds, NULL, NULL, NULL)) == 0)
			continue;

		if (rs == -1 && (errno == EAGAIN || errno == EINTR))
			continue;

		if (FD_ISSET(cfd, &fds) && !tunnel(cfd, pfd))
			die("client closed connection");

		if (FD_ISSET(pfd, &fds) && !tunnel(pfd, cfd))
			die("proxy closed connection");
	}
}

static void handle(int cfd, struct sockaddr_in *proxy, struct sockaddr_in *dest)
{
	int pfd;

	pfd = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (pfd == -1)
		die("socket(): %s", strerror(errno));

	debug("connecting proxy");

	if (connect(pfd, (struct sockaddr *) proxy, sizeof(*proxy)) == -1)
		die("connect(): %s", strerror(errno));

	debug("socks5 authentication");
	socks5_auth(pfd);

	debug("socks5 relay");
	socks5_relay(pfd, dest);

	debug("starting relay");
	make_tunnel(cfd, pfd);
}

int main(int ac, char * const av[])
{
	struct sockaddr_in proxy, dest, local;
	int fd, so;

	if (ac < 6)
		die("usage: socks localport proxyhost proxyport desthost destport");

	local.sin_addr.s_addr = inet_addr("127.0.0.1");
	local.sin_port = make_port(av[1]);
	local.sin_family = AF_INET;

	proxy.sin_addr.s_addr = make_addr(av[2]);
	proxy.sin_port = make_port(av[3]);
	proxy.sin_family = AF_INET;

	dest.sin_addr.s_addr = make_addr(av[4]);
	dest.sin_port = make_port(av[5]);
	dest.sin_family = AF_INET;

	debug("local: host=%s port=%u", inet_ntoa(local.sin_addr), ntohs(local.sin_port));
	debug("proxy: host=%s port=%u", inet_ntoa(proxy.sin_addr), ntohs(proxy.sin_port));
	debug("dest:  host=%s port=%u", inet_ntoa(dest.sin_addr), ntohs(dest.sin_port));

	fd = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (fd == -1)
		die("socket(): %s", strerror(errno));

	so = 1;
	if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &so, sizeof(so)) == -1)
		die("setsockopt(SO_REUSEADDR): %s", strerror(errno));

	if (bind(fd, (struct sockaddr *) &local, sizeof(local)) == -1)
		die("bind(): %s", strerror(errno));

	if (listen(fd, 10) == -1)
		die("listen(): %s", strerror(errno));

	signal(SIGCHLD, SIG_IGN);

	debug("entering listen loop");

	for (;;) {
		int cfd;
		pid_t pid;

		cfd = accept(fd, NULL, NULL);
		if (cfd == -1 && (errno == EAGAIN || errno == EINTR))
			continue;
		else if (cfd == -1)
			die("accept(): %s", strerror(errno));

		if ((pid = fork()) == -1)
			die("fork(): %s", strerror(errno));
		else if (!pid) {
			close(fd);
			handle(cfd, &proxy, &dest);
			close(cfd);
			exit(EXIT_SUCCESS);
		}

		close(cfd);
	}

	return 0;
}
