/* tst.c - Program for testing a concurrent TCP server with "concurrency"-many 
 *         concurrent connections. Each connection, writes a request, 
 *         then reads until the connection is closed by the server. 
 *         Statistics are collected and printed out. We write as request
 *         a legal HTTP request asking for the a file.
 *         This program works with web servers, but it can also test
 *         other kinds of TCP servers.
 * Usage is:
 *    % tst options [host]
 * where the options are:
 *   -n           total number of requests made (default is 1)
 *   -c           number of concurrent connections (default is 1)
 *   -t           time limit on run (default is unlimited time)
 *   -p           port on server (default is 80)
 *   -f           file requested (default is "/"). Beware that if you want to
 *                retrieve the moo.html file of user smith you will have
 *                to say -f "~smith/moo.html"
 *   -v           display all that is received from server [CAREFULL!]
 *   -h           usage information
 * The default host is the localhost. The program runs until 'requests' 
 * requests have been made, or the timeout 'tlimit' has expired.
 * Among other things, the program determines the time in milliseconds
 * to when connection is established (TimeToConnect), to when some
 * part of the response becomes available (Latency), to when
 * the response is completed (ResponseTime)
 * Warning: Nothing is done to ensure a particular rate at which 
 *       requests are generated.
 * NOTE: THIS PROGRAM IS ABRIDGED FROM ab.c WHICH IS AVAILABLE
 *       AT WWW.APACHE.ORG AS PART OF DISTRIBUTION OF APACHE SERVER.
 */

#include <sys/time.h>
#include <sys/ioctl.h>
#include <stdlib.h>
#include <stdio.h>
#include <fcntl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netdb.h>
#include <errno.h>
#include <string.h>
#include <netinet/in.h>

/* ------------------- DEFINITIONS -------------------------- */

struct connection { /* descriptor of each connection */
    int fd;             /* the connected socket */
    int read;		/* amount of bytes read */
    struct timeval start, connect, firstread, done;
};

struct data { /* data collected for each request */
    int read;			/* number of bytes read */
    int ctime;			/* time in ms to connect */
    int rtime;                  /* time in ms to first read = latency */
    int time;			/* time in ms for connection */
};

#define MAXSTRING 1024          /* maximum size of various strings */

#define min(a,b) (((a)<(b))?(a):(b))
#define max(a,b) (((a)>(b))?(a):(b))

/* --------------------- GLOBALS ---------------------------- */

int requests = 1;		/* Number of requests to make - default is 1 */
int concurrency = 1;		/* Number of multiple requests to make */
int tlimit = 0;			/* Total time limit for test */
char hostname[MAXSTRING];	/* host name */
int port = 80;			/* port number */
char filename[MAXSTRING] = "/"; /* file requested from server */
int doclen = 0;			/* the length the document read */
int totalread = 0;		/* total number of bytes read */
int done = 0, bad = 0;          /* number of done and bad requests */
int started = 0;                /* number of requests started. The requests */
                                /* in progress are started - done */
int writeallout = 0;            /* if 1 we print all that comes from server */
struct timeval start, endtime;
char request[512];              /* global request */
int reqlen;                     /* length of global request */
struct connection *con;		/* array of descriptors of connections */
struct data *stats;		/* data for each request */
fd_set readbits, writebits;	/* bits for select */
struct sockaddr_in server;	/* server addr structure */

/* --------------------------------------------------------- */

void usage (const char * command)
{
  printf(
	 "Usage is:\n"
	 "   %s options [host [port]]\n"
	 "where the options are:\n"
	 "   -n           total number of requests made\n"
	 "   -c           number of concurrent connections\n"
	 "   -t           time limit on run\n"
	 "   -p           port on server, default is 80\n"
         "   -f           file requested, default is '/'\n"
         "   -v           display all that is sent from the server [CAREFULL!]\n"
	 "   -h           usage information\n"
	 "The default host is the localhost.\n"
         "NOTE: THIS PROGRAM IS ABRIDGED FROM ab.c \n"
	 "WHICH IS AVAILABLE FROM WWW.APACHE.ORG AS PART OF\n"
	 "THE APACHE SERVER DISTRIBUTION\n",
	 command
	 );
}

/* little function to print a message and exit */
static void err(const char *s)
{
    if (errno) {
	perror(s);
    } else 
	printf("%s\n", s);
    exit(errno);
}

