/* tcpclient.c - Computing the data rate on a TCP connection as
   a function of buffer size.
   Compile with
      % gcc -o tcpclient tcpclient.c -lm
   And run with  
      % tcpclient hostname 
***************************************************************/

#include <sys/time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <netdb.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <fcntl.h>
#include <signal.h>
#include <errno.h>
#include <math.h>

#define SERV_TCP_PORT 7621
#define MIN_SIZE 16*1024
#define MAX_SIZE 1024*1024
#define NSIZES 6
#define N_ITERATIONS 9

long int time_elapsed(struct timeval time1, struct timeval time2);

/* Handler for SIGPIPE signal */
void sig_pipe (int n) 
{
   static signal_count = 0;
   printf("SIGPIPE Signal\n");
   if(++signal_count == 10) {
     perror("Exceeded count on pipe signals. Bailing out");
     exit(1);
   }
}

main(int argc, char *argv[])
{
    int	s;
    struct sockaddr_in serv_addr;
    struct hostent *sp;
    int sendbuff, recvbuff;
    int bufsize;
    static char buf[MAX_SIZE];
    struct timeval time_start, time_end;
    double avgs[NSIZES], stddevs[NSIZES];
    int lcv;
    FILE *log;

    if (argc<2) {
      fprintf(stderr, "usage: %s <server_name>\n", argv[0]);
      exit(1);
    } 

    log = fopen("log.dat", "w");
    if (log == NULL) {
      perror("Cannot open log\n");
      exit(1);
    }

    if ((sp = gethostbyname(argv[1])) == NULL) {
      fprintf(stderr,"%s: host unknown.\n", argv[1]);
      return 0;
    }

    /* Establish handling of SIGPIPE signal */
    if ((int)signal(SIGPIPE, sig_pipe) == -1) {
      perror("Unable to set up signal handler for SIGPIPE");
      exit(1);
    }

    bufsize = MIN_SIZE;
    for (lcv = 0; lcv < NSIZES; ++lcv, bufsize = bufsize + bufsize) {
      int iter, k, t;
      double bitRate[N_ITERATIONS]; 
      double sum = 0.0;

      printf("bufsize = %d\n", bufsize);
      fprintf(log, "bufsize = %d\n", bufsize);

      for (iter = 0; iter<N_ITERATIONS; ++iter) {
	int total;
	long int t;

	bzero((char *) &serv_addr, sizeof(serv_addr));
	serv_addr.sin_family=AF_INET;
	serv_addr.sin_port=htons(SERV_TCP_PORT);
	bcopy((void *)sp->h_addr, (void *)&serv_addr.sin_addr, sp->h_length);
	
	/* init and bind a socket */
	if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
	  perror("client: socket");
	  return 0;
	}
	/* setup the buffersize for the socket*/
	sendbuff = bufsize; 
	if(setsockopt(s,SOL_SOCKET,SO_SNDBUF,&sendbuff,sizeof(sendbuff)) < 0) {
	  perror("client: setsockopt SO_SNDBUF");
	  return 0;
	}
	recvbuff = bufsize; 
	if(setsockopt(s,SOL_SOCKET,SO_RCVBUF,&recvbuff,sizeof(recvbuff)) < 0) {
	  perror("client: setsockopt SO_RCVBUF");
	  return 0;
	}
	if (connect(s, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) <0) {
	  perror("client: can't connect to server");
	  exit(1);
	}
	/* START TRANSMISSION */
	//	printf("Starting iteration %d\n", iter);
	if (gettimeofday(&time_start, NULL) != 0) {
	  perror("client: gettimeofday");
	  return 0;
	}
	total = 0;
	for (;;) {
	  int n;
	 //	  printf("*\n");
	  n = write(s, buf, bufsize);
	  total += n;
	 //	  printf("total = %d\n", total);
	  if (total>=MAX_SIZE) break;
	}
	read(s, buf, bufsize);
	if (gettimeofday(&time_end, NULL) != 0) {
	  perror("client: gettimeofday");
	  return 0;
	}
	t = time_elapsed(time_start, time_end);
	bitRate[iter] = MAX_SIZE*8.0/(1.0*t);
	
	//	printf("\tbandwidth is %f Mbps\n", bitRate[iter]);
	fprintf(log, "\tbandwidth is %f Mbps\n", bitRate[iter]);
	sum += bitRate[iter];
	close(s);
	sleep(1);
      }
      avgs[lcv] = sum/N_ITERATIONS;
      stddevs[lcv] = 0.0;
      for (k = 0; k<N_ITERATIONS; ++k)
	stddevs[lcv] += (bitRate[k] - avgs[lcv])*(bitRate[k] - avgs[lcv]);
      stddevs[lcv] = sqrt(stddevs[lcv]/N_ITERATIONS);

      printf("Average: %f Mbps, Standard Deviation: %f\n", avgs[lcv],
	     stddevs[lcv]);
      fprintf(log, "Average: %f Mbps, Standard Deviation: %f\n", 
	      avgs[lcv], stddevs[lcv]);
      //	  printf("Client: Closed connection\n");
    }
    printf("TOTAL AVERAGES:\n");
    for (lcv = 0; lcv<NSIZES; ++lcv) {
      printf("\t%f", avgs[lcv]);
      fprintf(log, "\t%f", avgs[lcv]);
    }
    printf("\n");
    return 0;
  } 


long int time_elapsed(struct timeval time1, struct timeval time2)
{
  return (1000000*(time2.tv_sec-time1.tv_sec) + (time2.tv_usec-time1.tv_usec));
}

