2463077aecbfc1a7a4c993a472314946b55320d6
[swifty.git] / src / kernel / mptp.c
1 #include <linux/module.h>
2 #include <linux/version.h>
3 #include <net/sock.h>
4 #include <net/protocol.h>
5 #include <net/ip.h>
6 #include <net/route.h>
7
8 #include "mptp.h"
9 #include "debug.h"
10
11 MODULE_DESCRIPTION("Multi-Party Transport Protocol");
12 MODULE_AUTHOR("Adrian Bondrescu/Cornel Mercan");
13 MODULE_LICENSE("GPL");
14
15 struct mptp_sock {
16         struct inet_sock sock;
17         /* mptp socket speciffic data */
18         uint16_t src;
19         uint16_t dst;
20 };
21
22 static struct mptp_sock *sock_port_map[MAX_MPTP_PORT];
23
24 static inline struct mptp_sock *mptp_sk(struct sock *sock)
25 {
26         return (struct mptp_sock *)(sock);
27 }
28
29 static inline struct mptphdr *mptp_hdr(const struct sk_buff *skb)
30 {
31         return (struct mptphdr *)skb_transport_header(skb);
32 }
33
34 static inline uint16_t get_next_free_port(void)
35 {
36         int i;
37         for (i = MIN_MPTP_PORT; i < MAX_MPTP_PORT; i++)
38                 if (sock_port_map[i] == NULL)
39                         return i;
40         return 0;
41 }
42
43 static inline void mptp_unhash(uint16_t port)
44 {
45         sock_port_map[port] = NULL;
46 }
47
48 static inline void mptp_hash(uint16_t port, struct mptp_sock *ssh)
49 {
50         sock_port_map[port] = ssh;
51 }
52
53 static inline struct mptp_sock *mptp_lookup(uint16_t port)
54 {
55         return sock_port_map[port];
56 }
57
58 static int mptp_release(struct socket *sock)
59 {
60         struct sock *sk = sock->sk;
61         struct mptp_sock *ssk = mptp_sk(sk);
62
63         if (unlikely(!sk))
64                 return 0;
65
66         mptp_unhash(ssk->src);
67
68         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
69
70         synchronize_net();
71
72         sock_orphan(sk);
73         sock->sk = NULL;
74
75         skb_queue_purge(&sk->sk_receive_queue);
76
77         log_debug("mptp_release sock=%p\n", sk);
78         sock_put(sk);
79
80         return 0;
81 }
82
83 static int mptp_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
84 {
85         struct sockaddr_mptp *mptp_addr;
86         struct mptp_sock *ssk;
87         int err;
88         uint16_t port;
89
90         if (unlikely
91             (addr_len <
92              sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest))) {
93                 log_error("Invalid size for sockaddr (%d)\n", addr_len);
94                 err = -EINVAL;
95                 goto out;
96         }
97
98         mptp_addr = (struct sockaddr_mptp *)addr;
99
100         log_debug("Bind received port=%u (network order)\n",
101                   mptp_addr->dests[0].port);
102         port = ntohs(mptp_addr->dests[0].port);
103         if (port == 0)
104                 port = get_next_free_port();
105
106         if (unlikely(port == 0 || port >= MAX_MPTP_PORT)) {
107                 log_error("Invalid value for sockaddr port (%u)\n", port);
108                 err = -EINVAL;
109                 goto out;
110         }
111
112         if (unlikely(mptp_lookup(port) != NULL)) {
113                 log_error("Port %u already in use\n", port);
114                 err = -EADDRINUSE;
115                 goto out;
116         }
117
118         ssk = mptp_sk(sock->sk);
119         sock->sk->sk_rcvbuf = 10 * 1024 * 1024;
120         ssk->src = port;
121
122         mptp_hash(port, ssk);
123
124         log_debug("Socket %p bound to port %u\n", ssk, port);
125
126         return 0;
127
128  out:
129         return err;
130 }
131
132 static int mptp_connect(struct socket *sock, struct sockaddr *addr,
133                         int addr_len, int flags)
134 {
135         int err;
136         struct sock *sk;
137         struct inet_sock *isk;
138         struct mptp_sock *ssk;
139
140         log_debug("mptp_connect\n");
141
142         if (unlikely(sock == NULL)) {
143                 log_error("Sock is NULL\n");
144                 err = -EINVAL;
145                 goto out;
146         }
147         sk = sock->sk;
148
149         if (unlikely(sk == NULL)) {
150                 log_error("Sock->sk is NULL\n");
151                 err = -EINVAL;
152                 goto out;
153         }
154
155         isk = inet_sk(sk);
156         ssk = mptp_sk(sk);
157
158         if (unlikely(ssk->src != 0)) {
159                 log_error("ssk->src is not NULL\n");
160                 err = -EINVAL;
161                 goto out;
162         }
163
164         if (likely(addr)) {
165                 struct sockaddr_mptp *mptp_addr = (struct sockaddr_mptp *)addr;
166
167                 if (unlikely(addr_len < sizeof(*mptp_addr) ||
168                              addr_len <
169                              mptp_addr->count * sizeof(struct mptp_dest)
170                              || mptp_addr->count <= 0)) {
171                         log_error("Invalid size or address family\n");
172                         err = -EINVAL;
173                         goto out;
174                 }
175                 ssk->dst = ntohs(mptp_addr->dests[0].port);
176                 if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_MPTP_PORT)) {
177                         log_error("Invalid value for destination port(%u)\n",
178                                   ssk->dst);
179                         err = -EINVAL;
180                         goto out;
181                 }
182
183                 isk->inet_daddr = mptp_addr->dests[0].addr;
184                 log_debug
185                     ("Received from user space destination port=%u and address=%u\n",
186                      ssk->dst, isk->inet_daddr);
187         } else {
188                 log_error("Invalid mptp_addr (NULL)\n");
189                 err = -EINVAL;
190                 goto out;
191         }
192
193         ssk->src = get_next_free_port();
194         if (unlikely(ssk->src == 0)) {
195                 log_error("No free ports\n");
196                 err = -ENOMEM;
197                 goto out;
198         }
199
200         mptp_hash(ssk->src, ssk);
201
202         return 0;
203
204  out:
205         return err;
206 }
207
208 static int mptp_sendmsg(struct kiocb *iocb, struct socket *sock,
209                         struct msghdr *msg, size_t len)
210 {
211         int err;
212         uint16_t dport;
213         __be32 daddr;
214         uint16_t sport;
215         struct sk_buff *skb;
216         struct sock *sk;
217         struct inet_sock *isk;
218         struct mptp_sock *ssk;
219         struct mptphdr *shdr;
220         int connected = 0;
221         int totlen;
222         struct rtable *rt = NULL;
223         int dests = 0;
224         int i;
225         struct sockaddr_mptp *mptp_addr = NULL;
226         int ret = 0;
227
228     if (unlikely(sock == NULL)) {
229         log_error("Sock is NULL\n");
230         err = -EINVAL;
231         goto out;
232     }
233     sk = sock->sk;
234
235     if (unlikely(sk == NULL)) {
236         log_error("Sock->sk is NULL\n");
237         err = -EINVAL;
238         goto out;
239     }
240
241     isk = inet_sk(sk);
242     ssk = mptp_sk(sk);
243
244     sport = ssk->src;
245     if (sport == 0) {
246         sport = get_next_free_port();
247         if (unlikely(sport == 0)) {
248             log_error("No free ports\n");
249             err = -ENOMEM;
250             goto out;
251         }
252     }
253
254     if (msg->msg_name) {
255         mptp_addr = (struct sockaddr_mptp *) msg->msg_name;
256
257         if (unlikely(msg->msg_namelen < sizeof(*mptp_addr) + mptp_addr->count * sizeof(struct mptp_dest) || 
258                      mptp_addr->count <= 0)) {
259             log_error("Invalid size for msg_name (size=%u, addr_count=%u)\n", msg->msg_namelen, mptp_addr->count);
260             err = -EINVAL;
261             goto out;
262         }
263
264         dests = mptp_addr->count;
265     } else {
266         BUG();
267         if (unlikely(!ssk->dst || !isk->inet_daddr)) {
268             log_error("No destination port/address\n");
269             err = -EDESTADDRREQ;
270             goto out;
271         }
272         dport = ssk->dst;
273         daddr = isk->inet_daddr;
274
275         log_debug("Got from socket destination port=%u and address=%u\n", dport, daddr);
276         connected = 1;
277     }
278
279     if (msg->msg_iovlen < dests)
280         dests = msg->msg_iovlen;
281
282     for (i = 0; i < dests; i++) {
283         struct mptp_dest *dest = &mptp_addr->dests[i];
284         struct iovec *iov = &msg->msg_iov[i];
285         char *payload;
286
287         dport = ntohs(dest->port);
288         if (unlikely(dport == 0 || dport >= MAX_MPTP_PORT)) {
289             log_error("Invalid value for destination port(%u)\n", dport);
290             err = -EINVAL;
291             goto out;
292         }       
293
294         daddr = dest->addr;
295         log_debug("Received from user space destination port=%u and address=%u\n", dport, daddr);
296
297         len = iov->iov_len;
298         totlen = len + sizeof(struct mptphdr) + sizeof(struct iphdr);
299         skb = sock_alloc_send_skb(sk, totlen, msg->msg_flags & MSG_DONTWAIT, &err);
300         if (unlikely(!skb)) {
301             log_error("sock_alloc_send_skb failed\n");
302             goto out;
303         }
304         log_debug("Allocated %u bytes for skb (payload size=%u)\n", totlen, len);
305
306         skb_reset_network_header(skb);
307         skb_reserve(skb, sizeof(struct iphdr));
308         log_debug("Reseted network header\n");
309         skb_reset_transport_header(skb);
310         skb_put(skb, sizeof(struct mptphdr));
311         log_debug("Reseted transport header\n");
312
313         shdr = (struct mptphdr *) skb_transport_header(skb);
314         shdr->dst = htons(dport);
315         shdr->src = htons(sport);
316         shdr->len = htons(len + sizeof(struct mptphdr));
317
318         payload = skb_put(skb, len);
319         log_debug("payload=%p\n", payload);
320
321         err = skb_copy_datagram_from_iovec(skb, sizeof(struct mptphdr), iov, 0, len);
322         if (unlikely(err)) {
323             log_error("skb_copy_datagram_from_iovec failed\n");
324             goto out_free;
325         }
326         log_debug("Copied %u bytes into the skb\n", len);
327
328         if (connected)
329             rt = (struct rtable *) __sk_dst_check(sk, 0);
330
331         if (rt == NULL) {
332             struct flowi fl = { .fl4_dst = daddr,
333                 .proto = sk->sk_protocol,
334                 .flags = inet_sk_flowi_flags(sk),
335             };
336             err = ip_route_output_flow(sock_net(sk), &rt, &fl, sk, 0);
337             if (unlikely(err)) {
338                 log_error("Route lookup failed\n");
339                 goto out_free;
340             }
341 #if LINUX_VERSION_CODE < KERNEL_VERSION(2, 6, 36)
342                         sk_dst_set(sk, dst_clone(&rt->u.dst));
343 #else
344                         sk_dst_set(sk, dst_clone(&rt->dst));
345 #endif
346                 }
347
348         skb->local_df = 1;
349         err = ip_queue_xmit(skb);
350         if (likely(!err)) {
351             log_debug("Sent %u bytes on wire\n", len);
352                         ret += len;
353                         dest->bytes = len;
354                 } else {
355                         log_error("ip_queue_xmit failed\n");
356                         dest->bytes = -1;
357                 }
358         }
359
360         return ret;
361
362  out_free:
363         kfree(skb);
364
365  out:
366         return err;
367 }
368
369 static int mptp_recvmsg(struct kiocb *iocb, struct socket *sock,
370                         struct msghdr *msg, size_t len, int flags)
371 {
372         struct sk_buff *skb;
373         struct sockaddr_mptp *mptp_addr;
374         struct sock *sk = sock->sk;
375         int err, copied;
376         int i;
377         struct sockaddr_mptp *ret_addr = (struct sockaddr_mptp *)msg->msg_name;
378         ret_addr->count = 0;
379
380         log_debug("Trying to receive sock=%p sk=%p flags=%d\n", sock, sk,
381                   flags);
382
383         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
384         if (unlikely(!skb)) {
385                 log_error("skb_recv_datagram failed with %d\n", err);
386                 goto out;
387         }
388
389         for (i = 0; i < msg->msg_iovlen; i++) {
390                 log_debug("Received skb %p\n", skb);
391
392                 mptp_addr = (struct sockaddr_mptp *)skb->cb;
393
394                 copied = skb->len;
395                 if (copied > msg->msg_iov[i].iov_len) {
396                         copied = msg->msg_iov[i].iov_len;
397                         msg->msg_flags |= MSG_TRUNC;
398                 }
399
400                 err = skb_copy_datagram_iovec(skb, 0, &msg->msg_iov[i], copied);
401                 if (unlikely(err)) {
402                         log_error("skb_copy_datagram_iovec\n");
403                         goto out_free;
404                 }
405                 log_debug("Received %d bytes\n", copied);
406
407                 sock_recv_ts_and_drops(msg, sk, skb);
408
409                 if (ret_addr) {
410                         memcpy(&ret_addr->dests[i], &mptp_addr->dests[0],
411                                sizeof(ret_addr->dests[i]));
412                         ret_addr->dests[i].bytes = copied;
413                 }
414
415                 err = copied;
416
417  out_free:
418                 skb_free_datagram(sk, skb);
419
420                 if (i == msg->msg_iovlen - 1)
421                         break;
422
423                 skb = skb_recv_datagram(sk, flags, 1, &err);
424                 if (likely(err == -EAGAIN)) {
425                         log_debug("No more skbs in the queue, returning...\n");
426                         err = copied;
427                         break;
428                 }
429         }
430
431         ret_addr->count = i + 1;
432         msg->msg_namelen =
433             sizeof(struct sockaddr_mptp) + (i + 1) * sizeof(struct mptp_dest);
434
435  out:
436         return err;
437 }
438
439 static int mptp_rcv(struct sk_buff *skb)
440 {
441         struct mptphdr *shdr;
442         struct mptp_sock *ssk;
443         __be16 len;
444         uint16_t src, dst;
445         struct sockaddr_mptp *mptp_addr;
446         int err;
447         int addr_size = sizeof(struct sockaddr_mptp) + sizeof(struct mptp_dest);
448
449         if (unlikely(!pskb_may_pull(skb, sizeof(struct mptphdr)))) {
450                 log_error("Insufficient space for header\n");
451                 goto drop;
452         }
453
454         shdr = (struct mptphdr *)skb->data;
455         len = ntohs(shdr->len);
456
457         if (unlikely(skb->len < len)) {
458                 log_error("Malformed packet (packet_len=%u, skb_len=%u)\n", len,
459                           skb->len);
460                 goto drop;
461         }
462
463         if (unlikely(len < sizeof(struct mptphdr))) {
464                 log_error
465                     ("Malformed packet (packet_len=%u sizeof(mptphdr)=%u\n",
466                      len, sizeof(struct mptphdr));
467                 goto drop;
468         }
469
470         src = ntohs(shdr->src);
471         dst = ntohs(shdr->dst);
472         if (unlikely
473             (src == 0 || dst == 0 || src >= MAX_MPTP_PORT
474              || dst >= MAX_MPTP_PORT)) {
475                 log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src,
476                           shdr->dst);
477                 goto drop;
478         }
479
480         skb_pull(skb, sizeof(struct mptphdr));
481         len -= sizeof(struct mptphdr);
482
483         pskb_trim(skb, len);
484
485         log_debug("Received %u bytes from from port=%u to port=%u\n",
486                   len - sizeof(struct mptphdr), src, dst);
487
488         ssk = mptp_lookup(dst);
489         if (ssk == NULL) {
490                 log_error("MPTP lookup failed for port %u\n", dst);
491                 goto drop;
492         }
493
494         BUG_ON(addr_size > sizeof(skb->cb));
495
496         mptp_addr = (struct sockaddr_mptp *)skb->cb;
497         mptp_addr->dests[0].port = shdr->src;
498         mptp_addr->dests[0].addr = ip_hdr(skb)->saddr;
499
500         log_debug("Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src),
501                   mptp_addr->dests[0].addr);
502
503 #if LINUX_VERSION_CODE < KERNEL_VERSION(2, 6, 36)
504         err = ip_queue_rcv_skb((struct sock *)&ssk->sock, skb);
505 #else
506         err = sock_queue_rcv_skb((struct sock *)&ssk->sock, skb);
507 #endif
508         if (unlikely(err)) {
509                 log_error("ip_queue_rcv_skb failed with %d\n", err);
510                 consume_skb(skb);
511         }
512         return NET_RX_SUCCESS;
513
514  drop:
515         kfree(skb);
516         return NET_RX_DROP;
517 }
518
519 static struct proto mptp_prot = {
520         .obj_size = sizeof(struct mptp_sock),
521         .owner = THIS_MODULE,
522         .name = "MPTP",
523 };
524
525 static const struct proto_ops mptp_ops = {
526         .family = PF_INET,
527         .owner = THIS_MODULE,
528         .release = mptp_release,
529         .bind = mptp_bind,
530         .connect = mptp_connect,
531         .socketpair = sock_no_socketpair,
532         .accept = sock_no_accept,
533         .getname = sock_no_getname,
534         .poll = datagram_poll,
535         .ioctl = sock_no_ioctl,
536         .listen = sock_no_listen,
537         .shutdown = sock_no_shutdown,
538         .setsockopt = sock_no_setsockopt,
539         .getsockopt = sock_no_getsockopt,
540         .sendmsg = mptp_sendmsg,
541         .recvmsg = mptp_recvmsg,
542         .mmap = sock_no_mmap,
543         .sendpage = sock_no_sendpage,
544 };
545
546 static const struct net_protocol mptp_protocol = {
547         .handler = mptp_rcv,
548         .no_policy = 1,
549         .netns_ok = 1,
550 };
551
552 static struct inet_protosw mptp_protosw = {
553         .type = SOCK_DGRAM,
554         .protocol = IPPROTO_MPTP,
555         .prot = &mptp_prot,
556         .ops = &mptp_ops,
557         .no_check = 0,
558 };
559
560 static int __init mptp_init(void)
561 {
562         int rc;
563
564         rc = proto_register(&mptp_prot, 1);
565         if (unlikely(rc)) {
566                 log_error("Error registering mptp protocol\n");
567                 goto out;
568         }
569
570         rc = inet_add_protocol(&mptp_protocol, IPPROTO_MPTP);
571         if (unlikely(rc)) {
572                 log_error("Error adding mptp protocol\n");
573                 goto out_unregister;
574         }
575
576         inet_register_protosw(&mptp_protosw);
577         log_debug("MPTP entered\n");
578
579         return 0;
580
581  out_unregister:
582         proto_unregister(&mptp_prot);
583
584  out:
585         return rc;
586 }
587
588 static void __exit mptp_exit(void)
589 {
590         inet_unregister_protosw(&mptp_protosw);
591
592         inet_del_protocol(&mptp_protocol, IPPROTO_MPTP);
593
594         proto_unregister(&mptp_prot);
595
596         log_debug("MPTP exited\n");
597 }
598
599 module_init(mptp_init);
600 module_exit(mptp_exit);