Remove initial test file
[swifty.git] / src / kernel / swift.c
1 #include <linux/module.h>
2 #include <net/sock.h>
3 #include <net/protocol.h>
4 #include <net/ip.h>
5 #include <net/route.h>
6
7 #include "swift.h"
8
9 MODULE_DESCRIPTION("Swift Transport Protocol");
10 MODULE_AUTHOR("Adrian Bondrescu/Cornel Mercan");
11 MODULE_LICENSE("GPL");
12
13 struct swift_sock {
14         struct inet_sock sock;
15         /* swift socket speciffic data */
16         __be16 src;
17         __be16 dst;
18         __be16 len;
19 };
20
21 static struct swift_sock * sock_port_map[MAX_SWIFT_PORT];
22
23 static inline struct swift_sock * swift_sk(struct sock * sock)
24 {
25         return (struct swift_sock *)(sock);
26 }
27
28 static inline struct swifthdr * swift_hdr(const struct sk_buff * skb)
29 {
30         return (struct swifthdr *) skb_transport_header(skb);
31 }
32
33 static inline __be16 get_next_free_port(void)
34 {
35         int i;
36         for (i = MIN_SWIFT_PORT; i < MAX_SWIFT_PORT; i ++)
37                 if (sock_port_map[i] == NULL)
38                         return i;
39         return 0;
40 }
41
42 static inline void swift_unhash(__be16 port)
43 {
44         sock_port_map[port] = NULL;
45 }
46
47 static inline void swift_hash(__be16 port, struct swift_sock *ssh)
48 {
49         sock_port_map[port] = ssh;
50 }
51
52 static inline struct swift_sock * swift_lookup(__be16 port)
53 {
54         return sock_port_map[port];
55 }
56
57 static int swift_release(struct socket *sock)
58 {
59         struct sock *sk = sock->sk;
60         struct swift_sock * ssk = swift_sk(sk);
61
62         if (!sk)
63                 return 0;
64
65         swift_unhash(ssk->src);
66         
67         sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
68
69         synchronize_net();
70
71         sock_orphan(sk);
72         sock->sk = NULL;
73
74         skb_queue_purge(&sk->sk_receive_queue);
75
76         printk(KERN_DEBUG "swift_release sock=%p\n", sk);
77         sock_put(sk);
78
79         return 0;
80 }
81
82 static int swift_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
83 {
84         struct sockaddr_swift *swift_addr;
85         struct swift_sock *ssk;
86         int err;
87         __be16 port;
88
89         err = -EINVAL;
90         if (addr_len < sizeof(struct sockaddr_swift)) {
91                 printk(KERN_ERR "Invalid size for sockaddr\n");
92                 goto out;
93         }
94
95         swift_addr = (struct sockaddr_swift *) addr;
96
97         err = -EINVAL;
98         if (swift_addr->sin_family != AF_INET) {
99                 printk(KERN_ERR "Invalid family for sockaddr\n");
100                 goto out;
101         }
102
103         port = ntohs(swift_addr->sin_port);
104
105         err = -EINVAL;
106         if (port == 0 || port >= MAX_SWIFT_PORT) {
107                 printk(KERN_ERR "Invalid value for sockaddr port (%u)\n", port);
108                 goto out;
109         }
110         
111         err = -EADDRINUSE;
112         if (swift_lookup(port) != NULL) {
113                 printk(KERN_ERR "Port %u already in use\n", port);
114                 goto out;
115         }
116
117         ssk = swift_sk(sock->sk);
118         ssk->src = port;
119
120         swift_hash(port, ssk);
121
122         printk(KERN_DEBUG "Socket %p bound to port %u\n", ssk, port);
123         
124         return 0;
125
126 out:
127         return -EINVAL;
128 }
129
130 static int swift_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags)
131 {
132         printk(KERN_DEBUG "swift_connect\n");
133         return 0;
134 }
135
136 static int swift_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len)
137 {
138         int err;
139         __be16 dport;
140         __be32 daddr;
141         __be16 sport;
142         struct sk_buff * skb;
143         struct sock * sk; 
144         struct inet_sock * isk;
145         struct swift_sock * ssk;
146         struct swifthdr * shdr;
147         int connected = 0;
148         int totlen;
149         struct rtable * rt = NULL;
150         
151         err = -EINVAL;
152         if (sock == NULL) {
153                 printk(KERN_ERR "Sock is NULL\n");
154                 goto out;
155         }
156         sk = sock->sk;
157
158         err = -EINVAL;
159         if (sk == NULL) {
160                 printk(KERN_ERR "Sock->sk is NULL\n");
161                 goto out;
162         }
163
164         err = -ENOMEM;
165         sport = get_next_free_port();
166         if (sport == 0) {
167                 printk(KERN_ERR "No free ports\n");
168                 goto out;
169         }
170
171         isk = inet_sk(sk);
172         ssk = swift_sk(sk);
173
174         if (msg->msg_name) {
175                 struct sockaddr_swift * swift_addr = (struct sockaddr_swift *) msg->msg_name;
176                 
177                 err = -EINVAL;
178                 if (msg->msg_namelen < sizeof(*swift_addr) || swift_addr->sin_family != AF_INET) {
179                         printk(KERN_ERR "Invalid size or address family\n");
180                         goto out;
181                 }
182                 
183                 dport = ntohs(swift_addr->sin_port);
184                 if (dport == 0 || dport >= MAX_SWIFT_PORT) {
185                         printk(KERN_ERR "Invalid value for destination port(%u)\n", dport);
186                         goto out;
187                 }       
188
189                 daddr = swift_addr->sin_addr.s_addr;
190                 printk(KERN_DEBUG "Received from user space destination port=%u and address=%u\n", dport, daddr);
191         } else {
192                 err = -EDESTADDRREQ;
193                 if (!ssk->dst || !isk->inet_daddr) {
194                         printk(KERN_ERR "No destination port/address\n");
195                         goto out;
196                 }
197                 dport = ssk->dst;
198                 daddr = isk->inet_daddr;
199
200                 printk(KERN_DEBUG "Got from socket destination port=%u and address=%u\n", dport, daddr);
201                 connected = 1;
202         }
203
204         totlen = len + sizeof(struct swifthdr) + sizeof(struct iphdr);
205         skb = sock_alloc_send_skb(sk, totlen, msg->msg_flags & MSG_DONTWAIT, &err);
206         if (!skb) {
207                 printk(KERN_ERR "sock_alloc_send_skb failed\n");
208                 goto out;
209         }
210         printk(KERN_DEBUG "Allocated %u bytes for skb (payload size=%u)\n", totlen, len);
211
212         skb_reset_network_header(skb);
213         skb_reserve(skb, sizeof(struct iphdr));
214         printk(KERN_DEBUG "Reseted network header\n");
215         skb_reset_transport_header(skb);
216         skb_put(skb, sizeof(struct swifthdr));
217         printk(KERN_DEBUG "Reseted transport header\n");
218
219         shdr = (struct swifthdr *) skb_transport_header(skb);
220         shdr->dst = ntohs(dport);
221         shdr->src = ntohs(sport);
222         shdr->len = ntohs(len + sizeof(struct swifthdr));
223
224         printk(KERN_DEBUG "payload=%p\n", skb_put(skb, len));
225
226         err = skb_copy_datagram_from_iovec(skb, sizeof(struct swifthdr), msg->msg_iov, 0, len);
227         if (err) {
228                 printk(KERN_ERR "skb_copy_datagram_from_iovec failed\n");
229                 goto out_free;
230         }
231         printk(KERN_DEBUG "Copied %u bytes into the skb\n", len);
232
233         if (connected)
234                 rt = (struct rtable *) __sk_dst_check(sk, 0);
235
236         if (rt == NULL) {
237                 struct flowi fl = { .fl4_dst = daddr,
238                                     .proto = sk->sk_protocol,
239                                     .flags = inet_sk_flowi_flags(sk),
240                                   };
241                 err = ip_route_output_flow(sock_net(sk), &rt, &fl, sk, 0);
242                 if (err) {
243                         printk(KERN_ERR "Route lookup failed\n");
244                         goto out_free;
245                 }
246                 sk_dst_set(sk, dst_clone(&rt->dst));
247         }
248         
249         err = ip_queue_xmit(skb);
250         if (!err)
251                 printk(KERN_DEBUG "Sent %u bytes on wire\n", len);
252         else
253                 printk(KERN_ERR "ip_queue_xmit failed\n");
254
255         return err;
256
257 out_free:
258         kfree(skb);
259
260 out:
261         return err;
262 }
263
264 static int swift_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg, size_t len, int flags)
265 {
266         struct sk_buff *skb;
267         struct sockaddr_swift *swift_addr;
268         struct sock * sk = sock->sk;
269         int err, copied;
270
271         skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &err);
272         if (!skb) {
273                 printk(KERN_ERR "skb_recv_datagram\n");
274                 goto out;
275         }
276
277         printk(KERN_DEBUG "Received skb %p\n", skb);
278
279         swift_addr = (struct sockaddr_swift *) skb->cb;
280         msg->msg_namelen = sizeof(struct sockaddr_swift);
281
282         copied = skb->len;
283         if (copied > len) {
284                 copied = len;
285                 msg->msg_flags |= MSG_TRUNC;
286         }
287
288         err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
289         if (err) {
290                 printk(KERN_ERR "skb_copy_datagram_iovec\n");
291                 goto out_free;
292         }
293
294         sock_recv_ts_and_drops(msg, sk, skb);
295
296         if (msg->msg_name)
297                 memcpy(msg->msg_name, swift_addr, msg->msg_namelen);
298         
299         err = copied;
300
301 out_free:
302         skb_free_datagram(sk, skb);
303
304 out:
305         return err;
306 }
307
308 static int swift_rcv(struct sk_buff *skb)
309 {
310         struct swifthdr *shdr;
311         struct swift_sock *ssk;
312         __be16 len;
313         __be16 src, dst;
314         struct sockaddr_swift * swift_addr;
315         int err;
316
317         if (!pskb_may_pull(skb, sizeof(struct swifthdr))) {
318                 printk(KERN_ERR "Insufficient space for header\n");
319                 goto drop;
320         }
321         
322         shdr = (struct swifthdr *) skb->data;
323         len = ntohs(shdr->len);
324
325         if (skb->len < len) {
326                 printk(KERN_ERR "Malformed packet (packet_len=%u, skb_len=%u)\n", len, skb->len);
327                 goto drop;
328         }
329
330         if (len < sizeof(struct swifthdr)) {
331                 printk(KERN_ERR "Malformed packet (packet_len=%u sizeof(swifthdr)=%u\n", len, sizeof(struct swifthdr));
332                 goto drop;
333         }
334         
335         src = ntohs(shdr->src);
336         dst = ntohs(shdr->dst);
337         if (src == 0 || dst == 0 || src >= MAX_SWIFT_PORT || dst >= MAX_SWIFT_PORT) {
338                 printk(KERN_ERR "Malformed packet (src=%u, dst=%u)\n", shdr->src, shdr->dst);
339                 goto drop;
340         }
341
342         skb_pull(skb, sizeof(struct swifthdr));
343         len -= sizeof(struct swifthdr);
344
345         pskb_trim(skb, len);
346
347         printk(KERN_DEBUG "Received %u bytes from from port=%u to port=%u\n", len - sizeof(struct swifthdr), src, dst);
348
349         ssk = swift_lookup(dst); 
350         if (ssk == NULL) {
351                 printk(KERN_ERR "Swift lookup failed for port %u\n", dst);
352                 goto drop;
353         }
354
355         BUILD_BUG_ON(sizeof(struct sockaddr_swift) > sizeof(skb->cb));
356         
357         swift_addr = (struct sockaddr_swift *) skb->cb;
358         swift_addr->sin_family = AF_INET;
359         swift_addr->sin_port = shdr->src;
360         swift_addr->sin_addr.s_addr = ip_hdr(skb)->saddr;
361
362         printk(KERN_DEBUG "Setting sin_port=%u, sin_addr=%u\n", ntohs(shdr->src), swift_addr->sin_addr.s_addr);
363
364         err = ip_queue_rcv_skb((struct sock *) &ssk->sock, skb);
365         if (err) {
366                 printk(KERN_ERR "ip_queu_rcv_skb\n");
367                 consume_skb(skb);
368         }
369         return NET_RX_SUCCESS;
370
371 drop:
372         kfree(skb);
373         return NET_RX_DROP;
374 }
375
376 static struct proto swift_prot = {
377         .obj_size = sizeof(struct swift_sock),
378         .owner    = THIS_MODULE,
379         .name     = "SWIFT",
380 };
381
382 static const struct proto_ops swift_ops = {
383         .family     = PF_INET,
384         .owner      = THIS_MODULE,
385         .release    = swift_release,
386         .bind       = swift_bind,
387         .connect    = swift_connect,
388         .socketpair = sock_no_socketpair,
389         .accept     = sock_no_accept,
390         .getname    = sock_no_getname,
391         .poll       = datagram_poll,
392         .ioctl      = sock_no_ioctl,
393         .listen     = sock_no_listen,
394         .shutdown   = sock_no_shutdown,
395         .setsockopt = sock_no_setsockopt,
396         .getsockopt = sock_no_getsockopt,
397         .sendmsg    = swift_sendmsg,
398         .recvmsg    = swift_recvmsg,
399         .mmap       = sock_no_mmap,
400         .sendpage   = sock_no_sendpage,
401 };
402
403 static const struct net_protocol swift_protocol = {
404         .handler   = swift_rcv,
405         .no_policy = 1,
406         .netns_ok  = 1,
407 };
408
409 static struct inet_protosw swift_protosw = {
410         .type     = SOCK_DGRAM,
411         .protocol = IPPROTO_SWIFT,
412         .prot     = &swift_prot,
413         .ops      = &swift_ops,
414         .no_check = 0,
415 };
416
417 static int __init swift_init(void)
418 {
419         int rc;
420
421         rc = proto_register(&swift_prot, 1);
422         if (rc) {
423                 printk(KERN_ERR "Error registering swift protocol\n");
424                 goto out;
425         }
426
427         rc = inet_add_protocol(&swift_protocol, IPPROTO_SWIFT);
428         if (rc) {
429                 printk(KERN_ERR "Error adding swift protocol\n");
430                 goto out_unregister;
431         }
432
433         inet_register_protosw(&swift_protosw);
434         printk(KERN_DEBUG "Swift entered\n");
435
436         return 0;
437
438 out_unregister:
439         proto_unregister(&swift_prot);
440
441 out:
442         return rc;
443 }
444
445 static void __exit swift_exit(void)
446 {
447         inet_unregister_protosw(&swift_protosw);
448
449         inet_del_protocol(&swift_protocol, IPPROTO_SWIFT);
450
451         proto_unregister(&swift_prot);
452
453         printk(KERN_DEBUG "Swift exited\n");
454 }
455
456 module_init(swift_init);
457 module_exit(swift_exit);