Move MPTP ports to 16bit. Until now they were on 8bit.
authorAdrian Bondrescu <adi.bondrescu@gmail.com>
Sun, 24 Jun 2012 12:15:45 +0000 (15:15 +0300)
committerAdrian Bondrescu <adi.bondrescu@gmail.com>
Sun, 24 Jun 2012 12:15:45 +0000 (15:15 +0300)
src/kernel/debug.h
src/kernel/mptp.c
src/kernel/mptp.h

index d886812..68b8642 100644 (file)
@@ -1,11 +1,11 @@
 #ifndef _MPTP_DEBUG_H
 #define _MPTP_DEBUG_H
 
-#define log_error(...) printk(KERN_ERR "SWIF-ERROR : " __VA_ARGS__)
+#define log_error(...) printk(KERN_ERR "MPTP-ERROR : " __VA_ARGS__)
 
 #if 0
 
-#define log_debug(...) printk(KERN_DEBUG "SWIF-DEBUG : " __VA_ARGS__)
+#define log_debug(...) printk(KERN_DEBUG "MPTP-DEBUG : " __VA_ARGS__)
 
 #else
 
index 05ed68b..94c5263 100644 (file)
@@ -15,8 +15,8 @@ MODULE_LICENSE("GPL");
 struct mptp_sock {
        struct inet_sock sock;
        /* mptp socket speciffic data */
-       uint8_t src;
-       uint8_t dst;
+       uint16_t src;
+       uint16_t dst;
 };
 
 static struct mptp_sock * sock_port_map[MAX_MPTP_PORT];
@@ -31,7 +31,7 @@ static inline struct mptphdr * mptp_hdr(const struct sk_buff * skb)
        return (struct mptphdr *) skb_transport_header(skb);
 }
 
-static inline uint8_t get_next_free_port(void)
+static inline uint16_t get_next_free_port(void)
 {
        int i;
        for (i = MIN_MPTP_PORT; i < MAX_MPTP_PORT; i ++)
@@ -40,17 +40,17 @@ static inline uint8_t get_next_free_port(void)
        return 0;
 }
 
-static inline void mptp_unhash(uint8_t port)
+static inline void mptp_unhash(uint16_t port)
 {
        sock_port_map[port] = NULL;
 }
 
-static inline void mptp_hash(uint8_t port, struct mptp_sock *ssh)
+static inline void mptp_hash(uint16_t port, struct mptp_sock *ssh)
 {
        sock_port_map[port] = ssh;
 }
 
-static inline struct mptp_sock * mptp_lookup(uint8_t port)
+static inline struct mptp_sock * mptp_lookup(uint16_t port)
 {
        return sock_port_map[port];
 }
@@ -85,17 +85,18 @@ static int mptp_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
        struct sockaddr_mptp *mptp_addr;
        struct mptp_sock *ssk;
        int err;
-       uint8_t port;
+       uint16_t port;
 
-       if (unlikely(addr_len < sizeof(struct sockaddr_mptp))) {
-               log_error("Invalid size for sockaddr\n");
+       if (unlikely(addr_len < sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest))) {
+               log_error("Invalid size for sockaddr (%d)\n", addr_len);
                err = -EINVAL;
                goto out;
        }
 
        mptp_addr = (struct sockaddr_mptp *) addr;
 
