 /****************************************************************************
 *
 * Copyright (c) 2020 Broadcom. All rights reserved
 * The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
 *
 * Unless you and Broadcom execute a separate written software license
 * agreement governing use of this software, this software is licensed to
 * you under the terms of the GNU General Public License version 2 (the
 * "GPL"), available at [http://www.broadcom.com/licenses/GPLv2.php], with
 * the following added to such license:
 *
 * As a special exception, the copyright holders of this software give you
 * permission to link this software with independent modules, and to copy
 * and distribute the resulting executable under terms of your choice,
 * provided that you also meet, for each linked independent module, the
 * terms and conditions of the license of that module. An independent
 * module is a module which is not derived from this software. The special
 * exception does not apply to any modifications of the software.
 *
 * Notwithstanding the above, under no circumstances may you combine this
 * software in any way with any other Broadcom software provided under a
 * license other than the GPL, without Broadcom's express prior written
 * consent.
 *
 ****************************************************************************
 * Author: Jayesh Patel <jayeshp@broadcom.com>
 ****************************************************************************/

#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/kthread.h>
#include <linux/if_arp.h>
#include <linux/rtnetlink.h>
#include <linux/etherdevice.h>
#include <net/netlink.h>
#include <net/switchdev.h>
#include "proc_cmd.h"
#include "flowmgr.h"

#ifndef NDA_RTA
#define NDA_RTA(r) \
	((struct rtattr *)(((char *)(r)) + NLMSG_ALIGN(sizeof(struct ndmsg))))
#endif
#ifndef NDA_PAYLOAD
#define NDA_PAYLOAD(n)	NLMSG_PAYLOAD(n, sizeof(struct ndmsg))
#endif

struct task_struct *kflowmgrd_task;

struct flowmgr_br_fdb_entry
{
	struct net_device	*dev;
	struct net_device	*br;
	unsigned char		addr[ETH_ALEN];
	__u16			vlan_id;
	unsigned char		is_local:1,
				is_static:1,
				added_by_user:1,
				added_by_external_learn:1;
};


static inline __u16 rta_getattr_u16(const struct rtattr *rta)
{
	return *(__u16 *)RTA_DATA(rta);
}

static inline __u32 rta_getattr_u32(const struct rtattr *rta)
{
	return *(__u32 *)RTA_DATA(rta);
}

static void parse_rtattr(struct rtattr *tb[], int max, struct rtattr *rta, int len)
{
	memset(tb, 0, (max + 1) * sizeof(tb[0]));

	while (RTA_OK(rta, len)) {
		if (rta->rta_type <= max) {
			tb[rta->rta_type] = rta;
		}
		rta = RTA_NEXT(rta, len);
	}
	if (len) {
		pr_err("deficit %d, rta_len=%d!", len, rta->rta_len);
	}
}

static void fdb_print_flags(struct seq_file *s, unsigned int flags)
{
	if (flags & NTF_SELF)
		pr_seq(s, "%s ", "self");

	if (flags & NTF_ROUTER)
		pr_seq(s, "%s ", "router");

	if (flags & NTF_EXT_LEARNED)
		pr_seq(s, "%s ", "extern_learn");
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0))
	if (flags & NTF_OFFLOADED)
		pr_seq(s, "%s ", "offload");
#endif
	if (flags & NTF_MASTER)
		pr_seq(s, "%s ", "master");
}

static const char *fdb_get_state_str(unsigned int s,  struct flowmgr_br_fdb_entry *fdb)
{
	static char buf[32];

	if (s & NUD_PERMANENT) {
		fdb->is_local = 1;
		return "permanent";
	}

	if (s & NUD_NOARP) {
		fdb->is_static = 1;
		return "static";
	}

	if (s & NUD_STALE) {
		return "stale";
	}

	if (s & NUD_REACHABLE)
		return "";

	sprintf(buf, "state=%#x", s);
	return buf;
}

