/*
 * Name: Simple Example of log for bad packets
 * Date: Fri Feb 18 21:42:57 2000
 * Author: pIGpEN [pigpen@s0ftpj.org, deadhead@sikurezza.org]
 *
 * SoftProject Digital Security for Y2K (www.s0ftpj.org)
 * Sikurezza.org Italian Security MailingList (www.sikurezza.org)
 * 
 * COFFEE-WARE LICENSE - This source code is like "THE BEER-WARE LICENSE" by
 * Poul-Henning Kamp <phk@FreeBSD.ORG> but you can give me in return a coffee.
 * 
 * Tested on: FreeBSD 4.0-19990705-CURRENT FreeBSD 4.0-19990705-CURRENT #6 i386
 */

/*
 * *stat structures show you general information... this module gives also  
 * source address for variables of udpstat used in udp_input()... It is only
 * an example... 
 *
 * Feb 10 10:11:13 storpio /kernel UDP: udps_badsum 4 from xxx.xxx.xxx.xxx
 *
 * We know source of a packet with badsum and its number in
 * udpstat.udps_badsum...
 *
 * A cool idea can be log bad packets (at least ip header of these)
 * via pseudo device (a bit like ipl) or via procfs...
 */
 
/*
 * Use a Makefile for kld....
 */
 
#include <sys/param.h>
#include <sys/systm.h>
#include <sys/malloc.h>
#include <sys/mbuf.h>
#include <sys/kernel.h>
#include <sys/proc.h>
#include <sys/socket.h>
#include <sys/socketvar.h>
#include <sys/sysctl.h>
#include <sys/syslog.h>
#include <sys/protosw.h>
#include <net/if.h>
#include <net/route.h>
#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/in_pcb.h>
#include <netinet/ip.h>
#include <netinet/ip_var.h>
#include <netinet/ip_icmp.h>
#include <netinet/udp.h>
#include <netinet/udp_var.h>


extern struct 	protosw 	inetsw[];
extern struct 	udpstat 	udpstat;
extern int  	log_in_vain;
extern struct 	inpcbhead 	udb;
extern struct 	sockaddr_in 	udp_in;

static void  	new_udp_input		__P((register struct mbuf *, int ));
static int  	s_load	 		__P((struct module *, int, void *));
int		badport_bandlim		__P((int));

/*
 * saving udp_input() with a funct ptr is not necessary... because it isn't 
 * static so we can access it...
 */

static int
s_load (struct module *module, int cmd, void *arg)
{
 int s;

 switch(cmd) {
	case MOD_LOAD:
			s = splnet();
			inetsw[ip_protox[IPPROTO_UDP]].pr_input = new_udp_input;
		        splx(s);
			break;
			
	case MOD_UNLOAD:
			s = splnet();
			inetsw[ip_protox[IPPROTO_UDP]].pr_input = udp_input;
			splx(s);
			break;
 }

 return 0;
}

static moduledata_t s_mod_1 = {
	        "udp_mod",
	        s_load,
	        0
};

DECLARE_MODULE(udp_mod, s_mod_1, SI_SUB_PSEUDO, SI_ORDER_ANY);

