/*
 * Copyright (c) 2003-2012
 * Distributed Systems Software.  All rights reserved.
 * See the file LICENSE for redistribution information.
 */

#ifndef lint
static const char copyright[] =
"Copyright (c) 2003-2012\n\
Distributed Systems Software.  All rights reserved.";
static const char revid[] =
  "$Id: netlib.c 2586 2012-03-15 16:21:40Z brachman $";
#endif

#include <sys/ioctl.h>

#ifdef DSSLIB
#include "dsslib.h"
#else
#include "local.h"
#endif

#if defined(DACS_OS_SOLARIS)
#include <sys/filio.h>
#endif

static const char *log_module_name = "netlib";

int
net_tcp_nodelay(int sd)
{
  int flag;

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, TCP_NODELAY, &flag, sizeof(flag)) == -1)
	return(-1);

  return(0);
}

int
net_socket_reuseaddr(int sd)
{
  int flag;

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)) == -1)
    return(-1);

  return(0);
}

static int
set_blocking_status(int sd, int block, char **errmsg)
{
  int flags;

  if (sd < 0) {
	*errmsg = "Bad file descriptor";
	return(-1);
  }

  if ((flags = fcntl(sd, F_GETFL)) == -1) {
	if (errmsg != NULL)
	  *errmsg = ds_xprintf("set_blocking_status fcntl(F_GETFL): %s",
						   strerror(errno));
	return(-1);
  }

  if (block)
	flags &= ~O_NONBLOCK;
  else
	flags |= O_NONBLOCK;

  if (fcntl(sd, F_SETFL, flags) == -1) {
	int en;

	en = errno;
	if (ioctl(sd, FIONBIO, &flags) == -1) {
	  if (errmsg != NULL)
		*errmsg = ds_xprintf("set_blocking_status ioctl(FIONBIO): %s, fcntl(F_SETFL): %s",
							 strerror(errno), strerror(en));
	  return(-1);
	}
  }

  return(0);
}

/*
 * Set the given file descriptor to be non-blocking.
 * Return -1 on error, 0 otherwise.
 */
int
net_set_nonblocking(int sd, char **errmsg)
{

  return(set_blocking_status(sd, 0, errmsg));
}

/*
 * Set the given file descriptor to be blocking.
 * Return -1 on error, 0 otherwise.
 */
int
net_set_blocking(int sd, char **errmsg)
{

  return(set_blocking_status(sd, 1, errmsg));
}

int
net_write(int sd, void *buf, size_t buflen)
{
  char *p;
  ssize_t n;
  size_t nrem;

  nrem = buflen;
  p = buf;
  while (1) {
    n = write(sd, (void *) p, nrem);
    if (n == -1) {
      if (errno == EAGAIN) {
        /* XXX Block until sd is writable? */
        continue;
      }
      return(-1);
    }

    if ((size_t) n == nrem)
      break;

    nrem -= n;
    p += n;
  }

  return(0);
}

int
net_write_str(int sd, char *mesg)
{
  size_t len;

  len = strlen(mesg);
  return(net_write(sd, mesg, len));
}

/*
 * Assuming that STR is a domain name optionally followed by
 * a colon and a port number or service name, extract the port number (or
 * service name) as PORT and determine PORTNUM.
 * If an error does not occur, also set HOSTNAME.
 * Any of these arguments can be NULL if their value is not needed.
 * Return 1 if a port number is found, 0 if it is not found, or -1 if it
 * is found but is invalid.
 * If no port component is present and PORTSTR is non-NULL, then PORTSTR
 * is set to NULL.
 */
int
net_parse_hostname_port(char *str, char **hostname, char **portstr,
						in_port_t *portnum)
{
  int rc;
  char *p, *s;
  in_port_t aport;

  s = strdup(str);

  if (portstr != NULL)
	*portstr = NULL;

  rc = 0;
  if ((p = strchr(s, (int) ':')) != NULL) {
	/* Get the port number. */
	*p++ = '\0';
	if ((aport = net_get_service_port(p, NULL, NULL)) == 0)
	  return(-1);

	if (portnum != NULL)
	  *portnum = aport;

	if (portstr != NULL)
	  *portstr = p;

	rc = 1;
  }

  if (hostname != NULL)
	*hostname = s;

  return(rc);
}