/* returns the time in ms between two timevals */
static int timedif(struct timeval a, struct timeval b)
{
  return (a.tv_sec - b.tv_sec)*1000  + (a.tv_usec - b.tv_usec)/1000;
}

/* start asynchronous non-blocking connection */
static void start_connection(struct connection * c)
{
  int cbad;     /* number of consecutive failures in connect */

  for (cbad = 0; cbad < 3 ; ++cbad ) {
    int temp = 1;
    c->read = 0;
    if ((c->fd = socket(AF_INET, SOCK_STREAM, 0)) < 0)
      err("socket");
    ioctl(c->fd, FIONBIO, &temp); /* make socket non blocking */
    gettimeofday(&c->start, 0);
    if ((connect(c->fd, (struct sockaddr *) & server, sizeof(server)) == 0) /* successful connection */
        || (errno == EINPROGRESS)) { /*connection completed asynchronously*/
	FD_SET(c->fd, &writebits);
	return;
    }
    close(c->fd);
  }
  err("\nTest aborted after 3 connection failures\n\n");
}

/* write out request to a connection - assumes we can write
   (small) request out in one go into socket buffer  */
static void write_request(struct connection * c)
{
    int w;
    if (started >= requests) return;
    started++;
    gettimeofday(&c->connect, 0);
    w = write(c->fd,request,reqlen);
    FD_SET(c->fd, &readbits);
    FD_CLR(c->fd, &writebits);
    shutdown(c->fd, 1);
}

/* read data from connection. Return 0 if and only if
 * connection should be terminated */
static int read_connection(struct connection * c)
{
  static char buffer[8192+1];
  int r;

  if (c->read == 0) 
    gettimeofday(&c->firstread, 0);
  if ((r = read(c->fd, buffer, sizeof(buffer)-1)) > 0) {
    if ( writeallout == 1) {
	buffer[r] = '\0';
	printf("%s", buffer);
    }
    c->read += r;
    totalread += r;
  }
  if (r == 0 || (r < 0 && errno != EAGAIN)) /*EAGAIN=operation would block*/
    return 0;
  return 1;
}

/* close down connection and save stats */
static void close_connection(struct connection * c)
{
  struct data s;

  /* save out times */
  gettimeofday(&c->done, 0);
  s.read = c->read;
  s.ctime = timedif(c->connect, c->start);
  s.rtime = timedif(c->firstread, c->start);
  s.time = timedif(c->done, c->start);
  stats[done++] = s;

  if ((doclen == 0) && (c->read > 0)) 
    doclen = c->read;
  if (doclen == 0 || (c->read != doclen))
    bad++;

  FD_CLR(c->fd, &readbits);
  FD_CLR(c->fd, &writebits);
  close(c->fd);
  return;
}

/* run the tests */
static void test(void)
{
  struct timeval timeout, now;
  fd_set sel_read, sel_except, sel_write;
  int i;
  struct hostent *he;
  int sbad = 0;             /* number of failures in select */
  
  printf("\nBenchmarking %s\n\n", hostname);
  
  /* get server information */
  if (!(he = gethostbyname(hostname)))
      err("bad hostname");
  server.sin_family = he->h_addrtype;
  server.sin_port = htons(port);
  server.sin_addr.s_addr = ((unsigned long *) (he->h_addr_list[0]))[0];
  
  con = malloc(concurrency * sizeof(struct connection));
  memset(con, 0, concurrency * sizeof(struct connection));
  stats = malloc(requests * sizeof(struct data));
  FD_ZERO(&readbits);
  FD_ZERO(&writebits);
  
  /* lets start */
  gettimeofday(&start, 0);
  
  /* start 'concurrency' requests */
  for (i = 0; i < concurrency; i++)
    start_connection(&con[i]);
  
  while (done < requests) {
    int n;
    /* setup bit arrays */
    memcpy(&sel_except, &readbits, sizeof(readbits));
    memcpy(&sel_read, &readbits, sizeof(readbits));
    memcpy(&sel_write, &writebits, sizeof(readbits));
    
    /* check for time limit expiration */
    gettimeofday(&now, 0);
    if (tlimit && timedif(now, start) > tlimit) 
      requests = done;	/* so stats are correct */

    /* Timeout of 30 seconds. */
    timeout.tv_sec = 30;
    timeout.tv_usec = 0;
    if ((n = select(FD_SETSIZE, &sel_read, &sel_write, &sel_except, &timeout)) == 0) {
      requests = done;
      printf("\nServer timed out\n\n");
      return;
    } 
    if (n  <  0)
      err("select");

    for (i = 0; i < concurrency; i++) {
      int s = con[i].fd;
      if (FD_ISSET(s, &sel_except)) {
	if (sbad++ > 10) 
	  err("\nTest aborted after 10 select failures\n\n");
	start_connection(&con[i]);
	continue;
      }
      if (FD_ISSET(s, &sel_write))
	write_request(&con[i]);
      if (FD_ISSET(s, &sel_read))
	if (!read_connection(&con[i])) {
	  close_connection(&con[i]);
	  start_connection(&con[i]);
} } } }

