/******************************************************************
 Copyright (c), 2014-2024, T&W ELECTRONICS(SHENTHEN) Co., Ltd.

 ļ: Dnshijack.c
 ļ:

 ޶¼:
        1. : wuyouhui, wuyouhui@twsz.com
           : 20140418
           : Create

******************************************************************/
 
#include <unistd.h>     
#include <stdio.h>     
#include <stdlib.h>  
#include <string.h>  

#include <sys/stat.h>    
#include <sys/socket.h>     
#include <sys/types.h>        
#include <asm/types.h> 
#include <arpa/inet.h>

#include <linux/netlink.h>     
#include <linux/socket.h>   
#include <linux/ip.h>
#include <linux/udp.h>
#include <linux/tcp.h>


#include <stddef.h>  
#include <errno.h>
#include "debug.h"

typedef u_int16_t uint16;
typedef u_int32_t uint32;

/******************************************************************************
*                                 MACRO                                      *
******************************************************************************/

#define SET_UINT16( num, buff) num = htons(*(uint16*)*buff); *buff += 2
#define SET_UINT32( num, buff) num = htonl(*(uint32*)*buff); *buff += 4
#define SET_UINT16_TO_N(buf, val, count) *(uint16*)buf = htons(val);count += 2; buf += 2
#define SET_UINT32_TO_N(buf, val, count) *(uint32*)buf = htonl(val);count += 4; buf += 4



#define MAX_MSG_LEN 4096  
#define NAME_SIZE 255
#define MAX_PACKET_SIZE 512

#define NUM_RRS 5
#define PORT 53
#define TCP  6
#define UDP  17


/******************************************************************************
*                                STRUCT                                      *
******************************************************************************/

enum{ A = 1,      /* a host address */
	NS,       /* an authoritative name server */
	MD,       /* a mail destination (Obsolete - use MX) */
	MF,       /* */
	CNAME,    /* the canonical name for an alias */
	SOA,      /* marks the start of a zone of authority  */
	MB,       /* a mailbox domain name (EXPERIMENTAL) */
	MG,       /* */
	MR,       /* */
	NUL,      /* */
	WKS,      /* a well known service description */
	PTR,      /* a domain name pointer */
	HINFO,    /* host information */
	MINFO,    /* mailbox or mail list information */
	MX,       /* mail exchange */
	TXT,      /* text strings */

	AAA = 0x1c /* IPv6 A */
};
/* CLASS values */
enum{
  IN = 1,         /* the Internet */
    CS,
    CH,
    HS
};

/* OPCODE values */
enum{
  QUERY,
    IQUERY,
    STATUS
};


struct dns_rr{//
  char name[NAME_SIZE];
  uint16 type;
  uint16 class;
  uint32 ttl;
  uint16 rdatalen;
  char data[NAME_SIZE];
};
/*****************************************************************************/
union header_flags {
  uint16 flags;
  struct {
#if defined(__LITTLE_ENDIAN_BITFIELD)
    unsigned short int rcode:4;
    unsigned short int unused:3;
    unsigned short int recursion_avail:1;
    unsigned short int want_recursion:1;
    unsigned short int truncated:1;
    unsigned short int authorative:1;
    unsigned short int opcode:4;
    unsigned short int question:1;
#elif  defined(__BIG_ENDIAN_BITFIELD)
    unsigned short int question:1;
    unsigned short int opcode:4;
    unsigned short int authorative:1;
    unsigned short int truncated:1;
    unsigned short int want_recursion:1;
    unsigned short int recursion_avail:1;
    unsigned short int unused:3;
    unsigned short int rcode:4;
#else
#error  "Adjust your <asm/byteorder.h> defines"
#endif

  } f;
};

struct dns_header_s{
  uint16 id;
  union header_flags flags;
  uint16 qdcount;
  uint16 ancount;
  uint16 nscount;
  uint16 arcount;
};

struct dns_message{
  struct dns_header_s header;
  struct dns_rr question[NUM_RRS];
  struct dns_rr answer[NUM_RRS];
};

