X-Git-Url: http://p2p-next.cs.pub.ro/gitweb/?a=blobdiff_plain;f=src%2Fkernel%2Fmptp.c;fp=src%2Fkernel%2Fmptp.c;h=94c526343fac0635367d6929673e7fe295603b7b;hb=dd7dd9d88c8eb825f100f4d8285aa4290e7e5a74;hp=05ed68b9e25ee33bc2848ce2bbff580d15a937ad;hpb=7b901a75a2e19d8d518bb3015744a87020d0d109;p=swifty.git diff --git a/src/kernel/mptp.c b/src/kernel/mptp.c index 05ed68b..94c5263 100644 --- a/src/kernel/mptp.c +++ b/src/kernel/mptp.c @@ -15,8 +15,8 @@ MODULE_LICENSE("GPL"); struct mptp_sock { struct inet_sock sock; /* mptp socket speciffic data */ - uint8_t src; - uint8_t dst; + uint16_t src; + uint16_t dst; }; static struct mptp_sock * sock_port_map[MAX_MPTP_PORT]; @@ -31,7 +31,7 @@ static inline struct mptphdr * mptp_hdr(const struct sk_buff * skb) return (struct mptphdr *) skb_transport_header(skb); } -static inline uint8_t get_next_free_port(void) +static inline uint16_t get_next_free_port(void) { int i; for (i = MIN_MPTP_PORT; i < MAX_MPTP_PORT; i ++) @@ -40,17 +40,17 @@ static inline uint8_t get_next_free_port(void) return 0; } -static inline void mptp_unhash(uint8_t port) +static inline void mptp_unhash(uint16_t port) { sock_port_map[port] = NULL; } -static inline void mptp_hash(uint8_t port, struct mptp_sock *ssh) +static inline void mptp_hash(uint16_t port, struct mptp_sock *ssh) { sock_port_map[port] = ssh; } -static inline struct mptp_sock * mptp_lookup(uint8_t port) +static inline struct mptp_sock * mptp_lookup(uint16_t port) { return sock_port_map[port]; } @@ -85,17 +85,18 @@ static int mptp_bind(struct socket *sock, struct sockaddr *addr, int addr_len) struct sockaddr_mptp *mptp_addr; struct mptp_sock *ssk; int err; - uint8_t port; + uint16_t port; - if (unlikely(addr_len < sizeof(struct sockaddr_mptp))) { - log_error("Invalid size for sockaddr\n"); + if (unlikely(addr_len < sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest))) { + log_error("Invalid size for sockaddr (%d)\n", addr_len); err = -EINVAL; goto out; } mptp_addr = (struct sockaddr_mptp *) addr; - port = mptp_addr->dests[0].port; + log_debug("Bind received port=%u (network order)\n", mptp_addr->dests[0].port); + port = ntohs(mptp_addr->dests[0].port); if (unlikely(port == 0 || port >= MAX_MPTP_PORT)) { log_error("Invalid value for sockaddr port (%u)\n", port); @@ -164,7 +165,7 @@ static int mptp_connect(struct socket *sock, struct sockaddr *addr, int addr_len err = -EINVAL; goto out; } - ssk->dst = mptp_addr->dests[0].port; + ssk->dst = ntohs(mptp_addr->dests[0].port); if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_MPTP_PORT)) { log_error("Invalid value for destination port(%u)\n", ssk->dst); err = -EINVAL; @@ -197,9 +198,9 @@ out: static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len) { int err; - uint8_t dport; + uint16_t dport; __be32 daddr; - uint8_t sport; + uint16_t sport; struct sk_buff * skb; struct sock * sk; struct inet_sock * isk; @@ -211,6 +212,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * int dests = 0; int i; struct sockaddr_mptp * mptp_addr = NULL; + int ret = 0; if (unlikely(sock == NULL)) { log_error("Sock is NULL\n"); @@ -241,10 +243,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * if (msg->msg_name) { mptp_addr = (struct sockaddr_mptp *) msg->msg_name; - if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) || - msg->msg_namelen < mptp_addr->count * sizeof(struct mptp_dest) || + if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) + mptp_addr->count * sizeof(struct mptp_dest) || mptp_addr->count <= 0)) { - log_error("Invalid size for msg_name\n"); + log_error("Invalid size for msg_name (size=%u, addr_count=%u)\n", msg->msg_namelen, mptp_addr->count); err = -EINVAL; goto out; } @@ -272,7 +273,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * struct iovec *iov = &msg->msg_iov[i]; char *payload; - dport = dest->port; + dport = ntohs(dest->port); if (unlikely(dport == 0 || dport >= MAX_MPTP_PORT)) { log_error("Invalid value for destination port(%u)\n", dport); err = -EINVAL; @@ -299,9 +300,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * log_debug("Reseted transport header\n"); shdr = (struct mptphdr *) skb_transport_header(skb); - shdr->dst = dport; - shdr->src = sport; - shdr->len = ntohs(len + sizeof(struct mptphdr)); + shdr->dst = htons(dport); + shdr->src = htons(sport); + shdr->len = htons(len + sizeof(struct mptphdr)); payload = skb_put(skb, len); log_debug("payload=%p\n", payload); @@ -335,13 +336,14 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * skb->local_df = 1; err = ip_queue_xmit(skb); - if (likely(!err)) + if (likely(!err)) { log_debug("Sent %u bytes on wire\n", len); - else + ret += len; + } else log_error("ip_queue_xmit failed\n"); } - return err; + return ret; out_free: kfree(skb); @@ -414,7 +416,7 @@ static int mptp_rcv(struct sk_buff *skb) struct mptphdr *shdr; struct mptp_sock *ssk; __be16 len; - uint8_t src, dst; + uint16_t src, dst; struct sockaddr_mptp * mptp_addr; int err; int addr_size = sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest); @@ -437,8 +439,8 @@ static int mptp_rcv(struct sk_buff *skb) goto drop; } - src = shdr->src; - dst = shdr->dst; + src = ntohs(shdr->src); + dst = ntohs(shdr->dst); if (unlikely(src == 0 || dst == 0 || src >= MAX_MPTP_PORT || dst >= MAX_MPTP_PORT)) { log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst); goto drop;