# TestFinder class, define set of tests to run.
#
# Copyright (c) 2020-2021 Virtuozzo International GmbH
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import os
import glob
import re
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, List, Iterator, Set


@contextmanager
def chdir(path: Optional[str] = None) -> Iterator[None]:
    if path is None:
        yield
        return

    saved_dir = os.getcwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(saved_dir)


class TestFinder:
    def __init__(self, test_dir: Optional[str] = None) -> None:
        self.groups = defaultdict(set)

        with chdir(test_dir):
            self.all_tests = glob.glob('[0-9][0-9][0-9]')
            self.all_tests += [f for f in glob.iglob('tests/*')
                               if not f.endswith('.out') and
                               os.path.isfile(f + '.out')]

            for t in self.all_tests:
                with open(t, encoding="utf-8") as f:
                    for line in f:
                        if line.startswith('# group: '):
                            for g in line.split()[2:]:
                                self.groups[g].add(t)
                            break

    def add_group_file(self, fname: str) -> None:
        with open(fname, encoding="utf-8") as f:
            for line in f:
                line = line.strip()

                if (not line) or line[0] == '#':
                    continue

                words = line.split()
                test_file = self.parse_test_name(words[0])
                groups = words[1:]

                for g in groups:
                    self.groups[g].add(test_file)

    def parse_test_name(self, name: str) -> str:
        if '/' in name:
            raise ValueError('Paths are unsupported for test selection, '
                             f'requiring "{name}" is wrong')

        if re.fullmatch(r'\d+', name):
            # Numbered tests are old naming convention. We should convert them
            # to three-digit-length, like 1 --> 001.
            name = f'{int(name):03}'
        else:
            # Named tests all should be in tests/ subdirectory
            name = os.path.join('tests', name)

        if name not in self.all_tests:
            raise ValueError(f'Test "{name}" is not found')

        return name

    def find_tests(self, groups: Optional[List[str]] = None,
                   exclude_groups: Optional[List[str]] = None,
                   tests: Optional[List[str]] = None,
                   start_from: Optional[str] = None) -> List[str]:
        """Find tests

        Algorithm:

        1. a. if some @groups specified
             a.1 Take all tests from @groups
             a.2 Drop tests, which are in at least one of @exclude_groups or in
                 'disabled' group (if 'disabled' is not listed in @groups)
             a.3 Add tests from @tests (don't exclude anything from them)

           b. else, if some @tests specified:
             b.1 exclude_groups must be not specified, so just take @tests

           c. else (only @exclude_groups list is non-empty):
             c.1 Take all tests
             c.2 Drop tests, which are in at least one of @exclude_groups or in
                 'disabled' group

        2. sort

        3. If start_from specified, drop tests from first one to @start_from
           (not inclusive)
        """
        if groups is None:
            groups = []
        if exclude_groups is None:
            exclude_groups = []
        if tests is None:
            tests = []

        res: Set[str] = set()
        if groups:
            # Some groups specified. exclude_groups supported, additionally
            # selecting some individual tests supported as well.
            res.update(*(self.groups[g] for g in groups))
        elif tests:
            # Some individual tests specified, but no groups. In this case
            # we don't support exclude_groups.
            if exclude_groups:
                raise ValueError("Can't exclude from individually specified "
                                 "tests.")
        else:
            # No tests no groups: start from all tests, exclude_groups
            # supported.
            res.update(self.all_tests)

        if 'disabled' not in groups and 'disabled' not in exclude_groups:
            # Don't want to modify function argument, so create new list.
            exclude_groups = exclude_groups + ['disabled']

        res = res.difference(*(self.groups[g] for g in exclude_groups))

        # We want to add @tests. But for compatibility with old test names,
        # we should convert any number < 100 to number padded by
        # leading zeroes, like 1 -> 001 and 23 -> 023.
        for t in tests:
            res.add(self.parse_test_name(t))

        sequence = sorted(res)

        if start_from is not None:
            del sequence[:sequence.index(self.parse_test_name(start_from))]

        return sequence