struct sockaddr_in *
net_make_sockaddr(char *hostname, in_port_t port)
{
  struct hostent *hp;
  struct sockaddr_in *addr;

  if (hostname == NULL) {
	log_msg((LOG_ERROR_LEVEL, "No hostname available?"));
	return(NULL);
  }

  addr = (struct sockaddr_in *) malloc(sizeof(struct sockaddr_in));
  memset(addr, 0, sizeof(struct sockaddr_in));
  addr->sin_family = AF_INET;
  addr->sin_port = htons((u_short) port);

  if ((hp = gethostbyname(hostname)) == NULL) {
	in_addr_t numeric_addr;

	if ((numeric_addr = inet_addr(hostname)) == -1) {
	  log_err((LOG_ERROR_LEVEL, "gethostbyname : Bad hostname: \"%s\"\n",
			   hostname));
	  return(NULL);
	}
	memcpy((void *) &addr->sin_addr, &numeric_addr, sizeof(numeric_addr));
  }
  else
	memcpy((void *) &addr->sin_addr, hp->h_addr, hp->h_length);

  endhostent();

  return(addr);
}

int
net_connect_to_server_ssl(char *hostname, in_port_t port,
						  Net_connection_type ct,
						  char *ssl_prog, char *ssl_prog_args,
						  char *ssl_prog_client_crt,
						  char *ssl_prog_ca_crt,
						  int *read_fd, int *write_fd)
{
  int argc, i, rc;
  char **argv, *remote;
  Dsvec *v;

  if (hostname == NULL) {
	log_msg((LOG_ERROR_LEVEL, "No hostname given"));
	return(-1);
  }
  if (port == 0) {
	log_msg((LOG_ERROR_LEVEL, "No port given"));
	return(-1);
  }

  v = dsvec_init(NULL, sizeof(char *));
  dsvec_add_ptr(v, ssl_prog);

  argc = 0;
  if (ssl_prog_args != NULL) {
	static Mkargv conf = { 0, 0, " ", NULL, NULL };

	if ((argc = mkargv(ssl_prog_args, &conf, &argv)) == -1) {
	  log_msg((LOG_ERROR_LEVEL, "Invalid ssl_prog_args argument"));
	  return(-1);
	}
  }

  if (ssl_prog_client_crt != NULL) {
	dsvec_add_ptr(v, (void *) "-ccf");
	dsvec_add_ptr(v, (void *) ssl_prog_client_crt);
  }

  if (ct == NET_SSL_VERIFY) {
	dsvec_add_ptr(v, (void *) "-vt");
	dsvec_add_ptr(v, (void *) "peer");
	dsvec_add_ptr(v, (void *) "-caf");
	dsvec_add_ptr(v, ssl_prog_ca_crt);
  }

  for (i = 0; i < argc; i++)
	dsvec_add_ptr(v, argv[i]);

  remote = ds_xprintf("%s:%u", hostname, port);
  dsvec_add_ptr(v, remote);
  dsvec_add_ptr(v, NULL);

  rc = filterthru((char **) dsvec_base(v), NULL,
				  read_fd, write_fd, NULL, NULL);
  if (rc == -1) {
	log_msg((LOG_ERROR_LEVEL, "SSL filter failed"));
	return(-1);
  }

  return(0);
}

/*
 * Establish a TCP connection to HOSTNAME:PORT returning 0 if successful
 * (setting READ_FD and/or WRITE_FD to its read/write descriptors),
 * -1 otherwise.
 */