static int rtnl_parse_fdb(struct seq_file *s, struct nlmsghdr *n, struct flowmgr_br_fdb_entry *fdb)
{
	struct ndmsg *r = nlmsg_data(n);
	int len = n->nlmsg_len;
	struct rtattr *tb[NDA_MAX+1];
	__u16 vid = 0;

	if (n->nlmsg_type != RTM_NEWNEIGH && n->nlmsg_type != RTM_DELNEIGH) {
		pr_err("Not RTM_NEWNEIGH: %08x %08x %08x\n",
		       n->nlmsg_len, n->nlmsg_type, n->nlmsg_flags);
		return 0;
	}

	len -= NLMSG_LENGTH(sizeof(*r));
	if (len < 0) {
		pr_err("BUG: wrong nlmsg len %d\n", len);
		return -1;
	}

	if (r->ndm_family != AF_BRIDGE)
		return 0;

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 0))
	if (!(r->ndm_flags & NTF_OFFLOADED))
		return 0;
#endif

	parse_rtattr(tb, NDA_MAX, NDA_RTA(r),
		     n->nlmsg_len - NLMSG_LENGTH(sizeof(*r)));

	if (tb[NDA_VLAN])
		vid = rta_getattr_u16(tb[NDA_VLAN]);

	if (n->nlmsg_type == RTM_DELNEIGH)
		pr_seq(s, "Deleted ");

	if (tb[NDA_LLADDR] && (RTA_PAYLOAD(tb[NDA_LLADDR]) == ETH_ALEN)) {
		const char *lladdr;

		lladdr = RTA_DATA(tb[NDA_LLADDR]);
		memcpy(fdb->addr, lladdr, ETH_ALEN);
		pr_seq(s, "%pM ", fdb->addr);
	}

	fdb->dev = __dev_get_by_index(&init_net, r->ndm_ifindex);
	pr_seq(s, "dev %s ", fdb->dev->name);

	if (vid) {
		fdb->vlan_id = vid;
		pr_seq(s, "vlan %hu ", vid);
	}

	fdb_print_flags(s, r->ndm_flags);

	if (tb[NDA_MASTER]) {
		fdb->br = __dev_get_by_index(&init_net,
			rta_getattr_u32(tb[NDA_MASTER]));
		pr_seq(s, "master %s ",
		       fdb->br->name);
	}

	pr_seq(s, "%s\n", fdb_get_state_str(r->ndm_state, fdb));
	return 0;
}

static struct socket *rtnl_open(void)
{
	struct socket *socket = NULL;
	struct sockaddr_nl nladdr = {};

	/* create socket */
	if (sock_create(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE, &socket)) {
		pr_err("%s: Failed to create socket.\n", __func__);
		return NULL;
	}

	nladdr.nl_family = AF_NETLINK;

	/* bind to incomming port */
	if (socket->ops->bind(socket, (struct sockaddr *)&nladdr,
			      sizeof(nladdr))) {
		pr_err("%s: Failed to bind socket to port %d.\n",
		       __func__, nladdr.nl_groups);
		sock_release(socket);
		return NULL;
	}

	/* check sk */
	if (socket->sk == NULL) {
		pr_err("%s: socket->sk == NULL\n", __func__);
		sock_release(socket);
		return NULL;
	}
	return socket;
}

static void rtnl_close(struct socket *socket)
{
	if (socket)
		sock_release(socket);
}

static int rtnl_send(struct socket *socket)
{
	int ret = 0;
	struct ifinfomsg ifm = {};
	struct msghdr msg = {};
	struct  nlmsghdr nlh;
	struct kvec iov[2] = { { &nlh, sizeof(nlh) }, {&ifm, sizeof(ifm)} };
	int iov_size;

	ifm.ifi_family = AF_BRIDGE;
	iov_size = sizeof(nlh) + sizeof(ifm);

	nlh.nlmsg_len = NLMSG_LENGTH(sizeof(ifm));
	nlh.nlmsg_type = RTM_GETNEIGH;
	nlh.nlmsg_flags = NLM_F_DUMP|NLM_F_REQUEST;
	nlh.nlmsg_seq = 1;

	ret = kernel_sendmsg(socket, &msg, iov, 2, iov_size);

	return ret;
}

static int rtnl_dump_done(struct nlmsghdr *h)
{
	int len = *(int *)nlmsg_data(h);

	if (h->nlmsg_len < NLMSG_LENGTH(sizeof(int))) {
		pr_err("DONE truncated\n");
		return -1;
	}

	if (len < 0) {
		int errno = -len;
		switch (errno) {
		case ENOENT:
		case EOPNOTSUPP:
			return -1;
		case EMSGSIZE:
			pr_err("Error: Buffer too small for object.\n");
			break;
		default:
			pr_err("RTNETLINK answers");
		}
		return len;
	}

	return 0;
}