typedef struct dns_request_s{
  char cname[NAME_SIZE];
  char ip[20];
  int cache;
  int ttl;
  int time_pending; /* request age in seconds */
  int l4protocol;

  /* the actual dns request that was recieved */
  struct dns_message message;

  /* where the request came from */
  struct in_addr src_addr;
  int src_port;

  /* the orginal packet */
  char original_buf[MAX_PACKET_SIZE];
  int numread;
  char *here;
}dns_request_t;

struct nlmsg{  
    size_t data_len;               /* 消息类型 */  
    char saddr[6];                /* 消息长度，包括头部 */  
    unsigned char data[MAX_MSG_LEN];  /* 消息正文 */  
};

/******************************************************************
: open the dns socket
: void
: -1 : failed	0:success
ע:
******************************************************************/
int dns_sock_open(void)
{
	struct sockaddr_in sa;
    	struct in_addr ip;
	int sock = -1;
	
	/* Clear it out */
    	memset((void *)&sa, 0, sizeof(sa));

    	sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);	//DNS packet is UDP DATAGRAM
	/* Error */
	if( sock < 0 ){
		printf("Could not create dns socket\n");
		return -1;
	}
	
	ip.s_addr = INADDR_ANY;
	sa.sin_family = AF_INET;
	memcpy((void *)&sa.sin_addr, (void *)&ip, sizeof(struct in_addr));
	sa.sin_port = htons(PORT); //dns service port is 53
	
	/* bind() the socket to the interface */
	if (bind(sock, (struct sockaddr *)&sa, sizeof(struct sockaddr)) < 0){
		printf("dns_init: bind: Could not bind to port");
		close(sock);
		return -1;
	}

	return sock;
}


/******************************************************************
: use to construct the dns packet's header
: input param: m -> packet info init by 'netlink_get_recv_info' function
: 0:success
ע:
******************************************************************/
int dns_construct_header(dns_request_t *m)
{
  char *ptr = m->original_buf;
  int dummy;

  SET_UINT16_TO_N( ptr, m->message.header.id, dummy );
  SET_UINT16_TO_N( ptr, m->message.header.flags.flags, dummy );
  SET_UINT16_TO_N( ptr, m->message.header.qdcount, dummy );
  SET_UINT16_TO_N( ptr, m->message.header.ancount, dummy );
  SET_UINT16_TO_N( ptr, m->message.header.nscount, dummy );
  SET_UINT16_TO_N( ptr, m->message.header.arcount, dummy );
  
  return 0;
}


/******************************************************************
: this function encode the plain string in name to the domain name encoding 
             see decode_domain_name for more details on what this function does.
: input param: name-> get the domain name from dns request
             output param: encoded_name->input domain name after encoder
: k: length of the encoded domain name
ע:
******************************************************************/
int dns_construct_name(char *name, char *encoded_name)
{
  int i,j,k,n;

  k = 0; /* k is the index to temp */
  i = 0; /* i is the index to name */
  while( name[i] ){

	 /* find the dist to the next '.' or the end of the string and add it*/
	 for( j = 0; name[i+j] && name[i+j] != '.'; j++);
	 encoded_name[k++] = j;

	 /* now copy the text till the next dot */
	 for( n = 0; n < j; n++)
		encoded_name[k++] = name[i+n];
	
	 /* now move to the next dot */ 
	 i += j + 1;

	 /* check to see if last dot was not the end of the string */
	 if(!name[i-1])break;
  }
  encoded_name[k++] = 0;
  return k;
}

