diff --git a/lib/linguist/classifier.rb b/lib/linguist/classifier.rb index ceaace21..c690b6b9 100644 --- a/lib/linguist/classifier.rb +++ b/lib/linguist/classifier.rb @@ -74,7 +74,8 @@ module Linguist # Public: Guess language of data. # - # data - Array of tokens or String data to analyze. + # data - Array of tokens or String data to analyze. + # languages - Array of Languages to restrict to. # # Examples # @@ -83,12 +84,14 @@ module Linguist # # Returns sorted Array of result pairs. Each pair contains the # Language and a Float score. - def classify(tokens) + def classify(tokens, languages = @languages.keys) tokens = Tokenizer.new(tokens).tokens if tokens.is_a?(String) scores = {} - @languages.keys.each do |language| - scores[language] = tokens_probability(tokens, language) * language_probability(language) + languages.each do |language| + language_name = language.is_a?(Language) ? language.name : language + scores[language_name] = tokens_probability(tokens, language_name) * + language_probability(language_name) end scores.sort { |a, b| b[1] <=> a[1] }.map { |score| [Language[score[0]], score[1]] } diff --git a/test/test_classifier.rb b/test/test_classifier.rb index 9280d150..bac2efa1 100644 --- a/test/test_classifier.rb +++ b/test/test_classifier.rb @@ -30,6 +30,19 @@ class TestClassifier < Test::Unit::TestCase assert_equal Language["Objective-C"], results.first[0] end + def test_restricted_classify + classifier = Classifier.new + classifier.train Language["Ruby"], fixture("ruby/foo.rb") + classifier.train Language["Objective-C"], fixture("objective-c/Foo.h") + classifier.train Language["Objective-C"], fixture("objective-c/Foo.m") + + results = classifier.classify(fixture("objective-c/hello.m"), [Language["Objective-C"]]) + assert_equal Language["Objective-C"], results.first[0] + + results = classifier.classify(fixture("objective-c/hello.m"), [Language["Ruby"]]) + assert_equal Language["Ruby"], results.first[0] + end + def test_instance_classify_empty results = Classifier.instance.classify("") assert results.first[1] < 0.5, results.first.inspect