Add support for multiple IP/port pairs in swift_sockaddr. Also add the user space...
[swifty.git] / src / kernel / swift.c
index c54567f..b21c106 100644 (file)
@@ -1,4 +1,5 @@
 #include <linux/module.h>
+#include <linux/version.h>
 #include <net/sock.h>
 #include <net/protocol.h>
 #include <net/ip.h>
@@ -14,8 +15,8 @@ MODULE_LICENSE("GPL");
 struct swift_sock {
        struct inet_sock sock;
        /* swift socket speciffic data */
-       __be16 src;
-       __be16 dst;
+       uint8_t src;
+       uint8_t dst;
 };
 
 static struct swift_sock * sock_port_map[MAX_SWIFT_PORT];
@@ -30,7 +31,7 @@ static inline struct swifthdr * swift_hdr(const struct sk_buff * skb)
        return (struct swifthdr *) skb_transport_header(skb);
 }
 
-static inline __be16 get_next_free_port(void)
+static inline uint8_t get_next_free_port(void)
 {
        int i;
        for (i = MIN_SWIFT_PORT; i < MAX_SWIFT_PORT; i ++)
@@ -39,17 +40,17 @@ static inline __be16 get_next_free_port(void)
        return 0;
 }
 
-static inline void swift_unhash(__be16 port)
+static inline void swift_unhash(uint8_t port)
 {
        sock_port_map[port] = NULL;
 }
 
-static inline void swift_hash(__be16 port, struct swift_sock *ssh)
+static inline void swift_hash(uint8_t port, struct swift_sock *ssh)
 {
        sock_port_map[port] = ssh;
 }
 
-static inline struct swift_sock * swift_lookup(__be16 port)
+static inline struct swift_sock * swift_lookup(uint8_t port)
 {
        return sock_port_map[port];
 }
@@ -84,7 +85,7 @@ static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
        struct sockaddr_swift *swift_addr;
        struct swift_sock *ssk;
        int err;