static int rtnl_dump_error(struct nlmsghdr *h)
{

	if (h->nlmsg_len < NLMSG_LENGTH(sizeof(struct nlmsgerr))) {
		pr_err("ERROR truncated\n");
		return -1;
	} else {
		const struct nlmsgerr *err = (struct nlmsgerr *)nlmsg_data(h);
		pr_err("ERROR %d\n", -err->error);
		return -err->error;
	}
}

static int rtnl_recv(struct seq_file *s, struct socket *socket)
{
	int ret = 0;
	struct msghdr msg = {};
	unsigned char *recvbuf;
	size_t recvbuf_size = 4096;
	int recvlen;
	struct kvec iov;
	struct flowmgr_br_fdb_entry fdb;

	/* allocate buffer memory */
	recvbuf = kmalloc(recvbuf_size, GFP_KERNEL);
	if (!recvbuf) {
		pr_err("%s: Failed to alloc recvbuf.\n", __func__);
		ret = -ENOMEM;
		goto fail;
	}

	iov.iov_base = recvbuf;
	iov.iov_len = recvbuf_size;

	recvlen = kernel_recvmsg(socket, &msg, &iov, 1,
				 recvbuf_size, 0);
	if (recvlen > 0) {
		struct nlmsghdr *h = (struct nlmsghdr *)recvbuf;
		int found_done = 0;

		while (NLMSG_OK(h, recvlen)) {
			if (h->nlmsg_type == NLMSG_DONE) {
				ret = rtnl_dump_done(h);
				if (ret < 0) {
					goto fail;
				}

				found_done = 1;
				break;
			}

			if (h->nlmsg_type == NLMSG_ERROR) {
				ret = rtnl_dump_error(h);
				goto fail;
			}

			rtnl_parse_fdb(s, h, &fdb);
			h = NLMSG_NEXT(h, recvlen);
		}
	}
fail:
	/* free recvbuf */
	kfree(recvbuf);

	return ret;
}

int flowmgr_arl_show(struct seq_file *s)
{
	struct socket *socket;
	socket = rtnl_open();

	if (!socket)
		return 0;

	if (rtnl_send(socket) < 0)
		goto done;
	if (rtnl_recv(s, socket))
		goto done;
done:
	rtnl_close(socket);
	return 0;
}

