X-Git-Url: http://p2p-next.cs.pub.ro/gitweb/?a=blobdiff_plain;f=src%2Fkernel%2Fmptp.c;h=1ee2441b65b8a4c18cf3e8ee9664b5c6a84eb052;hb=f414cd9dde28bcf84ee86442ee338e48f024afc2;hp=6f1bddb24f648d8e3610584da186f0f4e16358a1;hpb=34a80b6097950ce3a769c65f70fc917c6c6c5db6;p=swifty.git diff --git a/src/kernel/mptp.c b/src/kernel/mptp.c index 6f1bddb..1ee2441 100644 --- a/src/kernel/mptp.c +++ b/src/kernel/mptp.c @@ -15,8 +15,8 @@ MODULE_LICENSE("GPL"); struct mptp_sock { struct inet_sock sock; /* mptp socket speciffic data */ - uint8_t src; - uint8_t dst; + uint16_t src; + uint16_t dst; }; static struct mptp_sock * sock_port_map[MAX_MPTP_PORT]; @@ -31,7 +31,7 @@ static inline struct mptphdr * mptp_hdr(const struct sk_buff * skb) return (struct mptphdr *) skb_transport_header(skb); } -static inline uint8_t get_next_free_port(void) +static inline uint16_t get_next_free_port(void) { int i; for (i = MIN_MPTP_PORT; i < MAX_MPTP_PORT; i ++) @@ -40,17 +40,17 @@ static inline uint8_t get_next_free_port(void) return 0; } -static inline void mptp_unhash(uint8_t port) +static inline void mptp_unhash(uint16_t port) { sock_port_map[port] = NULL; } -static inline void mptp_hash(uint8_t port, struct mptp_sock *ssh) +static inline void mptp_hash(uint16_t port, struct mptp_sock *ssh) { sock_port_map[port] = ssh; } -static inline struct mptp_sock * mptp_lookup(uint8_t port) +static inline struct mptp_sock * mptp_lookup(uint16_t port) { return sock_port_map[port]; } @@ -85,17 +85,20 @@ static int mptp_bind(struct socket *sock, struct sockaddr *addr, int addr_len) struct sockaddr_mptp *mptp_addr; struct mptp_sock *ssk; int err; - uint8_t port; + uint16_t port; - if (unlikely(addr_len < sizeof(struct sockaddr_mptp))) { - log_error("Invalid size for sockaddr\n"); + if (unlikely(addr_len < sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest))) { + log_error("Invalid size for sockaddr (%d)\n", addr_len); err = -EINVAL; goto out; } mptp_addr = (struct sockaddr_mptp *) addr; - port = mptp_addr->dests[0].port; + log_debug("Bind received port=%u (network order)\n", mptp_addr->dests[0].port); + port = ntohs(mptp_addr->dests[0].port); + if (port == 0) + port = get_next_free_port(); if (unlikely(port == 0 || port >= MAX_MPTP_PORT)) { log_error("Invalid value for sockaddr port (%u)\n", port); @@ -164,7 +167,7 @@ static int mptp_connect(struct socket *sock, struct sockaddr *addr, int addr_len err = -EINVAL; goto out; } - ssk->dst = mptp_addr->dests[0].port; + ssk->dst = ntohs(mptp_addr->dests[0].port); if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_MPTP_PORT)) { log_error("Invalid value for destination port(%u)\n", ssk->dst); err = -EINVAL; @@ -197,9 +200,9 @@ out: static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len) { int err; - uint8_t dport; + uint16_t dport; __be32 daddr; - uint8_t sport; + uint16_t sport; struct sk_buff * skb; struct sock * sk; struct inet_sock * isk; @@ -211,6 +214,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * int dests = 0; int i; struct sockaddr_mptp * mptp_addr = NULL; + int ret = 0; if (unlikely(sock == NULL)) { log_error("Sock is NULL\n"); @@ -241,10 +245,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * if (msg->msg_name) { mptp_addr = (struct sockaddr_mptp *) msg->msg_name; - if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) || - msg->msg_namelen < mptp_addr->count * sizeof(struct mptp_dest) || + if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) + mptp_addr->count * sizeof(struct mptp_dest) || mptp_addr->count <= 0)) { - log_error("Invalid size for msg_name\n"); + log_error("Invalid size for msg_name (size=%u, addr_count=%u)\n", msg->msg_namelen, mptp_addr->count); err = -EINVAL; goto out; } @@ -272,7 +275,7 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * struct iovec *iov = &msg->msg_iov[i]; char *payload; - dport = dest->port; + dport = ntohs(dest->port); if (unlikely(dport == 0 || dport >= MAX_MPTP_PORT)) { log_error("Invalid value for destination port(%u)\n", dport); err = -EINVAL; @@ -299,9 +302,9 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * log_debug("Reseted transport header\n"); shdr = (struct mptphdr *) skb_transport_header(skb); - shdr->dst = dport; - shdr->src = sport; - shdr->len = ntohs(len + sizeof(struct mptphdr)); + shdr->dst = htons(dport); + shdr->src = htons(sport); + shdr->len = htons(len + sizeof(struct mptphdr)); payload = skb_put(skb, len); log_debug("payload=%p\n", payload); @@ -335,13 +338,14 @@ static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * skb->local_df = 1; err = ip_queue_xmit(skb); - if (likely(!err)) + if (likely(!err)) { log_debug("Sent %u bytes on wire\n", len); - else + ret += len; + } else log_error("ip_queue_xmit failed\n"); } - return err; + return ret; out_free: kfree(skb); @@ -358,6 +362,7 @@ static int mptp_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * int err, copied; int i; struct sockaddr_mptp *ret_addr = (struct sockaddr_mptp *) msg->msg_name; + ret_addr->count = 0; log_debug("Trying to receive sock=%p sk=%p flags=%d\n", sock, sk, flags); @@ -384,6 +389,7 @@ static int mptp_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr * goto out_free; } log_debug("Received %d bytes\n", copied); + msg->msg_iov[i].iov_len = copied; sock_recv_ts_and_drops(msg, sk, skb); @@ -418,7 +424,7 @@ static int mptp_rcv(struct sk_buff *skb) struct mptphdr *shdr; struct mptp_sock *ssk; __be16 len; - uint8_t src, dst; + uint16_t src, dst; struct sockaddr_mptp * mptp_addr; int err; int addr_size = sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest); @@ -441,8 +447,8 @@ static int mptp_rcv(struct sk_buff *skb) goto drop; } - src = shdr->src; - dst = shdr->dst; + src = ntohs(shdr->src); + dst = ntohs(shdr->dst); if (unlikely(src == 0 || dst == 0 || src >= MAX_MPTP_PORT || dst >= MAX_MPTP_PORT)) { log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst); goto drop;