From f747b49347817d37bb44f6ac0c44e625bde215f2 Mon Sep 17 00:00:00 2001 From: Joshua Peek Date: Thu, 7 Jun 2012 17:10:28 -0500 Subject: [PATCH] Add simple classifier --- lib/linguist/classifier.rb | 55 ++++++++++++++++++++++++++++++++++++++ test/test_classifier.rb | 26 ++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 lib/linguist/classifier.rb create mode 100644 test/test_classifier.rb diff --git a/lib/linguist/classifier.rb b/lib/linguist/classifier.rb new file mode 100644 index 00000000..90375ff9 --- /dev/null +++ b/lib/linguist/classifier.rb @@ -0,0 +1,55 @@ +require 'linguist/tokenizer' + +module Linguist + # Language bayesian classifier. + class Classifier + def initialize + @tokens = Hash.new { |h, k| h[k] = Hash.new(0) } + @language_tokens = Hash.new(0) + @languages = Hash.new(0) + @languages_total = 0 + @tokens_total = 0 + end + + def train(language, data) + tokens = Tokenizer.new(data).tokens + + tokens.each do |token| + @tokens[language][token] += 1 + @language_tokens[language] += 1 + @tokens_total += 1 + end + @languages[language] += 1 + @languages_total += 1 + end + + def classify(data) + tokens = Tokenizer.new(data).tokens + + scores = {} + @languages.keys.each do |language| + scores[language] = tokens_probability(tokens, language) * language_probability(language) + end + + scores.sort { |a, b| b[1] <=> a[1] } + end + + def tokens_probability(tokens, language) + tokens.inject(1.0) do |sum, token| + sum *= token_probability(token, language) + end + end + + def token_probability(token, language) + if @tokens[language][token] == 0 + 1 / @tokens_total.to_f + else + @tokens[language][token].to_f / @languages[language].to_f + end + end + + def language_probability(language) + @languages[language].to_f / @languages_total.to_f + end + end +end diff --git a/test/test_classifier.rb b/test/test_classifier.rb new file mode 100644 index 00000000..51cd8f99 --- /dev/null +++ b/test/test_classifier.rb @@ -0,0 +1,26 @@ +require 'linguist/classifier' +require 'linguist/language' + +require 'test/unit' + +class TestClassifier < Test::Unit::TestCase + include Linguist + + def fixtures_path + File.expand_path("../fixtures", __FILE__) + end + + def fixture(name) + File.read(File.join(fixtures_path, name)) + end + + def test_truth + classifier = Classifier.new + classifier.train Language["Ruby"], fixture("ruby/foo.rb") + classifier.train Language["Objective-C"], fixture("objective-c/Foo.h") + classifier.train Language["Objective-C"], fixture("objective-c/Foo.m") + + results = classifier.classify(fixture("objective-c/hello.m")) + assert_equal Language["Objective-C"], results.first[0] + end +end