/*
** multi irc proxy
**
** (c) 2004 gophi
*/

#include <signal.h>
#include <syslog.h>
#include <stdio.h>
#include <errno.h>
#include <netdb.h>
#include <unistd.h>
#include <pwd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <stdarg.h>
#include <time.h>

#define AUTH_PORT 113
#define BIND_PORT 6667
#define MAX_FORKS 10
#define AUTH_TIMEOUT 10
#define LISTEN_QUEUE 10

struct user_type {
	char *user;	/* na podstawie tego user@host użytkownik będzie identyfikowany */
	char *host;
	char *vuser;	/* ten user@host będzie wysyłany do serwera */
	char *vhost;
	char *server;	/* ...do tego serwera */
	int port;	/* na ten port */
	int server_family;	/* tak samo dla vhosta i serwera */
} static users[] = {
	"gophi", "apcoh.org", "gophi", "chce.byc.jak.shog.pl", "warszawa6.irc.pl", 6661, AF_INET6,
	"yodar", "apcoh.org", "yodar", "yodar.gophi.net", "warszawa6.irc.pl", 6662, AF_INET6,
	"tree", "apcoh.org", "tree", "tree.gophi.net", "warszawa6.irc.pl", 6663, AF_INET6,
	"cinq", "apcoh.org", "cinq", "cinq.gophi.net", "warszawa6.irc.pl", 6664, AF_INET6,
	"regis", "apcoh.org", "regis", "regis.gophi.net", "krakow6.irc.pl", 6665, AF_INET6,
//	"jceel", "82.160.8.131", "jceel", "jceel.gophi.net", "warszawa6.irc.pl", 6664, AF_INET6,
//	"-", "80.53.192.2", "jceel", "jceel.gophi.net", "warszawa6.irc.pl", 6665, AF_INET6,
//	"jceel", "192.168.6.4", "jceel", "jceel.gophi.net", "warszawa6.irc.pl", 6664, AF_INET6,
	"jceel", "192.168.6.4", "jceel", "chce.byc.jak.gophi.net", "warszawa6.irc.pl", 6664, AF_INET6,
};

#define NUM_USERS (sizeof(users) / sizeof(struct user_type))

struct userdata_type {
	struct in_addr host;
	struct addrinfo *vhost, *server;
} userdata[NUM_USERS];

extern int errno;
volatile int num_forks = 0;
volatile char flag_alrm = 0;
int bind_fd;

void debug (const char *fmt, ...)
{
	va_list args;
	char buf[512];

	va_start (args, fmt);
	vsnprintf (buf, sizeof(buf), fmt, args);
	va_end (args);

	fprintf (stderr, "debug[%u]: %s\n", getpid(), buf);
}

void atomic_write (int fd, char *buf, int len)
{
	int pos = 0;
	int written;

	while (len) {
		written = write(fd, buf + pos, len);
		if (written < 0) {
			debug ("atomic_write(): write(): %m.");
			_exit (1);
		}
		pos += written;
		len -= written;
	}
}

void sockprintf (int fd, const char *fmt, ...)
{
	va_list args;
	char buf[512];
	int num;

	va_start (args, fmt);
	num = vsnprintf(buf, sizeof(buf), fmt, args);
	va_end (args);

	atomic_write (fd, buf, num);
}

void kill_conn (int fd, const char *fmt, ...)
{
	va_list args;
	char buf[512];

	va_start (args, fmt);
	vsnprintf (buf, sizeof(buf), fmt, args);
	va_end (args);

	fprintf (stderr, "debug[%u]: kill_conn: %s\r\n", getpid(), buf);
	sockprintf (fd, "ERROR :%s\r\n", buf);

	close (fd);
}

void send_notice (int fd, char *type, const char *fmt, ...)
{
	va_list args;
	char buf[512];

	va_start (args, fmt);
	vsnprintf (buf, sizeof(buf), fmt, args);
	va_end (args);

	fprintf (stderr, "debug[%u]: send_notice: %s\r\n", getpid(), buf);
	sockprintf (fd, ":irc.gophi.net NOTICE %s :*** %s\r\n", type, buf);
}