/* calculate and output results */
static void output_results(void)
{
    int timetaken;
    int k;
    int totalcon = 0, totalfread = 0, total = 0;
    int mincon = 9999999, mintot = 999999, minfread = 999999;
    int maxcon = 0, maxtot = 0, maxfread = 0;

    gettimeofday(&endtime, 0);
    timetaken = max(1, timedif(endtime, start)); /* At least one ms */
    printf("\n\n\nServer Hostname:        %s\n", hostname);
    printf("Server Port:            %d\n\n", port);
    printf("Concurrency Level:      %d\n", concurrency);
    if (requests == 0) {
      printf("There are no requests\n\n");
      return;
    }

    printf("Document Length:        %d bytes\n\n", doclen);
    printf("Time taken for tests:   %d.%03d seconds\n",
	   timetaken / 1000, timetaken % 1000);
    printf("Completed requests:      %d\n", done);
    printf("Failed requests:        %d\n", bad);
    printf("Total transferred:      %d bytes\n", totalread);
    printf("Requests per second:    %.2f\n", 1000.0*(float)(done)/timetaken);
    printf("Transfer rate:          %.2f kb/s received\n",
            (float) (totalread) / timetaken);

    /* work out connection times */
    for (k = 0; k < requests; k++) {
      struct data s = stats[k];
      mincon = min(mincon, s.ctime);
      minfread = min(minfread, s.rtime);
      mintot = min(mintot, s.time);
      maxcon = max(maxcon, s.ctime);
      maxfread = max(maxfread, s.rtime);
      maxtot = max(maxtot, s.time);
      totalcon += s.ctime;
      totalfread += s.rtime;
      total += s.time;
    }
    printf("               min   avg   max\n");
    printf("TimeToConnect:%5d %5d %5d\n", mincon, totalcon / requests, maxcon);
    printf("Latency:      %5d %5d %5d\n", minfread, totalfread / requests, 
	   maxfread);
    printf("ResponseTime: %5d %5d %5d\n", mintot, total / requests, maxtot);
}

/* ------------------------------------------------------- */

extern char *optarg;
extern int optind, opterr;

int main(int argc, char **argv)
{
  int c;

  /* determine what are the` command-line args */  
  optind = 1;
  while ((c = getopt(argc, argv, "n:c:t:p:f:vh")) > 0) {
    switch (c) {
    case 'n':
      requests = atoi(optarg);
      if (requests <= 0) 
	err("Invalid number of requests\n");
      break;
    case 'c':
      concurrency = atoi(optarg);
      break;
    case 't':
      tlimit = 1000*atoi(optarg);
      requests = (requests)?requests:10000;	
      break;
    case 'p':
      port = atoi(optarg);
      break;
    case 'f':
      strcpy(filename, optarg);
      break;
    case 'v':
      writeallout = 1;
      break;
    case 'h':
      usage(argv[0]);
      return 0;
    default:
      fprintf(stderr, "%s: invalid option `%c'\n", argv[0], c);
      usage(argv[0]);
      return(0);
    }
  }

  if (optind == argc) { 
    if (gethostname(hostname, MAXSTRING) < 0) {
      perror("gethostname fails");
      exit(1);
    }
  } else
    strcpy(hostname, argv[optind]);

  /* setup request message */
  sprintf(request, "GET /%s HTTP/1.0\r\n"
	  "User-Agent: TCPServerTester\r\n"
	  "Host: %s\r\n"
	  "Accept: */*\r\n"
	  "\r\n",
	  filename, hostname);
  reqlen = strlen(request);
  
  test();
  output_results();  
  exit(0);
}
