Added likely/unlikely logic
[swifty.git] / src / kernel / swift.c
1 #include <linux/module.h>
2 #include <net/sock.h>
3 #include <net/protocol.h>
4 #include <net/ip.h>
5 #include <net/route.h>
6
7 #include "swift.h"
8 #include "debug.h"
9
10 MODULE_DESCRIPTION("Swift Transport Protocol");
11 MODULE_AUTHOR("Adrian Bondrescu/Cornel Mercan");
12 MODULE_LICENSE("GPL");
13
14 struct swift_sock {
15         struct inet_sock sock;
16         /* swift socket speciffic data */
17         __be16 src;
18         __be16 dst;
19 };
20
21 static struct swift_sock * sock_port_map[MAX_SWIFT_PORT];
22
23 static inline struct swift_sock * swift_sk(struct sock * sock)
24 {
25         return (struct swift_sock *)(sock);
26 }
27
28 static inline struct swifthdr * swift_hdr(const struct sk_buff * skb)
29 {
30         return (struct swifthdr *) skb_transport_header(skb);
31 }
32
33 static inline __be16 get_next_free_port(void)
34 {
35         int i;
36         for (i = MIN_SWIFT_PORT; i < MAX_SWIFT_PORT; i ++)
37                 if (sock_port_map[i] == NULL)
38                         return i;
39         return 0;
40 }
41
42 static inline void swift_unhash(__be16 port)
43 {
44         sock_port_map[port] = NULL;
45 }
46
47 static inline void swift_hash(__be16 port, struct swift_sock *ssh)
48 {
49         sock_port_map[port] = ssh;
50 }
51
52 static inline struct swift_sock * swift_lookup(__be16 port)
53 {
54         return sock_port_map[port];
55 }
56
57 static int swift_release(struct socket *sock)
58 {
59         struct sock *sk = sock->sk;
60         struct swift_sock * ssk = swift_sk(sk);
61
62         if (unlikely(!sk))
63                 return 0;
64
65         swift_unhash(ssk->src);
66         
67         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
68
69         synchronize_net();
70
71         sock_orphan(sk);
72         sock->sk = NULL;
73
74         skb_queue_purge(&sk->sk_receive_queue);
75
76         log_debug("swift_release sock=%p\n", sk);
77         sock_put(sk);
78
79         return 0;
80 }
81
82 static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
83 {
84         struct sockaddr_swift *swift_addr;
85         struct swift_sock *ssk;
86         int err;
87         __be16 port;
88
89         if (unlikely(addr_len < sizeof(struct sockaddr_swift))) {
90                 log_error("Invalid size for sockaddr\n");
91                 err = -EINVAL;
92                 goto out;
93         }
94
95         swift_addr = (struct sockaddr_swift *) addr;
96
97         if (unlikely(swift_addr->sin_family != AF_INET)) {
98                 log_error("Invalid family for sockaddr\n");
99                 err = -EINVAL;
100                 goto out;
101         }
102
103         port = ntohs(swift_addr->sin_port);
104
105         if (unlikely(port == 0 || port >= MAX_SWIFT_PORT)) {
106                 log_error("Invalid value for sockaddr port (%u)\n", port);
107                 err = -EINVAL;
108                 goto out;
109         }
110         
111         if (unlikely(swift_lookup(port) != NULL)) {
112                 log_error("Port %u already in use\n", port);
113                 err = -EADDRINUSE;
114                 goto out;
115         }
116
117         ssk = swift_sk(sock->sk);
118         ssk->src = port;
119
120         swift_hash(port, ssk);
121
122         log_debug("Socket %p bound to port %u\n", ssk, port);
123         
124         return 0;
125
126 out:
127         return err;
128 }
129
130 static int swift_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags)
131 {
132         int err;
133         struct sock * sk; 
134         struct inet_sock * isk;
135         struct swift_sock * ssk;
136
137         log_debug("swift_connect\n");
138
139         if (unlikely(sock == NULL)) {
140                 log_error("Sock is NULL\n");
141                 err = -EINVAL;
142                 goto out;
143         }
144         sk = sock->sk;
145
146         if (unlikely(sk == NULL)) {
147                 log_error("Sock->sk is NULL\n");
148                 err = -EINVAL;
149                 goto out;
150         }
151
152         isk = inet_sk(sk);
153         ssk = swift_sk(sk);
154
155         if (unlikely(ssk->src != 0)) {
156                 log_error("ssk->src is not NULL\n");
157                 err = -EINVAL;
158                 goto out;
159         }
160         
161         if (likely(addr)) {
162                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) addr;
163                 
164                 if (unlikely(addr_len < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET)) {
165                         log_error("Invalid size or address family\n");
166                         err = -EINVAL;
167                         goto out;
168                 }
169                 ssk->dst = ntohs(swift_addr->sin_port);
170                 if (unlikely(ssk->dst == 0 || ssk->dst >= MAX_SWIFT_PORT)) {
171                         log_error("Invalid value for destination port(%u)\n", ssk->dst);
172                         err = -EINVAL;
173                         goto out;
174                 }       
175         
176                 isk->inet_daddr = swift_addr->sin_addr.s_addr;
177                 log_debug("Received from user space destination port=%u and address=%u\n", ssk->dst, isk->inet_daddr);
178         } else {
179                 log_error("Invalid swift_addr (NULL)\n");
180                 err = -EINVAL;
181                 goto out;
182         }
183         
184         ssk->src = get_next_free_port();
185         if (unlikely(ssk->src == 0)) {
186                 log_error("No free ports\n");
187                 err = -ENOMEM;
188                 goto out;
189         }
190         
191         swift_hash(ssk->src, ssk);
192
193         return 0;
194
195 out:
196         return err;
197 }
198
199 static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
200 {
201         int err;
202         __be16 dport;
203         __be32 daddr;
204         __be16 sport;
205         struct sk_buff * skb;
206         struct sock * sk; 
207         struct inet_sock * isk;
208         struct swift_sock * ssk;
209         struct swifthdr * shdr;
210         int connected = 0;
211         int totlen;
212         struct rtable * rt = NULL;
213         
214         if (unlikely(sock == NULL)) {
215                 log_error("Sock is NULL\n");
216                 err = -EINVAL;
217                 goto out;
218         }
219         sk = sock->sk;
220
221         if (unlikely(sk == NULL)) {
222                 log_error("Sock->sk is NULL\n");
223                 err = -EINVAL;
224                 goto out;
225         }
226
227         isk = inet_sk(sk);
228         ssk = swift_sk(sk);
229
230         sport = ssk->src;
231         if (sport == 0) {
232                 sport = get_next_free_port();
233                 if (unlikely(sport == 0)) {
234                         log_error("No free ports\n");
235                         err = -ENOMEM;
236                         goto out;
237                 }
238         }
239
240         if (msg->msg_name) {
241                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) msg->msg_name;
242                 
243                 if (unlikely(msg->msg_namelen < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET)) {
244                         log_error("Invalid size or address family\n");
245                         err = -EINVAL;
246                         goto out;
247                 }
248                 
249                 dport = ntohs(swift_addr->sin_port);
250                 if (unlikely(dport == 0 || dport >= MAX_SWIFT_PORT)) {
251                         log_error("Invalid value for destination port(%u)\n", dport);
252                         err = -EINVAL;
253                         goto out;
254                 }       
255
256                 daddr = swift_addr->sin_addr.s_addr;
257                 log_debug("Received from user space destination port=%u and address=%u\n", dport, daddr);
258         } else {
259                 if (unlikely(!ssk->dst || !isk->inet_daddr)) {
260                         log_error("No destination port/address\n");
261                         err = -EDESTADDRREQ;
262                         goto out;
263                 }
264                 dport = ssk->dst;
265                 daddr = isk->inet_daddr;
266
267                 log_debug("Got from socket destination port=%u and address=%u\n", dport, daddr);
268                 connected = 1;
269         }
270
271         totlen = len + sizeof(struct swifthdr) + sizeof(struct iphdr);
272         skb = sock_alloc_send_skb(sk, totlen, msg->msg_flags & MSG_DONTWAIT, &err);
273         if (unlikely(!skb)) {
274                 log_error("sock_alloc_send_skb failed\n");
275                 goto out;
276         }
277         log_debug("Allocated %u bytes for skb (payload size=%u)\n", totlen, len);
278
279         skb_reset_network_header(skb);
280         skb_reserve(skb, sizeof(struct iphdr));
281         log_debug("Reseted network header\n");
282         skb_reset_transport_header(skb);
283         skb_put(skb, sizeof(struct swifthdr));
284         log_debug("Reseted transport header\n");
285
286         shdr = (struct swifthdr *) skb_transport_header(skb);
287         shdr->dst = ntohs(dport);
288         shdr->src = ntohs(sport);
289         shdr->len = ntohs(len + sizeof(struct swifthdr));
290
291         log_debug("payload=%p\n", skb_put(skb, len));
292
293         err = skb_copy_datagram_from_iovec(skb, sizeof(struct swifthdr), msg->msg_iov, 0, len);
294         if (unlikely(err)) {
295                 log_error("skb_copy_datagram_from_iovec failed\n");
296                 goto out_free;
297         }
298         log_debug("Copied %u bytes into the skb\n", len);
299
300         if (connected)
301                 rt = (struct rtable *) __sk_dst_check(sk, 0);
302
303         if (rt == NULL) {
304                 struct flowi fl = { .fl4_dst = daddr,
305                                     .proto = sk->sk_protocol,
306                                     .flags = inet_sk_flowi_flags(sk),
307                                   };
308                 err = ip_route_output_flow(sock_net(sk), &rt, &fl, sk, 0);
309                 if (unlikely(err)) {
310                         log_error("Route lookup failed\n");
311                         goto out_free;
312                 }
313                 sk_dst_set(sk, dst_clone(&rt->dst));
314         }
315         
316         err = ip_queue_xmit(skb);
317         if (likely(!err))
318                 log_debug("Sent %u bytes on wire\n", len);
319         else
320                 log_error("ip_queue_xmit failed\n");
321
322         return err;
323
324 out_free:
325         kfree(skb);
326
327 out:
328         return err;
329 }
330
331 static int swift_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len, int flags)
332 {
333         struct sk_buff *skb;
334         struct sockaddr_swift *swift_addr;
335         struct sock * sk = sock->sk;
336         int err, copied;
337
338         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
339         if (unlikely(!skb)) {
340                 log_error("skb_recv_datagram\n");
341                 goto out;
342         }
343
344         log_debug("Received skb %p\n", skb);
345
346         swift_addr = (struct sockaddr_swift *) skb->cb;
347         msg->msg_namelen = sizeof(struct sockaddr_swift);
348
349         copied = skb->len;
350         if (copied > len) {
351                 copied = len;
352                 msg->msg_flags |= MSG_TRUNC;
353         }
354
355         err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
356         if (unlikely(err)) {
357                 log_error("skb_copy_datagram_iovec\n");
358                 goto out_free;
359         }
360
361         sock_recv_ts_and_drops(msg, sk, skb);
362
363         if (msg->msg_name)
364                 memcpy(msg->msg_name, swift_addr, msg->msg_namelen);
365         
366         err = copied;
367
368 out_free:
369         skb_free_datagram(sk, skb);
370
371 out:
372         return err;
373 }
374
375 static int swift_rcv(struct sk_buff *skb)
376 {
377         struct swifthdr *shdr;
378         struct swift_sock *ssk;
379         __be16 len;
380         __be16 src, dst;
381         struct sockaddr_swift * swift_addr;
382         int err;
383
384         if (unlikely(!pskb_may_pull(skb, sizeof(struct swifthdr)))) {
385                 log_error("Insufficient space for header\n");
386                 goto drop;
387         }
388         
389         shdr = (struct swifthdr *) skb->data;
390         len = ntohs(shdr->len);
391
392         if (unlikely(skb->len < len)) {
393                 log_error("Malformed packet (packet_len=%u, skb_len=%u)\n", len, skb->len);
394                 goto drop;
395         }
396
397         if (unlikely(len < sizeof(struct swifthdr))) {
398                 log_error("Malformed packet (packet_len=%u sizeof(swifthdr)=%u\n", len, sizeof(struct swifthdr));
399                 goto drop;
400         }
401         
402         src = ntohs(shdr->src);
403         dst = ntohs(shdr->dst);
404         if (unlikely(src == 0 || dst == 0 || src >= MAX_SWIFT_PORT || dst >= MAX_SWIFT_PORT)) {
405                 log_error("Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
406                 goto drop;
407         }
408
409         skb_pull(skb, sizeof(struct swifthdr));
410         len -= sizeof(struct swifthdr);
411
412         pskb_trim(skb, len);
413
414         log_debug("Received %u bytes from from port=%u to port=%u\n", len - sizeof(struct swifthdr), src, dst);
415
416         ssk = swift_lookup(dst); 
417         if (ssk == NULL) {
418                 log_error("Swift lookup failed for port %u\n", dst);
419                 goto drop;
420         }
421
422         BUILD_BUG_ON(sizeof(struct sockaddr_swift) > sizeof(skb->cb));
423         
424         swift_addr = (struct sockaddr_swift *) skb->cb;
425         swift_addr->sin_family = AF_INET;
426         swift_addr->sin_port = shdr->src;
427         swift_addr->sin_addr.s_addr = ip_hdr(skb)->saddr;
428
429         log_debug("Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->sin_addr.s_addr);
430
431         err = ip_queue_rcv_skb((struct sock *) &ssk->sock, skb);
432         if (unlikely(err)) {
433                 log_error("ip_queu_rcv_skb\n");
434                 consume_skb(skb);
435         }
436         return NET_RX_SUCCESS;
437
438 drop:
439         kfree(skb);
440         return NET_RX_DROP;
441 }
442
443 static struct proto swift_prot = {
444         .obj_size = sizeof(struct swift_sock),
445         .owner    = THIS_MODULE,
446         .name     = "SWIFT",
447 };
448
449 static const struct proto_ops swift_ops = {
450         .family     = PF_INET,
451         .owner      = THIS_MODULE,
452         .release    = swift_release,
453         .bind       = swift_bind,
454         .connect    = swift_connect,
455         .socketpair = sock_no_socketpair,
456         .accept     = sock_no_accept,
457         .getname    = sock_no_getname,
458         .poll       = datagram_poll,
459         .ioctl      = sock_no_ioctl,
460         .listen     = sock_no_listen,
461         .shutdown   = sock_no_shutdown,
462         .setsockopt = sock_no_setsockopt,
463         .getsockopt = sock_no_getsockopt,
464         .sendmsg    = swift_sendmsg,
465         .recvmsg    = swift_recvmsg,
466         .mmap       = sock_no_mmap,
467         .sendpage   = sock_no_sendpage,
468 };
469
470 static const struct net_protocol swift_protocol = {
471         .handler   = swift_rcv,
472         .no_policy = 1,
473         .netns_ok  = 1,
474 };
475
476 static struct inet_protosw swift_protosw = {
477         .type     = SOCK_DGRAM,
478         .protocol = IPPROTO_SWIFT,
479         .prot     = &swift_prot,
480         .ops      = &swift_ops,
481         .no_check = 0,
482 };
483
484 static int __init swift_init(void)
485 {
486         int rc;
487
488         rc = proto_register(&swift_prot, 1);
489         if (unlikely(rc)) {
490                 log_error("Error registering swift protocol\n");
491                 goto out;
492         }
493
494         rc = inet_add_protocol(&swift_protocol, IPPROTO_SWIFT);
495         if (unlikely(rc)) {
496                 log_error("Error adding swift protocol\n");
497                 goto out_unregister;
498         }
499
500         inet_register_protosw(&swift_protosw);
501         log_debug("Swift entered\n");
502
503         return 0;
504
505 out_unregister:
506         proto_unregister(&swift_prot);
507
508 out:
509         return rc;
510 }
511
512 static void __exit swift_exit(void)
513 {
514         inet_unregister_protosw(&swift_protosw);
515
516         inet_del_protocol(&swift_protocol, IPPROTO_SWIFT);
517
518         proto_unregister(&swift_prot);
519
520         log_debug("Swift exited\n");
521 }
522
523 module_init(swift_init);
524 module_exit(swift_exit);