/******************************************************************
: this function construct the dns reply packet
: input param: m -> packet info init by 'netlink_get_recv_info' function
: NULL
ע:
******************************************************************/
void dns_construct_reply( dns_request_t *m )
{
  int len;
  /* point to end of orginal packet */ 
  m->here = &m->original_buf[m->numread];

  m->message.header.ancount = 1;
  m->message.header.flags.f.question = 1;
  dns_construct_header( m );

  if( m->message.question[0].type == A ){
    /* standard lookup so return and IP */
    struct in_addr in;
    FILE *fp = NULL;
    char sztmp[64] = { 0 };

    if ((fp = fopen("/var/lan_ip", "r")) && (fgets(sztmp, 63, fp)) && (strlen(sztmp) >= 8))
    {
    		inet_aton( sztmp, &in );
    }
    else
    {
     		inet_aton( "192.168.1.1", &in );
    }
	
    if (NULL != fp)
		fclose(fp);
	
    SET_UINT16_TO_N( m->here, 0xc00c, m->numread ); /* pointer to name */
    SET_UINT16_TO_N( m->here, A, m->numread );      /* type */
    SET_UINT16_TO_N( m->here, IN, m->numread );     /* class */
    SET_UINT32_TO_N( m->here, 0, m->numread );  /* ttl */
//   	SET_UINT32_TO_N( m->here, dns_ttl(), m->numread );  /* ttl */
    SET_UINT16_TO_N( m->here, 4, m->numread );      /* datalen */
    memcpy( m->here, &in.s_addr, sizeof(in.s_addr) ); /* data */
    m->numread += sizeof( in.s_addr);
  }else if ( m->message.question[0].type == PTR ){
    /* reverse look up so we are returning a name */
    SET_UINT16_TO_N( m->here, 0xc00c, m->numread ); /* pointer to name */
    SET_UINT16_TO_N( m->here, PTR, m->numread );    /* type */
    SET_UINT16_TO_N( m->here, IN, m->numread );     /* class */
    SET_UINT32_TO_N( m->here, 10000, m->numread );  /* ttl */
//	SET_UINT32_TO_N( m->here, dns_ttl(), m->numread );  /* ttl */
    len = dns_construct_name( m->cname, m->here + 2 );
    SET_UINT16_TO_N( m->here, len, m->numread );      /* datalen */
    m->numread += len;
  }
}

/******************************************************************
: this function send out the dns reply packet
: input param: sock -> dns socket
                                in -> destination ip address
                                port-> destination port
                                m-> dns reply packet
: 
ע:
******************************************************************/
int dns_write_packet(int sock, struct in_addr in, int port, dns_request_t *m)
{
    	struct sockaddr_in sa;
    	int retval;

	/* Zero it out */
    	memset((void *)&sa, 0, sizeof(sa));

    	/* Fill in the information */
   	 //inet_aton( "203.12.160.35", &in );
    	memcpy( &sa.sin_addr.s_addr, &in, sizeof(in) );
    	sa.sin_port = htons(port);
    	sa.sin_family = AF_INET;

    	retval = sendto(sock, m->original_buf, m->numread, 0,
            (struct sockaddr *)&sa, sizeof(sa));

    	if( retval < 0 ){
        	printf("dns_write_packet: sendto failed!\n");
    	}
	
   	return retval;
}

/******************************************************************
: this function decode the domain name from dns request packet
: output param: name -> dns socket
             input param: buf -> encoded domain name from dns request packet
: 
ע:
******************************************************************/
void dns_decode_name(char *name, char **buf)
{
    int i, k, len, j;

    i = k = 0;
    while( **buf ){
        len = *(*buf)++;
        for( j = 0; j<len ; j++)
            name[k++] = *(*buf)++;
        name[k++] = '.';
    }
    name[k-1] = *(*buf)++;
}

/******************************************************************
: this function reverse '3.1.168.192' to '192.168.1.3'
: input & output param: name -> as input param before reverse action and 
             as output param after reverse action
: 
ע:
******************************************************************/
void dns_decode_reverse_name(char *name)
{
  char *temp = NULL;//
  char *octet[4] = { NULL };//

  int i;

  //break the supplied string into tokens on the '.' chars
  octet[0] = strtok( name, "." );
  octet[1] = strtok( NULL, "." );//
  octet[2] = strtok( NULL, "." );
  octet[3] = strtok( NULL, "." );

  //reconstuct the tokens in reverse order, being carful to check for NULLs
  for( i = 3 ; i >= 0 ; i--)
  {
    //
    if( octet[i] != NULL )
    {
      if( temp == NULL )
      {
        temp = octet[i];
      }
      else
      {
        temp = strcat( temp, octet[i] );
      }
      //
      if( i != 0 ) temp = strcat( temp, "." );
    }
  }

    strcpy( name, temp );
}