-       port = mptp_addr->dests[0].port;
+       log_debug("Bind received port=%u (network order)\n", mptp_addr->dests[0].port);
+       port = ntohs(mptp_addr->dests[0].port);
 
        if (unlikely(port == 0 || port >= MAX_MPTP_PORT)) {
                log_error("Invalid value for sockaddr port (%u)\n", port);
@@ -164,7 +165,7 @@ static int mptp_connect(struct socket *sock, struct sockaddr *addr, int addr_len
                        err = -EINVAL;
                        goto out;
                }
-               ssk->dst = mptp_addr->dests[0].port;
+               ssk->dst = ntohs(mptp_addr->dests[0].port);
                if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_MPTP_PORT)) {
                        log_error("Invalid value for destination port(%u)\n", ssk->dst);
                        err = -EINVAL;
@@ -197,9 +198,9 @@ out:
 static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
 {
        int err;
-       uint8_t dport;
+       uint16_t dport;
     __be32 daddr;
-    uint8_t sport;
+    uint16_t sport;
     struct sk_buff * skb;
     struct sock * sk; 
     struct inet_sock * isk;
@@ -211,6 +212,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *
     int dests = 0;
     int i;
     struct sockaddr_mptp * mptp_addr = NULL;
+       int ret = 0;
 
     if (unlikely(sock == NULL)) {
         log_error("Sock is NULL\n");
@@ -241,10 +243,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *
     if (msg->msg_name) {
         mptp_addr = (struct sockaddr_mptp *) msg->msg_name;
 
-        if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) || 
-                     msg->msg_namelen < mptp_addr->count * sizeof(struct mptp_dest) || 
+        if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) + mptp_addr->count * sizeof(struct mptp_dest) || 
                      mptp_addr->count <= 0)) {
-            log_error("Invalid size for msg_name\n");
+            log_error("Invalid size for msg_name (size=%u, addr_count=%u)\n", msg->msg_namelen, mptp_addr->count);
             err = -EINVAL;
             goto out;
         }
@@ -272,7 +273,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *
         struct iovec *iov = &msg->msg_iov[i];
         char *payload;
 
-        dport = dest->port;
+        dport = ntohs(dest->port);
         if (unlikely(dport == 0 || dport >= MAX_MPTP_PORT)) {
             log_error("Invalid value for destination port(%u)\n", dport);
             err = -EINVAL;
@@ -299,9 +300,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *
         log_debug("Reseted transport header\n");
 
         shdr = (struct mptphdr *) skb_transport_header(skb);
-        shdr->dst = dport;
-        shdr->src = sport;
-        shdr->len = ntohs(len + sizeof(struct mptphdr));
+        shdr->dst = htons(dport);
+        shdr->src = htons(sport);
+        shdr->len = htons(len + sizeof(struct mptphdr));
 
         payload = skb_put(skb, len);
         log_debug("payload=%p\n", payload);
@@ -335,13 +336,14 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *
 
         skb->local_df = 1;
         err = ip_queue_xmit(skb);
-        if (likely(!err))
+        if (likely(!err)) {
             log_debug("Sent %u bytes on wire\n", len);
-        else
+                       ret += len;
+               } else
             log_error("ip_queue_xmit failed\n");
     }
 
-       return err;
+       return ret;
 
 out_free:
        kfree(skb);
@@ -414,7 +416,7 @@ static int mptp_rcv(struct sk_buff *skb)
        struct mptphdr *shdr;
        struct mptp_sock *ssk;
        __be16 len;
-       uint8_t src, dst;
+       uint16_t src, dst;
        struct sockaddr_mptp * mptp_addr;
        int err;
        int addr_size = sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest);
@@ -437,8 +439,8 @@ static int mptp_rcv(struct sk_buff *skb)
                goto drop;
        }
        
-       src = shdr->src;
-       dst = shdr->dst;
+       src = ntohs(shdr->src);
+       dst = ntohs(shdr->dst);
        if (unlikely(src == 0 || dst == 0 || src >= MAX_MPTP_PORT || dst >= MAX_MPTP_PORT)) {
                log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
                goto drop;
index 59c5b17..51e5b0c 100644 (file)
@@ -4,7 +4,7 @@
 #define IPPROTO_MPTP 137
 
 #define MIN_MPTP_PORT 1
-#define MAX_MPTP_PORT 256
+#define MAX_MPTP_PORT 65536
 
 #ifndef __KERNEL__
 #include <inttypes.h>
@@ -12,7 +12,7 @@
 
 struct mptp_dest {
     uint32_t addr;
-    uint8_t port;
+    uint16_t port;
 };
 
 struct sockaddr_mptp {
@@ -22,8 +22,8 @@ struct sockaddr_mptp {
 
 #ifdef __KERNEL__
 struct mptphdr {
-       uint8_t src;
-       uint8_t dst;
+       uint16_t src;
+       uint16_t dst;
        __be16 len;
 };
 #endif