diff --git a/README.rst b/README.rst index 274a3e4..a46da95 100644 --- a/README.rst +++ b/README.rst @@ -112,6 +112,8 @@ Below is the full listing of options:: this only triggers if there is only one star import in the file; this is skipped if there are any uses of `__all__` or `del` in the file + --populate-all populate `__all__` with unused import found in the + code. --remove-all-unused-imports remove all unused imports (not just those from the standard library) diff --git a/autoflake.py b/autoflake.py index 46d7a88..5787274 100755 --- a/autoflake.py +++ b/autoflake.py @@ -294,7 +294,8 @@ def break_up_import(line): def filter_code(source, additional_imports=None, expand_star_imports=False, remove_all_unused_imports=False, - remove_unused_variables=False): + remove_unused_variables=False, + populate_all=False): """Yield code with unused imports removed.""" imports = SAFE_IMPORTS if additional_imports: @@ -335,6 +336,10 @@ def filter_code(source, additional_imports=None, else: marked_variable_line_numbers = frozenset() + if populate_all: + marked_import_line_numbers = frozenset() + source = populate_all_with_modules(source, marked_unused_module) + sio = io.StringIO(source) previous_line = '' for line_number, line in enumerate(sio.readlines(), start=1): @@ -478,6 +483,43 @@ def filter_useless_pass(source): yield line +def populate_all_with_modules(source, marked_unused_module): + all_syntax = re.search('^__all__(.)+\]', source, flags=re.MULTILINE) + if all_syntax: + # If there are existing `__all__`, parse it and append to it + insert_position = all_syntax.span()[0] + end_position = all_syntax.span()[1] + all_modules = all_syntax.group().split('=')[1].strip() + all_modules = ast.literal_eval(all_modules) + else: + # If no existing `__all__`, always append in EOF + insert_position = len(source) + end_position = -1 + all_modules = [] + + for modules in marked_unused_module.values(): + # Get the imported name, `a.b.Foo` -> Foo + all_modules += [get_imported_name(name) for name in modules] + + new_all_syntax = '__all__ = ' + str(all_modules) + source = source[:insert_position] + new_all_syntax + source[end_position:] + return source + + +def get_imported_name(module): + """ + Return only imported name from pyflakes full module path + + Example: + - `a.b.Foo` -> `Foo` + - `a as b` -> b + """ + if '.' in module: + return module.split('.')[-1] + elif ' as ' in module: + return module.split(' as ')[-1] + return module + def get_indentation(line): """Return leading whitespace.""" if line.strip(): @@ -497,7 +539,8 @@ def get_line_ending(line): def fix_code(source, additional_imports=None, expand_star_imports=False, - remove_all_unused_imports=False, remove_unused_variables=False): + remove_all_unused_imports=False, remove_unused_variables=False, + populate_all=False): """Return code with all filtering run on it.""" if not source: return source @@ -515,9 +558,10 @@ def fix_code(source, additional_imports=None, expand_star_imports=False, additional_imports=additional_imports, expand_star_imports=expand_star_imports, remove_all_unused_imports=remove_all_unused_imports, - remove_unused_variables=remove_unused_variables)))) + remove_unused_variables=remove_unused_variables, + populate_all=populate_all)))) - if filtered_source == source: + if filtered_source == source or populate_all: break source = filtered_source @@ -537,7 +581,9 @@ def fix_file(filename, args, standard_out): additional_imports=args.imports.split(',') if args.imports else None, expand_star_imports=args.expand_star_imports, remove_all_unused_imports=args.remove_all_unused_imports, - remove_unused_variables=args.remove_unused_variables) + remove_unused_variables=args.remove_unused_variables, + populate_all=args.populate_all, + ) if original_source != filtered_source: if args.in_place: @@ -692,6 +738,9 @@ def _main(argv, standard_out, standard_error): 'one star import in the file; this is skipped if ' 'there are any uses of `__all__` or `del` in the ' 'file') + parser.add_argument('--populate-all', action='store_true', + help='populate `__all__` with unused import found in ' + 'the code.') parser.add_argument('--remove-all-unused-imports', action='store_true', help='remove all unused imports (not just those from ' 'the standard library)') diff --git a/test_autoflake.py b/test_autoflake.py index ca54190..e6e8e30 100755 --- a/test_autoflake.py +++ b/test_autoflake.py @@ -485,6 +485,49 @@ def foo(): """ self.assertEqual(line, ''.join(autoflake.filter_code(line))) + def test_filter_code_populate_all(self): + self.assertEqual(""" +import math +import sys +__all__ = ['math', 'sys'] +""", ''.join(autoflake.filter_code(""" +import math +import sys +""", populate_all=True))) + + def test_filter_code_populate_all_appending(self): + self.assertEqual(""" +import math +import sys +__all__ = ['math', 'sys'] +""", ''.join(autoflake.filter_code(""" +import math +import sys +__all__ = ['math'] +""", populate_all=True))) + + def test_filter_code_populate_all_ignore_comment(self): + self.assertEqual(""" +import math +import sys +# __all__ = ['math'] +__all__ = ['math', 'sys'] +""", ''.join(autoflake.filter_code(""" +import math +import sys +# __all__ = ['math'] +""", populate_all=True))) + + def test_filter_code_populate_all_from_import(self): + self.assertEqual(""" +from a.b import Foo +from a.c import Bar +__all__ = ['Foo', 'Bar'] +""", ''.join(autoflake.filter_code(""" +from a.b import Foo +from a.c import Bar +""", populate_all=True))) + def test_fix_code(self): self.assertEqual( """\