-       __be16 port;
+       uint8_t port;
 
        if (unlikely(addr_len < sizeof(struct sockaddr_swift))) {
                log_error("Invalid size for sockaddr\n");
@@ -94,13 +95,7 @@ static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 
        swift_addr = (struct sockaddr_swift *) addr;
 
-       if (unlikely(swift_addr->sin_family != AF_INET)) {
-               log_error("Invalid family for sockaddr\n");
-               err = -EINVAL;
-               goto out;
-       }
-
-       port = ntohs(swift_addr->sin_port);
+       port = swift_addr->dests[0].port;
 
        if (unlikely(port == 0 || port >= MAX_SWIFT_PORT)) {
                log_error("Invalid value for sockaddr port (%u)\n", port);
@@ -161,19 +156,21 @@ static int swift_connect(struct socket *sock, struct sockaddr *addr, int addr_le
        if (likely(addr)) {
                struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) addr;
                
-               if (unlikely(addr_len < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET)) {
+        if (unlikely(addr_len < sizeof(*swift_addr) || 
+                     addr_len < swift_addr->count * sizeof(struct swift_dest) || 
+                     swift_addr->count <= 0)) {
                        log_error("Invalid size or address family\n");
                        err = -EINVAL;
                        goto out;
                }
-               ssk->dst = ntohs(swift_addr->sin_port);
+               ssk->dst = swift_addr->dests[0].port;
                if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_SWIFT_PORT)) {
                        log_error("Invalid value for destination port(%u)\n", ssk->dst);
                        err = -EINVAL;
                        goto out;
                }       
        
-               isk->inet_daddr = swift_addr->sin_addr.s_addr;
+               isk->inet_daddr = swift_addr->dests[0].addr;
                log_debug("Received from user space destination port=%u and address=%u\n", ssk->dst, isk->inet_daddr);
        } else {
                log_error("Invalid swift_addr (NULL)\n");
@@ -199,9 +196,9 @@ out:
 static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
 {
        int err;
-       __be16 dport;
+       uint8_t dport;
        __be32 daddr;
-       __be16 sport;
+       uint8_t sport;
        struct sk_buff * skb;
        struct sock * sk; 
        struct inet_sock * isk;
@@ -240,20 +237,22 @@ static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr
        if (msg->msg_name) {
                struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) msg->msg_name;
                
-               if (unlikely(msg->msg_namelen < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET)) {
-                       log_error("Invalid size or address family\n");
+        if (unlikely(msg->msg_namelen < sizeof(*swift_addr) || 
+                     msg->msg_namelen < swift_addr->count * sizeof(struct swift_dest) || 
+                     swift_addr->count <= 0)) {
+                       log_error("Invalid size for msg_name\n");
                        err = -EINVAL;
                        goto out;
                }
                
-               dport = ntohs(swift_addr->sin_port);
+               dport = swift_addr->dests[0].port;
                if (unlikely(dport == 0 || dport >= MAX_SWIFT_PORT)) {
                        log_error("Invalid value for destination port(%u)\n", dport);
                        err = -EINVAL;
                        goto out;
                }       
 
-               daddr = swift_addr->sin_addr.s_addr;
+               daddr = swift_addr->dests[0].addr;
                log_debug("Received from user space destination port=%u and address=%u\n", dport, daddr);
        } else {
                if (unlikely(!ssk->dst || !isk->inet_daddr)) {
@@ -284,8 +283,8 @@ static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr
        log_debug("Reseted transport header\n");
 
        shdr = (struct swifthdr *) skb_transport_header(skb);
-       shdr->dst = ntohs(dport);
-       shdr->src = ntohs(sport);
+       shdr->dst = dport;
+       shdr->src = sport;
        shdr->len = ntohs(len + sizeof(struct swifthdr));
 
        log_debug("payload=%p\n", skb_put(skb, len));
@@ -310,7 +309,11 @@ static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr
                        log_error("Route lookup failed\n");
                        goto out_free;
                }
+#if LINUX_VERSION_CODE < KERNEL_VERSION(2, 6, 36)
+               sk_dst_set(sk, dst_clone(&rt->u.dst));
+#else
                sk_dst_set(sk, dst_clone(&rt->dst));
+#endif
        }
        
        err = ip_queue_xmit(skb);
@@ -335,6 +338,8 @@ static int swift_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr
        struct sock * sk = sock->sk;
        int err, copied;
 
+    log_debug("Trying to receive sock=%p sk=%p flags=%d\n", sock, sk, flags);
+
        skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
        if (unlikely(!skb)) {
                log_error("skb_recv_datagram\n");
@@ -377,7 +382,7 @@ static int swift_rcv(struct sk_buff *skb)
        struct swifthdr *shdr;
        struct swift_sock *ssk;
        __be16 len;
-       __be16 src, dst;
+       uint8_t src, dst;
        struct sockaddr_swift * swift_addr;
        int err;
 
@@ -399,8 +404,8 @@ static int swift_rcv(struct sk_buff *skb)
                goto drop;
        }
        
-       src = ntohs(shdr->src);
-       dst = ntohs(shdr->dst);
+       src = shdr->src;
+       dst = shdr->dst;
        if (unlikely(src == 0 || dst == 0 || src >= MAX_SWIFT_PORT || dst >= MAX_SWIFT_PORT)) {
                log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
                goto drop;
@@ -422,11 +427,10 @@ static int swift_rcv(struct sk_buff *skb)
        BUILD_BUG_ON(sizeof(struct sockaddr_swift) > sizeof(skb->cb));
        
        swift_addr = (struct sockaddr_swift *) skb->cb;
-       swift_addr->sin_family = AF_INET;
-       swift_addr->sin_port = shdr->src;
-       swift_addr->sin_addr.s_addr = ip_hdr(skb)->saddr;
+       swift_addr->dests[0].port = shdr->src;
+       swift_addr->dests[0].addr = ip_hdr(skb)->saddr;
 
-       log_debug("Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->sin_addr.s_addr);
+       log_debug("Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->dests[0].addr);
 
        err = ip_queue_rcv_skb((struct sock *) &ssk->sock, skb);
        if (unlikely(err)) {