static void 
new_udp_input(m, iphlen)
	register struct mbuf *m;
	int iphlen;
{
	register struct ip *ip;
	register struct udphdr *uh;
	register struct inpcb *inp;
	struct mbuf *opts = 0;
	int len;
	struct ip save_ip;

	udpstat.udps_ipackets++;

	if (iphlen > sizeof (struct ip)) {
		ip_stripoptions(m, (struct mbuf *)0);
		iphlen = sizeof(struct ip);
	}

	ip = mtod(m, struct ip *);
	if (m->m_len < iphlen + sizeof(struct udphdr)) {
		if ((m = m_pullup(m, iphlen + sizeof(struct udphdr))) == 0) {
			udpstat.udps_hdrops++;
			log(LOG_INFO, 
			   "UDP: udps_hdrops %ld from %s\n",
			    udpstat.udps_hdrops,
			    inet_ntoa(ip->ip_src));
			return;
		}
		ip = mtod(m, struct ip *);
	}
	uh = (struct udphdr *)((caddr_t)ip + iphlen);

	len = ntohs((u_short)uh->uh_ulen);
	if (ip->ip_len != len) {
		if (len > ip->ip_len || len < sizeof(struct udphdr)) {
			udpstat.udps_badlen++;
			log(LOG_INFO,
			    "UDP: udps_badlen %ld from %s\n",
			    udpstat.udps_badlen,
			    inet_ntoa(ip->ip_src));
			goto bad;
		}
		m_adj(m, len - ip->ip_len);
		/* ip->ip_len = len; */
	}
	save_ip = *ip;

	if (uh->uh_sum) {
		bzero(((struct ipovly *)ip)->ih_x1, 9);
		((struct ipovly *)ip)->ih_len = uh->uh_ulen;
		uh->uh_sum = in_cksum(m, len + sizeof (struct ip));
		if (uh->uh_sum) {
			udpstat.udps_badsum++;
			log(LOG_INFO,
			    "UDP: udps_badsum %ld from %s\n",
			    udpstat.udps_badsum,
			    inet_ntoa(ip->ip_src));
			m_freem(m);
			return;
		}
	}

	if (IN_MULTICAST(ntohl(ip->ip_dst.s_addr)) ||
	    in_broadcast(ip->ip_dst, m->m_pkthdr.rcvif)) {
		struct inpcb *last;
		
		udp_in.sin_port = uh->uh_sport;
		udp_in.sin_addr = ip->ip_src;
		m->m_len -= sizeof (struct udpiphdr);
		m->m_data += sizeof (struct udpiphdr);

		last = NULL;
		for (inp = udb.lh_first; inp != NULL; 
				inp = inp->inp_list.le_next) {
			if (inp->inp_lport != uh->uh_dport)
				continue;
			if (inp->inp_laddr.s_addr != INADDR_ANY) {
				if (inp->inp_laddr.s_addr !=
				    ip->ip_dst.s_addr)
					continue;
			}
			if (inp->inp_faddr.s_addr != INADDR_ANY) {
				if (inp->inp_faddr.s_addr !=
				    ip->ip_src.s_addr ||
				    inp->inp_fport != uh->uh_sport)
					continue;
			}

			if (last != NULL) {
				struct mbuf *n;

				if ((n = m_copy(m, 0, M_COPYALL)) != NULL) {
					if (last->inp_flags & INP_CONTROLOPTS
					    || last->inp_socket->so_options & 
					    SO_TIMESTAMP)
						ip_savecontrol(last, &opts, ip, n);
					if (sbappendaddr(&last->inp_socket->so_rcv,
						(struct sockaddr *)&udp_in,
						n, opts) == 0) {
						m_freem(n);
						if (opts)
						    m_freem(opts);
						udpstat.udps_fullsock++;
					} else
						sorwakeup(last->inp_socket);
					opts = 0;
				}
			}
			last = inp;
			if ((last->inp_socket->so_options&(SO_REUSEPORT |
							SO_REUSEADDR)) == 0)
				break;
		}

		if (last == NULL) {
			udpstat.udps_noportbcast++;
			goto bad;
		}
		if (last->inp_flags & INP_CONTROLOPTS
		    || last->inp_socket->so_options & SO_TIMESTAMP)
			ip_savecontrol(last, &opts, ip, m);
		if (sbappendaddr(&last->inp_socket->so_rcv,
		     (struct sockaddr *)&udp_in,
		     m, opts) == 0) {
			udpstat.udps_fullsock++;
			goto bad;
		}
		sorwakeup(last->inp_socket);
		return;
	}
	inp = in_pcblookup_hash(&udbinfo, ip->ip_src, uh->uh_sport,
	    ip->ip_dst, uh->uh_dport, 1);
	if (inp == NULL) {
		if (log_in_vain) {
			char buf[4*sizeof "123"];

			strcpy(buf, inet_ntoa(ip->ip_dst));
			log(LOG_INFO,
			    "Connection attempt to UDP %s:%d from %s:%d\n",
			    buf, ntohs(uh->uh_dport), inet_ntoa(ip->ip_src),
			    ntohs(uh->uh_sport));
		}
		udpstat.udps_noport++;
		if (m->m_flags & (M_BCAST | M_MCAST)) {
			udpstat.udps_noportbcast++;
			goto bad;
		}
		*ip = save_ip;

		if (badport_bandlim(0) < 0)
			goto bad;

		icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_PORT, 0, 0);
		return;
	}

	udp_in.sin_port = uh->uh_sport;
	udp_in.sin_addr = ip->ip_src;
	if (inp->inp_flags & INP_CONTROLOPTS
	    || inp->inp_socket->so_options & SO_TIMESTAMP)
		ip_savecontrol(inp, &opts, ip, m);
	iphlen += sizeof(struct udphdr);
	m->m_len -= iphlen;
	m->m_pkthdr.len -= iphlen;
	m->m_data += iphlen;
	if (sbappendaddr(&inp->inp_socket->so_rcv, (struct sockaddr *)&udp_in,
	    m, opts) == 0) {
		udpstat.udps_fullsock++;
		goto bad;
	}
	sorwakeup(inp->inp_socket);
	return;
bad:
	m_freem(m);
	if (opts)
		m_freem(opts);
}