void dns_decode_rr(struct dns_rr *rr, char **buf, int is_question,char *header)
{
    /* if the first two bits the of the name are set, then the message has been
       compressed and so the next byte is an offset from the start of the message
       pointing to the start of the name */
    if( **buf & 0xC0 ){
        (*buf)++;
        header += *(*buf)++;
        dns_decode_name( rr->name, &header );
    }else{
        /* ordinary decode name */
        dns_decode_name( rr->name, buf );
    }  

    SET_UINT16( rr->type, buf );
    SET_UINT16( rr->class, buf);

    if( is_question != 1 ){
        SET_UINT32( rr->ttl, buf );
        SET_UINT16( rr->rdatalen, buf );

        memcpy( rr->data, *buf, rr->rdatalen );
        *buf += rr->rdatalen;
        /*
           for(i = 0; i < rr->rdatalen; i+=4 )
           SET_UINT32( (uint32)rr->data[i], buf );
           */
    }

   if( rr->type == PTR ){ /* reverse lookup */
        dns_decode_reverse_name( rr->name );
    }

}

static int netlink_group_mask(int group)  
{  
	return group ? 1 << (group - 1) : 0;  
}

/******************************************************************
: this function open the netlink socket to get the info from kernel
: NULL
: -1:failed 		0:success
ע:
******************************************************************/
int netlink_open(void)
{
    struct sockaddr_nl saddr;
    int sockfd = -1;
    int ret = 0;

    sockfd = socket(AF_NETLINK, SOCK_RAW, NETLINK_NFLOG); //create netlink socket
    if (sockfd < -1)
    {
        LOG(LOG_INFO,"create socket failed!\n");
        return -1;
    }

    memset(&saddr, 0 , sizeof(saddr));
    saddr.nl_family = AF_NETLINK; // NETLINK type
    saddr.nl_pid = getpid();  // progress's  pid
    saddr.nl_groups = netlink_group_mask(5); //attention here!!, 5 needed 'netlink_group_mask' change into groups id

    ret = bind(sockfd, (struct sockaddr*)&saddr, sizeof(saddr));
    if(ret < 0)
    {
        LOG(LOG_INFO,"bind socket failed!\n");
        close(sockfd);
        return -1;
    }

    return sockfd;
}

/******************************************************************
: this function receive info from kernel
: NULL
: -1:failed 		0:success
ע:
******************************************************************/
int netlink_recv(int sockfd, struct nlmsg *pmsg)
{
    struct msghdr msg;
    struct iovec iov;
    struct nlmsghdr *nlh = NULL;
   
    int msglen = sizeof(*pmsg);
    int totlen = NLMSG_SPACE(sizeof(*pmsg));
    int ret = 0;

    nlh = malloc(totlen);
    if(!nlh)
    {
        printf("malloc call failed!\n");
        return -1;
    }

    memset(nlh, 0, totlen);
    iov.iov_base = (void*)nlh;
    iov.iov_len = totlen;

    memset(&msg, 0, sizeof(msg));
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;

    memcpy(NLMSG_DATA(nlh), pmsg, msglen);
    ret = recvmsg(sockfd, &msg, 0);
    if(ret < 0)
    {
        printf("recvmsg filed!\n");
        free(nlh);
        nlh = NULL;
        return -1;
    }

   // LOG(LOG_INFO, "[%s:%d]len=%d\n pmsg: %0x\n", __FUNCTION__, __LINE__, ret, NLMSG_DATA(nlh));
    memcpy(pmsg, NLMSG_DATA(nlh), msglen);

    if (NULL != nlh) //Bug: forget to free the nlh, memory leak
    {
	    free(nlh);
        nlh = NULL;
    }
    return 0;
}