int
net_connect_to_server(char *hostname, in_port_t port,
					  int *read_fd, int *write_fd)
{
  int flag, sd;
  struct sockaddr_in addr;
  struct hostent *hp;
  in_addr_t numeric_addr;

  if (hostname == NULL) {
	log_msg((LOG_ERROR_LEVEL, "No hostname given"));
	return(-1);
  }
  if (port == 0) {
	log_msg((LOG_ERROR_LEVEL, "No port given"));
	return(-1);
  }

  if ((sd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
	log_err((LOG_ERROR_LEVEL, "socket"));
	return(-1);
  }

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, TCP_NODELAY, &flag, sizeof(flag)) == -1) {
	log_err((LOG_ERROR_LEVEL, "setsockopt"));
	return(-1);
  }

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)) == -1) {
	log_err((LOG_ERROR_LEVEL, "setsockopt2"));
	close(sd);
	return(-1);
  }

  memset(&addr, 0, sizeof(addr));
  addr.sin_family = AF_INET;
  addr.sin_port = htons((u_short) port);
  if ((hp = gethostbyname(hostname)) == NULL) {
	if ((numeric_addr = inet_addr(hostname)) == -1) {
	  log_msg((LOG_ERROR_LEVEL, "Unknown hostname: \"%s\"", hostname));
	  close(sd);
	  return(-1);
	}
	memcpy((char *) &addr.sin_addr, &numeric_addr, sizeof(numeric_addr));
  }
  else
	memcpy((char *) &addr.sin_addr, hp->h_addr, hp->h_length);

  log_msg((LOG_DEBUG_LEVEL, "Connecting to %s:%u", hostname, port));

  if (connect(sd, (struct sockaddr *) &addr, sizeof(addr)) == -1) {
	log_err((LOG_ERROR_LEVEL, "connect"));
	close(sd);
	return(-1);
  }

  if (read_fd != NULL) {
	*read_fd = sd;
	if (write_fd != NULL) {
	  if ((*write_fd = dup(sd)) == -1) {
		log_err((LOG_ERROR_LEVEL, "dup"));
		close(sd);
		return(-1);
	  }
	}
  }
  else if (write_fd != NULL)
	*write_fd = sd;
  else
	return(-1);

  return(0);
}

/*
 * Test if descriptor SD is ready for reading.
 * If TIMEOUT is NULL, block until it is ready, otherwise wait for an interval
 * of at most TIMEOUT.
 * Return -1 if an error occurs, 0 if the interval expires without SD becoming
 * ready to read, or 1 if SD is ready.
 */
int
net_input_or_timeout(int sd, struct timeval *timeout)
{
  int n;
  fd_set fdset;

  FD_ZERO(&fdset);
  FD_SET(sd, &fdset);

 again:

  if ((n = select(sd + 1, &fdset, NULL, NULL, timeout)) == -1) {
	log_err((LOG_ERROR_LEVEL, "select"));
	if (errno == EINTR) {
	  /* XXX find out how much time remains and try again... */
	  sleep(1);
	  goto again;
	}
	return(-1);
  }

  if (n == 0) {
	log_msg((LOG_ERROR_LEVEL, "Timeout"));
	return(0);
  }

  return(1);
}

char *
net_sockaddr_name(struct sockaddr *sa, socklen_t salen)
{
  char *p;
  char localhost[128], localport[16];

  if (getnameinfo(sa, salen,
				  localhost, sizeof(localhost),
				  localport, sizeof(localport),
				  NI_NUMERICHOST | NI_NUMERICSERV) == 0)
	p = ds_xprintf("%s:%s", localhost, localport);
  else
	p = NULL;

  return(p);
}

/*
 * Return the local address associated with SD, or NULL.
 */
struct sockaddr *
net_socket_laddr(int sd)
{
  struct sockaddr *addr;
  socklen_t namelen;

  addr = ALLOC(struct sockaddr);
  namelen = sizeof(struct sockaddr_in);
  if (getsockname(sd, addr, &namelen) == -1) {
    log_err((LOG_ERROR_LEVEL, "getsockname"));
    return(NULL);
  }

  return(addr);
}

/*
 * Return the remote address associated with SD, or NULL.
 */
struct sockaddr *
net_socket_raddr(int sd)
{
  struct sockaddr *addr;
  socklen_t namelen;

  addr = ALLOC(struct sockaddr);
  namelen = sizeof(struct sockaddr_in);
  if (getpeername(sd, addr, &namelen) == -1) {
    log_err((LOG_ERROR_LEVEL, "getpeername"));
    return(NULL);
  }

  return(addr);
}

char *
net_socket_lname(int sd)
{
  char *sname;
  struct sockaddr_in addr;
  socklen_t namelen;

  namelen = sizeof(struct sockaddr_in);
  if (getsockname(sd, (struct sockaddr *) &addr, &namelen) == -1) {
    log_err((LOG_ERROR_LEVEL, "getsockname"));
    return(NULL);
  }

  sname = net_sockaddr_name((struct sockaddr *) &addr,
							sizeof(struct sockaddr_in));

  return(sname);
}

