import React, {
    useCallback, useEffect, useMemo, useRef, useState,
} from 'react';
import PropTypes from 'prop-types';
import { getClassNames } from '../../utils';
import './MasonryGrid.css';

const DEFAULT_COLUMNS_COUNT = 1;

const MasonryGrid = (props) => {
    const {
        className, children, columnsBreakPoints, defaultColumnsCount,
    } = props;
    const masonryRef = useRef(null);
    const [columnsCount, setColumnsCount] = useState(defaultColumnsCount || DEFAULT_COLUMNS_COUNT);

    const breakPoints = useMemo(() => Object.keys(columnsBreakPoints)
        .sort((a, b) => a - b), [columnsBreakPoints]);
    const updateColumnsCount = useCallback(() => {
        const containerWidth = masonryRef.current.offsetWidth;
        let columnsCountNew = breakPoints.length > 0
            ? columnsBreakPoints[breakPoints[0]] : DEFAULT_COLUMNS_COUNT;

        breakPoints.forEach((breakPoint) => {
            if (breakPoint < containerWidth) {
                columnsCountNew = columnsBreakPoints[breakPoint];
            }
        });

        if (columnsCount && columnsCountNew !== columnsCount) {
            setColumnsCount(columnsCountNew);
        }
    }, [breakPoints, columnsCount, columnsBreakPoints, setColumnsCount]);

    useEffect(() => {
        updateColumnsCount();
        window.addEventListener('resize', updateColumnsCount);

        return () => window.removeEventListener('resize', updateColumnsCount);
    }, [updateColumnsCount]);

    const columns = useMemo(() => {
        const cols = Array.from({ length: columnsCount }, () => []);

        React.Children.forEach(children, (child, index) => {
            cols[index % columnsCount].push(child);
        });

        return cols;
    }, [children, columnsCount]);
    const renderColumn = useCallback((column) => column.map((item, i) => (
        <div key={`masonry_item_${i.toString()}`} className="masonry__item">
            {item}
        </div>
    )), []);

    return (
        <div ref={masonryRef} className={getClassNames('masonry', className)}>
            {columns.map((column, i) => (
                <div
                    key={`masonry_column_${i.toString()}`}
                    className="masonry__column"
                    style={{
                        display: 'flex',
                        flexDirection: 'column',
                        justifyContent: 'flex-start',
                        alignContent: 'stretch',
                        flex: 1,
                        width: 0,
                    }}
                >
                    {renderColumn(column)}
                </div>
            ))}
        </div>
    );
};

MasonryGrid.propTypes = {
    children: PropTypes.oneOfType([
        PropTypes.arrayOf(PropTypes.node),
        PropTypes.node,
    ]).isRequired,
    columnsBreakPoints: PropTypes.shape({
        [PropTypes.number]: PropTypes.number,
    }),
    className: PropTypes.string,
    defaultColumnsCount: PropTypes.number,
};

MasonryGrid.defaultProps = {
    columnsBreakPoints: {
        350: 1,
        750: 2,
        900: 3,
    },
    className: null,
    defaultColumnsCount: null,
};

export default MasonryGrid;