void hnd_term (int signo)
{
	debug ("odebrano sygnał %s, kończenie pracy.", strsignal(signo));

	if (bind_fd)
		close (bind_fd);

	_exit (0);
}

void hnd_chld (int signo)
{
	debug ("odebrano sygnał %s, odczytywanie stanu procesu potomnego.", strsignal(signo));

	while (waitpid(-1, NULL, WNOHANG) > 0);

	if (num_forks)
		num_forks--;

	signal (SIGCHLD, hnd_chld);

	debug ("obsługa sygnału %s zakończona pomyślnie.", strsignal(signo));
}

void hnd_segv (int signo)
{
	debug ("odebrano sygnał %s, zrzucanie obrazu procesu na dysk.", strsignal(signo));

	signal (SIGSEGV, SIG_DFL);
	raise (SIGSEGV);
}

void hnd_alrm (int signo)
{
	debug ("odebrano sygnał %s, koniec czasu oczekiwania.", strsignal(signo));
	flag_alrm = 1;
}

void resolve_host4 (char *hostname, struct in_addr *addr)
{
	struct hostent *hent;

	debug ("translacja nazwy %s.", hostname);

	hent = gethostbyname(hostname);
	if (!hent) {
		debug ("gethostbyname(): %m.");
		_exit (1);
	}

	memcpy ((char *) addr, hent->h_addr, sizeof(struct in_addr));

	debug ("%s ma adres %s.", hostname, inet_ntoa(*addr));
}

void resolve_host (char *hostname, struct addrinfo **addr, int port, int family)
{
	struct addrinfo hints;
	char strport[10];
	char ntop[NI_MAXHOST];
	int err;

	debug ("translacja nazwy %s.", hostname);

	bzero (&hints, sizeof(hints));
	hints.ai_family = family;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;
	sprintf (strport, "%u", port);

	err = getaddrinfo(hostname, strport, &hints, addr);
	if (err) {
		debug ("getaddrinfo(): %s.", gai_strerror(err));
		_exit (1);
	}

	if (getnameinfo((*addr)->ai_addr, (*addr)->ai_addrlen, ntop, sizeof(ntop), NULL, (size_t) NULL, NI_NUMERICHOST)) {
		debug ("getnameinfo(): %m.");
		_exit (1);
	}

	debug ("%s ma adres %s.", hostname, ntop);
}

void load_userdata (void)
{
	int i;

	debug ("ładowanie danych dla użytkowników.");

	for (i = 0; i < NUM_USERS; i++) {
		debug ("ładowanie danych dla użytkownika %s@%s.", users[i].user, users[i].host);
		resolve_host4 (users[i].host, &userdata[i].host);
		resolve_host (users[i].vhost, &userdata[i].vhost, 0, users[i].server_family);
		resolve_host (users[i].server, &userdata[i].server, users[i].port, users[i].server_family);
	}

	debug ("dane dla użytkowników załadowane.");
}