char *
net_socket_rname(int sd)
{
  char *sname;
  struct sockaddr_in addr;
  socklen_t namelen;

  namelen = sizeof(struct sockaddr_in);
  if (getpeername(sd, (struct sockaddr *) &addr, &namelen) == -1) {
    log_err((LOG_ERROR_LEVEL, "getpeername"));
    return(NULL);
  }

  sname = net_sockaddr_name((struct sockaddr *) &addr,
							sizeof(struct sockaddr_in));

  return(sname);
}

/*
 * PORTNAME is either a port number, a service name, or NULL.
 * Return the corresponding numeric port number, or zero if an error occurs
 * zero is an invalid port number in this context).
 * The protocol is implicitly TCP.
 * If PORTNAME is NULL, use DEFAULT_PORTNAME instead.
 */
in_port_t
net_get_service_port(char *portname, char *default_portname, char **errmsg)
{
  in_port_t port;

  port = 0;
  if (portname == NULL) {
    if ((portname = default_portname) == NULL)
	  return(0);
  }

  if (is_digit_string(portname)) {
    /* Must be a port number. */
    if (strnum(portname, STRNUM_IN_PORT_T, &port) == -1) {
      if (errmsg != NULL)
        *errmsg = "Invalid port number";
      return(0);
    }
  }
  else {
    struct servent *serv;

    /* Must be a service name, or invalid. */
    if ((serv = getservbyname(portname, "tcp")) == NULL) {
      if (errmsg != NULL)
        *errmsg = ds_xprintf("Can't find TCP service \"%s\"", portname);
      return(0);
    }
    port = htons(serv->s_port);
    endservent();
  }

  return(port);
}

int
net_accept_or_timeout(int sd, struct sockaddr *from, struct timeval *timeout)
{
  int new_sd;
  char *sname;
  socklen_t fromlen;

  sname = net_socket_lname(sd);
  log_msg((LOG_TRACE_LEVEL, "Waiting to accept connection to %s",
		   (sname == NULL) ? "???" : sname));

  if (net_input_or_timeout(sd, timeout) != 1)
	return(-1);

  fromlen = sizeof(struct sockaddr);
  if ((new_sd = accept(sd, from, &fromlen)) == -1) {
	log_err((LOG_ERROR_LEVEL, "accept"));
	return(-1);
  }

  if ((sname = net_sockaddr_name(from, fromlen)) != NULL)
	log_msg((LOG_TRACE_LEVEL, "Got new connection from %s", sname));
  else
	log_msg((LOG_TRACE_LEVEL, "Got new connection from ???"));

  return(new_sd);
}

/*
 * Create and configure a socket
 */
int
net_make_server_socket(char *hostname, in_port_t port,
					   int *sdp, struct sockaddr_in **namep)
{
  int flag, sd;
  char *sname;
  struct sockaddr_in *my_addr;
  socklen_t namelen;

  if ((sd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
	log_err((LOG_ERROR_LEVEL, "socket"));
	return(-1);
  }

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, TCP_NODELAY, &flag, sizeof(flag)) == -1) {
	log_err((LOG_ERROR_LEVEL, "setsockopt"));
	return(-1);
  }

  flag = 1;
  if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)) == -1) {
	log_err((LOG_ERROR_LEVEL, "setsockopt2"));
	return(-1);
  }

  if ((my_addr = net_make_sockaddr(hostname, port)) == NULL)
	return(-1);

  if (bind(sd, (const struct sockaddr *) my_addr, sizeof(struct sockaddr_in))
	  == -1) {
	log_err((LOG_ERROR_LEVEL, "bind"));
	return(-1);
  }

  if (listen(sd, 0) == -1) {
	log_err((LOG_ERROR_LEVEL, "listen"));
	return(-1);
  }

  /* This is necessary to find out which local port has been assigned. */
  namelen = sizeof(struct sockaddr_in);
  if (getsockname(sd, (struct sockaddr *) my_addr, &namelen) == -1) {
    log_err((LOG_ERROR_LEVEL, "getsockname"));
    return(-1);
  }

  if ((sname = net_sockaddr_name((struct sockaddr *) my_addr,
							   sizeof(struct sockaddr_in))) != NULL)
	log_msg((LOG_TRACE_LEVEL, "Created server socket at %s", sname));

  *sdp = sd;
  *namep = my_addr;

  return(0);
}