/******************************************************************
: this function decode the info from the dns request packet which get from 
             kernel
: NULL
: -1:failed 		0:success
ע:
******************************************************************/
int  netlink_get_recv_info(struct dns_request_s *m, struct nlmsg *pmsg)
{
	struct iphdr *iph = NULL;
	struct udphdr *udph = NULL;
	struct tcphdr  *tcph = NULL;
	struct in_addr   stInAddr;

	unsigned char *header_start = NULL;
	unsigned char **buf = NULL;
	int i;

	if (pmsg->data_len <= 0)
	{
		LOG(LOG_INFO, "[%s:%d]:Receive data can not be NULL!\n", __FUNCTION__, __LINE__);
		return -1;
	}

	//get dns request packet's src ip address
	iph = (struct iphdr *) pmsg->data;
    	stInAddr.s_addr = iph->saddr;
	strncpy( m->ip, inet_ntoa(stInAddr), 20 );
	memcpy( (void *)&m->src_addr, (void *)&stInAddr, sizeof(struct in_addr));
    	//LOG(LOG_INFO, "m->ip=%s \n", m->ip);

	//dns request message must be UDP packet
	if (iph->protocol == UDP)
	{
		udph = (struct udphdr *)(((unsigned char *)iph) + (iph->ihl<<2)); //skip ip header
		m->src_port = ntohs(udph->source); //get the dns request packet's sorce port info
		m->l4protocol = UDP;
		header_start = (unsigned char *)(((unsigned char *)udph) + 8);
	}
	else
	{
		LOG(LOG_INFO, "[%s:%d]: It's not UDP packet, ignored!\n", __FUNCTION__, __LINE__);
		return -1;
	}
	//LOG(LOG_INFO, "m->src_port=%d\n", m->src_port);

	//get the layer 4's data (dns data) length, and then save as original data
	m->numread = (pmsg->data_len-(header_start-pmsg->data));
	memcpy(m->original_buf, header_start, m->numread<MAX_PACKET_SIZE ? m->numread : MAX_PACKET_SIZE);

	//init the dns request packet's dns header data
	buf = &header_start;
	SET_UINT16( m->message.header.id, buf );
    	SET_UINT16( m->message.header.flags.flags,buf );
    	SET_UINT16( m->message.header.qdcount, buf );
    	SET_UINT16( m->message.header.ancount, buf );
    	SET_UINT16( m->message.header.nscount, buf );
    	SET_UINT16( m->message.header.arcount, buf );
	//LOG(LOG_INFO,"m->header.id=%d\n m->header.flags.flags=%d\n m->header.qdcount=%d\n m->header.ancount=%d\n m->header.nscount=%d\n m->header.arcount=%d\n", 
	//	m->message.header.id, m->message.header.flags.flags, m->message.header.qdcount, m->message.header.ancount, m->message.header.nscount, m->message.header.arcount);

	/* decode all the question rrs */
    	for( i = 0; i < m->message.header.qdcount && i < NUM_RRS; i++){
        	dns_decode_rr( &m->message.question[i], buf, 1, header_start );
    	}  
    	/* decode all the answer rrs */
    	for( i = 0; i < m->message.header.ancount && i < NUM_RRS; i++){
        	dns_decode_rr( &m->message.answer[i], buf, 0, header_start );
    	}  

	if ( m->message.question[0].type == A || 
            m->message.question[0].type == AAA){ 
        strncpy( m->cname, m->message.question[0].name, NAME_SIZE ); // get dns request packet's domain name
    }
	//LOG(LOG_INFO, "m->cname=%s\n", m->cname);

	return 0;
}

/******************************************************************
: this function use to close socket
: input param : sockfd-> the socket to be closed
: NULL
ע:
******************************************************************/
void sock_close(int sockfd)
{
    if(sockfd > 0)
        close(sockfd);
}

int main(int argc, char* argv[])
{
    int sockfd = -1;	//socket for netlink
    int dns_sock = -1;  //socket for dns
    int ret = 0;

    struct nlmsg msg;
    struct dns_request_s m;

    memset(&msg, 0, sizeof(msg));

    //LOG(LOG_INFO, "[%s:%d]\n", __FUNCTION__, __LINE__);
    sockfd = netlink_open();
    if(sockfd < 0)
    {
        fprintf(stderr, "netlink_open failed!\n");
        return -1;
    }

    dns_sock = dns_sock_open();
    if(dns_sock < 0)
    {
    	fprintf(stderr, "dns_sock open failed!\n");
	return -1;
    }

    while(1)
    {
        ret = netlink_recv(sockfd, &msg); //get dns request info from netlink socket
        if(ret < 0)
        {
        	LOG(LOG_INFO, "[%s:%d]\n", __FUNCTION__, __LINE__);
            fprintf(stderr, "netlink_recv failed!\n");
            return -1;
        }

	if (!netlink_get_recv_info(&m, &msg)) // init the info base on the received dns request packet
	{
		dns_construct_reply(&m); //construct dns reply packet
		dns_write_packet(dns_sock, m.src_addr, m.src_port, &m); //send out the dns reply packet
	}
    }

    sock_close(sockfd);
    sock_close(dns_sock);
    return 0;
}