#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 0))
static int
netlink_monitor(void *ptr)
{
	int ret = 0;
	struct msghdr msg;
	unsigned char *recvbuf;
	size_t recvbuf_size = 4096;
	int recvlen;
	struct socket *socket = NULL;
	struct sockaddr_nl addr;
	DECLARE_COMPLETION_ONSTACK(wait);

	/* allocate buffer memory */
	recvbuf = kmalloc(recvbuf_size, GFP_KERNEL);
	if (!recvbuf) {
		pr_err("%s: Failed to alloc recvbuf.\n", __func__);
		ret = -ENOMEM;
		goto fail;
	}

	/* make daemon */
	allow_signal(SIGTERM);

	/* create socket */
	if (sock_create(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE, &socket)) {
		pr_err("%s: Failed to create socket.\n", __func__);
		ret = -EIO;
		goto fail;
	}

	/* Zeroing addr */
	memset(&addr, 0, sizeof(addr));

	addr.nl_family = AF_NETLINK;
	addr.nl_groups = 1<<(RTNLGRP_NEIGH-1);

	/* bind to incomming port */
	if (socket->ops->bind(socket, (struct sockaddr *)&addr,
			      sizeof(addr))) {
		pr_err("%s: Failed to bind socket to port %d.\n",
		       __func__, addr.nl_groups);
		ret = -EINVAL;
		goto fail;
	}

	/* check sk */
	if (socket->sk == NULL) {
		pr_err("%s: socket->sk == NULL\n", __func__);
		ret = -EIO;
		goto fail;
	}

	/* build receive message */
	msg.msg_name = NULL;
	msg.msg_namelen = 0;
	msg.msg_control = NULL;
	msg.msg_controllen = 0;

	/* read loop */
	while (!signal_pending(current) && !kthread_should_stop()) {
		struct kvec iov = {
			.iov_base = recvbuf,
			.iov_len = recvbuf_size,
		};
		struct  nlmsghdr *nlh;
		struct ndmsg *ndm;
		struct nlattr *nla;
		unsigned char *lladdr;
		struct net_device *dev;

		recvlen = kernel_recvmsg(socket, &msg, &iov, 1,
					 recvbuf_size, 0);
		if (recvlen > 0) {
			/* cast the received buffer */
			nlh = (struct nlmsghdr *)recvbuf;
			ndm = nlmsg_data(nlh);
			nla = (struct nlattr *)(ndm + 1);
			lladdr = nla_data(nla);
			dev = __dev_get_by_index(&init_net, ndm->ndm_ifindex);

			if (!dev)
				continue;

			if (netif_is_bridge_master(dev))
				continue;

			if (is_multicast_ether_addr(lladdr))
				continue;

			if (ndm->ndm_family != AF_BRIDGE)
				continue;

			/* We are just intrested in Neigh information */
			if (nlh->nlmsg_type == RTM_NEWNEIGH) {
				pr_debug("Bridge New Neigh: %s %pM %d\n",
					 dev->name,
					 lladdr,
					 ndm->ndm_state);
				/* Call Switch Dev API  */
				rtnl_lock();
				switchdev_port_fdb_add(NULL, NULL, dev, lladdr, 0, 0);
				rtnl_unlock();
			} else if (nlh->nlmsg_type == RTM_DELNEIGH) {
				pr_debug("Bridge Del Neigh: %s %pM %d\n",
					 dev->name,
					 lladdr,
					 ndm->ndm_state);
				/* Call Switch Dev API  */
				rtnl_lock();
				switchdev_port_fdb_del(NULL, NULL, dev, lladdr, 0);
				rtnl_unlock();
			}
		} else {
		}
	}

fail:
	/* free recvbuf */
	kfree(recvbuf);

	/* close socket */
	if (socket)
		sock_release(socket);

	return ret;
}

static inline
int kflowmgr_arl_task(void *ptr)
{
	netlink_monitor(ptr);
	return 0;
}
#endif
static void *flowmgr_arl_seq_start(struct seq_file *seq, loff_t *pos)
{
	if (!*pos)
		return SEQ_START_TOKEN;
	return 0;
}

static void *flowmgr_arl_seq_next(struct seq_file *seq, void *v,
				      loff_t *pos)
{
	(*pos)++;
	return 0;
}

static void flowmgr_arl_seq_stop(struct seq_file *seq, void *v)
{
}

static int flowmgr_arl_seq_show(struct seq_file *seq, void *v)
{
	if (!v)
		return -1;
	flowmgr_arl_show(seq);

	return 0;

}

static const struct seq_operations flowmgr_arl_seq_ops = {
	.start	= flowmgr_arl_seq_start,
	.next	= flowmgr_arl_seq_next,
	.stop	= flowmgr_arl_seq_stop,
	.show	= flowmgr_arl_seq_show,
};

static int flowmgr_arl_seq_open(struct inode *inode, struct file *file)
{
	int ret = seq_open(file, &flowmgr_arl_seq_ops);
	return ret;
};

static const struct proc_ops flowmgr_proc_arl_fops = {
	.proc_open     = flowmgr_arl_seq_open,
	.proc_read     = seq_read,
	.proc_lseek    = seq_lseek,
	.proc_release  = seq_release,
};

#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 0))
static struct task_struct *kflowmgr_arl;
#endif

int flowmgr_arl_init(void)
{
#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 0))
	kflowmgr_arl = kthread_create(kflowmgr_arl_task, NULL, "%s", "kflowmgr_arl");
	if (kflowmgr_arl)
		wake_up_process(kflowmgr_arl);
#endif
	proc_create_data("arl", S_IRUGO, flowmgr.proc_dir,
			 &flowmgr_proc_arl_fops, NULL);
	return 0;
}

void flowmgr_arl_exit(void)
{
#if (LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 0))
	kthread_stop(kflowmgr_arl);
#endif
	remove_proc_entry("arl", flowmgr.proc_dir);
}