void bind_socket (void)
{
	struct sockaddr_in addr;
	struct in_addr in_address;
	int rs;

	debug ("rozpoczynanie czekania na połączenia.");

	bind_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (!bind_fd) {
		debug ("socket(): %m.");
		_exit (1);
	}

	rs = 1;
	if (setsockopt(bind_fd, SOL_SOCKET, SO_REUSEADDR, (char *) &rs, sizeof(rs))) {
		debug ("setsockopt(): %m.");
		_exit (1);
	}

	addr.sin_addr.s_addr = INADDR_ANY;
	addr.sin_port = htons(BIND_PORT);
	addr.sin_family = AF_INET;

	if (bind(bind_fd, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
		debug ("bind(): %m.");
		_exit (1);
	}

	if (listen(bind_fd, LISTEN_QUEUE) < 0) {
		debug ("listen(): %m.");
		_exit (1);
	}


	debug ("czekanie na połączenia rozpoczęte.");
}

char is_user (struct in_addr *addr)
{
	int i;

	for (i = 0; i < NUM_USERS; i++)
		if (!memcmp(&userdata[i].host, addr, sizeof(struct in_addr)))
			return 1;

	return 0;
}

void read_line (int fd, char *buf, int buf_len)
{
	int i = 0, num_read;
	char ch, done = 0;

	while (!done) {
		num_read = read(fd, &ch, 1);
		if ((!num_read) || (num_read < 0 && errno == EINTR && flag_alrm))
			done = 1;
		else switch (ch) {
			case 0x0D:
				break;
			case 0x0A:
				done = 1;
				break;
			default:
				if (i < buf_len - 2)
					buf[i++] = ch;
				else
					done = 1;
				break;
		}
	}

	buf[i] = (char) NULL;
}

int get_auth (int fd, struct sockaddr_in *src_addr)
{
	struct sockaddr_in addr;
	int auth_fd;
	char buf[40], code[sizeof(buf)], user[sizeof(buf)];
	int user_num;

	for (user_num = 0; user_num < NUM_USERS; user_num++)
		if (src_addr->sin_addr.s_addr == userdata[user_num].host.s_addr && (!strcmp("-", users[user_num].user))) {
			send_notice (fd, "AUTH", "Twój adres IP ma uprawnienia do łączenia bez autoryzacji.");
			return user_num;
		}

	send_notice (fd, "AUTH", "Czas oczekiwania na połączenie z serwerem autoryzacji: %u sekund.", AUTH_TIMEOUT);

	alarm (AUTH_TIMEOUT);
	signal (SIGALRM, hnd_alrm);

	auth_fd = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (auth_fd <= 0) {
		kill_conn (fd, "Wewnętrzny błąd socket(): %m.");
		_exit (1);
	}

	addr.sin_addr = src_addr->sin_addr;
	addr.sin_port = htons(AUTH_PORT);
	addr.sin_family = AF_INET;

	snprintf (buf, sizeof(buf), "%u , %u\r\n", ntohs(src_addr->sin_port), BIND_PORT);

	send_notice (fd, "AUTH", "Port klienta: %u, port serwera: %u.", ntohs(src_addr->sin_port), BIND_PORT);
	send_notice (fd, "AUTH", "Łączenie z serwerem autoryzacji...");

	if (connect(auth_fd, (struct sockaddr *) &addr, sizeof(addr)) < 0)
		if (errno != EINTR) {
			kill_conn (fd, "Wewnętrzny błąd connect(): %m.");
			_exit (1);
		} else if (flag_alrm) {
			kill_conn (fd, "Przekroczony czas oczekiwania na połączenie z serwerem autoryzacji.");
			_exit (1);
		} else {
			kill_conn (fd, "Łączenie zostało przerwane.");
			_exit (1);
		}

	send_notice (fd, "AUTH", "Wysyłanie zapytania.");
	atomic_write (auth_fd, buf, strlen(buf));
	send_notice (fd, "AUTH", "Oczekiwanie na odpowiedź.");
	read_line (auth_fd, buf, sizeof(buf));
	close (auth_fd);

	if (flag_alrm) {
		kill_conn (fd, "Przekroczony czas oczekiwania na odpowiedź serwera autoryzacji.");
		_exit (1);
	}

	alarm (0);
	signal (SIGALRM, SIG_DFL);

	send_notice (fd, "AUTH", "Przetwarzanie odpowiedzi.");
	sscanf (buf, "%*d , %*d : %s : %*s : %s", code, user);

	if (strcasecmp(code, "userid")) {
		kill_conn (fd, "Serwer autoryzacji zwrócił kod %s: %s.", code, user);
		_exit (1);
	}

	send_notice (fd, "AUTH", "Kod odpowiedzi: %s, użytkownik: %s.", code, user);
	send_notice (fd, "AUTH", "Wyszukiwanie użytkownika w tablicy autoryzacji...");

	for (user_num = 0; user_num < NUM_USERS; user_num++)
		if (src_addr->sin_addr.s_addr == userdata[user_num].host.s_addr && (!strcmp(user, users[user_num].user))) {
			send_notice (fd, "AUTH", "Zostałeś poprawnie zidentyfikowany.");
			return user_num;
		}

	send_notice (fd, "AUTH", "Ten serwer nie ma dla ciebie linii autoryzacji.");
	kill_conn (fd, "Ten serwer nie ma dla ciebie linii autoryzacji.");
	_exit (1);

	/* NOTREACHED */

	return 0;
}

void drop_privs (int fd, char *user)
{
	struct passwd *pw;

	debug ("zrzucanie uprawnień na użytkownika %s.", user);
	send_notice (fd, "AUTH", "Przygotowywanie lokalnego serwera autoryzacji.");

	if (getuid() != 0) {
		kill_conn (fd, "Nie można zrzucić uprawnień. Serwer musi chodzić z uid=0.");
		_exit (1);
	}

	pw = getpwnam(user);
	if (!pw) {
		kill_conn (fd, "Nie ma w systemie użytkownika %s.", user);
		_exit (1);
	}

	send_notice (fd, "AUTH", "Lokalny identyfikator użytkownika: %u.", pw->pw_uid);
	send_notice (fd, "AUTH", "Lokalna grupa użytkownika: %u.", pw->pw_gid);

	if (initgroups(pw->pw_name, pw->pw_gid)) {
		kill_conn (fd, "Nie można zainicjować grup: %m.");
		_exit (1);
	}

	if (setuid(pw->pw_uid)) {
		kill_conn (fd, "Nie można zrzucić uprawnień: %m.");
		_exit (1);
	}

	send_notice (fd, "AUTH", "Lokalny serwer autoryzacji gotowy.");
}

int connect_user (int fd, int user)
{
	int srv_fd;
	struct addrinfo *ai;
	char ntop[NI_MAXHOST];
	char strport[NI_MAXSERV];
	char have_conn = 0;

	snprintf (strport, sizeof(strport), "%u", users[user].port);

	send_notice (fd, "*", "Ustalanie wszystkich adresów serwera %s.", users[user].server);

	for (ai = userdata[user].server; ai; ai = ai->ai_next) {
		if (ai->ai_family != users[user].server_family)
			continue;

		send_notice (fd, "*", "Pobieranie informacji na temat serwera.");

		if (getnameinfo(ai->ai_addr, ai->ai_addrlen, ntop, sizeof(ntop), strport, sizeof(strport), NI_NUMERICHOST | NI_NUMERICSERV)) {
			kill_conn (fd, "Błąd wewnętrzny getnameinfo(): %m.");
			_exit (1);
		}

		srv_fd = socket(ai->ai_family, SOCK_STREAM, IPPROTO_TCP);
		if (srv_fd < 0) {
			kill_conn (fd, "Błąd wewnętrzny socket(): %m.");
			_exit (1);
		}

		send_notice (fd, "*", "Ustawianie wirtualnego hosta %s.", users[user].vhost);

		if (bind(srv_fd, userdata[user].vhost->ai_addr, userdata[user].vhost->ai_addrlen) < 0) {
			kill_conn (fd, "Błąd wewnętrzny bind(): %m.");
			_exit (1);
		}

		send_notice (fd, "*", "Próba połączenia z %s [%s] na porcie %s.", users[user].server, ntop, strport);

		if (connect(srv_fd, ai->ai_addr, ai->ai_addrlen) >= 0) {
			send_notice (fd, "*", "Połączenie udane, przetwarzanie.");
			have_conn = 1;
			break;
		}

		send_notice (fd, "*", "Połączenie nie udało się z powodu %m.");
		close (srv_fd);
	}

	if (!have_conn) {
		kill_conn (fd, "Nie udało się połączyć z żadnym z adresów przypisanych do serwera.");
		_exit (1);
	}

	return srv_fd;
}

char do_pass (int src_fd, int dst_fd)
{
	int num;
	char buf[16384];

	num = read(src_fd, buf, sizeof(buf));
	if (!num)
		return 1;

	atomic_write (dst_fd, buf, num);

	return 0;
}

void pass_data (int cli_fd, int srv_fd)
{
	int max_fd;
	fd_set rdfds;
	int rs;
	char done = 0;

	max_fd = (cli_fd > srv_fd) ? cli_fd + 1 : srv_fd + 1;

	send_notice (cli_fd, "*", "Deskryptor połączenia klienta %u, deskryptor połączenia serwera %u.", cli_fd, srv_fd);
	send_notice (cli_fd, "*", "Rozpoczęto tunelowanie.");

	while (!done) {
		FD_ZERO (&rdfds);
		FD_SET (cli_fd, &rdfds);
		FD_SET (srv_fd, &rdfds);
		rs = select(max_fd, &rdfds, NULL, NULL, NULL);
		if (!rs)
			continue;
		else if (rs < 0) {
			kill_conn (cli_fd, "Zamykanie połączenia z powodu błędu %m.");
			_exit (1);
		}

		if (FD_ISSET(cli_fd, &rdfds))
			done |= do_pass(cli_fd, srv_fd);

		if (FD_ISSET(srv_fd, &rdfds))
			done |= do_pass(srv_fd, cli_fd);
	}

	send_notice (cli_fd, "*", "Połączenie zakończone.");

	close (srv_fd);
	close (cli_fd);
}

void hnd_conn (int fd, struct sockaddr_in *addr)
{
	char *ipaddr;
	unsigned short port;
	int user_num;
	int dest_fd;

	ipaddr = inet_ntoa(addr->sin_addr);
	port = ntohs(addr->sin_port);

	debug ("obsługiwanie połączenia spod adresu %s:%u.", ipaddr, port);
	send_notice (fd, "*", "Multiproxy (C) 2004 gophi@linux.net.pl.");
	send_notice (fd, "*", "Odebrano połączenie spod adresu %s:%u.", ipaddr, port);
	send_notice (fd, "AUTH", "Poczekaj na przetworzenie swojego połączenia.");
	user_num = get_auth(fd, addr);

	debug ("użytkownik %s@%s został poprawnie autoryzowany.", users[user_num].user, users[user_num].host);

	drop_privs (fd, users[user_num].vuser);
	dest_fd = connect_user(fd, user_num);

	pass_data (fd, dest_fd);
}

void daemonize (void)
{
	debug ("przechodzenie w tło.");

	debug ("tworzenie kopii obrazu procesu.");

	switch (fork()) {
		case 0:
			break;
		case -1:
			debug ("fork(): %m");
			_exit (1);
		default:
			_exit (0);
	}

	debug ("ustawianie identyfikatora grupy sesji.");
	if (setsid() < 0) {
		debug ("setsid(): %m");
		_exit (1);
	}

	debug ("przechodzenie do głównego katalogu.");
	if (chdir("/")) {
		debug ("chdir(): %m");
		_exit (1);
	}

	debug ("przechodzenie w tło poprawne.");
}

int main (int argc, char **argv)
{
	struct sockaddr_in conn_addr;
	socklen_t addr_sz = sizeof(struct sockaddr_in);
	char *ipaddr;
	int conn_fd;
	pid_t pid;

	debug ("uruchamianie mproxy.");
	load_userdata();
	bind_socket();

	signal (SIGINT, hnd_term);
	signal (SIGTERM, hnd_term);
	signal (SIGCHLD, hnd_chld);
	signal (SIGSEGV, hnd_segv);

	daemonize();

	debug ("mproxy jest gotowe do pracy.");

	while (!0) {
		conn_fd = accept(bind_fd, (struct sockaddr *) &conn_addr, &addr_sz);
		ipaddr = inet_ntoa(conn_addr.sin_addr);
		debug ("połączenie spod %s.", ipaddr);

		if (!is_user(&conn_addr.sin_addr)) {
			kill_conn (conn_fd, "Nie znaleziono użytkownika pasującego do hostmaski.");
			continue;
		}

		if (num_forks >= MAX_FORKS) {
			kill_conn (conn_fd, "Tymczasowe przeciążenie serwera, spróbuj później.");
			continue;
		}

		num_forks++;

		pid = fork();
		switch (pid) {
			case 0:
				debug ("obsługiwanie połączenia.");
				close (bind_fd);
				bind_fd = 0;
				hnd_conn (conn_fd, &conn_addr);
				_exit (0);
			case -1:
				debug ("fork(): %m.");
				_exit (1);
			default:
				debug ("wywołano proces potomny [%u] do obsługi połączenia.", pid);
				close (conn_fd);
		}
	}
}
