5ac64edc28a24c91b5a1bc14d01498a6e0e66ec4
[swifty.git] / src / kernel / swift.c
1 #include <linux/module.h>
2 #include <net/sock.h>
3 #include <net/protocol.h>
4 #include <net/ip.h>
5 #include <net/route.h>
6
7 #include "swift.h"
8 #include "debug.h"
9
10 MODULE_DESCRIPTION("Swift Transport Protocol");
11 MODULE_AUTHOR("Adrian Bondrescu/Cornel Mercan");
12 MODULE_LICENSE("GPL");
13
14 struct swift_sock {
15         struct inet_sock sock;
16         /* swift socket speciffic data */
17         __be16 src;
18         __be16 dst;
19 };
20
21 static struct swift_sock * sock_port_map[MAX_SWIFT_PORT];
22
23 static inline struct swift_sock * swift_sk(struct sock * sock)
24 {
25         return (struct swift_sock *)(sock);
26 }
27
28 static inline struct swifthdr * swift_hdr(const struct sk_buff * skb)
29 {
30         return (struct swifthdr *) skb_transport_header(skb);
31 }
32
33 static inline __be16 get_next_free_port(void)
34 {
35         int i;
36         for (i = MIN_SWIFT_PORT; i < MAX_SWIFT_PORT; i ++)
37                 if (sock_port_map[i] == NULL)
38                         return i;
39         return 0;
40 }
41
42 static inline void swift_unhash(__be16 port)
43 {
44         sock_port_map[port] = NULL;
45 }
46
47 static inline void swift_hash(__be16 port, struct swift_sock *ssh)
48 {
49         sock_port_map[port] = ssh;
50 }
51
52 static inline struct swift_sock * swift_lookup(__be16 port)
53 {
54         return sock_port_map[port];
55 }
56
57 static int swift_release(struct socket *sock)
58 {
59         struct sock *sk = sock->sk;
60         struct swift_sock * ssk = swift_sk(sk);
61
62         if (!sk)
63                 return 0;
64
65         swift_unhash(ssk->src);
66         
67         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
68
69         synchronize_net();
70
71         sock_orphan(sk);
72         sock->sk = NULL;
73
74         skb_queue_purge(&sk->sk_receive_queue);
75
76         log_debug("swift_release sock=%p\n", sk);
77         sock_put(sk);
78
79         return 0;
80 }
81
82 static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
83 {
84         struct sockaddr_swift *swift_addr;
85         struct swift_sock *ssk;
86         int err;
87         __be16 port;
88
89         err = -EINVAL;
90         if (addr_len < sizeof(struct sockaddr_swift)) {
91                 log_error("Invalid size for sockaddr\n");
92                 goto out;
93         }
94
95         swift_addr = (struct sockaddr_swift *) addr;
96
97         err = -EINVAL;
98         if (swift_addr->sin_family != AF_INET) {
99                 log_error("Invalid family for sockaddr\n");
100                 goto out;
101         }
102
103         port = ntohs(swift_addr->sin_port);
104
105         err = -EINVAL;
106         if (port == 0 || port >= MAX_SWIFT_PORT) {
107                 log_error("Invalid value for sockaddr port (%u)\n", port);
108                 goto out;
109         }
110         
111         err = -EADDRINUSE;
112         if (swift_lookup(port) != NULL) {
113                 log_error("Port %u already in use\n", port);
114                 goto out;
115         }
116
117         ssk = swift_sk(sock->sk);
118         ssk->src = port;
119
120         swift_hash(port, ssk);
121
122         log_debug("Socket %p bound to port %u\n", ssk, port);
123         
124         return 0;
125
126 out:
127         return -EINVAL;
128 }
129
130 static int swift_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags)
131 {
132         int err;
133         struct sock * sk; 
134         struct inet_sock * isk;
135         struct swift_sock * ssk;
136
137         log_debug("swift_connect\n");
138
139         err = -EINVAL;
140         if (sock == NULL) {
141                 log_error("Sock is NULL\n");
142                 goto out;
143         }
144         sk = sock->sk;
145
146         err = -EINVAL;
147         if (sk == NULL) {
148                 log_error("Sock->sk is NULL\n");
149                 goto out;
150         }
151         
152         isk = inet_sk(sk);
153         ssk = swift_sk(sk);
154
155         if (ssk->src != 0) {
156                 log_error("ssk->src is not NULL\n");
157                 goto out;
158         }
159         
160         err = -EINVAL;
161         if (addr) {
162                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) addr;
163                 
164                 err = -EINVAL;
165                 if (addr_len < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET) {
166                         log_error("Invalid size or address family\n");
167                         goto out;
168                 }
169                 ssk->dst = ntohs(swift_addr->sin_port);
170                 if (ssk->dst == 0 || ssk->dst >= MAX_SWIFT_PORT) {
171                         log_error("Invalid value for destination port(%u)\n", ssk->dst);
172                         goto out;
173                 }       
174         
175                 isk->inet_daddr = swift_addr->sin_addr.s_addr;
176                 log_debug("Received from user space destination port=%u and address=%u\n", ssk->dst, isk->inet_daddr);
177         } else {
178                 log_error("Invalid swift_addr (NULL)\n");
179                 goto out;
180         }
181         
182         err = -ENOMEM;
183         ssk->src = get_next_free_port();
184         if (ssk->src == 0) {
185                 log_error("No free ports\n");
186                 goto out;
187         }
188         
189         swift_hash(ssk->src, ssk);
190
191         return 0;
192
193 out:
194         return err;
195
196 }
197
198 static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
199 {
200         int err;
201         __be16 dport;
202         __be32 daddr;
203         __be16 sport;
204         struct sk_buff * skb;
205         struct sock * sk; 
206         struct inet_sock * isk;
207         struct swift_sock * ssk;
208         struct swifthdr * shdr;
209         int connected = 0;
210         int totlen;
211         struct rtable * rt = NULL;
212         
213         err = -EINVAL;
214         if (sock == NULL) {
215                 log_error("Sock is NULL\n");
216                 goto out;
217         }
218         sk = sock->sk;
219
220         err = -EINVAL;
221         if (sk == NULL) {
222                 log_error("Sock->sk is NULL\n");
223                 goto out;
224         }
225
226         isk = inet_sk(sk);
227         ssk = swift_sk(sk);
228
229         sport = ssk->src;
230         if (sport == 0) {
231                 err = -ENOMEM;
232                 sport = get_next_free_port();
233                 if (sport == 0) {
234                         log_error("No free ports\n");
235                         goto out;
236                 }
237         }
238
239         if (msg->msg_name) {
240                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) msg->msg_name;
241                 
242                 err = -EINVAL;
243                 if (msg->msg_namelen < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET) {
244                         log_error("Invalid size or address family\n");
245                         goto out;
246                 }
247                 
248                 dport = ntohs(swift_addr->sin_port);
249                 if (dport == 0 || dport >= MAX_SWIFT_PORT) {
250                         log_error("Invalid value for destination port(%u)\n", dport);
251                         goto out;
252                 }       
253
254                 daddr = swift_addr->sin_addr.s_addr;
255                 log_debug("Received from user space destination port=%u and address=%u\n", dport, daddr);
256         } else {
257                 err = -EDESTADDRREQ;
258                 if (!ssk->dst || !isk->inet_daddr) {
259                         log_error("No destination port/address\n");
260                         goto out;
261                 }
262                 dport = ssk->dst;
263                 daddr = isk->inet_daddr;
264
265                 log_debug("Got from socket destination port=%u and address=%u\n", dport, daddr);
266                 connected = 1;
267         }
268
269         totlen = len + sizeof(struct swifthdr) + sizeof(struct iphdr);
270         skb = sock_alloc_send_skb(sk, totlen, msg->msg_flags & MSG_DONTWAIT, &err);
271         if (!skb) {
272                 log_error("sock_alloc_send_skb failed\n");
273                 goto out;
274         }
275         log_debug("Allocated %u bytes for skb (payload size=%u)\n", totlen, len);
276
277         skb_reset_network_header(skb);
278         skb_reserve(skb, sizeof(struct iphdr));
279         log_debug("Reseted network header\n");
280         skb_reset_transport_header(skb);
281         skb_put(skb, sizeof(struct swifthdr));
282         log_debug("Reseted transport header\n");
283
284         shdr = (struct swifthdr *) skb_transport_header(skb);
285         shdr->dst = ntohs(dport);
286         shdr->src = ntohs(sport);
287         shdr->len = ntohs(len + sizeof(struct swifthdr));
288
289         log_debug("payload=%p\n", skb_put(skb, len));
290
291         err = skb_copy_datagram_from_iovec(skb, sizeof(struct swifthdr), msg->msg_iov, 0, len);
292         if (err) {
293                 log_error("skb_copy_datagram_from_iovec failed\n");
294                 goto out_free;
295         }
296         log_debug("Copied %u bytes into the skb\n", len);
297
298         if (connected)
299                 rt = (struct rtable *) __sk_dst_check(sk, 0);
300
301         if (rt == NULL) {
302                 struct flowi fl = { .fl4_dst = daddr,
303                                     .proto = sk->sk_protocol,
304                                     .flags = inet_sk_flowi_flags(sk),
305                                   };
306                 err = ip_route_output_flow(sock_net(sk), &rt, &fl, sk, 0);
307                 if (err) {
308                         log_error("Route lookup failed\n");
309                         goto out_free;
310                 }
311                 sk_dst_set(sk, dst_clone(&rt->dst));
312         }
313         
314         err = ip_queue_xmit(skb);
315         if (!err)
316                 log_debug("Sent %u bytes on wire\n", len);
317         else
318                 log_error("ip_queue_xmit failed\n");
319
320         return err;
321
322 out_free:
323         kfree(skb);
324
325 out:
326         return err;
327 }
328
329 static int swift_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len, int flags)
330 {
331         struct sk_buff *skb;
332         struct sockaddr_swift *swift_addr;
333         struct sock * sk = sock->sk;
334         int err, copied;
335
336         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
337         if (!skb) {
338                 log_error("skb_recv_datagram\n");
339                 goto out;
340         }
341
342         log_debug("Received skb %p\n", skb);
343
344         swift_addr = (struct sockaddr_swift *) skb->cb;
345         msg->msg_namelen = sizeof(struct sockaddr_swift);
346
347         copied = skb->len;
348         if (copied > len) {
349                 copied = len;
350                 msg->msg_flags |= MSG_TRUNC;
351         }
352
353         err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
354         if (err) {
355                 log_error("skb_copy_datagram_iovec\n");
356                 goto out_free;
357         }
358
359         sock_recv_ts_and_drops(msg, sk, skb);
360
361         if (msg->msg_name)
362                 memcpy(msg->msg_name, swift_addr, msg->msg_namelen);
363         
364         err = copied;
365
366 out_free:
367         skb_free_datagram(sk, skb);
368
369 out:
370         return err;
371 }
372
373 static int swift_rcv(struct sk_buff *skb)
374 {
375         struct swifthdr *shdr;
376         struct swift_sock *ssk;
377         __be16 len;
378         __be16 src, dst;
379         struct sockaddr_swift * swift_addr;
380         int err;
381
382         if (!pskb_may_pull(skb, sizeof(struct swifthdr))) {
383                 log_error("Insufficient space for header\n");
384                 goto drop;
385         }
386         
387         shdr = (struct swifthdr *) skb->data;
388         len = ntohs(shdr->len);
389
390         if (skb->len < len) {
391                 log_error("Malformed packet (packet_len=%u, skb_len=%u)\n", len, skb->len);
392                 goto drop;
393         }
394
395         if (len < sizeof(struct swifthdr)) {
396                 log_error("Malformed packet (packet_len=%u sizeof(swifthdr)=%u\n", len, sizeof(struct swifthdr));
397                 goto drop;
398         }
399         
400         src = ntohs(shdr->src);
401         dst = ntohs(shdr->dst);
402         if (src == 0 || dst == 0 || src >= MAX_SWIFT_PORT || dst >= MAX_SWIFT_PORT) {
403                 log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
404                 goto drop;
405         }
406
407         skb_pull(skb, sizeof(struct swifthdr));
408         len -= sizeof(struct swifthdr);
409
410         pskb_trim(skb, len);
411
412         log_debug("Received %u bytes from from port=%u to port=%u\n", len - sizeof(struct swifthdr), src, dst);
413
414         ssk = swift_lookup(dst); 
415         if (ssk == NULL) {
416                 log_error("Swift lookup failed for port %u\n", dst);
417                 goto drop;
418         }
419
420         BUILD_BUG_ON(sizeof(struct sockaddr_swift) > sizeof(skb->cb));
421         
422         swift_addr = (struct sockaddr_swift *) skb->cb;
423         swift_addr->sin_family = AF_INET;
424         swift_addr->sin_port = shdr->src;
425         swift_addr->sin_addr.s_addr = ip_hdr(skb)->saddr;
426
427         log_debug("Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->sin_addr.s_addr);
428
429         err = ip_queue_rcv_skb((struct sock *) &ssk->sock, skb);
430         if (err) {
431                 log_error("ip_queu_rcv_skb\n");
432                 consume_skb(skb);
433         }
434         return NET_RX_SUCCESS;
435
436 drop:
437         kfree(skb);
438         return NET_RX_DROP;
439 }
440
441 static struct proto swift_prot = {
442         .obj_size = sizeof(struct swift_sock),
443         .owner    = THIS_MODULE,
444         .name     = "SWIFT",
445 };
446
447 static const struct proto_ops swift_ops = {
448         .family     = PF_INET,
449         .owner      = THIS_MODULE,
450         .release    = swift_release,
451         .bind       = swift_bind,
452         .connect    = swift_connect,
453         .socketpair = sock_no_socketpair,
454         .accept     = sock_no_accept,
455         .getname    = sock_no_getname,
456         .poll       = datagram_poll,
457         .ioctl      = sock_no_ioctl,
458         .listen     = sock_no_listen,
459         .shutdown   = sock_no_shutdown,
460         .setsockopt = sock_no_setsockopt,
461         .getsockopt = sock_no_getsockopt,
462         .sendmsg    = swift_sendmsg,
463         .recvmsg    = swift_recvmsg,
464         .mmap       = sock_no_mmap,
465         .sendpage   = sock_no_sendpage,
466 };
467
468 static const struct net_protocol swift_protocol = {
469         .handler   = swift_rcv,
470         .no_policy = 1,
471         .netns_ok  = 1,
472 };
473
474 static struct inet_protosw swift_protosw = {
475         .type     = SOCK_DGRAM,
476         .protocol = IPPROTO_SWIFT,
477         .prot     = &swift_prot,
478         .ops      = &swift_ops,
479         .no_check = 0,
480 };
481
482 static int __init swift_init(void)
483 {
484         int rc;
485
486         rc = proto_register(&swift_prot, 1);
487         if (rc) {
488                 log_error("Error registering swift protocol\n");
489                 goto out;
490         }
491
492         rc = inet_add_protocol(&swift_protocol, IPPROTO_SWIFT);
493         if (rc) {
494                 log_error("Error adding swift protocol\n");
495                 goto out_unregister;
496         }
497
498         inet_register_protosw(&swift_protosw);
499         log_debug("Swift entered\n");
500
501         return 0;
502
503 out_unregister:
504         proto_unregister(&swift_prot);
505
506 out:
507         return rc;
508 }
509
510 static void __exit swift_exit(void)
511 {
512         inet_unregister_protosw(&swift_protosw);
513
514         inet_del_protocol(&swift_protocol, IPPROTO_SWIFT);
515
516         proto_unregister(&swift_prot);
517
518         log_debug("Swift exited\n");
519 }
520
521 module_init(swift_init);
522 module_exit(swift_exit);