From f25594e9622ddf9b6bb62e3fc3785cf38eca0c00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ken-H=C3=A5vard=20Lieng?= Date: Sun, 13 Jan 2019 05:10:11 +0100 Subject: [PATCH] Add casefolding to irc lib --- pkg/irc/case.go | 125 +++++++++++++++++++++++++++++++++++++++++++ pkg/irc/case_test.go | 27 ++++++++++ pkg/irc/conn.go | 4 +- pkg/irc/message.go | 4 +- 4 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 pkg/irc/case.go create mode 100644 pkg/irc/case_test.go diff --git a/pkg/irc/case.go b/pkg/irc/case.go new file mode 100644 index 00000000..ef24c590 --- /dev/null +++ b/pkg/irc/case.go @@ -0,0 +1,125 @@ +package irc + +import ( + "unicode/utf8" +) + +const ( + // ASCII maps a-z as the lower case of A-Z + ASCII = "ascii" + // RFC1459 maps a-z and {, |, }, ~ as the lower case of A-Z and [, \, ], ^ + RFC1459 = "rfc1459" + // RFC1459Strict maps a-z and {, |, } as the lower case of A-Z and [, \, ] + RFC1459Strict = "strict-rfc1459" +) + +func (c *Client) Casefold(s string) string { + mapping := c.Support.Get("CASEMAPPING") + if mapping == "" { + mapping = RFC1459 + } + return Casefold(mapping, s) +} + +func (c *Client) EqualFold(s1, s2 string) bool { + mapping := c.Support.Get("CASEMAPPING") + if mapping == "" { + mapping = RFC1459 + } + return EqualFold(mapping, s1, s2) +} + +func Casefold(mapping, s string) string { + switch mapping { + case ASCII: + return toLower(s, 'Z') + case RFC1459: + return toLower(s, '^') + case RFC1459Strict: + return toLower(s, ']') + } + + return s +} + +func EqualFold(mapping, s1, s2 string) bool { + switch mapping { + case ASCII: + return equalFold(s1, s2, 'Z') + case RFC1459: + return equalFold(s1, s2, '^') + case RFC1459Strict: + return equalFold(s1, s2, ']') + } + + return s1 == s2 +} + +func toLower(s string, end byte) string { + hasUpper := false + for i := 0; i < len(s); i++ { + c := s[i] + if hasUpper = 'A' <= c && c <= end; hasUpper { + break + } + } + + if !hasUpper { + return s + } + + b := make([]byte, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + + // Skip Unicode characters + if c >= utf8.RuneSelf { + _, size := utf8.DecodeRuneInString(s[i:]) + for cEnd := i + size; i < cEnd; i++ { + b[i] = s[i] + } + i-- + continue + } + + if 'A' <= c && c <= end { + c += 32 + } + b[i] = c + } + return string(b) +} + +func equalFold(s1, s2 string, end rune) bool { + for s1 != "" && s2 != "" { + var r1, r2 rune + if s1[0] < utf8.RuneSelf { + r1, s1 = rune(s1[0]), s1[1:] + } else { + r, size := utf8.DecodeRuneInString(s1) + r1, s1 = r, s1[size:] + } + if s2[0] < utf8.RuneSelf { + r2, s2 = rune(s2[0]), s2[1:] + } else { + r, size := utf8.DecodeRuneInString(s2) + r2, s2 = r, s2[size:] + } + + if r1 == r2 { + continue + } + + if r2 < r1 { + r2, r1 = r1, r2 + } + + if 'A' <= r1 && r1 <= end && r2 == r1+32 { + continue + } + + return false + } + + return s1 == s2 +} diff --git a/pkg/irc/case_test.go b/pkg/irc/case_test.go new file mode 100644 index 00000000..3b8ec857 --- /dev/null +++ b/pkg/irc/case_test.go @@ -0,0 +1,27 @@ +package irc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCasefold(t *testing.T) { + assert.Equal(t, "caላke[^", Casefold(ASCII, "CaላkE[^")) + assert.Equal(t, "caላke{~", Casefold(RFC1459, "CaላkE[^")) + assert.Equal(t, "caላke{^", Casefold(RFC1459Strict, "CaላkE[^")) +} + +func TestEqualFold(t *testing.T) { + assert.True(t, EqualFold(ASCII, "caላke[^", "CaላkE[^")) + assert.False(t, EqualFold(ASCII, "caላke{~", "CaላkE[^")) + + assert.True(t, EqualFold(RFC1459, "caላke{~", "CaላkE[^")) + assert.False(t, EqualFold(RFC1459, "cላke[^", "CaላkE[^")) + + assert.True(t, EqualFold(RFC1459Strict, "caላke{^", "CaላkE[^")) + assert.False(t, EqualFold(RFC1459Strict, "caላke[~", "CaላkE[^")) + + assert.True(t, EqualFold(ASCII, "", "")) + assert.False(t, EqualFold(ASCII, "", " ")) +} diff --git a/pkg/irc/conn.go b/pkg/irc/conn.go index 1708dce9..20c9ac07 100644 --- a/pkg/irc/conn.go +++ b/pkg/irc/conn.go @@ -213,12 +213,12 @@ func (c *Client) recv() { go c.write("PONG :" + msg.LastParam()) case Join: - if msg.Nick == c.GetNick() { + if c.EqualFold(msg.Nick, c.GetNick()) { c.addChannel(msg.Params[0]) } case Nick: - if msg.Nick == c.GetNick() { + if c.EqualFold(msg.Nick, c.GetNick()) { c.setNick(msg.LastParam()) } diff --git a/pkg/irc/message.go b/pkg/irc/message.go index da109d41..2da60ff3 100644 --- a/pkg/irc/message.go +++ b/pkg/irc/message.go @@ -140,9 +140,9 @@ func (i *iSupport) Get(key string) string { func (i *iSupport) GetInt(key string) int { i.lock.Lock() - v := cast.ToInt(i.support[key]) + v := i.support[key] i.lock.Unlock() - return v + return cast.ToInt(v) } func splitParam(param string) (string, string) {